CNTISFN343 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: satyam_343
Testers: apoorv_me, tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Combinatorics

PROBLEM:

Find the number of arrays A of length N for which there exist two permutations P and Q satisfying A_i = \max(P_i, Q_i) for all 1 \leq i \leq N.

EXPLANATION:

Any valid array A we’re looking at is obtained as the pointwise maximum of two permutations - in particular, this means A cannot contain any element more than twice.

Let’s fix K, the number of elements that occur exactly twice.
Note that 0 \leq K \leq \left\lfloor \frac{N}{2} \right\rfloor.
Further, observe that this means exactly K elements must also not appear at all in A - after all, the K elements that appear twice take up 2K positions, and the remaining N - 2K positions are taken up by elements that occur once each, leaving K elements without a place.

Let x_1 \lt x_2 \lt \ldots \lt x_K be the distinct elements that appear twice,
and y_1 \lt y_2 \lt \ldots \lt y_K be the distinct elements that don’t appear at all.

Claim: There exists a valid array A with this configuration if and only if y_i \lt x_i for every i.

Proof

If y_i \lt x_i for every i, it’s quite simple to construct valid permutations P and Q.

  • For every index j such that A_j is neither a x_i nor a y_i, set P_j = Q_j = A_j.
  • Next, suppose x_i appears at indices i_1 and i_2.
    Set P_{i_1} = Q_{i_2} = x_i and Q_{i_1} = P_{i_2} = y_i.

It’s easy to see that P and Q are permutations, and A_i = \max(P_i, Q_i) for every i.

Now, for the converse.
Suppose A is a valid array.

Consider some element x_i, and look at its appearance in the permutation P, say at index j.

  • x_i appears twice in A, so \max(P_j, Q_j) = x_i should hold.
    This means Q_j \lt P_j should be true.
  • Then, if Q_j does not appear in y, the occurrence of Q_j in P should be again paired with some element less than it; since Q_j must appear in A.
  • This ‘chain’ of going to a smaller and smaller value will continue till one of the elements of y are reached.

So, we have several ‘chains’ like this, each starting at one of the x_i (either in P or Q) and ending at one of the y_i (in the other permutation).

Now, if y_i \geq x_i for some i, we have 2i chains starting at elements \leq x_i, but at most 2i-2 chains ending at elements \leq x_i.
This is clearly a contradiction, and so cannot happen.


All that remains is actually counting such configurations (and the number of ways to rearrange them, once the elements are fixed).
Since the x_i and y_i are related to each other, counting becomes a bit easier when we combine them, rather than deal with them separately.
Specifically,

  1. Choose the 2K distinct elements that will become the x_i and y_i.
  2. Choose 2K positions that they will occupy.
  3. Choose which of these 2K elements will be x_i (i.e appear in A), and which will be y_i.
  4. Count the number of ways to rearrange elements within the chosen positions.

Of these, the first, second, and fourth points are elementary combinatorics - only the third requires more observation.

Recall that we wanted y_i \lt x_i to hold for every i.
In other words, for each prefix of the 2K elements, there should be at least as many elements chosen as a y_i, as there are elements chosen as an x_i.
This is, in fact, a rather common setup; and there are many ways to visualize it - for example, one way is to look at it as forming a balanced parentheses sequence of length 2K, where the y_i correspond to opening brackets and the x_i to closing ones.
It’s well-known that such configurations are counted by the Catalan numbers.

So, for a fixed K, we obtain the number of arrays A as:

\binom{N}{2K}^2 \cdot C_K \cdot (N - 2K)! \cdot \frac{(2K)!}{2^K}

where C_K = \frac{1}{K+1} \binom{2K}{K} denotes the K-th Catalan number.

This can be computed in \mathcal{O}(1) time after factorial precomputation, so summing it up across all K gives us a linear solution.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h> 
using namespace std;
#define ll long long
#define nline "\n"
#define all(x) x.begin(),x.end()
const ll MOD=998244353;
const ll MAX=500500;
vector<ll> fact(MAX+2,1),inv_fact(MAX+2,1);
ll binpow(ll a,ll b,ll MOD){
    ll ans=1;
    a%=MOD;  
    while(b){
        if(b&1)
            ans=(ans*a)%MOD;
        b/=2;
        a=(a*a)%MOD;
    }
    return ans;
}
ll inverse(ll a,ll MOD){
    return binpow(a,MOD-2,MOD);
} 
void precompute(ll MOD){
    for(ll i=2;i<MAX;i++){
        fact[i]=(fact[i-1]*i)%MOD;
    }
    inv_fact[MAX-1]=inverse(fact[MAX-1],MOD);
    for(ll i=MAX-2;i>=0;i--){
        inv_fact[i]=(inv_fact[i+1]*(i+1))%MOD;
    }
}
ll nCr(ll a,ll b,ll MOD){
    if((a<0)||(a<b)||(b<0))
        return 0;   
    ll denom=(inv_fact[b]*inv_fact[a-b])%MOD;
    return (denom*fact[a])%MOD;  
}
ll getv(ll x){
	ll now=(nCr(2*x,x,MOD)*inverse(x+1,MOD))%MOD;
	return now;
}
void solve(){
	ll n; cin>>n;
	ll ans=0,div=1;
	for(ll i=0;i<=(n/2);i++){
		ll now=(getv(i)*div)%MOD;
		now=(now*nCr(n,2*i,MOD))%MOD;
		ans=(ans+now)%MOD;
		div=(div*inverse(2,MOD))%MOD;
	}
	ans=(ans*fact[n])%MOD;
	cout<<ans<<nline;
}
int main()                                                                                 
{         
  ios_base::sync_with_stdio(false);                         
  cin.tie(NULL);                                  
  ll test_cases=1;                 
  cin>>test_cases;
  precompute(MOD);
  while(test_cases--){
      solve();
  }
  cout<<fixed<<setprecision(10);
  cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
} 

Editorialist's code (Python)
mod = 998244353
maxn = 5 * 10**5
fac = [1] * maxn
for i in range(1, maxn): fac[i] = fac[i-1] * i % mod
inv = fac[:]
inv[-1] = pow(inv[-1], mod-2, mod)
for i in reversed(range(maxn-1)): inv[i] = inv[i+1] * (i+1) % mod

def C(n, r):
    if r < 0 or n < r: return 0
    return fac[n] * inv[r] % mod * inv[n-r] % mod

catalan = [0] * 200005
for i in range(200005): catalan[i] = C(2*i, i) * pow(i+1, mod-2, mod) % mod

for _ in range(int(input())):
    n = int(input())
    ans = 0
    for k in range(n//2 + 1):
        ans += C(n, 2*k) * C(n, 2*k) * catalan[k] * fac[n - 2*k] * fac[2*k] * pow(2, k*(mod-2), mod)
        ans %= mod
    print(ans)
4 Likes