AVOIDINGM - Editoriale

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Easy

PREREQUISITES:

Dynamic programming, combinatorics

PROBLEM:

Given N and M, count the number of arrays A of length N with elements in [1, N] such that no subarray of A has its MEX equal to M.

In this version of the problem, M \le N \le 3000.

EXPLANATION:

For a set to have a MEX of M, it must contain (at least) one copy of each of the elements 0, 1, 2, \ldots, M-1, and must not contain M.

So, if in an array A the indices containing the value M are i_1, i_2, \ldots, i_k, then the “worst case” subarrays for us are exactly those of the form A[i_j+1, i_{j+1}-1], i.e. the entire subarray between two consecutive occurrences of M.
(For convenience, we treat i_0 = 0 and i_{k+1} = N+1).

As long as we can ensure that none of these worst-case subarrays have a MEX of M, that will guarantee that no subarray has a MEX of M (and vice versa.)
So, we can focus our efforts on that.


The above characterization lends itself quite naturally to a dynamic programming solution.

Let’s define dp_i to be the number of ways to fill in the first i elements such that:

  • No subarray of length M exists, and
  • A_i = M

Essentially, we’re building up the array by placing entire blocks of elements between occurrences of M.

The transitions are quite simple (in theory): to compute dp_i, let j \lt i be the previous occurrence of M; then there are dp_j ways to fill in elements till index j, and we only need to figure out how to fill in elements at indices j+1, \ldots, i-1 (which can be done completely independently of the rest of the array.)

Note that this requires us to solve the following subproblem:

Given an integer L, how many arrays of length L containing elements in [0, M-1] \cup [M+1, N] do not have a MEX of M?

We also want to know this for each value of L from 0 to N, to assist with transitions.
There are a few different ways to solve this subproblem: one of them is using dynamic programming as follows.

Define ways[i][j] to be the number of arrays of length i with elements in [0, M-1]\cup [M+1, N] such that they contain exactly j distinct elements among [0, M-1].
Transitions are fairly simple:

  • If the next element is either one of the j existing ones in [0, M-1], or an element from [M+1, N], then the value of j doesn’t change at all.
    There are j choices for the former and (N-M) for the latter, so we add ways[i][j] \cdot (N-M+j) to ways[i+1][j].
  • If the next element is a new one from [0, M-1], j increases by 1.
    There are M-j choices for the element, so we add ways[i][j] \cdot (M-j) to ways[i+1][j+1].
  • The base case is ways[0][0] = 1.

There are \mathcal{O}(N\cdot M) states and constant-time transitions from each, which is fast enough for the constraints of the easy version.

Note that this is not the only way to find these coefficients: there also exist non-dp solutions.
However, that’s more relevant to solving the harder version, and so will be further expanded upon there.


With the ways table computed, let’s return to our original dp.

Define w_L = ways[L][0] + ways[L][1] + \ldots + ways[L][M-1].
This is the number of arrays of length L whose MEX is not M (while not containing M.)
In terms of transitions, we now simply have:

dp_i = \sum_{j \lt i} dp_j \cdot w_{i-j-1}

This allows us to compute all the values of dp in \mathcal{O}(N^2) time overall, which is fast enough.

The final answer is dp_{N+1} since we can pretend the array has a copy of M at position N+1 and it won’t change anything.

TIME COMPLEXITY:

\mathcal{O}(NM + N^2) per testcase.

CODE:

Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    const int mod = 998244353;
    int t; cin >> t;
    while (t--) {
        int n, m; cin >> n >> m;

        vector ways(n+2, vector(m+2, 0ll));
        vector w(n+1, 0ll);
        ways[0][0] = 1;
        for (int i = 0; i <= n; ++i) for (int j = 0; j <= m; ++j) {
            ways[i][j] %= mod;
            if (j < m) w[i] = (w[i] + ways[i][j]) % mod;

            ways[i+1][j] += (n-m+j) * ways[i][j];
            ways[i+1][j+1] += (m-j) * ways[i][j];
        }

        vector dp(n+2, 0ll);
        dp[0] = 1;
        for (int i = 1; i <= n+1; ++i) {
            for (int j = i-1; j >= 0; --j) {
                dp[i] = (dp[i] + dp[j] * w[i-j-1]) % mod;
            }
        }
        cout << dp[n+1] << '\n';
    }
}