GRANDPAPA - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4

Setter: Cozma Tiberiu-Stefan
Tester: Harris Leung
Editorialist: Trung Dang

DIFFICULTY:

Medium-Hard

PREREQUISITES:

Dynamic Programming, Combinatorics

PROBLEM:

You are given two numbers N and M. A pair of arrays (A, B) is called K-beautiful if and only if:

  • |A| = |B| = N (i.e. both their lengths are equal to N)
  • For all i such that 1 \le i \le N, 0 \le A_i \le M and 0 \le B_i \le M.
  • mex(A) \leq mex(B)
  • A is lexicographically smaller than B
  • The first index i where A_i \lt B_i is K

For each K (1 \le K \le N), find the number of K-beautiful pairs and print those values modulo MOD.

EXPLANATION:

We first loop over K, then loop over the exact value of mex(A) (which I will denote as MEX), and finally over the exact number of distinct values that are less than MEX and appear in the common prefix A_{1..K-1} (or B_{1..K-1}).

Denote S_1 as the set of distinct values that are less than MEX and appear in the common prefix, S_2 to be the set of distinct values that are less than MEX and have yet to appear in the common prefix, S_3 to be the set of values larger than MEX, and S_4 to be the set of values larger or equal to MEX (this means S_4 = S_3 \cup \{MEX\}, and S_1 \cup S_2 = [1..MEX - 1]). The sizes of these sets can be calculated easily. We then have the following observations:

  • Since we want the mex value of A to be exactly MEX, the suffix A_{K..N} must contains all values within S_2. Similarly, since we want the mex value of B to be at least MEX, the suffix B_{K..N} must contains all values within S_2.
  • Every element in A that is not in S_1 or S_2 must be in S_3 (since we want mex(A) to be exactly MEX).
  • Every element in B that is not in S_1 or S_2 must be in S_4 (since we want mex(B) to be at least MEX), except for the prefix B_{1..K - 1} where it must comes from S_3 (since this is the common prefix).
  • Any value in S_1 or S_2 is less than any value in S_3 or S_4.

This leads us to loop over which sets A_K and B_K can be in. It turns out that there are 7 cases:

  • A_K \in S_1, B_K \in S_1.
  • A_K \in S_1, B_K \in S_2.
  • A_K \in S_1, B_K \in S_4.
  • A_K \in S_2, B_K \in S_1.
  • A_K \in S_2, B_K \in S_2.
  • A_K \in S_2, B_K \in S_4.
  • A_K \in S_3, B_K \in S_3.

I will guide you through the first case, but other cases can be reasoned similarly. The number of ways to have A_K \in S_1, B_K \in S_1 is:

\binom{MEX}{|S_1|} \cdot \frac{|S_1| \cdot (|S_1| - 1)}{2} \cdot f(K - 1, |S_1|, |S_3|) \cdot f(N - K, |S_2|, |S_1| + |S_3|) \cdot f(N - K, |S_2|, |S_1| + |S_4|)

where f(i, j, k) is the number of ways we can fill an array with i elements, in which we have j values that are “fixed”, and for each element not having a fixed value there are k ways to fill in its value. To explain the formula above:

  • \binom{MEX}{|S_1|} comes from choosing the set S_1 from [0..MEX - 1].
  • \frac{|S_1| \cdot (|S_1| - 1)}{2} comes from choosing A_K < B_K.
  • f(K - 1, |S_1|, |S_3|) comes from filling the common prefix such that it contains all values in S_1, and any elements not in S_1 comes from S_3.
  • f(N - K, |S_2|, |S_1| + |S_3|) comes from filling A_{K+1..N} such that it contains all values in S_2, and any elements not in S_2 comes from S_1 \cup S_3.
  • f(N - K, |S_2|, |S_1| + |S_4|) comes from filling B_{K+1..N} such that it contains all values in S_2, and any elements not in S_2 comes from S_1 \cup S_4.

We have yet to cover how to compute f(i, j, k), but it turns out that this a pretty easy task (reminiscence of counting the number of arrays with n elements and k distinct values):

f(i, j, k) = f(i - 1, j - 1, k) + (j + k) \cdot f(i - 1, j, k).

That is because there are 3 states that can reach f(i, j, k):

  • Add a value that has already been added from the fixed values, this can be done in j ways and comes from state f(i-1,j,k).
  • Add a new fixed value, which comes from state f(i - 1, j - 1, k).
  • Add a non-fixed value, this can be done in j ways and comes from state f(i-1,j,k).

The above reccurence won’t really calculate what we wanted, since the first appearance of the fixed values is unordered. Therefore, after the recursion, we need to multiply f(i, j, k) by j!.

TIME COMPLEXITY:

Time complexity is O(M \cdot N \cdot \min(M, N))

SOLUTION:

Setter's Solution
#include <bits/stdc++.h>
using namespace std;
int32_t MOD;
struct modint {
    int32_t value;
    modint() = default;
    modint(int32_t value_) : value((value_ < MOD ? value_ : value_% MOD)) {}
    modint(int64_t value_) : value((value_ < MOD ? value_ : value_ % MOD)) {}
    inline modint operator + (modint other) const { int32_t c = this->value + other.value; return modint(c >= MOD ? c - MOD : c); }
    inline modint operator - (modint other) const { int32_t c = this->value - other.value; return modint(c < 0 ? c + MOD : c); }
    inline modint operator * (modint other) const { int32_t c = (int64_t)this->value * other.value % MOD; return modint(c < 0 ? c + MOD : c); }
    inline modint& operator += (modint other) { this->value += other.value; if (this->value >= MOD) this->value -= MOD; return *this; }
    inline modint& operator -= (modint other) { this->value -= other.value; if (this->value < 0) this->value += MOD; return *this; }
    inline modint& operator *= (modint other) { this->value = (int64_t)this->value * other.value % MOD; if (this->value < 0) this->value += MOD; return *this; }
    inline modint operator - () const { return modint(this->value ? MOD - this->value : 0); }
    modint pow(int32_t k) const { modint x = *this, y = 1; for (; k; k >>= 1) { if (k & 1) y *= x; x *= x; } return y; }
    modint inv() const { return pow(MOD - 2); }  // MOD must be a prime
    inline modint operator /  (modint other) const { return *this * other.inv(); }
    inline modint operator /= (modint other) { return *this *= other.inv(); }
    inline bool operator == (modint other) const { return value == other.value; }
    inline bool operator != (modint other) const { return value != other.value; }
    inline bool operator < (modint other) const { return value < other.value; }
    inline bool operator > (modint other) const { return value > other.value; }
};
modint operator * (int64_t value, modint n) { return modint(value) * n; }
modint operator * (int32_t value, modint n) { return modint(value) * n; }
istream& operator >> (istream& in, modint& n) { return in >> n.value; }
ostream& operator << (ostream& out, modint n) { return out << n.value; }
vector<vector<modint>> pas;
vector<modint> fact;
vector<vector<vector<modint>>> dp;
modint combs(int n, int k) {
    if (n < k || k < 0) {
        return 0;
    }
    return pas[n][k];
}
modint solve(int i, int j, int k) {
    if (i < 0 || j < 0 || k < 0) {
        return 0;
    }
    if (i == 0) {
        return (j == 0);
    }
    if (dp[i][j][k] != -1) {
        return dp[i][j][k];
    }
    dp[i][j][k] = solve(i - 1, j, k) * j + solve(i - 1, j - 1, k) + solve(i - 1, j, k) * k;
    return dp[i][j][k];
}
modint count(int i, int j, int k) {
    if (i < 0 || j < 0 || k < 0) {
        return 0;
    }
    return solve(i, j, k) * fact[j];
}
int main() {
    std::ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    int n, m;
    cin >> n >> m >> MOD;
    ++m;
    pas = vector<vector<modint>>(m + 1, vector<modint>(m + 1));
    fact = vector<modint>(m + 1);
    fact[0] = 1;
    for (int i = 1; i <= m; ++i) {
        fact[i] = fact[i - 1] * i;
    }
    pas[0][0] = 1;
    for (int i = 1; i <= m; ++i) {
        for (int j = 0; j <= i; ++j) {
            if (j == i || j == 0) {
                pas[i][j] = 1;
            }
            else {
                pas[i][j] = pas[i - 1][j] + pas[i - 1][j - 1];
            }
        }
    }
    dp = vector<vector<vector<modint>>>(n + 1, vector<vector<modint>>(m + 1, vector<modint>(m + 1, -1)));
    vector<modint> ans(n + 1);
    for (int i = 0; i <= m; ++i) {
        for (int j = 1; j <= n; ++j) {
            for (int k = 0; k <= min(i, j - 1); ++k) {
                modint rep = 0;
                if (i >= 2) {
                    if (i == m) {
                        rep += (i * (i - 1) / 2) * count(j - 1, k, m - i) * count(n - j, i - k - 1, m - (i - k - 1)) * count(n - j, i - k - 1, m - (i - k - 1)) * combs(i - 2, k);
                        rep += (i * (i - 1) / 2) * count(j - 1, k, m - i) * count(n - j, i - k, m - (i - k)) * count(n - j, i - k - 1, m - (i - k - 1)) * combs(i - 2, k - 1);
                        rep += (i * (i - 1) / 2) * count(j - 1, k, m - i) * count(n - j, i - k - 1, m - (i - k - 1)) * count(n - j, i - k, m - (i - k)) * combs(i - 2, k - 1);
                        rep += (i * (i - 1) / 2) * count(j - 1, k, m - i) * count(n - j, i - k, m - (i - k)) * count(n - j, i - k, m - (i - k)) * combs(i - 2, k - 2);
                    }
                    else {
                        rep += (i * (i - 1) / 2) * count(j - 1, k, m - i - 1) * count(n - j, i - k - 1, m - (i - k - 1) - 1) * count(n - j, i - k - 1, m - (i - k - 1)) * combs(i - 2, k);
                        rep += (i * (i - 1) / 2) * count(j - 1, k, m - i - 1) * count(n - j, i - k, m - (i - k) - 1) * count(n - j, i - k - 1, m - (i - k - 1)) * combs(i - 2, k - 1);
                        rep += (i * (i - 1) / 2) * count(j - 1, k, m - i - 1) * count(n - j, i - k - 1, m - (i - k - 1) - 1) * count(n - j, i - k, m - (i - k)) * combs(i - 2, k - 1);
                        rep += (i * (i - 1) / 2) * count(j - 1, k, m - i - 1) * count(n - j, i - k, m - (i - k) - 1) * count(n - j, i - k, m - (i - k)) * combs(i - 2, k - 2);
                    }
                }
                if (i >= 1 && i < m) {
                    rep += (i * (m - i)) * count(j - 1, k, m - i - 1) * count(n - j, i - k - 1, m - 1 - (i - k - 1)) * count(n - j, i - k, m - (i - k)) * combs(i - 1, k);
                    rep += (i * (m - i)) * count(j - 1, k, m - i - 1) * count(n - j, i - k, m - 1 - (i - k)) * count(n - j, i - k, m - (i - k)) * combs(i - 1, k - 1);
                }
                if (m - i - 2 > 0) {
                    rep += ((m - i - 2) * (m - i - 1) / 2) * count(j - 1, k, m - 1 - i) * count(n - j, i - k, m - 1 - (i - k)) * count(n - j, i - k, m - (i - k)) * combs(i, k);
                }
                ans[j] += rep;
            }
        }
    }
    for (int i = 1; i <= n; ++i) {
        cout << ans[i] << ' ';
    }
}
Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;

const int N = 205;

int n, m, mod;
long long f[N][N][N], fct[N], c[N][N];

void init() {
    fct[0] = 1;
    for (int i = 1; i < N; i++) {
        fct[i] = fct[i - 1] * i % mod;
    }
    for (int i = 0; i < N; i++) {
        c[i][0] = 1;
        for (int j = 1; j <= i; j++) {
            c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % mod;
        }
    }
    for (int oth = 0; oth <= m + 1; oth++) {
        f[0][0][oth] = 1;
        for (int i = 1; i <= n; i++) {
            for (int j = 0; j <= i; j++) {
                f[i][j][oth] = ((j >= 1 ? f[i - 1][j - 1][oth] : 0) + f[i - 1][j][oth] * (j + oth)) % mod;
            }
        }
    }
    for (int oth = 0; oth <= m + 1; oth++) {
        for (int i = 0; i <= n; i++) {
            for (int j = 0; j <= i; j++) {
                (f[i][j][oth] *= fct[j]) %= mod;
            }
        }
    }
}

long long C(int n, int k) {
    return k < 0 || k > n ? 0 : c[n][k];
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m >> mod;
    init();
    for (int i = 1; i <= n; i++) {
        long long cur = 0;
        for (int mex = 0; mex <= m + 1; mex++) { // mex of A
            for (int sa = 0; sa <= mex; sa++) {  // # of distinct values < mex appearing in the common prefix a[1..i-1] (b[1..i-1])
                // S_1 = set of distinct values < mex appearing before i. |S_1| = sa
                int s1 = sa;
                // S_2 = set of distinct values < mex not appearing before i. |S_2| = mex - sa
                int s2 = mex - sa;
                // S_3 = set of values > mex. |S_3| = max(0, m - mex)
                int s3 = max(0, m - mex);
                // S_4 = set of values >= mex. |S_4| = m - mex + 1
                int s4 = m - mex + 1;
                
                // a[i] in S_1, b[i] in S_1
                if (s1 >= 2) {
                    cur += 1LL * s1 * (s1 - 1) / 2 % mod // choose a[i], b[i] from s1
                        * c[mex][s1] % mod // pick s1 values from mex
                        * f[i - 1][s1][s3] % mod // common prefix
                        * f[n - i][s2][s1 + s3] % mod // choose a[i+1..n]
                        * f[n - i][s2][s1 + s4] % mod; // choose b[i+1..n]
                }

                // a[i] in S_1, b[i] in S_2
                if (s1 >= 1 && s2 >= 1) {
                    cur += 1LL * mex * (mex - 1) / 2 % mod // choose a[i], b[i] from mex
                        * c[mex - 2][s1 - 1] % mod // pick remaining s1 - 1 values from mex - 2
                        * f[i - 1][s1][s3] % mod // common prefix
                        * f[n - i][s2][s1 + s3] % mod // choose a[i+1..n]
                        * f[n - i][s2 - 1][s1 + 1 + s4] % mod; // choose b[i+1..n]
                }

                // a[i] in S_1, b[i] in S_4
                if (s1 >= 1 && s4 >= 1) {
                    cur += 1LL * s1 * s4 % mod // choose a[i] from s1, b[i] from s4
                        * c[mex][s1] % mod // pick s1 values from mex
                        * f[i - 1][s1][s3] % mod // common prefix
                        * f[n - i][s2][s1 + s3] % mod // choose a[i+1..n]
                        * f[n - i][s2][s1 + s4] % mod; // choose b[i+1..n]
                }

                // a[i] in S_2, b[i] in S_1
                if (s2 >= 1 && s1 >= 1) {
                    cur += 1LL * mex * (mex - 1) / 2 % mod // choose a[i], b[i] from mex
                        * c[mex - 2][s1 - 1] % mod // pick remaining s1 - 1 values from mex - 2
                        * f[i - 1][s1][s3] % mod // common prefix
                        * f[n - i][s2 - 1][s1 + 1 + s3] % mod // choose a[i+1..n]
                        * f[n - i][s2][s1 + s4] % mod; // choose b[i+1..n]
                }

                // a[i] in S_2, b[i] in S_2
                if (s2 >= 2) {
                    cur += 1LL * s2 * (s2 - 1) / 2 % mod // choose a[i], b[i] from s2
                        * c[mex][s1] % mod // pick s1 values from mex
                        * f[i - 1][s1][s3] % mod // common prefix
                        * f[n - i][s2 - 1][s1 + 1 + s3] % mod // choose a[i+1..n]
                        * f[n - i][s2 - 1][s1 + 1 + s4] % mod; // choose b[i+1..n]
                }

                // a[i] in S_2, b[i] in S_4
                if (s2 >= 1 && s4 >= 1) {
                    cur += 1LL * s2 * s4 % mod // choose a[i] from s2, b[i] from s4
                        * c[mex][s1] % mod // pick s1 values from mex
                        * f[i - 1][s1][s3] % mod // common prefix
                        * f[n - i][s2 - 1][s1 + 1 + s3] % mod // choose a[i+1..n]
                        * f[n - i][s2][s1 + s4] % mod; // choose b[i+1..n]
                }

                // a[i] in S_3, b[i] in S_3
                if (s3 >= 2) {
                    cur += 1LL * s3 * (s3 - 1) / 2 % mod // choose a[i], b[i] from s3
                        * c[mex][s1] % mod // pick s1 values from mex
                        * f[i - 1][s1][s3] % mod // common prefix
                        * f[n - i][s2][s1 + s3] % mod // choose a[i+1..n]
                        * f[n - i][s2][s1 + s4] % mod; // choose b[i+1..n]
                }

                cur %= mod;
            }
        }
        cout << cur << " ";
    }
}
2 Likes