# GRANDPAPA - Editorial

Setter: Cozma Tiberiu-Stefan
Tester: Harris Leung
Editorialist: Trung Dang

Medium-Hard

# PREREQUISITES:

Dynamic Programming, Combinatorics

# PROBLEM:

You are given two numbers N and M. A pair of arrays (A, B) is called K-beautiful if and only if:

• |A| = |B| = N (i.e. both their lengths are equal to N)
• For all i such that 1 \le i \le N, 0 \le A_i \le M and 0 \le B_i \le M.
• mex(A) \leq mex(B)
• A is lexicographically smaller than B
• The first index i where A_i \lt B_i is K

For each K (1 \le K \le N), find the number of K-beautiful pairs and print those values modulo MOD.

# EXPLANATION:

We first loop over K, then loop over the exact value of mex(A) (which I will denote as MEX), and finally over the exact number of distinct values that are less than MEX and appear in the common prefix A_{1..K-1} (or B_{1..K-1}).

Denote S_1 as the set of distinct values that are less than MEX and appear in the common prefix, S_2 to be the set of distinct values that are less than MEX and have yet to appear in the common prefix, S_3 to be the set of values larger than MEX, and S_4 to be the set of values larger or equal to MEX (this means S_4 = S_3 \cup \{MEX\}, and S_1 \cup S_2 = [1..MEX - 1]). The sizes of these sets can be calculated easily. We then have the following observations:

• Since we want the mex value of A to be exactly MEX, the suffix A_{K..N} must contains all values within S_2. Similarly, since we want the mex value of B to be at least MEX, the suffix B_{K..N} must contains all values within S_2.
• Every element in A that is not in S_1 or S_2 must be in S_3 (since we want mex(A) to be exactly MEX).
• Every element in B that is not in S_1 or S_2 must be in S_4 (since we want mex(B) to be at least MEX), except for the prefix B_{1..K - 1} where it must comes from S_3 (since this is the common prefix).
• Any value in S_1 or S_2 is less than any value in S_3 or S_4.

This leads us to loop over which sets A_K and B_K can be in. It turns out that there are 7 cases:

• A_K \in S_1, B_K \in S_1.
• A_K \in S_1, B_K \in S_2.
• A_K \in S_1, B_K \in S_4.
• A_K \in S_2, B_K \in S_1.
• A_K \in S_2, B_K \in S_2.
• A_K \in S_2, B_K \in S_4.
• A_K \in S_3, B_K \in S_3.

I will guide you through the first case, but other cases can be reasoned similarly. The number of ways to have A_K \in S_1, B_K \in S_1 is:

\binom{MEX}{|S_1|} \cdot \frac{|S_1| \cdot (|S_1| - 1)}{2} \cdot f(K - 1, |S_1|, |S_3|) \cdot f(N - K, |S_2|, |S_1| + |S_3|) \cdot f(N - K, |S_2|, |S_1| + |S_4|)

where f(i, j, k) is the number of ways we can fill an array with i elements, in which we have j values that are â€śfixedâ€ť, and for each element not having a fixed value there are k ways to fill in its value. To explain the formula above:

• \binom{MEX}{|S_1|} comes from choosing the set S_1 from [0..MEX - 1].
• \frac{|S_1| \cdot (|S_1| - 1)}{2} comes from choosing A_K < B_K.
• f(K - 1, |S_1|, |S_3|) comes from filling the common prefix such that it contains all values in S_1, and any elements not in S_1 comes from S_3.
• f(N - K, |S_2|, |S_1| + |S_3|) comes from filling A_{K+1..N} such that it contains all values in S_2, and any elements not in S_2 comes from S_1 \cup S_3.
• f(N - K, |S_2|, |S_1| + |S_4|) comes from filling B_{K+1..N} such that it contains all values in S_2, and any elements not in S_2 comes from S_1 \cup S_4.

We have yet to cover how to compute f(i, j, k), but it turns out that this a pretty easy task (reminiscence of counting the number of arrays with n elements and k distinct values):

f(i, j, k) = f(i - 1, j - 1, k) + (j + k) \cdot f(i - 1, j, k).

That is because there are 3 states that can reach f(i, j, k):

• Add a value that has already been added from the fixed values, this can be done in j ways and comes from state f(i-1,j,k).
• Add a new fixed value, which comes from state f(i - 1, j - 1, k).
• Add a non-fixed value, this can be done in j ways and comes from state f(i-1,j,k).

The above reccurence wonâ€™t really calculate what we wanted, since the first appearance of the fixed values is unordered. Therefore, after the recursion, we need to multiply f(i, j, k) by j!.

# TIME COMPLEXITY:

Time complexity is O(M \cdot N \cdot \min(M, N))

# SOLUTION:

Setter's Solution
#include <bits/stdc++.h>
using namespace std;
int32_t MOD;
struct modint {
int32_t value;
modint() = default;
modint(int32_t value_) : value((value_ < MOD ? value_ : value_% MOD)) {}
modint(int64_t value_) : value((value_ < MOD ? value_ : value_ % MOD)) {}
inline modint operator + (modint other) const { int32_t c = this->value + other.value; return modint(c >= MOD ? c - MOD : c); }
inline modint operator - (modint other) const { int32_t c = this->value - other.value; return modint(c < 0 ? c + MOD : c); }
inline modint operator * (modint other) const { int32_t c = (int64_t)this->value * other.value % MOD; return modint(c < 0 ? c + MOD : c); }
inline modint& operator += (modint other) { this->value += other.value; if (this->value >= MOD) this->value -= MOD; return *this; }
inline modint& operator -= (modint other) { this->value -= other.value; if (this->value < 0) this->value += MOD; return *this; }
inline modint& operator *= (modint other) { this->value = (int64_t)this->value * other.value % MOD; if (this->value < 0) this->value += MOD; return *this; }
inline modint operator - () const { return modint(this->value ? MOD - this->value : 0); }
modint pow(int32_t k) const { modint x = *this, y = 1; for (; k; k >>= 1) { if (k & 1) y *= x; x *= x; } return y; }
modint inv() const { return pow(MOD - 2); }  // MOD must be a prime
inline modint operator /  (modint other) const { return *this * other.inv(); }
inline modint operator /= (modint other) { return *this *= other.inv(); }
inline bool operator == (modint other) const { return value == other.value; }
inline bool operator != (modint other) const { return value != other.value; }
inline bool operator < (modint other) const { return value < other.value; }
inline bool operator > (modint other) const { return value > other.value; }
};
modint operator * (int64_t value, modint n) { return modint(value) * n; }
modint operator * (int32_t value, modint n) { return modint(value) * n; }
istream& operator >> (istream& in, modint& n) { return in >> n.value; }
ostream& operator << (ostream& out, modint n) { return out << n.value; }
vector<vector<modint>> pas;
vector<modint> fact;
vector<vector<vector<modint>>> dp;
modint combs(int n, int k) {
if (n < k || k < 0) {
return 0;
}
return pas[n][k];
}
modint solve(int i, int j, int k) {
if (i < 0 || j < 0 || k < 0) {
return 0;
}
if (i == 0) {
return (j == 0);
}
if (dp[i][j][k] != -1) {
return dp[i][j][k];
}
dp[i][j][k] = solve(i - 1, j, k) * j + solve(i - 1, j - 1, k) + solve(i - 1, j, k) * k;
return dp[i][j][k];
}
modint count(int i, int j, int k) {
if (i < 0 || j < 0 || k < 0) {
return 0;
}
return solve(i, j, k) * fact[j];
}
int main() {
std::ios_base::sync_with_stdio(false);
cin.tie(nullptr);
int n, m;
cin >> n >> m >> MOD;
++m;
pas = vector<vector<modint>>(m + 1, vector<modint>(m + 1));
fact = vector<modint>(m + 1);
fact[0] = 1;
for (int i = 1; i <= m; ++i) {
fact[i] = fact[i - 1] * i;
}
pas[0][0] = 1;
for (int i = 1; i <= m; ++i) {
for (int j = 0; j <= i; ++j) {
if (j == i || j == 0) {
pas[i][j] = 1;
}
else {
pas[i][j] = pas[i - 1][j] + pas[i - 1][j - 1];
}
}
}
dp = vector<vector<vector<modint>>>(n + 1, vector<vector<modint>>(m + 1, vector<modint>(m + 1, -1)));
vector<modint> ans(n + 1);
for (int i = 0; i <= m; ++i) {
for (int j = 1; j <= n; ++j) {
for (int k = 0; k <= min(i, j - 1); ++k) {
modint rep = 0;
if (i >= 2) {
if (i == m) {
rep += (i * (i - 1) / 2) * count(j - 1, k, m - i) * count(n - j, i - k - 1, m - (i - k - 1)) * count(n - j, i - k - 1, m - (i - k - 1)) * combs(i - 2, k);
rep += (i * (i - 1) / 2) * count(j - 1, k, m - i) * count(n - j, i - k, m - (i - k)) * count(n - j, i - k - 1, m - (i - k - 1)) * combs(i - 2, k - 1);
rep += (i * (i - 1) / 2) * count(j - 1, k, m - i) * count(n - j, i - k - 1, m - (i - k - 1)) * count(n - j, i - k, m - (i - k)) * combs(i - 2, k - 1);
rep += (i * (i - 1) / 2) * count(j - 1, k, m - i) * count(n - j, i - k, m - (i - k)) * count(n - j, i - k, m - (i - k)) * combs(i - 2, k - 2);
}
else {
rep += (i * (i - 1) / 2) * count(j - 1, k, m - i - 1) * count(n - j, i - k - 1, m - (i - k - 1) - 1) * count(n - j, i - k - 1, m - (i - k - 1)) * combs(i - 2, k);
rep += (i * (i - 1) / 2) * count(j - 1, k, m - i - 1) * count(n - j, i - k, m - (i - k) - 1) * count(n - j, i - k - 1, m - (i - k - 1)) * combs(i - 2, k - 1);
rep += (i * (i - 1) / 2) * count(j - 1, k, m - i - 1) * count(n - j, i - k - 1, m - (i - k - 1) - 1) * count(n - j, i - k, m - (i - k)) * combs(i - 2, k - 1);
rep += (i * (i - 1) / 2) * count(j - 1, k, m - i - 1) * count(n - j, i - k, m - (i - k) - 1) * count(n - j, i - k, m - (i - k)) * combs(i - 2, k - 2);
}
}
if (i >= 1 && i < m) {
rep += (i * (m - i)) * count(j - 1, k, m - i - 1) * count(n - j, i - k - 1, m - 1 - (i - k - 1)) * count(n - j, i - k, m - (i - k)) * combs(i - 1, k);
rep += (i * (m - i)) * count(j - 1, k, m - i - 1) * count(n - j, i - k, m - 1 - (i - k)) * count(n - j, i - k, m - (i - k)) * combs(i - 1, k - 1);
}
if (m - i - 2 > 0) {
rep += ((m - i - 2) * (m - i - 1) / 2) * count(j - 1, k, m - 1 - i) * count(n - j, i - k, m - 1 - (i - k)) * count(n - j, i - k, m - (i - k)) * combs(i, k);
}
ans[j] += rep;
}
}
}
for (int i = 1; i <= n; ++i) {
cout << ans[i] << ' ';
}
}

Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;

const int N = 205;

int n, m, mod;
long long f[N][N][N], fct[N], c[N][N];

void init() {
fct[0] = 1;
for (int i = 1; i < N; i++) {
fct[i] = fct[i - 1] * i % mod;
}
for (int i = 0; i < N; i++) {
c[i][0] = 1;
for (int j = 1; j <= i; j++) {
c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % mod;
}
}
for (int oth = 0; oth <= m + 1; oth++) {
f[0][0][oth] = 1;
for (int i = 1; i <= n; i++) {
for (int j = 0; j <= i; j++) {
f[i][j][oth] = ((j >= 1 ? f[i - 1][j - 1][oth] : 0) + f[i - 1][j][oth] * (j + oth)) % mod;
}
}
}
for (int oth = 0; oth <= m + 1; oth++) {
for (int i = 0; i <= n; i++) {
for (int j = 0; j <= i; j++) {
(f[i][j][oth] *= fct[j]) %= mod;
}
}
}
}

long long C(int n, int k) {
return k < 0 || k > n ? 0 : c[n][k];
}

int main() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
cin >> n >> m >> mod;
init();
for (int i = 1; i <= n; i++) {
long long cur = 0;
for (int mex = 0; mex <= m + 1; mex++) { // mex of A
for (int sa = 0; sa <= mex; sa++) {  // # of distinct values < mex appearing in the common prefix a[1..i-1] (b[1..i-1])
// S_1 = set of distinct values < mex appearing before i. |S_1| = sa
int s1 = sa;
// S_2 = set of distinct values < mex not appearing before i. |S_2| = mex - sa
int s2 = mex - sa;
// S_3 = set of values > mex. |S_3| = max(0, m - mex)
int s3 = max(0, m - mex);
// S_4 = set of values >= mex. |S_4| = m - mex + 1
int s4 = m - mex + 1;

// a[i] in S_1, b[i] in S_1
if (s1 >= 2) {
cur += 1LL * s1 * (s1 - 1) / 2 % mod // choose a[i], b[i] from s1
* c[mex][s1] % mod // pick s1 values from mex
* f[i - 1][s1][s3] % mod // common prefix
* f[n - i][s2][s1 + s3] % mod // choose a[i+1..n]
* f[n - i][s2][s1 + s4] % mod; // choose b[i+1..n]
}

// a[i] in S_1, b[i] in S_2
if (s1 >= 1 && s2 >= 1) {
cur += 1LL * mex * (mex - 1) / 2 % mod // choose a[i], b[i] from mex
* c[mex - 2][s1 - 1] % mod // pick remaining s1 - 1 values from mex - 2
* f[i - 1][s1][s3] % mod // common prefix
* f[n - i][s2][s1 + s3] % mod // choose a[i+1..n]
* f[n - i][s2 - 1][s1 + 1 + s4] % mod; // choose b[i+1..n]
}

// a[i] in S_1, b[i] in S_4
if (s1 >= 1 && s4 >= 1) {
cur += 1LL * s1 * s4 % mod // choose a[i] from s1, b[i] from s4
* c[mex][s1] % mod // pick s1 values from mex
* f[i - 1][s1][s3] % mod // common prefix
* f[n - i][s2][s1 + s3] % mod // choose a[i+1..n]
* f[n - i][s2][s1 + s4] % mod; // choose b[i+1..n]
}

// a[i] in S_2, b[i] in S_1
if (s2 >= 1 && s1 >= 1) {
cur += 1LL * mex * (mex - 1) / 2 % mod // choose a[i], b[i] from mex
* c[mex - 2][s1 - 1] % mod // pick remaining s1 - 1 values from mex - 2
* f[i - 1][s1][s3] % mod // common prefix
* f[n - i][s2 - 1][s1 + 1 + s3] % mod // choose a[i+1..n]
* f[n - i][s2][s1 + s4] % mod; // choose b[i+1..n]
}

// a[i] in S_2, b[i] in S_2
if (s2 >= 2) {
cur += 1LL * s2 * (s2 - 1) / 2 % mod // choose a[i], b[i] from s2
* c[mex][s1] % mod // pick s1 values from mex
* f[i - 1][s1][s3] % mod // common prefix
* f[n - i][s2 - 1][s1 + 1 + s3] % mod // choose a[i+1..n]
* f[n - i][s2 - 1][s1 + 1 + s4] % mod; // choose b[i+1..n]
}

// a[i] in S_2, b[i] in S_4
if (s2 >= 1 && s4 >= 1) {
cur += 1LL * s2 * s4 % mod // choose a[i] from s2, b[i] from s4
* c[mex][s1] % mod // pick s1 values from mex
* f[i - 1][s1][s3] % mod // common prefix
* f[n - i][s2 - 1][s1 + 1 + s3] % mod // choose a[i+1..n]
* f[n - i][s2][s1 + s4] % mod; // choose b[i+1..n]
}

// a[i] in S_3, b[i] in S_3
if (s3 >= 2) {
cur += 1LL * s3 * (s3 - 1) / 2 % mod // choose a[i], b[i] from s3
* c[mex][s1] % mod // pick s1 values from mex
* f[i - 1][s1][s3] % mod // common prefix
* f[n - i][s2][s1 + s3] % mod // choose a[i+1..n]
* f[n - i][s2][s1 + s4] % mod; // choose b[i+1..n]
}

cur %= mod;
}
}
cout << cur << " ";
}
}

2 Likes