PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: satyam_343
Tester: yash_daga
Editorialist: iceknight1093
DIFFICULTY:
3097
PREREQUISITES:
Combinatorics, André’s reflection principle
PROBLEM:
You’re given N distinct integers between 1 and 2N as an array P.
Consider all permutations Q of [2N] such that P_i = Q_i for 1 \leq i \leq N.
Find the number of distinct prefix maximum arrays among all such Q.
EXPLANATION:
tl;dr
The answer is
where M = \max(P_1, P_2, \ldots, P_N).
Most of the difficulty of this problem comes from modelling it correctly.
For now, let’s ignore the prefix of length N we’re given, and just look at prefix maximum arrays as a whole.
Since the prefix maximum array is always going to be sorted, two such arrays will differ if and only if the counts of the elements in them differ.
So, it’s enough for us to consider counts of elements in the prefix maximum array.
Let c_i denote the number of times i occurs in the prefix maximum array. Then, we have the following constraints:
- c_i \geq 0 for every i
- c_1 + c_2 + \ldots + c_i \leq i for every i, because the elements \{1, 2, \ldots, i\} together cannot be prefix maximums more than i times.
- c_1 + c_2 + \ldots + c_{2N} = 2N.
Any assignment of integers to c_i that satisfies these conditions will give us a unique prefix maximum array, and it’s not hard to see that there will always exist a permutation that achieves such an array.
These contraints can be represented as lattice paths!
In particular, we can do the following:
- Start at (0, 0).
- For each 0 \leq i \lt 2N, make one rightward move followed by c_{i+1} upward moves.
So, after i rightward moves, we would’ve made exactly c_1 + c_2 + \ldots + c_i upward moves.
Then,
- c_1 + c_2 + \ldots + c_{2N} = 2N ensures that we will always end up at position (2N, 2N).
- c_1 + c_2 + \ldots + c_i \leq i means we will always stay below the line x = y (though we can touch it).
It’s not hard to see that any path from (0, 0) to (2N, 2N) that always stays below the line x = y similarly allows us to assign values to each c_i, so all we need to do is count the number of such paths!
Counting the number of right-up lattice paths that don’t cross the line x = y is a rather well-known problem, and can be done using the reflection principle to subtract ‘bad’ paths from the total number of paths.
In fact, when starting from (0, 0) and ending at (2N, 2N), the answer is exactly the 2N-th Catalan number!
A sketch of the proof for this case can be found here.
Let’s return to the original problem.
We are given the first N elements of the permutation, which doesn’t actually change much: in terms of our path, it simply fixes a prefix of our path.
In particular, if M = \max(P_1, P_2, \ldots, P_N), then the path so far has us end at position (M, N).
So, all we need to do is count the number of paths from (M, N) to (2N, 2N) that stay below the line x = y.
By the reflection principle, that works out as follows:
- The total number of paths from (M, N) to (2N, 2N) is \binom{3N-M}{N}.
- The ‘bad’ paths from (M, N) are in bijection with the total number of paths from (N-1, M+1), thus being equal to \binom{3N-M}{N+1}
This makes the answer simply
TIME COMPLEXITY
\mathcal{O}(N) per test case.
CODE:
Setter's code (C++)
#pragma GCC optimization("O3")
#pragma GCC optimization("Ofast,unroll-loops")
#include <bits/stdc++.h>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
#define ll long long
const ll INF_MUL=1e13;
const ll INF_ADD=1e18;
#define pb push_back
#define mp make_pair
#define nline "\n"
#define f first
#define s second
#define pll pair<ll,ll>
#define all(x) x.begin(),x.end()
#define vl vector<ll>
#define vvl vector<vector<ll>>
#define vvvl vector<vector<vector<ll>>>
#ifndef ONLINE_JUDGE
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);
#endif
void _print(ll x){cerr<<x;}
void _print(int x){cerr<<x;}
void _print(char x){cerr<<x;}
void _print(string x){cerr<<x;}
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";}
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=998244353;
const ll MAX=2000200;
vector<ll> fact(MAX+2,1),inv_fact(MAX+2,1);
ll binpow(ll a,ll b,ll MOD){
ll ans=1;
a%=MOD;
while(b){
if(b&1)
ans=(ans*a)%MOD;
b/=2;
a=(a*a)%MOD;
}
return ans;
}
ll inverse(ll a,ll MOD){
return binpow(a,MOD-2,MOD);
}
void precompute(ll MOD){
for(ll i=2;i<MAX;i++){
fact[i]=(fact[i-1]*i)%MOD;
}
inv_fact[MAX-1]=inverse(fact[MAX-1],MOD);
for(ll i=MAX-2;i>=0;i--){
inv_fact[i]=(inv_fact[i+1]*(i+1))%MOD;
}
}
ll nCr(ll a,ll b,ll MOD){
if((a<0)||(a<b)||(b<0))
return 0;
ll denom=(inv_fact[b]*inv_fact[a-b])%MOD;
return (denom*fact[a])%MOD;
}
void solve(){
ll n; cin>>n;
ll nax=0;
for(ll i=1;i<=n;i++){
ll x; cin>>x;
nax=max(nax,x);
}
ll x=n,y=nax;
ll ans=nCr(4*n-x-y,2*n-y,MOD);
ll a=y+1,b=x-1;
x=a,y=b;
ans=(ans-nCr(4*n-x-y,2*n-y,MOD)+MOD)%MOD;
cout<<ans<<nline;
return;
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
freopen("error.txt", "w", stderr);
#endif
ll test_cases=1;
cin>>test_cases;
precompute(MOD);
while(test_cases--){
solve();
}
cout<<fixed<<setprecision(10);
cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n";
}
Tester's code (C++)
//clear adj and visited vector declared globally after each test case
//check for long long overflow
//Incase of close mle change language to c++17 or c++14
//Check ans for n=1
#pragma GCC target ("avx2")
#pragma GCC optimize ("O3")
#pragma GCC optimize ("unroll-loops")
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#define int long long
#define IOS std::ios::sync_with_stdio(false); cin.tie(NULL);cout.tie(NULL);cout.precision(dbl::max_digits10);
#define pb push_back
#define mod 998244353ll
#define lld long double
#define mii map<int, int>
#define pii pair<int, int>
#define ll long long
#define ff first
#define ss second
#define all(x) (x).begin(), (x).end()
#define rep(i,x,y) for(int i=x; i<y; i++)
#define fill(a,b) memset(a, b, sizeof(a))
#define vi vector<int>
#define setbits(x) __builtin_popcountll(x)
#define print2d(dp,n,m) for(int i=0;i<=n;i++){for(int j=0;j<=m;j++)cout<<dp[i][j]<<" ";cout<<"\n";}
typedef std::numeric_limits< double > dbl;
using namespace __gnu_pbds;
using namespace std;
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> indexed_set;
//member functions :
//1. order_of_key(k) : number of elements strictly lesser than k
//2. find_by_order(k) : k-th element in the set
const long long N=2000005, INF=2000000000000000000;
const int inf=2e9 + 5;
lld pi=3.1415926535897932;
int lcm(int a, int b)
{
int g=__gcd(a, b);
return a/g*b;
}
int power(int a, int b, int p)
{
if(a==0)
return 0;
int res=1;
a%=p;
while(b>0)
{
if(b&1)
res=(1ll*res*a)%p;
b>>=1;
a=(1ll*a*a)%p;
}
return res;
}
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
int getRand(int l, int r)
{
uniform_int_distribution<int> uid(l, r);
return uid(rng);
}
int fact[N], inv[N];
void pre()
{
fact[0]=inv[0]=1;
rep(i,1,N)
fact[i]=(fact[i-1]*i)%mod;
rep(i,1,N)
inv[i]=power(fact[i], mod-2, mod);
}
int nCr(int n, int r)
{
if(min(n, r)<0 || r>n)
return 0;
if(n==r)
return 1;
return (((fact[n]*inv[r])%mod)*inv[n-r])%mod;
}
int32_t main()
{
IOS;
pre();
int t;
cin>>t;
while(t--)
{
int n;
cin>>n;
int mx=0;
rep(i,1,n+1)
{
int a;
cin>>a;
mx=max(mx, a);
}
int ans=(nCr(3*n-mx, n) - nCr(3*n-mx, n+1) + mod)%mod;
cout<<ans<<'\n';
}
}
Editorialist's code (Python)
mod = 998244353
lim = 2*10**6 + 5
fac = [1] * lim
for i in range(1, lim): fac[i] = fac[i-1] * i % mod
invf = [1]*lim
invf[-1] = pow(fac[-1], mod-2, mod)
for i in reversed(range(1, lim-1)): invf[i] = invf[i+1] * (i+1) % mod
def C(n, r):
if n < r or r < 0: return 0
return (fac[n] * invf[r] * invf[n-r]) % mod
for _ in range(int(input())):
n = int(input())
p = list(map(int, input().split()))
m = max(p)
ans = C(3*n - m, n)
ans -= C(3*n - m, n + 1)
print(ans % mod)