PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Authors: raysh_07 and everule1
Tester: mridulahi
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Elementary combinatorics
PROBLEM:
You have N binary strings, each of length M. The i-th is S_i.
You can do the following:
- Choose a triple (i, j, k) such that 1 \leq i \lt N, 1 \leq j \lt k \leq M, and S_{i+1, j} = S_{i+1, k} = 1.
- Then, swap S_{i, j} and S_{i, k}.
Find the number of distinct strings S_1 attainable via this process.
EXPLANATION:
Let’s start from the bottom of the grid and work our way upwards.
S_N can’t be changed at all, so there’s only 1 possibility for it.
Next, let’s look at S_{N-1}.
Note that:
- If S_{N, i} = 0, then S_{N-1, i} can’t be moved, and is fixed.
- On the other hand, if S_{N, i} = 1, then S_{N-1, i} can be swapped to some other position where S_N contains a 1.
In fact, if we look at only positions where S_N contains a 1, the values at these positions in S_{N-1} can be freely rearranged since any pair among them can be swapped.
Next, look at S_{N-2}.
Once again, the elements at positions that are 1's in S_{N-1} can be freely moved around.
However, there’s more: the power to rearrange some part of S_{N-1} gives us the ability to choose which set of positions contain 1's in S_{N-1} (to some extent).
In particular, let’s define P_i to be the set of indices that can be freely rearranged in S_i.
If an index is not in P_i, that means it’s fixed.
For instance, P_N is empty (because all the elements of S_N are fixed), and P_{N-1} consists of all those indices i such that S_{N, i} = 1.
Let’s try to compute P_{N-2} given that we know P_{N-1}.
- First, if S_{N-1} contains zero or one occurrence of 1, P_{N-2} will be empty: after all, there’s no way to rearrange the string.
- Otherwise, P_{N-2} will definitely consist of, at the very least, all those indices that are 1 in S_{N-1}.
However, as noted earlier, it might be larger: some one those ones in S_{N-1} can be moved to different indices depending on the contents of P_{N-1}. - In particular, we have the following:
- If there exists an index i \in P_{N-1} such that S_{N-1, i} = 1, then P_{N-2} will contain all of P_{N-1} (in addition to any indices that contain ones).
This is because this 1 can be shuffled around among the indices of P_{N-1}, which gives us enough freedom with swaps. - If no such i\in P_{N-1} exists, P_{N-2} will consist of only those indices in P_{N-1} that contain ones.
- If there exists an index i \in P_{N-1} such that S_{N-1, i} = 1, then P_{N-2} will contain all of P_{N-1} (in addition to any indices that contain ones).
So, knowing P_{N-1} and S_{N-1}, it’s quite easy to compute P_{N-2} in \mathcal{O}(M) time: the only thing we need to check is whether there’s any “intersection” between P_{N-1} and ones in S_{N-1}.
Simply repeat this process over and over again, till you finally compute P_1: the set of free positions in S_1.
Now, we know that:
- Indices not in P_1 have their values fixed.
- Indices in P_1 can have their values freely shuffled around.
If there are x zeros and y ones, we can thus obtain \binom{x+y}{x} distinct strings, by choosing which positions the zeros occur at.
Computing the final binomial coefficient will require division, which in the presence of a modulo can be done with the help of multiplicative inverses.
TIME COMPLEXITY:
\mathcal{O}(N\cdot M) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
#define f first
#define s second
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
const int mod = 1e9 + 7;
int power(int x, int y){
if (y==0) return 1;
int v = power(x, y/2);
v *= v; v %= mod;
if (y & 1) return v * x % mod;
else return v;
}
void Solve()
{
int n, m; cin >> n >> m;
string s[n];
for (int i = 0; i < n; i++) cin >> s[i];
vector<vector<bool>> dp(n, vector<bool>(m, 0));
if (n == 1){
cout << 1 << '\n';
return;
}
for (int i = 0; i < m; i++){
if (s[n -1][i] == '1') dp[n - 1][i] = true;
}
for (int i = n - 2; i > 0; i--){
int ones = 0;
for (int j = 0; j < m; j++){
if (s[i][j] == '1') ones++;
}
if (ones < 2) continue;
bool inter = false;
for (int j = 0; j < m; j++){
if (s[i][j] == '1' && dp[i + 1][j]) inter = true;
}
for (int j = 0; j < m; j++){
if (inter && dp[i + 1][j]) dp[i][j] = true;
else if (s[i][j] == '1') dp[i][j] = true;
}
}
int c0 = 0, c1 = 0;
for (int i = 0; i < m; i++){
if (dp[1][i]) {
if (s[0][i] == '1') c1++;
else c0++;
}
}
// cout << c0 << " " << c1 << "\n";
int ans = 1;
for (int i = 1; i <= c0 + c1; i++) ans *= i, ans %= mod;
for (int i = 1; i <= c0; i++) ans *= power(i, mod - 2), ans %= mod;
for (int i = 1; i <= c1; i++) ans *= power(i, mod - 2), ans %= mod;
cout << ans << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
// cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}
Tester's code (C++)
#include<bits/stdc++.h>
using namespace std;
#define all(x) begin(x), end(x)
#define sz(x) static_cast<int>((x).size())
#define int long long
const int mod = 1e9 + 7;
template<class T>
T power (T a, int b) {
T res = 1;
for (; b; b /= 2, a = a * a % mod) {
if (b % 2) res = res * a % mod;
}
return res;
}
signed main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int t;
cin >> t;
while (t--) {
int n, m;
cin >> n >> m;
string s[n];
for (int i = 0; i < n; i++) cin >> s[i];
if (n == 1) {
cout << 1 << "\n";
continue;
}
vector<int> v(m, 0);
for (int i = 0; i < m; i++) {
if (s[n - 1][i] == '1') v[i] = 1;
}
vector<int> v2(m, 0);
for (int i = n - 2; i > 0; i--) {
bool intersect = 0;
int cn = 0;
for (int j = 0; j < m; j++) {
if (s[i][j] == '1') {
v2[j] = 1;
cn++;
if (v[j]) intersect = 1;
}
}
if (intersect && cn > 1) {
for (int j = 0; j < m; j++) v2[j] |= v[j];
}
swap(v, v2);
fill(all(v2), 0ll);
}
int c0 = 0, c1 = 0;
for (int i = 0; i < m; i++) {
if (v[i] && s[0][i] == '0') c0++;
if (v[i] && s[0][i] == '1') c1++;
}
int ans = 1;
for (int i = c1 + 1; i <= c0 + c1; i++) ans = ans * i % mod;
for (int i = 2; i <= c0; i++) ans = ans * power(i, mod - 2) % mod;
cout << ans << "\n";
}
}
Editorialist's code (Python)
mod = 10**9 + 7
for _ in range(int(input())):
n, m = map(int, input().split())
L = [input() for _ in range(n)]
L.append('0'*m)
mark = [0]*m
for i in reversed(range(1, n)):
keep = False
marknew = [0]*m
for j in range(m):
if L[i][j] == '0': continue
if mark[j]: keep = True
marknew[j] = 1
mark, marknew = marknew, mark
if keep and sum(mark) > 1:
for j in range(m): mark[j] |= marknew[j]
zeros, ones = 0, 0
for i in range(m):
if mark[i]:
zeros += L[0][i] == '0'
ones += L[0][i] == '1'
ans = 1
for i in range(zeros):
ans = ans*(zeros + ones - i) % mod
ans = ans*pow(i+1, mod-2, mod) % mod
print(ans)