P8209 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: satyam_343
Tester: kingmessi
Editorialist: iceknight1093

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Dynamic programming, combinatorics

PROBLEM:

You’re given N and M.
Count the number of integer arrays A satisfying the following conditions:

  • A has length N, and 1 \le A_i \le M
  • There exists an index i (1 \le i \lt N) such that f(A, 1, i) = f(A, i+1, N), where f(A, l, r) denotes the number of distinct elements in the array [A_l, \ldots, A_r].

EXPLANATION:

Let’s understand what it means for an array to have f(A, 1, i) = f(A, i+1, N) for some index i.

Let P_i = f(A, 1, i) denote the number of distinct elements in the prefix ending at i, and S_i = f(A, i, N) denote the number of distinct elements in the suffix starting at i.
We’re looking for an index that satisfies P_i = S_{i+1}.

Observe that for any i, we have P_i \le P_{i+1} \le P_i + 1, and S_{i+1} \le S_{i} \le S_{i+1} + 1.
This is because extending a prefix/suffix by one element can add at most 1 to the number of distinct elements.

Now, note that we have P_1 = S_N = 1, and P_N = S_1 = f(A, 1, N).
So, we start with P_1 \le S_2 and end with P_{N-1} \ge S_N.
As we move i from 1 to N-1, P_i keeps increasing by at most 1 and S_{i+1} keeps decreasing by at most 1.
Observe that, if P_i \lt S_{i+1},

  • If neither value changes when moving to (P_{i+1}, S_{i+2}), the inequality is maintained.
  • If only P_{i+1} changes compared to P_i, or only S_{i+2} changes compared to S_{i+1}, then we’ll still have P_{i+1} \le S_{i+2}.
    This means either we’ll have an equality; or the strict inequality is maintained further.
  • If both P_{i+1} and S_{i+2} change, in most cases we’ll still have P_{i+1} \le S_{i+2}.
    There is exactly one exception: when P_i+1 = S_{i+1}, so that after the changes we’ll have P_{i+1} = S_{i+2}+1.
    Observe that if this case ever happens, then further indices can never have equality - the prefix has already crossed the suffix, and increasing the size of the prefix will just expand the distance between them.

Let’s try to understand when the “bad” case happens.
We need P_{i+1} = P_i + 1, which means that A_{i+1} hasn’t appeared in the prefix [1, i].
We need S_{i+2} = S_{i+1}-1, which means that A_{i+1} doesn’t have any other appearances in [i+2, N].
So, index i+1 must be the sole appearance of A_{i+1} in the array.

Further, the number of distinct elements in [1, i] and [i+2, N] must be equal.

It’s now easy to see that any array can have at most one such “bad” position.
We can use this fact to count the total number of “bad” arrays, by conditioning on the bad positions.

So, let’s fix the index i to be the “bad” position of the array A.
A_i must be unique in the array; there are M choices for what it can be.
Then, we need to fill in the other elements.
The prefix [1, i-1] and suffix [i+1, N] must have an equal number of distinct elements.
However, they don’t share any positions; so we can actually count for each part separately and multiply the counts.

That is, let’s define g(k, d) to be the number of arrays of length k containing exactly d distinct elements from [1, M-1].
(We use M-1 and not M because after fixing A_i, only M-1 distinct elements are available to us.)

If i and the value of A_i are fixed, the number of ways of filling in the rest of the array will simply equal \displaystyle\sum_d g(i-1, d) \cdot g(N-i, d).


Let’s now focus on computing g(k, d).
This is quite easy using dynamic programming:

  • The k-th element might be an element that already appears in the prefix of length k-1.
    There are d choices for which element is chosen, so we have g(k-1, d) \cdot d.
  • The k-th element might be new.
    This means the prefix till k-1 had (d-1) distinct elements, and there are (M-d) choices for the new element.
    So, we have g(k-1, d-1) \cdot (M-d).

Thus, we quite simply obtain g(k, d) = g(k-1, d)\cdot d + g(k-1, d-1)\cdot (M-d).

Since N, M \le 5000 all the values of g(k, d) can be computed in \mathcal{O}(NM) time and this is fast enough.

Once that’s done, the total number of “bad” arrays is, as noted above,

\sum_{i=1}^N \sum_{d=1}^{M-1} M\cdot g(i-1, d) \cdot g(N-i, d)

which again takes \mathcal{O}(NM) time.

To obtain the number of valid arrays, subtract this count from M^N, the total number of arrays.

TIME COMPLEXITY:

\mathcal{O}(NM) per testcase.

CODE:

Editorialist's code (PyPy3)
mod = 998244353

for _ in range(int(input())):
    n, m = map(int, input().split())

    dp = [ [0 for i in range(max(n, m)+1)] for j in range(n+1)]
    dp[0][0] = 1
    for i in range(n):
        for j in range(m):
            dp[i][j] %= mod

            dp[i+1][j] += dp[i][j] * j % mod
            dp[i+1][j+1] += dp[i][j] * (m-1-j) % mod
    
    ans = 0
    for i in range(1, n+1):
        for k in range(1, n+1):
            ans += dp[i-1][k] * dp[n-i][k] * m % mod
    print((pow(m, n, mod) - ans) % mod)
1 Like

Tough to understand but Good LR dp type question I think.