PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: bernarb01
Tester: mexomerf
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
DP
PROBLEM:
Count the number of N\times M grids containing the characters R,G,B
that contain at least one occurrence of RGB
as a contiguous sequence, either horizontally or vertically, forwards or backwards.
In the medium version, N \leq 5.
EXPLANATION:
This solution will continue from the editorial to the easy version.
Now that N \leq 5, our initial DP is too slow.
Let’s try to reduce the amount of information we need to maintain.
In the easy version, for each column, we maintained 3^N \cdot 3^N pieces of information: the last two columns.
However, the information of some pairs among them can be combined.
Let’s look at a single row, and the last two characters in it - say c_1 and c_2 (with c_1 being in column (i-2) and c_2 in (i-1)).
There are 9 possibilities for them in total, of course. However,
- If c_2 = \text{R}, c_1 can be anything and horizontal issues cannot pop up by placing another character.
So, we can compress all ofRR, GR, BR
into a single state since they’ll have the same transitions anyway. - If c_2 = \text{B}, similarly c_1 can be anything - compress everything into a single state again.
- If c_2 = \text{G}, then c_1 does matter, giving us three different states.
So, rather than 9 possible states, we really have only 5 for each row.
This gives us a total of 5^N states to care about once the column is fixed.
Once again, there are 3^N possible transitions from each state, each of which can be processed in \mathcal{O}(N) time - the details of the transitions are basically exactly the same as the easy version.
This brings our complexity to \mathcal{O}(NM\cdot 15^N), which for N \leq 5 and M \leq 50 is fast enough.
TIME COMPLEXITY:
\mathcal{O}(NM\cdot 15^N) per testcase.
CODE:
Author's code (C++)
/**
* author: BERNARD B.01
**/
#include <bits/stdc++.h>
using namespace std;
#ifdef B01
#include "deb.h"
#else
#define deb(...)
#endif
int md;
inline void add(int& a, int b) {
a += b;
if (a >= md) a -= md;
}
inline void sub(int& a, int b) {
a -= b;
if (a < 0) a += md;
}
inline int mul(int a, int b) {
return int(int64_t(a) * b % md);
}
int p3[7], p5[7];
inline int get_digit(int mask, int i) {
return (mask / p3[i]) % 3;
}
int n, m;
bool valid[729];
bool can[729][729];
int dp[50][15625];
vector<short> valid_masks;
inline int sol(int i, int mask1, int mask2, int mask12) {
if (i == m) {
return 1;
}
int& ret = dp[i][mask12];
if (ret != -1) {
return ret;
}
ret = 0;
for (int mask3 : valid_masks) {
bool fail = false;
for (int j = 0; j < n; j++) {
if (get_digit(mask1, j) == 0 && get_digit(mask2, j) == 1 && get_digit(mask3, j) == 2) {
fail = true;
break;
}
if (get_digit(mask1, j) == 2 && get_digit(mask2, j) == 1 && get_digit(mask3, j) == 0) {
fail = true;
break;
}
}
if (fail) {
continue;
}
int mask23 = 0;
for (int j = 0; j < n; j++) {
if (get_digit(mask3, j) == 1) {
if (get_digit(mask2, j) == 2) {
mask23 += p5[j];
} else if (get_digit(mask2, j) == 1) {
mask23 += 4 * p5[j];
}
} else {
if (get_digit(mask3, j) == 0) {
mask23 += 2 * p5[j];
} else {
mask23 += 3 * p5[j];
}
}
}
add(ret, sol(i + 1, mask2, mask3, mask23));
}
return ret;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
p3[0] = p5[0] = 1;
for (int i = 1; i < 7; i++) {
p3[i] = p3[i - 1] * 3;
p5[i] = p5[i - 1] * 5;
}
cin >> n >> m >> md;
for (int mask = 0; mask < p3[n]; mask++) {
valid[mask] = 1;
for (int j = 0; j < n - 2; j++) {
if (get_digit(mask, j) == 0 && get_digit(mask, j + 1) == 1 && get_digit(mask, j + 2) == 2) {
valid[mask] = 0;
break;
}
if (get_digit(mask, j) == 2 && get_digit(mask, j + 1) == 1 && get_digit(mask, j + 2) == 0) {
valid[mask] = 0;
break;
}
}
if (valid[mask]) {
valid_masks.push_back(mask);
}
}
memset(dp, -1, sizeof dp);
if (m == 1) {
cout << p3[n] - valid_masks.size() << '\n';
return 0;
}
int ans = 1;
for (int i = 0; i < m; i++) {
ans = mul(ans, p3[n]);
}
for (int mask1 : valid_masks) {
for (int mask2 : valid_masks) {
int mask12 = 0;
for (int j = 0; j < n; j++) {
if (get_digit(mask2, j) == 1) {
if (get_digit(mask1, j) == 2) {
mask12 += p5[j];
} else if (get_digit(mask1, j) == 1) {
mask12 += 4 * p5[j];
}
} else {
if (get_digit(mask2, j) == 0) {
mask12 += 2 * p5[j];
} else {
mask12 += 3 * p5[j];
}
}
}
sub(ans, sol(2, mask1, mask2, mask12));
}
}
cout << ans << '\n';
return 0;
}