RGBGRID - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: bernarb01
Tester: mexomerf
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

DP

PROBLEM:

Count the number of N\times M grids containing the characters R,G,B that contain at least one occurrence of RGB as a contiguous sequence, either horizontally or vertically, forwards or backwards.

In this version, N \leq 6.

EXPLANATION:

This solution will continue from the editorial to the easy version.

Now that N \leq 6, our initial DP is too slow.
Unlike the medium version, we instead move to a slightly different solution, and optimize that instead.

Instead of column-by-column, let’s instead fill in the grid one cell at a time - but in a specific order.
We’ll first fill in the first column, top to bottom.
Then the second column, top to bottom, followed by the third, and so on.

Suppose we’re currently trying to fill in cell (i, j). Let’s look at the information we need to know.

  1. The values in cells (i-1, j) and (i-2, j) (if they exist), to determine whether there’s a vertical RGB substring.
  2. The values in cells (i, j-1) and (i, j-2) (if they exist), to determine whether there’s a horizontal RGB substring.

Of course, one way to maintain this information is to store the last two columns, and also the values placed in this column so far.
Our transitions are now constant-time: we only need to check whether each of R, G, B can be placed at (i, j) without breaking anything.
However, we instead have about 3^{3N} possible states for each cell overall - so our overall time complexity hasn’t really changed at all.

We can improve this, though.
Note that when we’re placing (i, j), information about cells of the form (i-2, x) for x \lt j doesn’t really matter anymore - those cells cannot affect any future placements.

So, instead of storing the entire (i-2)-th column, we store it only partially - specifically, we store:

  • All previously placed elements in the current column - of which there are j-1.
  • The entire (i-1)-th column, for N elements.
  • The elements of the (i-2)-th column at or after row j - of which there are N-j+1 elements.

This brings us down to a total of just 2N elements that need to be stored!
While the optimization might seem simple, it does significantly affect complexity: we’re now down to
\mathcal{O}(NM\cdot 3^{2N}) overall, which is fast enough to get AC for the given constraints.


Now, as for implementation: recall that we’re placing elements in a specific order, from top-to-bottom in each column.
Under this order, rather than think about which elements from each column we’re storing information about, we can simply store information about the last 2N elements!

It’s quite easy to see that storing this is indeed equivalent to exactly the information we want, which makes for a rather simple final implementation.
It’s also quite simple to make this implementation use only \mathcal{O}(3^{2N}) memory, which might be helpful if your code is slow.

TIME COMPLEXITY:

\mathcal{O}(N\cdot M \cdot 3^{2N}) per testcase.

CODE:

Author's code (C++)
/**
 *    author:  BERNARD B.01
**/
#include <bits/stdc++.h>
 
using namespace std;
 
#ifdef B01
#include "deb.h"
#else
#define deb(...)
#endif
 
int main() {
  ios::sync_with_stdio(false);
  cin.tie(nullptr);
  int n, m, md;
  cin >> n >> m >> md;
  auto Add = [&](int& a, int b) {
    a += b;
    if (a >= md) a -= md;
  };
  auto Sub = [&](int& a, int b) {
    a -= b;
    if (a < 0) a += md;
  };
  auto Mul = [&](int a, int b) {
    return int(int64_t(a) * b % md);
  };
  vector<int> p3(n + 1);
  p3[0] = 1;
  for (int i = 1; i <= n; i++) {
    p3[i] = p3[i - 1] * 3;
  }
  auto GetDigit = [&](int mask, int i) {
    return (mask / p3[i]) % 3;
  };
  int s = p3[n];
  vector<int> valid_masks;
  for (int mask = 0; mask < s; mask++) {
    bool fail = false;
    for (int i = 1; i < n - 1; i++) {
      if (GetDigit(mask, i) == 1) {
        if (GetDigit(mask, i - 1) == 0 && GetDigit(mask, i + 1) == 2) {
          fail = true;
          break;
        }
        if (GetDigit(mask, i - 1) == 2 && GetDigit(mask, i + 1) == 0) {
          fail = true;
          break;
        }
      }
    }
    if (!fail) {
      valid_masks.push_back(mask);
    }
  }
  vector dp(n, vector(s, vector<int>(s)));
  dp[0][0][0] = 1;
  for (int j = 0; j < m; j++) {
    vector new_dp(n, vector(s, vector<int>(s)));
    for (int i = 0; i < n; i++) {
      for (int mask1 = 0; mask1 < s; mask1++) {
        for (int mask2 : valid_masks) {
          int new_mask1 = mask1 - GetDigit(mask1, i) * p3[i];
          for (int cur = 0; cur < 3; cur++, new_mask1 += p3[i]) {
            if (j > 1 && GetDigit(mask2, i) == 1) {
              if (GetDigit(mask1, i) == 0 && cur == 2) {
                continue;
              }
              if (GetDigit(mask1, i) == 2 && cur == 0) {
                continue;
              }
            }
            if (i > 1 && GetDigit(mask1, i - 1) == 1) {
              if (GetDigit(mask1, i - 2) == 0 && cur == 2) {
                continue;
              }
              if (GetDigit(mask1, i - 2) == 2 && cur == 0) {
                continue;
              }
            }
            if (i + 1 == n) {
              Add(new_dp[0][mask2][new_mask1], dp[i][mask1][mask2]);
            } else {
              Add(dp[i + 1][new_mask1][mask2], dp[i][mask1][mask2]);
            }
          }
        }
      }
    }
    swap(dp, new_dp);
  }
  int ans = 1;
  for (int i = 0; i < m; i++) {
    ans = Mul(ans, p3[n]);
  }
  for (int mask1 = 0; mask1 < s; mask1++) {
    for (int mask2 = 0; mask2 < s; mask2++) {
      Sub(ans, dp[0][mask1][mask2]);
    }
  }
  cout << ans << '\n';
  return 0;
}
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    vector<int> pow3(15, 1);
    for (int i = 1; i < 15; ++i) pow3[i] = pow3[i-1] * 3;
    
    int n, m, mod; cin >> n >> m >> mod;
    const int SZ = pow3[2*n];

    array<vector<int>, 2> dp;
    dp[0] = dp[1] = vector<int>(SZ);
    dp[0][0] = 1;
    int ans = 1, id = 0;
    
    for (int i = 0; i < m; ++i) for (int j = 0; j < n; ++j) {
        ans = (ans * 3ll) % mod;
        
        auto &cur = dp[id];
        auto &nxt = dp[id^1];
        id ^= 1;

        for (auto &x : nxt) x = 0;

        for (int mask = 0; mask < SZ; ++mask) {
            // place g
            int nmask = (mask*3 + 1) % SZ;
            nxt[nmask] = (nxt[nmask] + cur[mask]) % mod;

            // place r
            nmask = (mask*3) % SZ;
            bool allowed = true;
            if (j >= 2) allowed &= (mask%9) != 7;
            if (i >= 2) {
                int x = (mask / pow3[n-1]) % 3;
                int y = (mask / pow3[2*n-1]) % 3;
                allowed &= (x != 1) or (y != 2);
            }
            if (allowed) nxt[nmask] = (nxt[nmask] + cur[mask]) % mod;

            // place b
            nmask = (mask*3 + 2) % SZ;
            allowed = true;
            if (j >= 2) allowed &= (mask%9) != 1;
            if (i >= 2) {
                int x = (mask / pow3[n-1]) % 3;
                int y = (mask / pow3[2*n-1]) % 3;
                allowed &= (x != 1) or (y != 0);
            }
            if (allowed) nxt[nmask] = (nxt[nmask] + cur[mask]) % mod;
        }
    }

    for (int mask = 0; mask < SZ; ++mask)
        ans = (ans + mod - dp[id][mask]) % mod;
    cout << ans << '\n';

}

Actually we can combine the idea of medium and hard to only store 5^N states for the last N elements (R, B, RG, BG, GG), and each transition fills one cell. This could achieve O(NM\times 5^N\times 3) which is fast enough for N\le8,M\le50.
The idea was implemented by @lympanda. His solution: https://www.codechef.com/viewsolution/1088719646

1 Like

I wonder how the judge of codechef is implemented ? I figured out the intended solution in 5 minutes but spent 2 hours dealing with TLE and thought my solution is not good enough. Ironically, the intended solution is exactly the same as mine . I try running my solution against the extreme test case on other platforms such as Atcoder and it only takes around ~2s. Why does it get TLE with 6s time restriction on Codechef ?