PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: raysh07
Tester: sky_nik
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Stacks or sets/dsu
PROBLEM:
Let g(B, x) denote the maximum possible mex of B after x integers are appended to it.
Let f(B) denote the minimum integer C such that g(B, x+1) - g(B, x) = 1 for all x\geq C.
Given A, find the sum of f(A[l, r]) across all subarrays of A.
EXPLANATION:
Let’s work out the given function definitions first.
g(B, x) is not hard to figure out: the objective is to maximize the mex, so it’s best to only append numbers that don’t already exist in the array.
Among them, again to maximize the mex the best we can do is to append the x smallest non-negative integers that don’t appear in B.
So, g(B, x) equals the (x+1)-th smallest non-negative integer that doesn’t appear in B.
Now, for f(B).
Observe that g(B, x+1) - g(B, x) = 1 means that the (x+2)-th missing element is one more than the (x+1)-th.
However, the only way this holds for all x\geq C, is if B itself doesn’t contain any elements larger than the (C+1)-th missing element (otherwise there’d be a jump of \gt 1).
In other words, the (C+1)-th missing element should be strictly larger than any element of B.
Let M = \max(B). Then, we want to find C such that g(B, C) = M+1, since this is the smallest value larger than everything in B.
It’s not hard to see that this gives us C = M+1 - dist(B), where dist(B) denotes the number of distinct elements in B.
After all, we already have exactly dist(B) elements among [0, 1, \ldots, M], so we just need to fill in the gaps.
So, quite simply, f(B) = \max(B)+1 - dist(B).
Now, let’s move to computing this quantity across all subarrays of A, i.e,
Let’s break this summation into three parts: \max(A[l, r]), 1, and dist(A[l, r]).
1 appears once for each subarray, so \frac{N\cdot (N+1)}{2} times overall.
Next, we want to compute the sum of the maximums of all subarrays faster than \mathcal{O}(N^2).
This is a rather standard problem, and has several solutions.
For example:
- For index i, let L_i denote the largest index \lt i that has an element \gt A_i.
Similarly, let R_i denote the smallest index \gt i that has an element \geq A_i.
Then, A_i will be the maximum of any subarray A[l, r] such that L_i \lt l \leq i \leq r\lt R_i.
There are exactly (i-L_i)\cdot (R_i - i) such subarrays, so add this up for all i to get the answer.- Computing the L_i and R_i values is a standard problem, solvable using a stack.
- Alternately, you can start with the empty array and insert elements into it one by one in descending order.
Each time you insert something, L_i and R_i can be found as the closest already-inserted indices existing to its left and right (which can be found quickly using a set or dsu).
Finally, we want to compute the sum of the number of distinct elements present in all subarrays.
For this, let’s fix an integer x and count the number of subarrays that contain at least one occurrence of x. Summing this up across all x will give us the value we want.
For a fixed x,
- Initially, assume that x is present in every subarray, so it has an overall contribution of \frac{N\cdot (N+1)}{2}.
- Suppose x occurs at indices i_1, i_2, \ldots, i_k of A.
- Then, x doesn’t appear in any subarray that’s completely contained between some i_j and i_{j+1}.
It’s easy to count the number of such subarrays: let d = i_{j+1} - i_j - 1, then there are \frac{d\cdot (d+1)}{2} subarrays contained between them that don’t include x. - Don’t forget to also include the prefix and suffixes (i.e, subarrays before i_1 and after i_k) that don’t include x.
- Then, x doesn’t appear in any subarray that’s completely contained between some i_j and i_{j+1}.
Doing this for all x from 0 to N-1 takes \mathcal{O}(N) time overall, since each index of A is processed at most once.
Now that all three parts of the summation are individually known, add/subtract them appropriately to obtain the final answer.
TIME COMPLEXITY:
\mathcal{O}(N) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
#define f first
#define s second
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
int n;
const int N = 2e5 + 69;
int root[N], sz[N], a[N], last[N];
int find(int x){
if (x == root[x]) return x;
return root[x] = find(root[x]);
}
void unite(int x, int y){
x = find(x); y = find(y);
root[x] = y;
sz[y] += sz[x];
}
void Solve()
{
cin >> n;
for (int i = 1; i <= n; i++){
last[i - 1] = 0;
root[i] = i;
sz[i] = 1;
cin >> a[i];
}
vector <int> ord(n);
vector <bool> ac(n + 1, false);
iota(ord.begin(), ord.end(), 1);
sort(ord.begin(), ord.end(), [&](int x, int y){
return a[x] < a[y];
});
int ans = 0;
for (auto i : ord){
int l = 1, r = 1;
if (i != 1 && ac[i - 1]){
l += sz[find(i - 1)];
}
if (i != n && ac[i + 1]){
r += sz[find(i + 1)];
}
ac[i] = true;
ans += l * r * a[i];
if (i != 1 && ac[i - 1])
unite(i - 1, i);
if (i != n && ac[i + 1])
unite(i, i + 1);
}
for (int i = 1; i <= n; i++){
int l = i - last[a[i]];
int r = (n - i + 1);
ans -= l * r;
last[a[i]] = i;
}
ans += n * (n + 1) / 2;
cout << ans << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
int main() {
cin.tie(0)->sync_with_stdio(0);
int t; cin >> t; while (t--) {
int n;
cin >> n;
vector<int> a(n);
for (auto& ai : a) {
cin >> ai;
}
// f(b) = max(b) + 1 - len(set(b))
vector<int> prev_gt(n, -1), prev_eq(n, -1); {
stack<int> mono;
for (int i = 0; i < n; ++i) {
while (!mono.empty() && a[mono.top()] <= a[i]) {
mono.pop();
}
if (!mono.empty()) {
prev_gt[i] = mono.top();
}
mono.push(i);
}
}
vector<int> next_ge(n, n), next_eq(n, n); {
stack<int> mono;
map<int, int> mp;
for (int i = n - 1; i >= 0; --i) {
if (mp.count(a[i])) {
next_eq[i] = mp[a[i]];
}
mp[a[i]] = i;
while (!mono.empty() && a[mono.top()] < a[i]) {
mono.pop();
}
if (!mono.empty()) {
next_ge[i] = mono.top();
}
mono.push(i);
}
}
int64_t ans = 0;
for (int i = 0; i < n; ++i) {
ans += (i - prev_gt[i]) * (a[i] + 1ll) * (next_ge[i] - i);
ans -= (i + 1ll) * (next_eq[i] - i);
}
cout << ans << '\n';
}
}
Editorialist's code (Python)
def maxsum(a):
n = len(a)
lt, rt = [-1]*n, [n]*n
stk = [-1]*(n+1)
ptr = 0
for i in range(n):
while ptr > 0 and a[stk[ptr]] <= a[i]: ptr -= 1
lt[i] = stk[ptr]
ptr += 1
stk[ptr] = i
ptr = 0
stk[0] = n
for i in reversed(range(n)):
while ptr > 0 and a[stk[ptr]] < a[i]: ptr -= 1
rt[i] = stk[ptr]
ptr += 1
stk[ptr] = i
ans = 0
for i in range(n):
l, r = lt[i], rt[i]
ans += a[i] * (i-l) * (r-i)
return ans
def distsum(a):
n = len(a)
ans = n*n*(n+1)//2
last = [-1]*n
for i in range(n):
d = i - last[a[i]] - 1
ans -= d*(d+1)//2
last[a[i]] = i
for i in range(n):
d = n - last[i] - 1
ans -= d*(d+1)//2
return ans
for _ in range(int(input())):
n = int(input())
a = list(map(int, input().split()))
print(maxsum(a) - distsum(a) + n*(n+1)//2)