PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: raysh07
Tester: sushil2006
Editorialist: iceknight1093
DIFFICULTY:
Easy
PREREQUISITES:
Dynamic programming, combinatorics
PROBLEM:
Given N and M, count the number of arrays A of length N with elements in [1, N] such that no subarray of A has its MEX equal to M.
In this version of the problem, M \le N \le 3000.
EXPLANATION:
For a set to have a MEX of M, it must contain (at least) one copy of each of the elements 0, 1, 2, \ldots, M-1, and must not contain M.
So, if in an array A the indices containing the value M are i_1, i_2, \ldots, i_k, then the “worst case” subarrays for us are exactly those of the form A[i_j+1, i_{j+1}-1], i.e. the entire subarray between two consecutive occurrences of M.
(For convenience, we treat i_0 = 0 and i_{k+1} = N+1).
As long as we can ensure that none of these worst-case subarrays have a MEX of M, that will guarantee that no subarray has a MEX of M (and vice versa.)
So, we can focus our efforts on that.
The above characterization lends itself quite naturally to a dynamic programming solution.
Let’s define dp_i to be the number of ways to fill in the first i elements such that:
- No subarray of length M exists, and
- A_i = M
Essentially, we’re building up the array by placing entire blocks of elements between occurrences of M.
The transitions are quite simple (in theory): to compute dp_i, let j \lt i be the previous occurrence of M; then there are dp_j ways to fill in elements till index j, and we only need to figure out how to fill in elements at indices j+1, \ldots, i-1 (which can be done completely independently of the rest of the array.)
Note that this requires us to solve the following subproblem:
Given an integer L, how many arrays of length L containing elements in [0, M-1] \cup [M+1, N] do not have a MEX of M?
We also want to know this for each value of L from 0 to N, to assist with transitions.
There are a few different ways to solve this subproblem: one of them is using dynamic programming as follows.
Define ways[i][j] to be the number of arrays of length i with elements in [0, M-1]\cup [M+1, N] such that they contain exactly j distinct elements among [0, M-1].
Transitions are fairly simple:
- If the next element is either one of the j existing ones in [0, M-1], or an element from [M+1, N], then the value of j doesn’t change at all.
There are j choices for the former and (N-M) for the latter, so we add ways[i][j] \cdot (N-M+j) to ways[i+1][j]. - If the next element is a new one from [0, M-1], j increases by 1.
There are M-j choices for the element, so we add ways[i][j] \cdot (M-j) to ways[i+1][j+1]. - The base case is ways[0][0] = 1.
There are \mathcal{O}(N\cdot M) states and constant-time transitions from each, which is fast enough for the constraints of the easy version.
Note that this is not the only way to find these coefficients: there also exist non-dp solutions.
However, that’s more relevant to solving the harder version, and so will be further expanded upon there.
With the ways table computed, let’s return to our original dp.
Define w_L = ways[L][0] + ways[L][1] + \ldots + ways[L][M-1].
This is the number of arrays of length L whose MEX is not M (while not containing M.)
In terms of transitions, we now simply have:
This allows us to compute all the values of dp in \mathcal{O}(N^2) time overall, which is fast enough.
The final answer is dp_{N+1} since we can pretend the array has a copy of M at position N+1 and it won’t change anything.
TIME COMPLEXITY:
\mathcal{O}(NM + N^2) per testcase.
CODE:
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());
int main()
{
ios::sync_with_stdio(false); cin.tie(0);
const int mod = 998244353;
int t; cin >> t;
while (t--) {
int n, m; cin >> n >> m;
vector ways(n+2, vector(m+2, 0ll));
vector w(n+1, 0ll);
ways[0][0] = 1;
for (int i = 0; i <= n; ++i) for (int j = 0; j <= m; ++j) {
ways[i][j] %= mod;
if (j < m) w[i] = (w[i] + ways[i][j]) % mod;
ways[i+1][j] += (n-m+j) * ways[i][j];
ways[i+1][j+1] += (m-j) * ways[i][j];
}
vector dp(n+2, 0ll);
dp[0] = 1;
for (int i = 1; i <= n+1; ++i) {
for (int j = i-1; j >= 0; --j) {
dp[i] = (dp[i] + dp[j] * w[i-j-1]) % mod;
}
}
cout << dp[n+1] << '\n';
}
}