PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: satyam_343
Testers: iceknight1093, yash_daga
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Combinatorics
PROBLEM:
For two binary strings x and y of length N, define F(x, y) as follows:
- For a permutation P of \{1, 2, \ldots, N\}, let C and D be defined as:
- C_i = x_{P_i}
- D_i = y_{P_i}
- F(x, y) is the maximum value of \min(\text{LNDS}(C), \text{LNDS}(D)) across all permutations P.
Let S denote the set of all binary strings of length N. Compute
EXPLANATION:
Let’s first figure out how to quickly compute F(x, y) for a fixed x and y.
Notice that the values x_i and y_i are essentially ‘tied’ to each other, i.e, need to be moved to the same final position.
This gives us 4 types of positions, depending on which of x_i and y_i are 0/1.
Let c_{00} denote the number of positions such that x_i = y_i = 0. I’ll call these 00-positions.
Similarly define c_{01}, c_{10}, c_{11}.
We need to figure out how to arrange these 4 types optimally to maximize the length of the minimum LNDS.
- It’s always optimal to place 00-positions at the start of the array, since they’ll contribute to the LNDS of both strings.
- For the same reason, it’s optimal to place 11-positions at the end of the array.
- Together, they ensure the minimum LNDS has length at least c_{00} + c_{11}.
- Now we have to deal with the 01- and 10-positions.
Playing around a little with them, it can be seen that the best we can get is \max(c_{01}, c_{10}) + \frac{\min(c_{01}, c_{10})}{2} (where the division is floor division).
Proof
Our claim is that if we have x 01-positions and y 10-positions, the longest possible minimum LNDS has length \max(x, y) + {\min(x, y)}{2}.
This will be proved in two parts: a lower bound and an upper bound.
Lower bound
First, let’s prove the lower bound by constructing a sequence that has this as its minimum LNDS.
w.l.o.g let x \geq y, so \max(x, y) = x and \min(x, y) = y.
- First place y/2 occurrences of 10.
- Then place all x occurrences of 01.
- Finally, place all remaining occurrences of 10.
So, the first string looks like \underbrace{111\ldots 11}_{y/2} \ \underbrace{00\ldots 00}_{x}\ \underbrace{11\ldots 11}_{y-y/2} and the second looks like \underbrace{000\ldots 00}_{y/2} \ \underbrace{11\ldots 11}_{x}\ \underbrace{00\ldots 00}_{y-y/2}
The first y/2 + x characters of the first string are non-decreasing, while the last x+y/2 characters of the second string are non-decreasing.
This gives us the required lower bound: the minimum LNDS has length at least x + y/2.
Upper bound
Now for the upper bound.
Consider a configuration where the first binary string has LNDS of length \gt x + y/2.
In particular, this means that the last y/2 + 1 characters of this LNDS must be all be 1; and everything after them must be 0.
Further, we can also assume that there are no zeros between the ones (if there were, they don’t contribute to the LNDS anyway; so we can just move them to the end and it helps out the second string more).
Suppose there are z zeros after these y/2 + 1 positions.
The LNDS of the first string then has length at most x+y-z, since the last z positions can’t contribute.
Now let’s look at the LNDS of the second string.
- If the LNDS doesn’t include any of these y/2+1 positions, then it’ll have length strictly less than x+y/2 (because there are less than x+y/2 remaining positions).
- If the LNDS does include one of them, it might as well include them all: our setup guarantees that the values at these positions are 0 in the second string, and there are no ones inbetween them.
- In particular, the LNDS has length at most y + z (where z is as defined above).
We’d like y + z \geq x + y/2, which gives us z \geq x - y/2 \geq y/2.
In particular, this means the LNDS of the first string has length \leq x+y-y/2 \leq x+y/2, a contradiction to us assuming it was strictly greater.
So, the minimum LNDS length cannot exceed x + y/2, and our upper bound is proved as well.
With this information in hand, let’s compute the answer.
Notice that the score of a pair of strings depends really on only two things:
- How many pairs of same/different positions there are.
- Among the different positions, how many of them are 01 and how many are 10.
Since N\leq 2000, we can fix both.
Suppose there are k positions such that x_i = y_i; and m of the remaining N-k positions are 01.
Then,
- The LNDS length of such a pair is k + \max(x, y) + \min(x, y)/2, where x = m and y = N-k-m. Let this be \text{len}.
- The number of such pairs can be computed as follows:
- Fix the set of k equal positions, which can be done in \binom{N}{k} ways.
- For each equal position, fix whether it’s 00 or 11. This is 2^k choices.
- Of the remaining N-k positions, choose which m of them are 01-positions. This is \binom{N-k}{m} choices.
- Multiply all three values above to obtain the number of pairs of strings, say \text{count}.
- Then, the answer increases by \text{len}\times \text{count}, since \text{count} pairs of strings have an answer of \text{len}.
Do this for all k from 0 to N and m from 0 to N-k to obtain the final answer.
For a fixed k and m, we need to compute a couple of binomial coefficients and a power of 2.
The former can be done by precomputing factorials/inverse factorials (or even precomputing all 2000^2 relevant binomials using Pascal’s formula), while the latter can be done using binary exponentiation.
TIME COMPLEXITY:
\mathcal{O}(N^2) or \mathcal{O}(N^2 \log{MOD}) per testcase.
CODE:
Setter's code (C++)
//#pragma GCC target ("avx2")
#pragma GCC optimize ("O3")
#pragma GCC optimize ("unroll-loops")
#include <bits/stdc++.h>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
#define ll long long
const ll INF_MUL=1e13;
const ll INF_ADD=1e18;
#define pb push_back
#define mp make_pair
#define nline "\n"
#define f first
#define s second
#define pll pair<ll,ll>
#define all(x) x.begin(),x.end()
#define vl vector<ll>
#define vvl vector<vector<ll>>
#define vvvl vector<vector<vector<ll>>>
#ifndef ONLINE_JUDGE
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);
#endif
void _print(ll x){cerr<<x;}
void _print(char x){cerr<<x;}
void _print(string x){cerr<<x;}
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";}
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=998244353;
const ll MAX=500500;
vector<ll> fact(MAX+2,1),inv_fact(MAX+2,1);
ll binpow(ll a,ll b,ll MOD){
ll ans=1;
a%=MOD;
while(b){
if(b&1)
ans=(ans*a)%MOD;
b/=2;
a=(a*a)%MOD;
}
return ans;
}
ll inverse(ll a,ll MOD){
return binpow(a,MOD-2,MOD);
}
void precompute(ll MOD){
for(ll i=2;i<MAX;i++){
fact[i]=(fact[i-1]*i)%MOD;
}
inv_fact[MAX-1]=inverse(fact[MAX-1],MOD);
for(ll i=MAX-2;i>=0;i--){
inv_fact[i]=(inv_fact[i+1]*(i+1))%MOD;
}
}
ll nCr(ll a,ll b,ll MOD){
if((a<0)||(a<b)||(b<0))
return 0;
ll denom=(inv_fact[b]*inv_fact[a-b])%MOD;
return (denom*fact[a])%MOD;
}
vector<ll> power(MAX,1);
void solve(){
ll n; cin>>n;
ll ans=(n*power[2*n])%MOD;
for(ll l=0;l<=n;l++){
for(ll r=0;l+r<=n;r++){
ll now=min(l,r)+1;
now/=2;
ll ways=nCr(n,l,MOD)*nCr(n-l,r,MOD);
ways%=MOD;
ways=(ways*power[n-l-r])%MOD;
ans=(ans-now*ways)%MOD;
}
}
ans=(ans+MOD)%MOD;
cout<<ans<<nline;
return;
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
freopen("error.txt", "w", stderr);
#endif
ll test_cases=1;
cin>>test_cases;
precompute(MOD);
for(ll i=1;i<MAX;i++){
power[i]=(power[i-1]*2)%MOD;
}
while(test_cases--){
solve();
}
cout<<fixed<<setprecision(10);
cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n";
}
Tester's code (C++)
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int mod=998244353, N=2005;
int power(int a, int b, int p) {
if(a==0)
return 0;
int res=1;
a%=p;
while(b>0) {
if(b&1)
res=(1ll*res*a)%p;
b>>=1;
a=(1ll*a*a)%p;
}
return res;
}
int fact[N], inv[N], po2[N];
void pre()
{
fact[0]=inv[0]=po2[0]=1;
for(int i=1;i<N;i++)
po2[i]=(po2[i-1]*2)%mod;
for(int i=1;i<N;i++)
fact[i]=(fact[i-1]*i)%mod;
for(int i=1;i<N;i++)
inv[i]=power(fact[i], mod-2, mod);
}
int nCr(int n, int r)
{
if(min(n, r)<0 || r>n)
return 0;
if(n==r)
return 1;
return (((fact[n]*inv[r])%mod)*inv[n-r])%mod;
}
int32_t main() {
pre();
int t;
cin>>t;
while(t--)
{
int n, ans=0;
cin>>n;
for(int i=0;i<=n;i++)
{
for(int j=0;(j+i)<=n;j++)
{
int same=i, x=j, y=n-same-x;
int count=(nCr(n, same)*po2[i]%mod*nCr(n-same, x))%mod;
int fxy=(same + max(x, y) + min(x, y)/2);
ans=(ans + (count*fxy))%mod;
}
}
cout<<ans<<"\n";
}
}
Editorialist's code (Python)
mod = 998244353
maxN = 2005
C = [ [0 for _ in range(maxN)] for _ in range(maxN)]
for i in range(maxN):
C[i][0] = 1
for j in range(1, i+1):
C[i][j] = (C[i-1][j] + C[i-1][j-1]) % mod
sumn = 0
for i in range(int(input())):
n = int(input())
ans = 0
for same in range(n+1):
dif = n - same
# fix which positions are the same: C(n, same) ways
# fix which values they take: 2^same ways
# fix number of 01 positions, x: C(dif, x) ways
# length = same + max(x, dif-x) + min(x, dif-x)/2
ways = (C[n][same] * pow(2, same, mod)) % mod
for x in range(dif+1):
y = dif - x
ans += (same + max(x, y) + min(x, y)//2)*C[dif][x]*ways
ans %= mod
sumn += n
print(ans)