RGBGRID_SUB2 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: bernarb01
Tester: mexomerf
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

DP

PROBLEM:

Count the number of N\times M grids containing the characters R,G,B that contain at least one occurrence of RGB as a contiguous sequence, either horizontally or vertically, forwards or backwards.

In the medium version, N \leq 5.

EXPLANATION:

This solution will continue from the editorial to the easy version.

Now that N \leq 5, our initial DP is too slow.
Let’s try to reduce the amount of information we need to maintain.

In the easy version, for each column, we maintained 3^N \cdot 3^N pieces of information: the last two columns.
However, the information of some pairs among them can be combined.

Let’s look at a single row, and the last two characters in it - say c_1 and c_2 (with c_1 being in column (i-2) and c_2 in (i-1)).
There are 9 possibilities for them in total, of course. However,

  • If c_2 = \text{R}, c_1 can be anything and horizontal issues cannot pop up by placing another character.
    So, we can compress all of RR, GR, BR into a single state since they’ll have the same transitions anyway.
  • If c_2 = \text{B}, similarly c_1 can be anything - compress everything into a single state again.
  • If c_2 = \text{G}, then c_1 does matter, giving us three different states.

So, rather than 9 possible states, we really have only 5 for each row.
This gives us a total of 5^N states to care about once the column is fixed.

Once again, there are 3^N possible transitions from each state, each of which can be processed in \mathcal{O}(N) time - the details of the transitions are basically exactly the same as the easy version.
This brings our complexity to \mathcal{O}(NM\cdot 15^N), which for N \leq 5 and M \leq 50 is fast enough.

TIME COMPLEXITY:

\mathcal{O}(NM\cdot 15^N) per testcase.

CODE:

Author's code (C++)
/**
 *    author:  BERNARD B.01
**/
#include <bits/stdc++.h>

using namespace std;

#ifdef B01
#include "deb.h"
#else
#define deb(...)
#endif

int md;

inline void add(int& a, int b) {
  a += b;
  if (a >= md) a -= md;
}

inline void sub(int& a, int b) {
  a -= b;
  if (a < 0) a += md;
}

inline int mul(int a, int b) {
  return int(int64_t(a) * b % md);
}

int p3[7], p5[7];

inline int get_digit(int mask, int i) {
  return (mask / p3[i]) % 3;
}

int n, m;
bool valid[729];
bool can[729][729];
int dp[50][15625];
vector<short> valid_masks;

inline int sol(int i, int mask1, int mask2, int mask12) {
  if (i == m) {
    return 1;
  }
  int& ret = dp[i][mask12];
  if (ret != -1) {
    return ret;
  }
  ret = 0;
  for (int mask3 : valid_masks) {
    bool fail = false;
    for (int j = 0; j < n; j++) {
      if (get_digit(mask1, j) == 0 && get_digit(mask2, j) == 1 && get_digit(mask3, j) == 2) {
        fail = true;
        break;
      }
      if (get_digit(mask1, j) == 2 && get_digit(mask2, j) == 1 && get_digit(mask3, j) == 0) {
        fail = true;
        break;
      }
    }
    if (fail) {
      continue;
    }
    int mask23 = 0;
    for (int j = 0; j < n; j++) {
      if (get_digit(mask3, j) == 1) {
        if (get_digit(mask2, j) == 2) {
          mask23 += p5[j];
        } else if (get_digit(mask2, j) == 1) {
          mask23 += 4 * p5[j];
        }
      } else {
        if (get_digit(mask3, j) == 0) {
          mask23 += 2 * p5[j];
        } else {
          mask23 += 3 * p5[j];
        }
      }
    }
    add(ret, sol(i + 1, mask2, mask3, mask23));
  }
  return ret;
}

int main() {
  ios::sync_with_stdio(false);
  cin.tie(nullptr);
  p3[0] = p5[0] = 1;
  for (int i = 1; i < 7; i++) {
    p3[i] = p3[i - 1] * 3;
    p5[i] = p5[i - 1] * 5;
  }
  cin >> n >> m >> md;
  for (int mask = 0; mask < p3[n]; mask++) {
    valid[mask] = 1;
    for (int j = 0; j < n - 2; j++) {
      if (get_digit(mask, j) == 0 && get_digit(mask, j + 1) == 1 && get_digit(mask, j + 2) == 2) {
        valid[mask] = 0;
        break;
      }
      if (get_digit(mask, j) == 2 && get_digit(mask, j + 1) == 1 && get_digit(mask, j + 2) == 0) {
        valid[mask] = 0;
        break;
      }
    }
    if (valid[mask]) {
      valid_masks.push_back(mask);
    }
  }
  memset(dp, -1, sizeof dp);
  if (m == 1) {
    cout << p3[n] - valid_masks.size() << '\n';
    return 0;
  }
  int ans = 1;
  for (int i = 0; i < m; i++) {
    ans = mul(ans, p3[n]);
  }
  for (int mask1 : valid_masks) {
    for (int mask2 : valid_masks) {
      int mask12 = 0;
      for (int j = 0; j < n; j++) {
        if (get_digit(mask2, j) == 1) {
          if (get_digit(mask1, j) == 2) {
            mask12 += p5[j];
          } else if (get_digit(mask1, j) == 1) {
            mask12 += 4 * p5[j];
          }
        } else {
          if (get_digit(mask2, j) == 0) {
            mask12 += 2 * p5[j];
          } else {
            mask12 += 3 * p5[j];
          }
        }
      }
      sub(ans, sol(2, mask1, mask2, mask12));
    }
  }
  cout << ans << '\n';
  return 0;
}