PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: bernarb01
Tester: mexomerf
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
DP
PROBLEM:
Count the number of N\times M grids containing the characters R,G,B
that contain at least one occurrence of RGB
as a contiguous sequence, either horizontally or vertically, forwards or backwards.
In this version, N \leq 6.
EXPLANATION:
This solution will continue from the editorial to the easy version.
Now that N \leq 6, our initial DP is too slow.
Unlike the medium version, we instead move to a slightly different solution, and optimize that instead.
Instead of column-by-column, let’s instead fill in the grid one cell at a time - but in a specific order.
We’ll first fill in the first column, top to bottom.
Then the second column, top to bottom, followed by the third, and so on.
Suppose we’re currently trying to fill in cell (i, j). Let’s look at the information we need to know.
- The values in cells (i-1, j) and (i-2, j) (if they exist), to determine whether there’s a vertical
RGB
substring. - The values in cells (i, j-1) and (i, j-2) (if they exist), to determine whether there’s a horizontal
RGB
substring.
Of course, one way to maintain this information is to store the last two columns, and also the values placed in this column so far.
Our transitions are now constant-time: we only need to check whether each of R, G, B
can be placed at (i, j) without breaking anything.
However, we instead have about 3^{3N} possible states for each cell overall - so our overall time complexity hasn’t really changed at all.
We can improve this, though.
Note that when we’re placing (i, j), information about cells of the form (i-2, x) for x \lt j doesn’t really matter anymore - those cells cannot affect any future placements.
So, instead of storing the entire (i-2)-th column, we store it only partially - specifically, we store:
- All previously placed elements in the current column - of which there are j-1.
- The entire (i-1)-th column, for N elements.
- The elements of the (i-2)-th column at or after row j - of which there are N-j+1 elements.
This brings us down to a total of just 2N elements that need to be stored!
While the optimization might seem simple, it does significantly affect complexity: we’re now down to
\mathcal{O}(NM\cdot 3^{2N}) overall, which is fast enough to get AC for the given constraints.
Now, as for implementation: recall that we’re placing elements in a specific order, from top-to-bottom in each column.
Under this order, rather than think about which elements from each column we’re storing information about, we can simply store information about the last 2N elements!
It’s quite easy to see that storing this is indeed equivalent to exactly the information we want, which makes for a rather simple final implementation.
It’s also quite simple to make this implementation use only \mathcal{O}(3^{2N}) memory, which might be helpful if your code is slow.
TIME COMPLEXITY:
\mathcal{O}(N\cdot M \cdot 3^{2N}) per testcase.
CODE:
Author's code (C++)
/**
* author: BERNARD B.01
**/
#include <bits/stdc++.h>
using namespace std;
#ifdef B01
#include "deb.h"
#else
#define deb(...)
#endif
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, m, md;
cin >> n >> m >> md;
auto Add = [&](int& a, int b) {
a += b;
if (a >= md) a -= md;
};
auto Sub = [&](int& a, int b) {
a -= b;
if (a < 0) a += md;
};
auto Mul = [&](int a, int b) {
return int(int64_t(a) * b % md);
};
vector<int> p3(n + 1);
p3[0] = 1;
for (int i = 1; i <= n; i++) {
p3[i] = p3[i - 1] * 3;
}
auto GetDigit = [&](int mask, int i) {
return (mask / p3[i]) % 3;
};
int s = p3[n];
vector<int> valid_masks;
for (int mask = 0; mask < s; mask++) {
bool fail = false;
for (int i = 1; i < n - 1; i++) {
if (GetDigit(mask, i) == 1) {
if (GetDigit(mask, i - 1) == 0 && GetDigit(mask, i + 1) == 2) {
fail = true;
break;
}
if (GetDigit(mask, i - 1) == 2 && GetDigit(mask, i + 1) == 0) {
fail = true;
break;
}
}
}
if (!fail) {
valid_masks.push_back(mask);
}
}
vector dp(n, vector(s, vector<int>(s)));
dp[0][0][0] = 1;
for (int j = 0; j < m; j++) {
vector new_dp(n, vector(s, vector<int>(s)));
for (int i = 0; i < n; i++) {
for (int mask1 = 0; mask1 < s; mask1++) {
for (int mask2 : valid_masks) {
int new_mask1 = mask1 - GetDigit(mask1, i) * p3[i];
for (int cur = 0; cur < 3; cur++, new_mask1 += p3[i]) {
if (j > 1 && GetDigit(mask2, i) == 1) {
if (GetDigit(mask1, i) == 0 && cur == 2) {
continue;
}
if (GetDigit(mask1, i) == 2 && cur == 0) {
continue;
}
}
if (i > 1 && GetDigit(mask1, i - 1) == 1) {
if (GetDigit(mask1, i - 2) == 0 && cur == 2) {
continue;
}
if (GetDigit(mask1, i - 2) == 2 && cur == 0) {
continue;
}
}
if (i + 1 == n) {
Add(new_dp[0][mask2][new_mask1], dp[i][mask1][mask2]);
} else {
Add(dp[i + 1][new_mask1][mask2], dp[i][mask1][mask2]);
}
}
}
}
}
swap(dp, new_dp);
}
int ans = 1;
for (int i = 0; i < m; i++) {
ans = Mul(ans, p3[n]);
}
for (int mask1 = 0; mask1 < s; mask1++) {
for (int mask2 = 0; mask2 < s; mask2++) {
Sub(ans, dp[0][mask1][mask2]);
}
}
cout << ans << '\n';
return 0;
}
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());
int main()
{
ios::sync_with_stdio(false); cin.tie(0);
vector<int> pow3(15, 1);
for (int i = 1; i < 15; ++i) pow3[i] = pow3[i-1] * 3;
int n, m, mod; cin >> n >> m >> mod;
const int SZ = pow3[2*n];
array<vector<int>, 2> dp;
dp[0] = dp[1] = vector<int>(SZ);
dp[0][0] = 1;
int ans = 1, id = 0;
for (int i = 0; i < m; ++i) for (int j = 0; j < n; ++j) {
ans = (ans * 3ll) % mod;
auto &cur = dp[id];
auto &nxt = dp[id^1];
id ^= 1;
for (auto &x : nxt) x = 0;
for (int mask = 0; mask < SZ; ++mask) {
// place g
int nmask = (mask*3 + 1) % SZ;
nxt[nmask] = (nxt[nmask] + cur[mask]) % mod;
// place r
nmask = (mask*3) % SZ;
bool allowed = true;
if (j >= 2) allowed &= (mask%9) != 7;
if (i >= 2) {
int x = (mask / pow3[n-1]) % 3;
int y = (mask / pow3[2*n-1]) % 3;
allowed &= (x != 1) or (y != 2);
}
if (allowed) nxt[nmask] = (nxt[nmask] + cur[mask]) % mod;
// place b
nmask = (mask*3 + 2) % SZ;
allowed = true;
if (j >= 2) allowed &= (mask%9) != 1;
if (i >= 2) {
int x = (mask / pow3[n-1]) % 3;
int y = (mask / pow3[2*n-1]) % 3;
allowed &= (x != 1) or (y != 0);
}
if (allowed) nxt[nmask] = (nxt[nmask] + cur[mask]) % mod;
}
}
for (int mask = 0; mask < SZ; ++mask)
ans = (ans + mod - dp[id][mask]) % mod;
cout << ans << '\n';
}