PERMUTATION2 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: mathmodel
Tester: mexomerf
Editorialist: iceknight1093

DIFFICULTY:

easy

PREREQUISITES:

Inclusion-exclusion principle, combinatorics

PROBLEM:

Given N and K, count the number of permutations P of length N such that P_i + P_{i+1} = K holds for at least one index i (1 \leq i \lt N).

EXPLANATION:

If P_i + P_{i+1} = K, then the pair of elements (P_i, P_{i+1}) should be one of
(1, K-1), (2, K-2), \ldots, (K-1, 1).

In particular, note that each integer from 1 to K-1 has a unique partner that it can pair with to sum to K.
Since P must contain distinct elements, if P_i + P_{i+1} = K then it’s impossible for P_{i+1} + P_{i+2} = K to hold (since this would mean P_i = P_{i+2}).

So, in any permutation, the pairs of adjacent elements that sum up to K are mutually disjoint.


Let there be M unordered pairs of elements that can sum up to K (unordered meaning (1, K-1) and (K-1, 1) are considered to be the same).
We want to count permutations that have at least one of these pairs present.

Suppose there are x pairs in the permutation that sum to x.
Then,

  1. There are \binom{M}{x} ways to choose which pairs they are.
  2. There are 2 choices for the order of each pair, for 2^x choices in total.
  3. We now have (N - x) objects - the (N - 2x) unpaired elements and the x pairs. There are (N-x)! ways to arrange them to form a permutation.

This gives us a count of

\binom{M}{x} \cdot (N-x)! \cdot 2^x

Let this value be denoted f(x).

Observe that f(x) doesn’t quite count the number of permutations with exactly x pairs that sum to K: since we didn’t impose any restrictions on the position and order of unpaired elements, it’s entirely possible that they formed some pairs too.
In particular, note that if a configuration has \gt x pairs that sum to K, such a configuration would’ve been counted multiple times - exactly once for each possible subset of x pairs, in fact, so if there are y pairs in total, it’d have been counted exactly \binom{y}{x} times.

Let’s use this information to try and count each configuration exactly once.
That can be done as follows:

  • First, start with f(1). Now, every configuration with one pair has been counted once, but everything with two pairs has been counted \binom{2}{1} = 2 times.
  • So, we subtract f(2), which will ensure that everything with two pairs gets counted exactly once.
  • Next, in f(1) - f(2), configurations with three pairs get counted \binom{3}{1} - \binom{3}{2} = 3 - 3 = 0 times, so we must add f(3) to count them once.
  • In f(1) - f(2) + f(3), configurations with four pairs have been counted \binom{4}{1} - \binom{4}{2} + \binom{4}{3} = 2 times, so f(4) must be subtracted once to remove the overcounting.
    \vdots

Repeating this argument, we see that the coefficients of the f(x) values alternate between +1 and -1, giving us the final expression

f(1) - f(2) + f(3) - f(4) + \ldots = \sum_{i=1}^M (-1)^{i+1} f(i)
Proof

When we’re looking at i, the number of permutations with exactly i pairs is currently \binom{i}{1} - \binom{i}{2} + \binom{i}{3} - \ldots + (-1)^i \binom{i}{i-1}.

We look at even and odd i separately.

When i is odd, observe that \binom{i}{j} = \binom{i}{i-j}, but these two terms have opposite coefficients in the summation.
Each such pair will cancel out, making the overall sum 0.
So, we need to give f(i) a coefficient of +1.

Next, suppose i is even. We can’t cancel out terms like before.
Instead, observe that by the binomial theorem,

0^i = (1 + (-1))^i = \sum_{j=0}^i (-1)^j \binom{i}{j} = \binom{i}{0} - \binom{i}{1} + \binom{i}{2} - \ldots + \binom{i}{i}

The right side of that summation is almost exactly the value we have: in fact, what we have is exactly the middle i-1 terms of that summation, but negated.
So, if S is the value of our summation, we have 0 = \binom{i}{0} - S + \binom{i}{i}.
This means S = \binom{i}{0} + \binom{i}{i} = 1 + 1 = 2.

So, when i is even we give f(i) the sign -1, since configurations are counted twice but we want to count them once.

Each f(i) can be computed in constant or \mathcal{O}(\log N) time, so this is fast enough.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Author's code (Python)
m=10**9+7
FACSIZE=10**6
def power(a,b):
    x=1
    y=a
    while b>0:
        if b&1:
            x=(x*y)%m
        y=(y*y)%m
        b>>=1
    return x%m
def modular_inverse(n): return power(n,m-2)

f=[1]*FACSIZE
invfact=[1]*FACSIZE
def cfact():
    for i in range(2,FACSIZE):
        f[i]=f[i-1]*i%m
    invfact[FACSIZE-1]=modular_inverse(f[FACSIZE-1])
    for i in range(FACSIZE-2,-1,-1):
        invfact[i]=invfact[i+1]*(i+1)%m
def comb(n,k):
    if k<0 or n<k:return 0
    return f[n]*invfact[k]%m*invfact[n-k]%m
cfact()

def solve():
    n,k=map(int,input().split())
    ans=0
    tp=0
    for i in range(1,n+1):
        j=k-i
        if i>=j:break
        if 1<=j<=n:tp+=1
    p2=1
    for p in range(1,tp+1):
        p2=p2*2%m
        d=comb(tp,p)*f[p]%m
        d=d*p2%m
        d=d*f[n-2*p]%m
        d=d*comb(n-p,p)%m
        if p%2:ans=(ans+d)%m
        else:ans=(ans-d)%m
    print(ans%m)
for _ in range(int(input())):
    solve()
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define md 1000000007
#define N 100001
int modex(int a, int b){
    if(b == 0){
        return 1;
    }
    int res = modex(a, b / 2);
    res *= res;
    res %= md;
    if(b % 2){
        res *= (a % md);
    }
    return res % md;
}
int mod(int a, int b){
    a %= md;
    return (a * modex(b, md - 2)) % md;
}
int pw[N];
int fac[N];
int ncr(int n, int r){
    if(n < r || r < 0){
        return 0;
    }
    return mod(fac[n], fac[n - r] * fac[r]);
}
int32_t main() {
    pw[0] = 1;
    fac[0] = 1;
    for(int i = 1; i < N; i++){
        pw[i] = pw[i - 1] * 2;
        fac[i] = fac[i - 1] * i;
        pw[i] %= md;
        fac[i] %= md;
    }
	int t;
	cin>>t;
	while(t--){
	    int n, k;
	    cin>>n>>k;
	    int cnt = min(2* n - k + 1, k - 1) / 2;
	    int ans = 0;
	    for(int i = 1; i <= cnt; i++){
	        int temp = ncr(cnt, i) * fac[i] * (2 * (i % 2) - 1);
	        temp %= md;
	        temp *= pw[i];
	        temp %= md;
	        temp *= fac[n - 2 * i];
	        temp %= md;
	        temp *= ncr(n - i, i);
	        temp %= md;
	        ans += temp;
	        ans %= md;
	    }
	    if(ans < 0){
	        ans += md;
	    }
	    cout<<ans<<"\n";
	}
}

Editorialist's code (PyPy3)
mod = 10**9 + 7
fac = [1]
for i in range(1, 3 * 10**5): fac.append(fac[-1] * i % mod)
ifac = fac[:]
ifac[-1] = pow(ifac[-1], mod-2, mod)
for i in reversed(range(3 * 10**5 - 1)): ifac[i] = ifac[i+1] * (i + 1) % mod
def C(n, r):
    if n < r or r < 0: return 0
    return fac[n] * ifac[r] % mod * ifac[n-r] % mod

for _ in range(int(input())):
    n, k = map(int, input().split())
    
    pairs = 0
    for i in range(1, n+1):
        j = k - i
        if i < j and 1 <= j <= n: pairs += 1
    
    ans = 0
    for i in range(1, pairs + 1):
        if n - 2*i < 0: break
        
        ways = C(pairs, i) * fac[n - i] * pow(2, i, mod) % mod
        if i%2 == 1: ans += ways
        else: ans -= ways
    print(ans % mod)