PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: mathmodel
Tester: mexomerf
Editorialist: iceknight1093
DIFFICULTY:
easy
PREREQUISITES:
Inclusion-exclusion principle, combinatorics
PROBLEM:
Given N and K, count the number of permutations P of length N such that P_i + P_{i+1} = K holds for at least one index i (1 \leq i \lt N).
EXPLANATION:
If P_i + P_{i+1} = K, then the pair of elements (P_i, P_{i+1}) should be one of
(1, K-1), (2, K-2), \ldots, (K-1, 1).
In particular, note that each integer from 1 to K-1 has a unique partner that it can pair with to sum to K.
Since P must contain distinct elements, if P_i + P_{i+1} = K then it’s impossible for P_{i+1} + P_{i+2} = K to hold (since this would mean P_i = P_{i+2}).
So, in any permutation, the pairs of adjacent elements that sum up to K are mutually disjoint.
Let there be M unordered pairs of elements that can sum up to K (unordered meaning (1, K-1) and (K-1, 1) are considered to be the same).
We want to count permutations that have at least one of these pairs present.
Suppose there are x pairs in the permutation that sum to x.
Then,
- There are \binom{M}{x} ways to choose which pairs they are.
- There are 2 choices for the order of each pair, for 2^x choices in total.
- We now have (N - x) objects - the (N - 2x) unpaired elements and the x pairs. There are (N-x)! ways to arrange them to form a permutation.
This gives us a count of
Let this value be denoted f(x).
Observe that f(x) doesn’t quite count the number of permutations with exactly x pairs that sum to K: since we didn’t impose any restrictions on the position and order of unpaired elements, it’s entirely possible that they formed some pairs too.
In particular, note that if a configuration has \gt x pairs that sum to K, such a configuration would’ve been counted multiple times - exactly once for each possible subset of x pairs, in fact, so if there are y pairs in total, it’d have been counted exactly \binom{y}{x} times.
Let’s use this information to try and count each configuration exactly once.
That can be done as follows:
- First, start with f(1). Now, every configuration with one pair has been counted once, but everything with two pairs has been counted \binom{2}{1} = 2 times.
- So, we subtract f(2), which will ensure that everything with two pairs gets counted exactly once.
- Next, in f(1) - f(2), configurations with three pairs get counted \binom{3}{1} - \binom{3}{2} = 3 - 3 = 0 times, so we must add f(3) to count them once.
- In f(1) - f(2) + f(3), configurations with four pairs have been counted \binom{4}{1} - \binom{4}{2} + \binom{4}{3} = 2 times, so f(4) must be subtracted once to remove the overcounting.
\vdots
Repeating this argument, we see that the coefficients of the f(x) values alternate between +1 and -1, giving us the final expression
Proof
When we’re looking at i, the number of permutations with exactly i pairs is currently \binom{i}{1} - \binom{i}{2} + \binom{i}{3} - \ldots + (-1)^i \binom{i}{i-1}.
We look at even and odd i separately.
When i is odd, observe that \binom{i}{j} = \binom{i}{i-j}, but these two terms have opposite coefficients in the summation.
Each such pair will cancel out, making the overall sum 0.
So, we need to give f(i) a coefficient of +1.Next, suppose i is even. We can’t cancel out terms like before.
Instead, observe that by the binomial theorem,0^i = (1 + (-1))^i = \sum_{j=0}^i (-1)^j \binom{i}{j} = \binom{i}{0} - \binom{i}{1} + \binom{i}{2} - \ldots + \binom{i}{i}The right side of that summation is almost exactly the value we have: in fact, what we have is exactly the middle i-1 terms of that summation, but negated.
So, if S is the value of our summation, we have 0 = \binom{i}{0} - S + \binom{i}{i}.
This means S = \binom{i}{0} + \binom{i}{i} = 1 + 1 = 2.So, when i is even we give f(i) the sign -1, since configurations are counted twice but we want to count them once.
Each f(i) can be computed in constant or \mathcal{O}(\log N) time, so this is fast enough.
TIME COMPLEXITY:
\mathcal{O}(N) per testcase.
CODE:
Author's code (Python)
m=10**9+7
FACSIZE=10**6
def power(a,b):
x=1
y=a
while b>0:
if b&1:
x=(x*y)%m
y=(y*y)%m
b>>=1
return x%m
def modular_inverse(n): return power(n,m-2)
f=[1]*FACSIZE
invfact=[1]*FACSIZE
def cfact():
for i in range(2,FACSIZE):
f[i]=f[i-1]*i%m
invfact[FACSIZE-1]=modular_inverse(f[FACSIZE-1])
for i in range(FACSIZE-2,-1,-1):
invfact[i]=invfact[i+1]*(i+1)%m
def comb(n,k):
if k<0 or n<k:return 0
return f[n]*invfact[k]%m*invfact[n-k]%m
cfact()
def solve():
n,k=map(int,input().split())
ans=0
tp=0
for i in range(1,n+1):
j=k-i
if i>=j:break
if 1<=j<=n:tp+=1
p2=1
for p in range(1,tp+1):
p2=p2*2%m
d=comb(tp,p)*f[p]%m
d=d*p2%m
d=d*f[n-2*p]%m
d=d*comb(n-p,p)%m
if p%2:ans=(ans+d)%m
else:ans=(ans-d)%m
print(ans%m)
for _ in range(int(input())):
solve()
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define md 1000000007
#define N 100001
int modex(int a, int b){
if(b == 0){
return 1;
}
int res = modex(a, b / 2);
res *= res;
res %= md;
if(b % 2){
res *= (a % md);
}
return res % md;
}
int mod(int a, int b){
a %= md;
return (a * modex(b, md - 2)) % md;
}
int pw[N];
int fac[N];
int ncr(int n, int r){
if(n < r || r < 0){
return 0;
}
return mod(fac[n], fac[n - r] * fac[r]);
}
int32_t main() {
pw[0] = 1;
fac[0] = 1;
for(int i = 1; i < N; i++){
pw[i] = pw[i - 1] * 2;
fac[i] = fac[i - 1] * i;
pw[i] %= md;
fac[i] %= md;
}
int t;
cin>>t;
while(t--){
int n, k;
cin>>n>>k;
int cnt = min(2* n - k + 1, k - 1) / 2;
int ans = 0;
for(int i = 1; i <= cnt; i++){
int temp = ncr(cnt, i) * fac[i] * (2 * (i % 2) - 1);
temp %= md;
temp *= pw[i];
temp %= md;
temp *= fac[n - 2 * i];
temp %= md;
temp *= ncr(n - i, i);
temp %= md;
ans += temp;
ans %= md;
}
if(ans < 0){
ans += md;
}
cout<<ans<<"\n";
}
}
Editorialist's code (PyPy3)
mod = 10**9 + 7
fac = [1]
for i in range(1, 3 * 10**5): fac.append(fac[-1] * i % mod)
ifac = fac[:]
ifac[-1] = pow(ifac[-1], mod-2, mod)
for i in reversed(range(3 * 10**5 - 1)): ifac[i] = ifac[i+1] * (i + 1) % mod
def C(n, r):
if n < r or r < 0: return 0
return fac[n] * ifac[r] % mod * ifac[n-r] % mod
for _ in range(int(input())):
n, k = map(int, input().split())
pairs = 0
for i in range(1, n+1):
j = k - i
if i < j and 1 <= j <= n: pairs += 1
ans = 0
for i in range(1, pairs + 1):
if n - 2*i < 0: break
ways = C(pairs, i) * fac[n - i] * pow(2, i, mod) % mod
if i%2 == 1: ans += ways
else: ans -= ways
print(ans % mod)