COUNTISFUN - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: satyam_343
Testers: iceknight1093, yash_daga
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Combinatorics

PROBLEM:

For two binary strings x and y of length N, define F(x, y) as follows:

  • For a permutation P of \{1, 2, \ldots, N\}, let C and D be defined as:
    • C_i = x_{P_i}
    • D_i = y_{P_i}
  • F(x, y) is the maximum value of \min(\text{LNDS}(C), \text{LNDS}(D)) across all permutations P.

Let S denote the set of all binary strings of length N. Compute

\sum_{x\in S} \sum_{y\in S} F(x, y)

EXPLANATION:

Let’s first figure out how to quickly compute F(x, y) for a fixed x and y.
Notice that the values x_i and y_i are essentially ‘tied’ to each other, i.e, need to be moved to the same final position.
This gives us 4 types of positions, depending on which of x_i and y_i are 0/1.

Let c_{00} denote the number of positions such that x_i = y_i = 0. I’ll call these 00-positions.
Similarly define c_{01}, c_{10}, c_{11}.
We need to figure out how to arrange these 4 types optimally to maximize the length of the minimum LNDS.

  • It’s always optimal to place 00-positions at the start of the array, since they’ll contribute to the LNDS of both strings.
  • For the same reason, it’s optimal to place 11-positions at the end of the array.
    • Together, they ensure the minimum LNDS has length at least c_{00} + c_{11}.
  • Now we have to deal with the 01- and 10-positions.
    Playing around a little with them, it can be seen that the best we can get is \max(c_{01}, c_{10}) + \frac{\min(c_{01}, c_{10})}{2} (where the division is floor division).
Proof

Our claim is that if we have x 01-positions and y 10-positions, the longest possible minimum LNDS has length \max(x, y) + {\min(x, y)}{2}.

This will be proved in two parts: a lower bound and an upper bound.

Lower bound

First, let’s prove the lower bound by constructing a sequence that has this as its minimum LNDS.
w.l.o.g let x \geq y, so \max(x, y) = x and \min(x, y) = y.

  • First place y/2 occurrences of 10.
  • Then place all x occurrences of 01.
  • Finally, place all remaining occurrences of 10.

So, the first string looks like \underbrace{111\ldots 11}_{y/2} \ \underbrace{00\ldots 00}_{x}\ \underbrace{11\ldots 11}_{y-y/2} and the second looks like \underbrace{000\ldots 00}_{y/2} \ \underbrace{11\ldots 11}_{x}\ \underbrace{00\ldots 00}_{y-y/2}

The first y/2 + x characters of the first string are non-decreasing, while the last x+y/2 characters of the second string are non-decreasing.
This gives us the required lower bound: the minimum LNDS has length at least x + y/2.

Upper bound

Now for the upper bound.
Consider a configuration where the first binary string has LNDS of length \gt x + y/2.
In particular, this means that the last y/2 + 1 characters of this LNDS must be all be 1; and everything after them must be 0.
Further, we can also assume that there are no zeros between the ones (if there were, they don’t contribute to the LNDS anyway; so we can just move them to the end and it helps out the second string more).

Suppose there are z zeros after these y/2 + 1 positions.
The LNDS of the first string then has length at most x+y-z, since the last z positions can’t contribute.

Now let’s look at the LNDS of the second string.

  • If the LNDS doesn’t include any of these y/2+1 positions, then it’ll have length strictly less than x+y/2 (because there are less than x+y/2 remaining positions).
  • If the LNDS does include one of them, it might as well include them all: our setup guarantees that the values at these positions are 0 in the second string, and there are no ones inbetween them.
  • In particular, the LNDS has length at most y + z (where z is as defined above).

We’d like y + z \geq x + y/2, which gives us z \geq x - y/2 \geq y/2.
In particular, this means the LNDS of the first string has length \leq x+y-y/2 \leq x+y/2, a contradiction to us assuming it was strictly greater.

So, the minimum LNDS length cannot exceed x + y/2, and our upper bound is proved as well.

With this information in hand, let’s compute the answer.
Notice that the score of a pair of strings depends really on only two things:

  • How many pairs of same/different positions there are.
  • Among the different positions, how many of them are 01 and how many are 10.

Since N\leq 2000, we can fix both.
Suppose there are k positions such that x_i = y_i; and m of the remaining N-k positions are 01.
Then,

  • The LNDS length of such a pair is k + \max(x, y) + \min(x, y)/2, where x = m and y = N-k-m. Let this be \text{len}.
  • The number of such pairs can be computed as follows:
    • Fix the set of k equal positions, which can be done in \binom{N}{k} ways.
    • For each equal position, fix whether it’s 00 or 11. This is 2^k choices.
    • Of the remaining N-k positions, choose which m of them are 01-positions. This is \binom{N-k}{m} choices.
  • Multiply all three values above to obtain the number of pairs of strings, say \text{count}.
  • Then, the answer increases by \text{len}\times \text{count}, since \text{count} pairs of strings have an answer of \text{len}.

Do this for all k from 0 to N and m from 0 to N-k to obtain the final answer.

For a fixed k and m, we need to compute a couple of binomial coefficients and a power of 2.
The former can be done by precomputing factorials/inverse factorials (or even precomputing all 2000^2 relevant binomials using Pascal’s formula), while the latter can be done using binary exponentiation.

TIME COMPLEXITY:

\mathcal{O}(N^2) or \mathcal{O}(N^2 \log{MOD}) per testcase.

CODE:

Setter's code (C++)
//#pragma GCC target ("avx2")
#pragma GCC optimize ("O3")
#pragma GCC optimize ("unroll-loops")


#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;   
using namespace std;  
#define ll long long
const ll INF_MUL=1e13; 
const ll INF_ADD=1e18; 
#define pb push_back                 
#define mp make_pair          
#define nline "\n"                           
#define f first                                          
#define s second                                             
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()     
#define vl vector<ll>             
#define vvl vector<vector<ll>>    
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);    
#endif       
void _print(ll x){cerr<<x;}   
void _print(char x){cerr<<x;}     
void _print(string x){cerr<<x;}    
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());   
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=998244353;           
const ll MAX=500500;
vector<ll> fact(MAX+2,1),inv_fact(MAX+2,1);
ll binpow(ll a,ll b,ll MOD){
    ll ans=1;
    a%=MOD;  
    while(b){
        if(b&1)
            ans=(ans*a)%MOD;
        b/=2;
        a=(a*a)%MOD;  
    }
    return ans;
}
ll inverse(ll a,ll MOD){
    return binpow(a,MOD-2,MOD);
} 
void precompute(ll MOD){
    for(ll i=2;i<MAX;i++){  
        fact[i]=(fact[i-1]*i)%MOD;
    }
    inv_fact[MAX-1]=inverse(fact[MAX-1],MOD);
    for(ll i=MAX-2;i>=0;i--){
        inv_fact[i]=(inv_fact[i+1]*(i+1))%MOD;
    }
}
ll nCr(ll a,ll b,ll MOD){
    if((a<0)||(a<b)||(b<0))
        return 0;   
    ll denom=(inv_fact[b]*inv_fact[a-b])%MOD;
    return (denom*fact[a])%MOD;    
}  
vector<ll> power(MAX,1);
void solve(){
    ll n; cin>>n; 
    ll ans=(n*power[2*n])%MOD;
    for(ll l=0;l<=n;l++){   
        for(ll r=0;l+r<=n;r++){   
            ll now=min(l,r)+1; 
            now/=2;  
            ll ways=nCr(n,l,MOD)*nCr(n-l,r,MOD);  
            ways%=MOD;  
            ways=(ways*power[n-l-r])%MOD;  
            ans=(ans-now*ways)%MOD;  
        }    
    }  
    ans=(ans+MOD)%MOD;   
    cout<<ans<<nline;   
    return;                                                  
}                                                                  
int main()                                                                                                           
{                  
    ios_base::sync_with_stdio(false);                                    
    cin.tie(NULL);               
    #ifndef ONLINE_JUDGE                          
    freopen("input.txt", "r", stdin);                                                    
    freopen("output.txt", "w", stdout);      
    freopen("error.txt", "w", stderr);                          
    #endif                            
    ll test_cases=1;                
    cin>>test_cases;
    precompute(MOD); 
    for(ll i=1;i<MAX;i++){
        power[i]=(power[i-1]*2)%MOD;
    }
    while(test_cases--){  
        solve();
    }  
    cout<<fixed<<setprecision(10);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n";   
}   
Tester's code (C++)
#include <bits/stdc++.h>                   
#define int long long     
using namespace std;
const int mod=998244353, N=2005;

int power(int a, int b, int p) {
    if(a==0)
    return 0;
    int res=1;
    a%=p;
    while(b>0) {
        if(b&1)
        res=(1ll*res*a)%p;
        b>>=1;
        a=(1ll*a*a)%p;
    }
    return res;
}
int fact[N], inv[N], po2[N];
void pre()
{
    fact[0]=inv[0]=po2[0]=1;
    for(int i=1;i<N;i++)
        po2[i]=(po2[i-1]*2)%mod;
    for(int i=1;i<N;i++)
        fact[i]=(fact[i-1]*i)%mod;
    for(int i=1;i<N;i++)
        inv[i]=power(fact[i], mod-2, mod);
}
int nCr(int n, int r)
{
    if(min(n, r)<0 || r>n)
    return 0;
    if(n==r)
    return 1;
    return (((fact[n]*inv[r])%mod)*inv[n-r])%mod;
}

int32_t main() {
    pre();
    int t;
    cin>>t;
    while(t--)
    {
        int n, ans=0;
        cin>>n;
        for(int i=0;i<=n;i++)
        {
            for(int j=0;(j+i)<=n;j++)
            {
                int same=i, x=j, y=n-same-x;
                int count=(nCr(n, same)*po2[i]%mod*nCr(n-same, x))%mod;
                int fxy=(same + max(x, y) + min(x, y)/2);
                ans=(ans + (count*fxy))%mod;
            }
        }
        cout<<ans<<"\n";
    }
}
Editorialist's code (Python)
mod = 998244353
maxN = 2005
C = [ [0 for _ in range(maxN)] for _ in range(maxN)]

for i in range(maxN):
	C[i][0] = 1
	for j in range(1, i+1):
		C[i][j] = (C[i-1][j] + C[i-1][j-1]) % mod

sumn = 0
for i in range(int(input())):
	n = int(input())
	ans = 0
	for same in range(n+1):
		dif = n - same

		# fix which positions are the same: C(n, same) ways
		# fix which values they take: 2^same ways
		# fix number of 01 positions, x: C(dif, x) ways
		# length = same + max(x, dif-x) + min(x, dif-x)/2
		
		ways = (C[n][same] * pow(2, same, mod)) % mod
		for x in range(dif+1):
			y = dif - x
			ans += (same + max(x, y) + min(x, y)//2)*C[dif][x]*ways
			ans %= mod
	sumn += n
	print(ans)
1 Like

I just precomputed answers with a cubic solution :upside_down_face: Took like 5 minutes but I was thinking about another problem during this time so yeah. Update: oh wait, O(N^2) per test case works? I thought for sure it was not going to pass… Update 2: OMG, I didn’t see the sum of N constraint rip

rip, I didn’t expect anyone to actually do that.

It’s something I brought up while testing, but every cubic solution we had was easily optimized to quadratic (especially given the skill level of someone who could come up with cubic in the first place) so introducing a variable mod and changing the test cases didn’t seem worth it.

Lucky for you, I guess :slight_smile: