PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: raysh07
Tester: iceknight1093
Editorialist: iceknight1093
DIFFICULTY:
Easy
PREREQUISITES:
Combinatorics, binary exponentiation
PROBLEM:
For an array A, define f(A) to be the minimum possible value of (i-j) across all pairs of indices such that j \lt i and A_i = A_j.
If no such pair exists, f(A) = 0 instead.
Given N and K, compute the sum of f(A) across all arrays A of length N with elements in [1, K].
EXPLANATION:
Let c_x denote the number of arrays A such that f(A) = x.
Assuming we’re able to compute all the values of c_x, the answer we want is then
Let’s now attempt to compute c_x for a fixed x.
For an array to have an answer of x, two conditions must hold:
- For any pair (i, j) such that j \lt i and A_i = A_j, we must have i - j \ge x.
This is obvious: if (i - j) were any smaller, the answer would be \lt x. - Second, there must exist some (i, j) such that A_i = A_j and i - j = x.
If no such pair exists, the answer for the array will be strictly larger than x instead.
Let’s focus on only the first constraint for now - ensuring that all pairs of equal elements have a difference of at least x in their indices.
We’ll try to build the array from index 1 to index N while satisfying this constraint.
- At index 1, we can choose any element.
Since elements are in [1, K], there are K choices. - At index 2, we can choose any element other than A_1.
There are K-1 choices. - At index 3, we can choose any element other than A_1 or A_2.
Since A_1 \ne A_2, there are K-2 choices. - In general, note that for each i \le x, when choosing A_i we must ensure it is distinct from all of A_1, A_2, \ldots, A_{i-1}, hence giving us K - (i-1) choices.
This takes care of all the indices till x. - Next, let’s look at index x+1.
The value here cannot be equal to any of A_x, A_{x-1}, \ldots, A_2.
However, there’s no issue if it equals A_1 (though it is not forced to be equal to it).
Since all of A_x, \ldots, A_2 are themselves distinct, we thus have x-1 forbidden values for A_{x+1}, and hence K - (x-1) choices for it. - It’s not hard to see that the exact same logic applies to index x+2 as well - that is, it has K - (x-1) choices, since it cannot equal any of A_{x+1}, \ldots, A_3 but anything else is ok.
- In fact, this reasoning applies to any index i \gt x, i.e. for each such index, there are K - (x-1) valid choices of A_i, after A_1, \ldots, A_{i-1} have been chosen.
There are N - x indices after x, and each of them has K - (x-1) choices.
So, that’s (K - (x-1))^{N-x} choices for the entire array.
As for the first x indices, we’ve already seen that the number of choices is
K\cdot (K-1)\cdot (K-2)\cdot\ldots\cdot (K-(x-1))
Thus, the number of arrays where every pair of equal elements is at least x apart, equals
Let’s call this value d_x for simplicity.
One thing to note is that when performing the above computations, we implicitly assumed that we always had elements to work with (i.e., that K - (x-1) \gt 0), since otherwise the combinatorics technically doesn’t make sense (you can’t have -1 ways to choose an element, when speaking in natural language).
However, we luckily don’t need any special case handling for impossible situations, since the product K\cdot (K-1)\cdot (K-2)\cdot\ldots\cdot (K-(x-1)) will just end up being 0 anyway when K \lt x.
Computing all the values of d_x can be done fairly easily in \mathcal{O}(N\log N) time.
d_x essentially has two parts: one is a falling product of elements from K, the other is (K - (x-1)) raised to the power of N-x.
The falling product can be simply maintained as you iterate x, and updated in constant time (alternately, note that it’s one factorial divided by another, though this interpretation will need special handling for the value becoming 0.)
The power of (K - (x-1)) can be computed in \mathcal{O}(\log N) time with binary exponentiation.
Once all the d_x values are known, observe that we simply have c_x = d_x - d_{x+1}.
This is because we take all arrays where every equal pair has a distance of at least x, and then remove all those arrays for which every equal pair has a distance of at least (x+1) - this will leave exactly those arrays with all distances \ge x, as well as at least one distance being exactly x.
To further simplify the implementation, it can be observed with a bit of algebra that the c_x values don’t even need to be computed explicitly: instead, the answer is simply
TIME COMPLEXITY:
\mathcal{O}(N\log N) per testcase.
CODE:
Editorialist's code (PyPy3)
mod = 998244353
for _ in range(int(input())):
n, k = map(int, input().split())
ans = pow(k, n, mod)
init = k
for x in range(2, n+1):
init = (init * (k + 1 - x)) % mod
ways = init
if x <= k: ways = (ways * pow(k-x+1, n-x, mod)) % mod
ans += ways
sub = 1
for i in range(n): sub = (sub * (k - i)) % mod
print((ans - n*sub) %mod)