BINGRIDFORCE - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Authors: raysh_07 and everule1
Tester: mridulahi
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Elementary combinatorics

PROBLEM:

You have N binary strings, each of length M. The i-th is S_i.
You can do the following:

  • Choose a triple (i, j, k) such that 1 \leq i \lt N, 1 \leq j \lt k \leq M, and S_{i+1, j} = S_{i+1, k} = 1.
  • Then, swap S_{i, j} and S_{i, k}.

Find the number of distinct strings S_1 attainable via this process.

EXPLANATION:

Let’s start from the bottom of the grid and work our way upwards.

S_N can’t be changed at all, so there’s only 1 possibility for it.
Next, let’s look at S_{N-1}.
Note that:

  • If S_{N, i} = 0, then S_{N-1, i} can’t be moved, and is fixed.
  • On the other hand, if S_{N, i} = 1, then S_{N-1, i} can be swapped to some other position where S_N contains a 1.
    In fact, if we look at only positions where S_N contains a 1, the values at these positions in S_{N-1} can be freely rearranged since any pair among them can be swapped.

Next, look at S_{N-2}.
Once again, the elements at positions that are 1's in S_{N-1} can be freely moved around.
However, there’s more: the power to rearrange some part of S_{N-1} gives us the ability to choose which set of positions contain 1's in S_{N-1} (to some extent).

In particular, let’s define P_i to be the set of indices that can be freely rearranged in S_i.
If an index is not in P_i, that means it’s fixed.
For instance, P_N is empty (because all the elements of S_N are fixed), and P_{N-1} consists of all those indices i such that S_{N, i} = 1.

Let’s try to compute P_{N-2} given that we know P_{N-1}.

  • First, if S_{N-1} contains zero or one occurrence of 1, P_{N-2} will be empty: after all, there’s no way to rearrange the string.
  • Otherwise, P_{N-2} will definitely consist of, at the very least, all those indices that are 1 in S_{N-1}.
    However, as noted earlier, it might be larger: some one those ones in S_{N-1} can be moved to different indices depending on the contents of P_{N-1}.
  • In particular, we have the following:
    • If there exists an index i \in P_{N-1} such that S_{N-1, i} = 1, then P_{N-2} will contain all of P_{N-1} (in addition to any indices that contain ones).
      This is because this 1 can be shuffled around among the indices of P_{N-1}, which gives us enough freedom with swaps.
    • If no such i\in P_{N-1} exists, P_{N-2} will consist of only those indices in P_{N-1} that contain ones.

So, knowing P_{N-1} and S_{N-1}, it’s quite easy to compute P_{N-2} in \mathcal{O}(M) time: the only thing we need to check is whether there’s any “intersection” between P_{N-1} and ones in S_{N-1}.

Simply repeat this process over and over again, till you finally compute P_1: the set of free positions in S_1.
Now, we know that:

  • Indices not in P_1 have their values fixed.
  • Indices in P_1 can have their values freely shuffled around.
    If there are x zeros and y ones, we can thus obtain \binom{x+y}{x} distinct strings, by choosing which positions the zeros occur at.

Computing the final binomial coefficient will require division, which in the presence of a modulo can be done with the help of multiplicative inverses.

TIME COMPLEXITY:

\mathcal{O}(N\cdot M) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
#define f first
#define s second

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
const int mod = 1e9 + 7;

int power(int x, int y){
    if (y==0) return 1;
    
    int v = power(x, y/2);
    v *= v; v %= mod;
    
    if (y & 1) return v * x % mod;
    else return v;
}

void Solve() 
{
    int n, m; cin >> n >> m;
    
    string s[n];
    for (int i = 0; i < n; i++) cin >> s[i];
    
    vector<vector<bool>> dp(n, vector<bool>(m, 0));
    
    if (n == 1){
        cout << 1 << '\n';
        return;
    }
    
    for (int i = 0; i < m; i++){
        if (s[n -1][i] == '1') dp[n - 1][i] = true;
    }
    
    for (int i = n - 2; i > 0; i--){
        int ones = 0;
        for (int j = 0; j < m; j++){
            if (s[i][j] == '1') ones++;
        }
        
        if (ones < 2) continue;
        
        bool inter = false;
        for (int j = 0; j < m; j++){
            if (s[i][j] == '1' && dp[i + 1][j]) inter = true;
        }
        
        for (int j = 0; j < m; j++){
            if (inter && dp[i + 1][j]) dp[i][j] = true;
            else if (s[i][j] == '1') dp[i][j] = true;
        }
    }
    
    int c0 = 0, c1 = 0;
    
    for (int i = 0; i < m; i++){
        if (dp[1][i]) {
            if (s[0][i] == '1') c1++;
            else c0++;
        }
    }
    
 //   cout << c0 << " " << c1 << "\n";
    int ans = 1;
    
    for (int i = 1; i <= c0 + c1; i++) ans *= i, ans %= mod;
    for (int i = 1; i <= c0; i++) ans *= power(i, mod - 2), ans %= mod;
    for (int i = 1; i <= c1; i++) ans *= power(i, mod - 2), ans %= mod;
    
    cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
   // cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++)
#include<bits/stdc++.h>

using namespace std;

#define all(x) begin(x), end(x)
#define sz(x) static_cast<int>((x).size())
#define int long long

const int mod = 1e9 + 7;

template<class T>
T power (T a, int b) {
        T res = 1;
        for (; b; b /= 2, a = a * a % mod) {
                if (b % 2) res = res * a % mod;
        }
        return res;
}


signed main() {

        ios::sync_with_stdio(0);
        cin.tie(0);
        cout.tie(0);

        int t;
        cin >> t;

        while (t--) {

                int n, m;
                cin >> n >> m;
                string s[n];
                for (int i = 0; i < n; i++) cin >> s[i];

                if (n == 1) {
                        cout << 1 << "\n";
                        continue;
                }
                vector<int> v(m, 0);
                for (int i = 0; i < m; i++) {
                        if (s[n - 1][i] == '1') v[i] = 1;
                }

                vector<int> v2(m, 0);
                for (int i = n - 2; i > 0; i--) {
                        bool intersect = 0;
                        int cn = 0;
                        for (int j = 0; j < m; j++) {
                                if (s[i][j] == '1') {
                                        v2[j] = 1;
                                        cn++;
                                        if (v[j]) intersect = 1;
                                }
                        }
                        if (intersect && cn > 1) {
                                for (int j = 0; j < m; j++) v2[j] |= v[j];
                        }
                        swap(v, v2);
                        fill(all(v2), 0ll);
                }

                int c0 = 0, c1 = 0;
                for (int i = 0; i < m; i++) {
                        if (v[i] && s[0][i] == '0') c0++;
                        if (v[i] && s[0][i] == '1') c1++; 
                }
                int ans = 1;
                for (int i = c1 + 1; i <= c0 + c1; i++) ans = ans * i % mod;
                for (int i = 2; i <= c0; i++) ans = ans * power(i, mod - 2) % mod; 

                cout << ans << "\n";

        }

        
}
Editorialist's code (Python)
mod = 10**9 + 7
for _ in range(int(input())):
    n, m = map(int, input().split())
    L = [input() for _ in range(n)]
    L.append('0'*m)
    mark = [0]*m
    for i in reversed(range(1, n)):
        keep = False
        marknew = [0]*m
        for j in range(m):
            if L[i][j] == '0': continue
            if mark[j]: keep = True
            marknew[j] = 1
        mark, marknew = marknew, mark
        if keep and sum(mark) > 1:
            for j in range(m): mark[j] |= marknew[j]
    zeros, ones = 0, 0
    for i in range(m):
        if mark[i]:
            zeros += L[0][i] == '0'
            ones += L[0][i] == '1'
    ans = 1
    for i in range(zeros):
        ans = ans*(zeros + ones - i) % mod
        ans = ans*pow(i+1, mod-2, mod) % mod
    print(ans)

Can someone explain to me the answer for this test case:
1
3 5
11000
01010
11101

The answer given by AC code is 10.
But, I think the answer should be 3. Let me explain:
There is only one possible string L3: 11101
There are four possible string L2: 10010, 01010, 00110, 00011
There are three possible strings L1: 11000, 01010, 10010.

I would greatly appreciate the assistance if someone could guide me in obtaining the remaining L1 strings.

3 Likes

lets consider L1: 01010
now for this L2 is 10010 which can be changed to 00110 right .
Based on this L1 can be converted to 00110.
these type of cases you have missed i.e after you are modifying L1 based on L2 you are stoping but later L2 can be changed and thus L1 further can be changed.

2 Likes

yeah, got it!
Thank you for explaining.