PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: kingmessi
Tester: pols_agyi_pols
Editorialist: iceknight1093
DIFFICULTY:
Easy
PREREQUISITES:
Elementary combinatorics
PROBLEM:
You’re given two integers N and M.
For an array A of length N, define f(A) as follows:
- You must do the following N-1 times:
- Choose two elements X, Y \in A
- Add \max(X, Y) to your score.
- Append X \ \& \ Y to A
- f(A) is the maximum possible final score.
Compute the sum of f(A) across all integer arrays of length N with elements in [1, M].
EXPLANATION:
Let’s try to find f(A) for a given array A.
Since the order of elements doesn’t matter, let’s sort A in ascending order, so that
A_1 \le A_2 \le\ldots\le A_N
Each operation removes two elements from A, adds the larger one to our score, and then inserts their bitwise AND into the array again.
Notably, the bitwise AND of two numbers cannot exceed the smaller of the numbers.
Since we want to maximize the answer, the optimistic maximum is then simply
i.e. every element other than A_1 becomes the maximum element of some operation.
It turns out that this is indeed achievable!
A simple construction is to always perform the operation with the smallest two remaining values.
So,
- Operate on A_1 and A_2. This adds A_2 to the score and inserts A_1 \& A_2 to the array.
Notably, the inserted value is not larger than A_1.
Let’s denote it by X_1. - Next, operate on A_3 and X_1, giving a score of A_3.
The newly inserted element (call it X_2) now cannot exceed X_1; which in turn means it doesn’t exceed A_1 again. - Next operate with X_2 and A_4 for a score of A_4 and again inserting an element that doesn’t exceed A_1, and so on.
This way, we’re able to turn every element from A_2 to A_N into a maximum once, which is noted earlier is the best we can do.
Let’s use the above observation to solve the problem at hand.
Observe that one way to write f(A) is f(A) = \text{sum}(A) - \min(A), since we’re adding up everything except the minimum element.
To sum this up this across all arrays A, we can then separately sum up \text{sum}(A) across all arrays, and the subtract out the sum of \min(A) across all arrays.
Consider computing the sum of \text{sum}(A) across all arrays A.
Let’s look at a single index i, and place an element X there.
This element then contributes X to the sum of M^{N-1} other arrays, since the remaining positions can take any value at all.
So, with a fixed i, we can sum this up across all X to obtain an overall contribution of
from elements in index i alone.
We then multiply this value by N to account for all indices of the array, to obtain an overall value of
Next, we look at computing the sum of \min(A) across all arrays; which we need to subtract out.
Let’s fix the value of the minimum to be Y, and try to count the number of arrays whose minimum element is Y.
Every element of the array must be \ge Y, meaning it must be in [Y, M] for (M-Y+1) choices per index.
Naturally, that corresponds to (M-Y+1)^N arrays.
However, this is the number of arrays whose minimum element is at least Y, not exactly Y.
So, we need to remove the arrays whose minimum element is \gt Y.
This corresponds to all arrays whose elements all lie in [Y+1, M], of which by the same token there are (M-Y)^N of.
Hence, the contribution of Y is Y\cdot \left((M-Y+1)^N - (M-Y)^N\right).
We need to sum this up across all Y, i.e.
This sum also telescopes to just being
This is easily computed in \mathcal{O}(M\log N) time with binary exponentiation, which is fast enough for the given constraints since the sum of M across tests is bounded.
TIME COMPLEXITY:
\mathcal{O}(M\log N) per testcase.
CODE:
Editorialist's code (PyPy3)
mod = 998244353
for _ in range(int(input())):
n, m = map(int, input().split())
ans = (n * m * (m+1) // 2) % mod
ans = (ans * pow(m, n-1, mod)) % mod
for i in range(1, m+1):
ans -= pow(i, n, mod)
print(ans % mod)