DELMX - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Easy

PREREQUISITES:

None

PROBLEM:

You have an array A containing distinct integers.
In one move, you can choose any subarray of length |A|-1 and delete the maximum element of this subarray.
How many distinct arrays are reachable by repeating this operation several times?

EXPLANATION:

Each deletion operation requires us to choose a subarray whose length is one less than the length of the whole array.
For an array A of length N, there are only two such subarrays: A[1\ldots N-1] and A[2\ldots N], i.e. the first N-1 elements or the last N-1 elements.
Further, these two subarrays have a large overlap: they both contain all of A_2, A_3, \ldots, A_{N-1}.

So, if \max(A) lies among indices 2, 3, \ldots, N-1, it doesn’t matter which subarray is chosen: the element deleted will just be the maximum.

This means the only real choice we have, is when \max(A) lies at one endpoint of A, i.e. A_1 = \max(A) or A_N = \max(A).
Suppose A_1 = \max(A).
There are then two choices available to us:

  1. Don’t delete \max(A).
    If we choose to do this, then we might as well never delete \max(A) in the future - because if we do delete it in the future, we could’ve just deleted it now and nothing would change.
    Since \max(A) is not being deleted, it will remain a border element - and so we only really have one option for which subarray to choose for all future moves.
    Essentially, we’re locked: if there are M elements remaining, there are M possible arrays that can be formed depending on how many elements we delete - for example, if A = [2, 1, 3] and we don’t want to delete 3, the arrays that can be formed are [2, 1, 3], [1, 3], [3].
  2. We do delete \max(A).
    Here, the process continues on as normal - we’ll need to check the position of the maximum in the remaining array, and so on.

Looking at the process more globally, what we’re really doing is: repeatedly delete the maximum element of A, till we decide not to (which is only possible when said maximum is the leftmost/rightmost of the remaining elements).
Once we decide not to delete the maximum, there are no more choices for moves since we have to ensure it isn’t deleted.

This allows us to count reachable arrays by conditioning on their maximum remaining element.
That is, fix an integer M (1 \leq M \leq N), and we’ll try to count the number of reachable arrays whose maximum element is M.

  • First, we must surely delete all elements larger than M.
  • This will leave us with exactly all the elements [1, M], in the same relative order they were in A.
  • Now,
    • If M itself is the leftmost or rightmost of these elements, we can choose to not delete it; and then there are M reachable arrays as noted earlier.
    • If M is not the leftmost or rightmost remaining element, then any move we perform will end up deleting M, so there’s only one reachable array with M as its maximum element (that being the current array, obtained by deleting all elements \gt M).

Note that to process a value of M quickly, we’ll need to know quickly whether it’s the leftmost/rightmost remaining element among [1, M].
One way of doing this quickly is to store \text{pos}[x] as the position of x in the array, and then just check if \text{pos}[M] = \max(\text{pos}[1], \ldots, \text{pos}[M]) or \text{pos}[M] = \min(\text{pos}[1], \ldots, \text{pos}[M]).
The prefix minimums/maximums can be precomputed, or just stored as you iterate.

Note that the answer can’t exceed 1 + 2 + \ldots + N since there are at most M arrays whose largest element is M.
This easily fits in a 64-bit integer type; but don’t forget to print the answer modulo 998244353 in the end since the answer can exceed the modulo.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Editorialist's code (PyPy3)
for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    pos = [0]*(n+1)
    for i in range(n):
        pos[a[i]] = i

    ans, L, R = n, n, -1
    for m in range(1, n+1):
        L = min(L, pos[m])
        R = max(R, pos[m])
        if L == pos[m] or R == pos[m]: ans += m-1
    print(ans % 998244353)
2 Likes