BALSUB7 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Combinatorics

PROBLEM:

For a binary string S and an integer K, define f(S, K) to be the minimum sum of costs of choosing K non-intersecting substrings from S; where the cost of a substring equals the difference between its counts of 0 and 1.

Define g(S) = \sum_{k=1}^{|S|} f(S, k).
Given N, compute the sum of g(S)^S across all binary strings S of length N.

EXPLANATION:

The first step is, of course, understanding how to compute f(S, K) for a fixed S.

We want to choose K non-intersecting substrings.
Because we’re allowed to skip elements (as in, the chosen substrings need not partition S), it’s obvious that f(S, K) does not exceed K, simply because every chosen substring in an optimal solution will never have cost \gt 1.

This is easy to prove: if we have a substring with cost \gt 1, it certainly has length \gt 1 as well.
But then we could keep just the first character of the substring and discard the rest, which would lower the cost of this substring to 1 while maintaining disjoint-ness.

So, the minimum possible cost is equivalent to just minimizing the number of substrings of cost 1 that we choose - and so equivalently, maximizing the number of substrings of cost 0.

Further, observe that:

  • If we choose a cost 1 substring, it’s optimal for it to be a single character.
    This is because if we have anything longer, we can simply drop all but one character from it.
  • By the same token, if we choose a cost 0 substring, it’s optimal for it to be either 01 or 10.
    The reason is similar: any longer cost-0 substring will have one of 01 or 10 as a substring so we can reduce to that.

Let us now try to maximize the number of cost-0 substrings we choose.

This can, in fact, simply be done greedily!
That is, process the string from left to right. Whenever you encounter a substring of the form 01 or 10 that doesn’t intersect a previously chosen substring, choose it.
The proof of optimality is simple: if S[i, i+1] is the leftmost occurrence of such a substring, and we don’t choose it, then the first chosen substring must start at an index \ge i+1 (because every character till index i must be the same.)
However,

  • If the first chosen substring starts at an index \gt i+1, then we can freely include [i, i+1] into the answer and improve it. So this cannot be the case.
  • If the first chosen substring is [i+1, i+2], we can simply replace it by [i, i+1] which preserves disjoint-ness while maintaining optimality of the answer.

So, quite simply, the greedy algorithm is optimal.

Let M denote the maximum number of cost-0 substrings that can be obtained, i.e. the result of this greedy algorithm.
Then, we can see that:

  • If 1 \le K \le M, then f(S, K) = 0 since we can ensure that only cost-0 substrings are chosen.
  • If M \lt K \le N-M, we have f(S, K) = K-M.
    This is because we have (N - 2M) elements that are free after choosing M cost-0 substrings.
    Each such element can contribute to one subarray while increasing the cost by 1.
    This covers up to M + (N - 2M) = N - M subarrays.
  • Finally, for N-M \lt K \le N, we have f(S, K) = N-2M + 2\cdot (K - N + M) = 2K-N
    This is because, for such large K, it’s impossible to keep M cost-0 substrings.
    So, we’re forced to start breaking each of them into two cost-1 substrings, so that each step K takes beyond N-M adds a cost of 2.

Finally, we have g(S) to be the sum of the above quantity across all K.
Observe that considering all three cases, this becomes the sum of:

  • 1+2+\ldots + (N-2M), which can be found in \mathcal{O}(1) time if we know M; and
  • (N-2M+2) + (N-2M+4) + \ldots + N, which can also be found in \mathcal{O}(1) time if we know M.

Define sc(M) = (1 + 2 + \ldots + (N-2M)) + ((N-2M+2) + (N-2M+4) + \ldots + N) to be the value of g(S) for a string S with parameter M.

Also define ct(M) to be the number of binary strings of length N that have a parameter of M, i.e. the maximum number of disjoint 0-cost substrings is M.

If we’re able to compute all the values ct(0), ct(1), \ldots, ct(N), then the final answer is simply

\sum_{M=0}^N ct(M) \cdot (sc(M) ^ N)

We already know that sc(M) can be computed in constant time once M is known.
So, let’s focus on ct(M).

We want M disjoint substrings, each of them being either 01 or 10.
Let’s arrange all of these substrings in a row first, to obtain a string of length 2M.
Further, for each of them we can independently choose whether it equals 01 or 10, for 2^M options.

Now, we need to separate them by some characters to reach a length of N.
Suppose there are x_i characters between the i-th and (i-1)-th 0-cost substrings.
(In particular, x_1 characters before the first substring, and x_{M+1} characters after the last one.)

Then, note that these separating characters are (almost) fixed uniquely!
Specifically,

  • If the first substring equals 01, then the first x_1 characters must all equal 0.
    If it equals 10 instead, then the first x_1 characters must all equal 1.
    This is because if any of them equals 1, we’d have an earlier occurrence of 10 or 01, but our greedy algorithm to compute M tells us that this cannot happen.
  • Similarly, depending on if the second substring equals 01 or 10, the next batch of x_2 characters must all equal 0 or 1, respectively.
  • This applies to all of x_3, x_4, \ldots, x_M as well - our choice of 10 or 01 uniquely determines them all.
  • The only outlier is the last batch of x_{M+1} characters.
    These must all be equal (otherwise we’d have more than M), but they can either all be 0 or all be 1.

Using this, we can now count the number of valid strings.
After the initial 2^M choices, we have two options:

  1. x_{M+1} = 0, i.e. there’s no trailing block of equal characters.
    Here, each possible choice of the tuple (x_1, x_2, \ldots, x_M) will give us a unique final string.
    The only constraint is that each x_i must be \ge 0, and their sum must equal N-2M.
    Counting the number of valid choices of (x_i) is exactly what is done by stars-and-bars, and it comes out to be \binom{N-M-1}{M-1}.
  2. x_{M+1} \gt 0, i.e. there’s a non-empty trailing block.
    Here, similarly we want to count the number of configurations of the tuple (x_1, \ldots, x_{M+1}) such that each element is non-negative, the last element is positive, and their sum equals N-2M.
    Once again this can be done by stars-and-bars, to obtain \binom{N-M-1}{M} configurations.
    This needs to be further multiplied by 2 to account for the last block being either all 0’s or all 1’s.

So, we have

ct(M) = 2^M \cdot \left(\binom{N-M-1}{M-1} + 2\cdot \binom{N-M-1}{M}\right)

Thus, with quick computation of binomial coefficients and modular exponentiation, the overall answer can be computed in \mathcal{O}(N \log N) time, so we’re done.

TIME COMPLEXITY:

\mathcal{O}(N \log N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

const int mod = 998244353;

struct mint{
    int x;

    mint (){ x = 0;}
    mint (int32_t xx){ x = xx % mod; if (x < 0) x += mod;}
    mint (long long xx){ x = xx % mod; if (x < 0) x += mod;}

    int val(){
        return x;
    }
    mint &operator++(){
        x++;
        if (x == mod) x = 0;
        return *this;
    }
    mint &operator--(){
        if (x == 0) x = mod;
        x--;
        return *this;
    }
    mint operator++(int32_t){
        mint result = *this;
        ++*this;
        return result;
    }
    
    mint operator--(int32_t){
        mint result = *this;
        --*this;
        return result;
    }
    mint& operator+=(const mint &b){
        x += b.x;
        if (x >= mod) x -= mod;
        return *this;
    }
    mint& operator-=(const mint &b){
        x -= b.x;
        if (x < 0) x += mod;
        return *this;
    }
    mint& operator*=(const mint &b){
        long long z = x;
        z *= b.x;
        z %= mod;
        x = (int)z;
        return *this;
    }
    mint operator+() const {
        return *this;
    }
    mint operator-() const {
        return mint() - *this;
    }
    mint operator/=(const mint &b){
        return *this = *this * b.inv();
    }
    mint power(long long n) const {
        mint ok = *this, r = 1;
        while (n){
            if (n & 1){
                r *= ok;
            }
            ok *= ok;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        return power(mod - 2);
    }
    friend mint operator+(const mint& a, const mint& b){ return mint(a) += b;}
    friend mint operator-(const mint& a, const mint& b){ return mint(a) -= b;}
    friend mint operator*(const mint& a, const mint& b){ return mint(a) *= b;}
    friend mint operator/(const mint& a, const mint& b){ return mint(a) /= b;}
    friend bool operator==(const mint& a, const mint& b){ return a.x == b.x;}
    friend bool operator!=(const mint& a, const mint& b){ return a.x != b.x;}
    mint power(mint a, long long n){
        return a.power(n);
    }
    friend ostream &operator<<(ostream &os, const mint &m) {
        os << m.x;
        return os;
    }
    explicit operator bool() const {
        return x != 0;
    }
};

// Remember to check MOD

struct factorials{
    int n;
    vector <mint> ff, iff;
    
    factorials(int nn){
        n = nn;
        ff.resize(n + 1);
        iff.resize(n + 1);
        
        ff[0] = 1;
        for (int i = 1; i <= n; i++){
            ff[i] = ff[i - 1] * i;
        }
        
        iff[n] = ff[n].inv();
        for (int i = n - 1; i >= 0; i--){
            iff[i] = iff[i + 1] * (i + 1);
        }
    }
    
    mint C(int n, int r){
        if (n == r) return mint(1);
        if (n < 0 || r < 0 || r > n) return mint(0);
        return ff[n] * iff[r] * iff[n - r];
    }
    
    mint P(int n, int r){
        if (n < 0 || r < 0 || r > n) return mint(0);
        return ff[n] * iff[n - r];
    }
    
    mint solutions(int n, int r){
        // Solutions to x1 + x2 + ... + xn = r, xi >= 0 
        return C(n + r - 1, n - 1);
    }
    
    mint catalan(int n){
        return ff[2 * n] * iff[n] * iff[n + 1];
    }
};

const int PRECOMP = 3e6 + 69;
factorials F(PRECOMP);

// REMEMBER To check MOD and PRECOMP

void Solve() 
{
    int n; cin >> n;
    
    mint ans = 0;
    vector <mint> p2(n + 1, 1);
    for (int i = 1; i <= n; i++){
        p2[i] = p2[i - 1] * 2;
    }
    
    for (int i = 0; i <= n / 2; i++){
        // i 01 pairs 
        // case 1 : they perfectly take up the entire n 
        mint tot = 0;
        {
            mint ways = F.solutions(i, n - 2 * i) * p2[i];
            tot += ways;
        }
        {
            mint ways = F.solutions(i + 1, n - 2 * i - 1) * p2[i + 1];
            tot += ways;
        }
        
        // 0 i times 
        // 1 (n - 2i) times
        // 2 i times 
        
        // 2 * i * (i + 1) / 2 + 1 * sum(i + 1 to n - i)
        // i (i + 1) + (n - i) * (n - i + 1) / 2 - i * (i + 1) / 2
        mint val2 = i * (i + 1) + (n - i) * (n - i + 1) / 2 - i * (i + 1) / 2;
        ans += val2.power(n) * tot;
    }
    
    cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}