MEXSUM - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author:
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Easy

PREREQUISITES:

None

PROBLEM:

You’re given an array A.
A pair of integers (L, R) is called good if 1 \leq L \lt R \lt N and \text{MEX}(A[1\ldots L]) = \text{MEX}(A[L+1\ldots R]) = \text{MEX}(A[R+1\ldots N]).

Across all good pairs corresponding to A, find the minimum and maximum value of \text{SUM}(A[1\ldots L]) - \text{SUM}(A[L+1\ldots R]) + \text{SUM}(A[R+1\ldots N])

EXPLANATION:

In simple words, a good pair corresponds to partitioning A into three non-empty subarrays, each of which has the same mex.
The score of such a partition is the sum of the first and third subarrays, minus the sum of the second subarray.

Let’s go back to the definition of mex.
Note that if \text{MEX}(S) = x for a set S, it means S contains all the integers 0, 1, 2, \ldots, x-1, but does not contain x.

So, if (L, R) is a good pair, with the common mex being M, it means all three subarrays in the partition contain the elements 0, 1, 2, \ldots, M-1, and do not contain M.
But this means \text{MEX}(A) = M as well, since these three subarrays together comprise the entire array!

Notice that this puts a rather strong condition on both L and R: the mex of the prefix till L must equal M, and so should the mex of the suffix after R.
Motivated by this, let’s define arrays \text{premex} and \text{sufmex}, representing the prefix and suffix mex-es of array A.

How to compute these?

We’ll look at computing \text{premex}, since applying this to the reverse of the array will compute \text{sufmex}.

Let S be a set of elements encountered so far.
For each i from 1 to N, do the following:

  1. Insert A_i into S.
  2. Start with \text{premex}_i = \text{premex}_{i-1} (and \text{premex}_1 = 0). Then, while \text{premex}_i is present in S, increment \text{premex}_i by 1.

This seemingly brute-force algorithm is in fact fast enough for our needs.
This is because starting with \text{premex}_i = \text{premex}_{i-1} means we’ll only ever increase \text{premex} values; and the maximum possible value a prefix mex can take is N so there will be at most N increases.

Each check can be done in \mathcal{O}(\log N) or \mathcal{O}(1) time using an appropriate data structure, so this is fast enough.
For the data structure: we need to insert elements, and check whether an element exists - a set can handle this, or you can just use an array of size N to mark elements.


Let’s now fix an index L such that \text{premex}_L = M, and try to find what the ideal choice of R is.
As noted above, R should satisfy \text{sufmex}_{R+1} = M. It’s also not hard to see that \text{sufmex}_i \geq \text{sufmex}_{i+1} for any i, so there’s some rightmost index R_0 satisfying \text{sufmex}_{R_0 + 1} = M, and only R \leq R_0 is valid.
This R_0 is independent of L obviously, so we can just compute and store it.

Now, back to L being fixed.
We need to ensure that the middle part also has a mex of M, i.e. \text{MEX}(A_{L+1}, \ldots, A_R) = M.
Observe that since M is the mex of the entire array, if this condition is satisfied for some value of R then all R' \geq R will satisfy it too.
So, if we define c_L to be the smallest index such that \text{MEX}(A_{L+1}, \ldots, A_{c_L}) = M, we must choose R \geq c_L.
The array c can be computed quickly for all indices in a variety of ways.

Details

One way of doing this is to use a two-pointer algorithm.

We’ll compute c in descending order of indices.

Suppose we want to compute c_i.
We already know c_{i+1}, meaning that the subarray [i+1, c_{i+1}] has a mex of M.
This means [i, c_{i+1}] also has a mex of M (recall that M doesn’t exist in the array, after all).
So, we can start with c_i = c_{i+1}, and then attempt to reduce c_i as much as possible.

When reducing c_i, we need to check if doing so will make the mex less than M.
This is easy to do: maintain the frequencies of all elements \lt M in the range [i, c_i], and check if reducing c_i will make any of them 0. If it won’t, c_i can be reduced; otherwise it can’t.

This algorithm computes every c_i value in linear time overall.

So, we now have both c_L and R_0, meaning our choice of R must satisfy c_L \leq R \leq R_0.
Analyzing this,

  • If c_L \gt R_0, no valid R exists.
  • Otherwise,
    • To minimize the sum, it’s optimal to make R as large as possible - so that more elements are subtracted and less are added. Since all elements are non-negative, choosing R = R_0 is optimal here.
    • To maximize the sum, the exact opposite applies: R should be as small as possible, so choose R = c_L.

So, for a fixed L, we get at most one choice of R each as candidates for optimizing both the minimum and maximum.
Once L and R are fixed, the actual score is easy enough to compute in constant time (using, say, prefix sums), so we can just try every option and take the best among them.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Editorialist's code (PyPy3)
def calc_prefix(a):
    n = len(a)
    mark = [0]*(n+1)
    pref = [0]*n
    
    ans = 0
    for i in range(n):
        mark[min(a[i], n)] = 1
        while mark[ans]: ans += 1
        pref[i] = ans
    return pref

for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    
    premex = calc_prefix(a)
    sufmex = calc_prefix(a[::-1])[::-1]

    target = premex[-1]
    optsuf = n-1
    while sufmex[optsuf] != target: optsuf -= 1

    link = [n]*n
    nxt = [n]*n
    for i in reversed(range(n)):
        if a[i] < target:
            link[i] = nxt[a[i]]
            nxt[a[i]] = i
    
    presum = [a[0]]
    for i in range(1, n): presum.append(presum[-1] + a[i])

    from_here = 1
    if target > 0: from_here = max(nxt[:target])

    mn, mx, pref = 10**18, -10**18, 0
    for i in range(optsuf - 1):
        if a[i] < target: from_here = max(from_here, link[i])
        if target == 0: from_here = i+1
        pref += a[i]

        if premex[i] != target: continue
        if from_here >= optsuf: break

        # max: [0, i], [i+1, from_here], [from_here+1, n-1]
        # min: [0, i], [i+1, optsuf-1], [optsuf, n-1]
        mx = max(mx, pref + presum[-1] - presum[from_here] - presum[from_here] + presum[i])
        mn = min(mn, pref + presum[-1] - presum[optsuf - 1] - presum[optsuf - 1] + presum[i])
    if mn == 10**18: mn = mx = -1
    print(mn, mx)
1 Like

Checkout this solution with code explanation if you have difficulty to follow the editorial.

While I get this for a single L, how do you find the lower limit c_L for every L? Doesn’t that make it O(N^2)?