RGBGRID_SUB1 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: bernarb01
Tester: mexomerf
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

DP

PROBLEM:

Count the number of N\times M grids containing the characters R,G,B that contain at least one occurrence of RGB as a contiguous sequence, either horizontally or vertically, forwards or backwards.

EXPLANATION:

We’ll instead count the number of grids that don’t have any occurrences of RGB within them, and subtract this number from the total number of grids (which is just 3^{NM}).

Note that the constraint on N is really small, which hints towards doing something with that.

Let’s try to fill in the grid column-by-column, as in, first choose the entire first column, then the entire second column, and so on.

A single column has 3^N choices for what it can be.
Let \text{dp}[i][S] denote the number of ways of filling in the first i columns, such that the i-th column has the value S (where S is a string containing the characters R,G,B only).

When S is placed at the i-th column, we of course need to ensure that RGB and BGR don’t get created as substrings.
For that:

  1. S itself must not contain them; which is easy to check.
  2. We also need to ensure that no horizontal substring is formed.
    To check this, note that we also need to know the (i-1)-th and (i-2)-th columns.

So, let’s change our dp definition to accommodate this information.
Let \text{dp}[i][S_1][S_2] denote the number of ways of placing the first i columns, such that the i-th column is S_1 and (i-1)-th column is S_2.
To compute this:

  1. First, check if S_1 and S_2 are themselves valid, obviously.
  2. Then, fix S_3, the value of the (i-2)-th column.
    After this, check if placing the columns S_3S_2S_1 in this order is valid, i.e, no RGB or BGR are formed horizontally.
    This check takes \mathcal{O}(N) time, since there are N rows to check.
  3. If everything is valid, add \text{dp}[i-1][S_2][S_3] to \text{dp}[i][S_1][S_2].

Since there are 3^N possible columns, we have M\cdot 3^N\cdot 3^N states in our DP.
Each of these states has 3^N transitions (fixing column (i-2)), and checking if the transition is valid takes \mathcal{O}(N) time.

So, we have an overall runtime of \mathcal{O}(N\cdot M \cdot 3^{3N}).
This is fast enough for the easy version.


A small note about implementation: in the DP states defined above, S_1, S_2, S_3 were called ‘strings’.
However, working with strings as the DP states is somewhat unwieldy (and can potentially slow things down).
To overcome this, represent the columns as integers in base 3 instead (which functionally means they can just be stored as integers).

For example, if we set R=0, G=1, B=2 then the string RBBG can be represented as
0\cdot 3^0 + 2\cdot 3^1 + 2\cdot 3^2 + 1\cdot 3^3 = 51.
It’s quite simple to get an individual digit in this representation - to get the coefficient of 3^k, floor divide by 3^k and then take modulo 3.

TIME COMPLEXITY:

\mathcal{O}(N\cdot M \cdot 3^{3N}) per testcase.

CODE:

Author's code (C++)
/**
 *    author:  BERNARD B.01
**/
#include <bits/stdc++.h>
 
using namespace std;
 
#ifdef B01
#include "deb.h"
#else
#define deb(...)
#endif
 
int md;
 
inline void add(int& a, int b) {
  a += b;
  if (a >= md) a -= md;
}
 
inline void sub(int& a, int b) {
  a -= b;
  if (a < 0) a += md;
}
 
inline int mul(int a, int b) {
  return int(int64_t(a) * b % md);
}
 
int p3[7];
 
inline int get_digit(int mask, int i) {
  return (mask / p3[i]) % 3;
}
 
int n, m;
bool valid[729];
bool can[729][729];
int dp[100][729][729];
vector<short> valid_masks;
 
inline int sol(int i, int mask1, int mask2) {
  if (i == m) {
    return 1;
  }
  int& ret = dp[i][mask1][mask2];
  if (ret != -1) {
    return ret;
  }
  ret = 0;
  for (int mask3 : valid_masks) {
    bool fail = false;
    for (int j = 0; j < n; j++) {
      if (get_digit(mask1, j) == 0 && get_digit(mask2, j) == 1 && get_digit(mask3, j) == 2) {
        fail = true;
        break;
      }
      if (get_digit(mask1, j) == 2 && get_digit(mask2, j) == 1 && get_digit(mask3, j) == 0) {
        fail = true;
        break;
      }
    }
    if (fail) {
      continue;
    }
    add(ret, sol(i + 1, mask2, mask3));
  }
  return ret;
}
 
int main() {
  ios::sync_with_stdio(false);
  cin.tie(nullptr);
  p3[0] = 1;
  for (int i = 1; i < 7; i++) {
    p3[i] = p3[i - 1] * 3;
  }
  cin >> n >> m >> md;
  for (int mask = 0; mask < p3[n]; mask++) {
    valid[mask] = 1;
    for (int j = 0; j < n - 2; j++) {
      if (get_digit(mask, j) == 0 && get_digit(mask, j + 1) == 1 && get_digit(mask, j + 2) == 2) {
        valid[mask] = 0;
        break;
      }
      if (get_digit(mask, j) == 2 && get_digit(mask, j + 1) == 1 && get_digit(mask, j + 2) == 0) {
        valid[mask] = 0;
        break;
      }
    }
    if (valid[mask]) {
      valid_masks.push_back(mask);
    }
  }
  memset(dp, -1, sizeof dp);
  if (m == 1) {
    cout << p3[n] - valid_masks.size() << '\n';
    return 0;
  }
  int ans = 1;
  for (int i = 0; i < m; i++) {
    ans = mul(ans, p3[n]);
  }
  for (int mask1 : valid_masks) {
    for (int mask2 : valid_masks) {
      sub(ans, sol(2, mask1, mask2));
    }
  }
  cout << ans << '\n';
  return 0;
}