PIVOTALREV - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: archit
Editorialist: raysh07

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Basic Combinatorics

PROBLEM:

You can operate on a binary string S as follows: choose S_i = 1, and then swap S_{i - 1} and S_{i + 1}. Let f(S, T) denote minimum number of operations, or 0 if not possible.

Find sum of f(S, T) over all 2^N possible strings T.

EXPLANATION:

First, let us observe how the operation changes the sequence. We may assume S_{i - 1} \ne S_{i + 1} as otherwise the operation is wasteful.

Then, we change 011 to 110 or 110 to 011, i.e. shifting the 0 by 2 places to left or right.

Let c_i denote the number of 1 between the i-th and the (i + 1)-th 0 s. Then, note that we replace c_i, c_{i + 1} by either (c_i - 2, c_{i + 1} + 2) or (c_i + 2, c_{i + 1} - 2), and all such moves are possible as long as all c_i remain positive.

This gives us a very nice characterization of operations.


Let us try to find when S can be converted to T. Obviously, number of 0 in both should be same, since operations don’t change that.

Further, define c_i vector as above, and d_i as the same vector but for T.

Then, another necessary (and sufficient) condition is that c_i \mod 2 = d_i \mod 2.

To compute the minimum cost, we can consider the minimum number of operations we need at each index i. Consider the i-th 0 as a barrier, and divide S and T into 2 halves based on that.

Let x_0 be the number of 1 s in S before the i-th 0 and y_0 be the number of 1 s in T before the i-th 0. Then, we need at least |\dfrac{x_0 - y_0}{2}| operations at i to make the prefixes have equal 1 s, which we can then permute.

Adding this up over all i gives us a lower bound on the number of operations. Infact this lower bound is achievable. At every step, we can operate on the first prefix with a deficit/excess in the number of 1 s in S.


To sum up, we characterize strings S, T by sequences c, d satisfying the following conditions:

  • \sum c_i = \sum d_i = N - z, where z is the number of 0 in S
  • |c_i| = |d_i| = z + 1
  • c_i \mod 2 = d_i \mod 2

And the number of operations is:

  • \sum_{i = 1}^{Z} |\dfrac{\sum_{j = 1}^{i} (c_j - d_j)}{2} |, i.e. sum of prefix balances divided by 2.

For counting over all strings T, we can instead count over all sequences d. Recall the operation cost, which was sum over prefix balances. We can fix \sum_{j = 1}^{i} d_j = X and then calculate the contribution to cost for all such sequences. The contribution will just be |\dfrac{X - S}{2}|, where S = \sum_{j = 1}^{i} c_j.

But, we also need to find exactly how many sequences/strings there exist with such values, i.e. how many sequences satisfying the conditions on d have \sum_{j = 1}^{i} d_i = X.

This is fairly standard Stars and Bars Approach.

First of all, subtract c_i \mod 2 from each d_i and then divide it by 2 to get a new sequence e_i. Now, the condition on e_i is solely that \sum e_i = \frac{N - Z - Q}{2} = W (let), where Q = \sum_{j = 1}^{Z + 1} C_j \mod 2, and we got rid of the parity constraint.

Also instead of \sum_{j = 1}^{i} d_i = X, we want to count number of e with \sum_{j = 1}^{i} e_i = \frac{X - P}{2} = V(let) where P = \sum_{j = 1}^{i} C_j \mod 2 .

The number of sequences with Stars and Bars is C(V + i - 1, i - 1) \cdot C(W + Z - i, Z - i - 1) using the fact that there are C(N + R - 1, N - 1) solutions to the equation x_1 + x_2 + ... + x_n = r. Here, C(n, r) = \dfrac{n!}{r! (n - r)!}

Finally, the problem can be solved in O(N^2) time by fixing all i and x, and then evaluating the cost contribution.

TIME COMPLEXITY:

\mathcal{O}(N^2) per testcase.

CODE:

Editorialist's Code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

const int mod = 998244353;

struct mint{
    int x;

    mint (){ x = 0;}
    mint (int32_t xx){ x = xx % mod; if (x < 0) x += mod;}
    mint (long long xx){ x = xx % mod; if (x < 0) x += mod;}

    int val(){
        return x;
    }
    mint &operator++(){
        x++;
        if (x == mod) x = 0;
        return *this;
    }
    mint &operator--(){
        if (x == 0) x = mod;
        x--;
        return *this;
    }
    mint operator++(int32_t){
        mint result = *this;
        ++*this;
        return result;
    }
    
    mint operator--(int32_t){
        mint result = *this;
        --*this;
        return result;
    }
    mint& operator+=(const mint &b){
        x += b.x;
        if (x >= mod) x -= mod;
        return *this;
    }
    mint& operator-=(const mint &b){
        x -= b.x;
        if (x < 0) x += mod;
        return *this;
    }
    mint& operator*=(const mint &b){
        long long z = x;
        z *= b.x;
        z %= mod;
        x = (int)z;
        return *this;
    }
    mint operator+() const {
        return *this;
    }
    mint operator-() const {
        return mint() - *this;
    }
    mint operator/=(const mint &b){
        return *this = *this * b.inv();
    }
    mint power(long long n) const {
        mint ok = *this, r = 1;
        while (n){
            if (n & 1){
                r *= ok;
            }
            ok *= ok;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        return power(mod - 2);
    }
    friend mint operator+(const mint& a, const mint& b){ return mint(a) += b;}
    friend mint operator-(const mint& a, const mint& b){ return mint(a) -= b;}
    friend mint operator*(const mint& a, const mint& b){ return mint(a) *= b;}
    friend mint operator/(const mint& a, const mint& b){ return mint(a) /= b;}
    friend bool operator==(const mint& a, const mint& b){ return a.x == b.x;}
    friend bool operator!=(const mint& a, const mint& b){ return a.x != b.x;}
    mint power(mint a, long long n){
        return a.power(n);
    }
    friend ostream &operator<<(ostream &os, const mint &m) {
        os << m.x;
        return os;
    }
    explicit operator bool() const {
        return x != 0;
    }
};

// Remember to check MOD

struct factorials{
    int n;
    vector <mint> ff, iff;
    
    factorials(int nn){
        n = nn;
        ff.resize(n + 1);
        iff.resize(n + 1);
        
        ff[0] = 1;
        for (int i = 1; i <= n; i++){
            ff[i] = ff[i - 1] * i;
        }
        
        iff[n] = ff[n].inv();
        for (int i = n - 1; i >= 0; i--){
            iff[i] = iff[i + 1] * (i + 1);
        }
    }
    
    mint C(int n, int r){
        if (n == r) return mint(1);
        if (n < 0 || r < 0 || r > n) return mint(0);
        return ff[n] * iff[r] * iff[n - r];
    }
    
    mint P(int n, int r){
        if (n < 0 || r < 0 || r > n) return mint(0);
        return ff[n] * iff[n - r];
    }
    
    mint solutions(int n, int r){
        // Solutions to x1 + x2 + ... + xn = r, xi >= 0 
        return C(n + r - 1, n - 1);
    }
    
    mint catalan(int n){
        return ff[2 * n] * iff[n] * iff[n + 1];
    }
};

const int PRECOMP = 3e6 + 69;
factorials F(PRECOMP);

// REMEMBER To check MOD and PRECOMP

void Solve() 
{
    int n; cin >> n;
    string s; cin >> s;
    
    vector <int> ones;
    vector <int> pos;
    pos.push_back(-1);
    for (int i = 0; i < n; i++){
        if (s[i] == '0'){
            pos.push_back(i);
        }
    }
    pos.push_back(n);
    
    for (int i = 1; i < pos.size(); i++){
        ones.push_back(pos[i] - pos[i - 1] - 1);
    }
    
    int extra = 0;
    for (auto x : ones){
        extra += x / 2;
    }
    
    // distribute extra to all of them
    // what is cost? delta 
    
    mint ans = 0;
    int m = ones.size();
    int got = 0;
    for (int i = 0; i + 1 < m; i++){
        got += (ones[i] / 2);
        for (int s = 0; s <= extra; s++){
            mint ways = F.solutions(i + 1, s) * F.solutions(m - i - 1, extra - s);
            
            ans += ways * abs(got - s);
        }
    }
    
    cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}