PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: raysh07
Tester: iceknight1093
Editorialist: iceknight1093
DIFFICULTY:
Medium
PREREQUISITES:
Dynamic programming, knowledge of monotonic stack
PROBLEM:
You’re given a symmetric function f:[N]\times [N] \to [N], where [N] = \{1, 2, \ldots, N\}.
Define the score of an array is the value of f(M_1, M_2), where M_1 and M_2 are the two largest elements in the array.
The value of an array is the maximum score across all its subarrays that have length at least 2.
Given N and the function f, compute the sum of values of all arrays of length N with elements in [1, N].
EXPLANATION:
The first step to solving this problem is to understand how to compute the value of any given array quickly.
An algorithm that runs in \mathcal{O}(N^3) is trivial: fix each subarray, compute its maximum and second maximum, and then lookup the appropriate function value.
It’s also quite easy to optimize this to run in \mathcal{O}(N^2), for example by fixing one endpoint of the subarray and storing the two maximums as we iterate the other end, avoiding recomputation.
However, it’s in fact possible to compute the value of an array in linear time!
To achieve this, observe that while there are around N^2 subarrays, there aren’t actually that many unique combinations of (maximum, second maximum) possible.
In particular, suppose we fix an index i, and try to look at all pairs where this element is the second maximum.
Then, the subarray certainly needs another element that’s \ge A_i.
So, let’s define L_i and R_i to be the indices of the nearest elements to the left/right respectively that are at least A_i.
Then, the only possible pairs in which A_i can be the second maximum are exactly (A_{L_i}, A_i) and (A_{R_i}, A_i).
This is because as soon as the subarray contains \ge 2 elements that are both \ge A_i, A_i no longer needs to be treated as the second maximum.
So, we see that there are at most 2N pairs of (maximum, second maximum) possible.
Finding these 2N pairs is fairly simple: we saw earlier that for an index i, only the indices L_i and R_i matter; and computing this next/previous greater element is a classical problem, solved using a monotonic stack.
We can use this idea to compute the sum of answers across all arrays.
Let’s try to build up the array by placing elements from left to right.
Suppose we’ve already placed the first i elements, and we’re now trying to place A_{i+1}.
To update the score, only pairs involving A_{i+1} matter.
This means we only care about:
- The nearest element to the left of A_{i+1} that’s not smaller than it; and
- All those elements to the left of A_{i+1} for which A_{i+1} is the next not-smaller element.
To find these, let’s look back at what the monotonic stack algorithm does.
The general algorithm is as follows:
- Let S be an empty stack.
- Iterate indices left to right.
- When processing index i, pop elements not larger than A_i from the top of the stack.
- Then, push A_i onto the stack.
The important parts here are steps 3 and 4, because:
- For each index j popped in step 3, we have R_j = i, i.e. index i is the next not-smaller element to the right of index j.
- Just before we push i onto the top of the stack in step 4, we have L_i = S.\text{top}, i.e. the topmost remaining index on the stack is the previous not-smaller element for index i.
So, the process of updating the stack is itself what gives us the requisite pairs of elements that need to be checked.
The above observation, combined with the fact that N is small, is what will lead us to a solution to the problem.
Observe that after we’ve placed the first i elements and are attempting to place A_{i+1}, the only thing that matters is the state of the monotone stack.
Specifically, we only care about which elements are present on the stack.
With that in mind, let’s define dp(i, mask, v) to be the number of configurations such that:
- We’ve placed the first i elements,
- mask is a bitmask representing the current state of the stack.
- The current value of the array is v.
Now, suppose we choose the element x to append to the array as the (i+1)-th element.
Then,
- All elements y\le x in mask will be popped from the stack.
Each such y will give us the pair (y, x), for which we set v \gets \max(v, f(y, x)). - The smallest element z\gt x in mask also needs to be considered;
we set v \gets \max(v, f(z, x)) - Change mask appropriately: unset all bits \lt x, and then set the bit x in the mask to denote x being pushed onto the stack.
- Let nmask be the updated stack mask, and nv be the new value of v after processing the newly obtained pairs.
We then transition to the state (i+1, nmask, nv), so add dp(i, mask, v) to the value of that state.
The complexity of this solution, if implemented directly, is \mathcal{O}(2^N N^4), because there are 2^N\cdot N^2 states, and we have \mathcal{O}(N^2) transitions from each one (fix the next value, then update the mask and value in linear time.)
It also uses \mathcal{O}(2^N N^2) memory.
If implemented well and in a fast enough language, this may pass.
There are some optimizations that can be made:
- The memory is trivially optimized to \mathcal{O}(2^N N), since when computing dp(i, mask, v) we only need to store the values of dp(i-1, \cdot, \cdot), i.e. the “previous row”.
- The time complexity can also be improved by a factor of N.
To do this, note that if the input stack mask and newly added value are fixed, the resulting stack mask and (max, second max) pairs are fixed - and we only care for the maximum score obtained among these pairs anyway.
So, we can simple precompute and store the results of all possible transitions first, and then just look them up when actually running the later algorithm.
The precomputation costs \mathcal{O}(2^N N^2) time and \mathcal{O}(2^N N) memory, but shaves a factor of N off in the resulting dp, leading to an overall time complexity of \mathcal{O}(2^N N^3).
TIME COMPLEXITY:
\mathcal{O}(2^N N^3) or \mathcal{O}(2^N N^4) per testcase.
CODE:
Editorialist's code (PyPy3)
mod = 998244353
for _ in range(int(input())):
n = int(input())
f = [list(map(int, input().split())) for i in range(n)]
trans = [ [(0, 0) for _ in range(2**n)] for _ in range(n)]
for v in range(n):
for mask in range(2**n):
nmask = mask
score = 1
for i in range(n):
if mask & 2**i:
score = max(score, f[i][v])
if i >= v: break
nmask ^= 2**i
nmask |= 2**v
trans[v][mask] = (score, nmask)
dp = [ [0 for _ in range(2**n)] for _ in range(n+1)]
dp[1][0] = 1
for i in range(n):
ndp = [ [0 for _ in range(2**n)] for _ in range(n+1)]
for mask in range(2**n):
for score in range(1, n+1):
for v in range(n):
newsc, newmask = trans[v][mask]
newsc = max(newsc, score)
ndp[newsc][newmask] += dp[score][mask]
dp, ndp = ndp, dp
ans = 0
for score in range(1, n+1):
for mask in range(2**n):
ans += score * dp[score][mask]
print(ans % mod)