PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: iceknight1093
Tester: sushil2006
Editorialist: iceknight1093
DIFFICULTY:
Medium
PREREQUISITES:
Divide and conquer, principle of inclusion-exclusion, binary search, dynamic programming
PROBLEM:
Alice and Bob play a game on an array.
While the length of the array is at least 2, Alice must choose an index that’s not an endpoint, and Bob will delete either the prefix ending there or the suffix starting there.
Alice wants to maximize the largest remaining value, while Bob wants to minimize it.
Find Alice’s score under optimal play, summed up across all subarrays.
EXPLANATION:
From the easy version, we already know that one valid solution is to binary search the answer K and then check if the resulting binary string with B_i = 1 \iff A_i \ge K can be partitioned into several “winning” subarrays of length at most 2, with one separating character between each subarray.
This can be extrapolated into a solution that doesn’t require binary search.
What we’re doing is really just breaking the entire array into blocks of size at most 2, separating adjacent blocks by a single element, and then:
- Take the maximum element from each block.
- Take the minimum among all these maximums.
That’s the score of this split into blocks.
The answer is then obtained by the maximum score across all such splits.
This observation allows us to modify the binary search check-dp as follows.
Define dp_i to be the answer upto index i.
Then, there are two options:
- We can take A_i to be a single element, giving a score of \min(A_i, dp_{i-2})
- We can take A_{i-1} and A_i together, giving a score of \min(\max(A_i, A_{i-1}), dp_{i-3})
With the final answer being dp_N.
This DP will be our key to summing up the value across all subarrays.
There are a handful of standard approaches that commonly come up when attempting to sum up some quantity across all subarrays.
For instance,
- Fix one endpoint and maintain the answer across all choices of the other one via some data structure (commonly, sweepline + segment tree)
- Use contribution: fix a value of the answer and count the number of subarrays that attain it in some way.
- The one we’ll use to solve this problem: divide and conquer.
Define f(L, R) to be the sum of answers across all subarrays with indices in [L, R].
Let M = \text{midpoint}(L, R).
We can recursively compute f(L, M) and f(M+1, R), after which we only need to deal with subarrays ‘crossing’ the middle.
To do this, we make use of the following observation:
- When splitting any array A into blocks of size at most 2 separated by single elements, among any three consecutive elements at least one must be a ‘separator’, i.e. not included in any block.
In particular, for subarrays that cross the middle, at least one of the indices \{M, M+1, M+2\} must be a ‘separator’.
With this in mind, for each L \le x \le M and 0 \le i \le 2, let’s define dp(x, i) to be the answer for the subarray [x, M+i] such that index M+i is not included in any block.
This can be computed for all L \le x \le M in \mathcal{O}(M-L) time by just using the \mathcal{O}(N) DP described at the start and running it in reverse starting from index M+i; while also ensuring we never take M+i into a block.
Similarly, for all M+1 \le y \le R and 0 \le i \le 2, we can compute dp(y, i) to be the answer for [M+i, y] assuming M+i is not included in any block.
Now, consider some subarray [x, y] where x is on the left side and y is on the right.
As noted earlier, at least one of \{M, M+1, M+2\} must be a separator.
Further, the parts on each side of the separator are independent; so we can take the best answer on each side but we’re limited by the lower one.
Thus, the answer for the pair (x, y) simply equals the maximum of:
- \min(dp(x, 0), dp(y, 0))
- \min(dp(x, 1), dp(y, 1))
- \min(dp(x, 2), dp(y, 2))
This needs to be summed up across all x and y.
To do this, we’ll use inclusion-exclusion.
Specifically, we’ll use the fact that
So, summing up the max-of-three-mins can be done by:
- Summing up each min individually
- Subtracting out sums of pairwise minimums
- Adding back in the sum of all minimums
For example, suppose we want to find the sum of
across all pairs of x and y.
(This corresponds to the \min(a, c) term in the inclusion-exclusion step.)
Because everything is a minimum, this works out to be
To compute this, we can define two arrays P and Q such that:
- P_x = \min(dp(x, 0), dp(x, 2))
- Q_y = \min(dp(y, 0), dp(y, 2))
Then the goal is to compute the sum of \min(P_x, Q_y) across all x, y.
This is fairly simple using sorting and binary search - fix one element of P to be the minimum and count the number of elements in Q that are not smaller than it; then similarly fix an element of Q to be the minimum and count the number of elements in P larger than it.
This will take a total of \mathcal{O}((R-L)\log(R-L)) time, and we must repeat it 7 times (once for each mask) to obtain the eventual answer.
When combined with the divide and conquer, this gives a solution that runs in \mathcal{O}(N\log^2 N) time (with a hidden multiplier of 7), which is fast enough for the constraints and time limit of 4 seconds.
TIME COMPLEXITY:
\mathcal{O}(N\log^2 N) per testcase.
CODE:
Tester's code (C++)
#include <algorithm>
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl
#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)
template<typename T>
void amin(T &a, T b) {
a = min(a,b);
}
template<typename T>
void amax(T &a, T b) {
a = max(a,b);
}
#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif
/*
*/
const int MOD = 1e9 + 7;
const int N = 1e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;
void solve(int test_case){
ll n; cin >> n;
vector<ll> a(n+5);
rep1(i,n) cin >> a[i];
vector<ll> dp(n+5,-inf2);
ll ans = 0;
auto naive = [&](ll l, ll r){
// fix right endpoint, find ans for all left endpoints
ll res = 0;
for(int p = l; p <= r; ++p){
dp[p+1] = -inf2, dp[p+2] = inf2;
rev(i,p,l){
ll v1 = -inf2, v2 = -inf2;
// single
v1 = min(dp[i+2],a[i]);
// double
if(i+1 <= p){
v2 = min(dp[i+3],max(a[i],a[i+1]));
}
dp[i] = max(v1,v2);
res += dp[i];
}
}
return res;
};
vector<vector<ll>> f(n+5,vector<ll>(3));
vector<vector<ll>> g(n+5,vector<ll>(3));
vector<ll> curr_state(n+5);
vector<ll> cnt1(8), cnt2(8);
auto go = [&](ll l, ll r, auto &&go) -> void{
if(l > r) return;
if(r-l+1 <= 9){
ll val = naive(l,r);
ans += val;
return;
}
ll mid = (l+r)>>1;
// left-side
for(int p = mid+1; p <= mid+3; ++p){
// p must be cutpoint --> should land EXACTLY at p+1
dp[p] = -inf2, dp[p+1] = inf2;
rev(i,p-1,l){
ll v1 = -inf2, v2 = -inf2;
// single
v1 = min(dp[i+2],a[i]);
// double
if(i+1 < p){
v2 = min(dp[i+3],max(a[i],a[i+1]));
}
dp[i] = max(v1,v2);
}
for(int i = l; i <= mid; ++i){
f[i][p-mid-1] = dp[i];
}
}
// right-side
for(int p = mid+1; p <= mid+3; ++p){
// p must be cutpoint --> should land EXACTLY at p-1
dp[p] = -inf2, dp[p-1] = inf2, dp[p-2] = -inf2;
for(int i = p+1; i <= r; ++i){
ll v1 = -inf2, v2 = -inf2;
// single
v1 = min(dp[i-2],a[i]);
// double
if(i-1 > p){
v2 = min(dp[i-3],max(a[i],a[i-1]));
}
dp[i] = max(v1,v2);
}
for(int i = mid+1; i <= r; ++i){
g[i][p-mid-1] = dp[i];
}
}
ll res = 0;
vector<array<ll,4>> vals;
for(int i = l; i <= mid; ++i){
rep(x,3){
vals.pb({f[i][x],i,x,0});
}
}
for(int i = mid+1; i <= r; ++i){
rep(x,3){
vals.pb({g[i][x],i,x,1});
}
}
for(int i = l; i <= r; ++i){
curr_state[i] = 0;
}
fill(all(cnt1),0);
fill(all(cnt2),0);
cnt1[0] = mid-l+1;
cnt2[0] = r-mid;
sort(all(vals));
ll siz = sz(vals);
vals.pb({2*inf2,-1,-1,-1});
ll pv = -2*inf2, pw = 0;
rep(ind,siz){
auto [v,i,x,t] = vals[ind];
if(t == 0){
cnt1[curr_state[i]]--;
curr_state[i] |= (1<<x);
cnt1[curr_state[i]]++;
}
else{
cnt2[curr_state[i]]--;
curr_state[i] |= (1<<x);
cnt2[curr_state[i]]++;
}
if(v != vals[ind+1][0]){
ll ways = 0;
rep(x,8){
rep(y,8){
if((x|y) == 7){
ways += cnt1[x]*cnt2[y];
}
}
}
ll delta = v-max(pv,0ll);
if(delta > 0){
res += (ways-pw)*v;
}
pv = v;
pw = ways;
}
}
ans += res;
go(l,mid,go);
go(mid+1,r,go);
};
go(1,n,go);
cout << ans << endl;
}
int main()
{
fastio;
int t = 1;
cin >> t;
rep1(i, t) {
solve(i);
}
cerr << "RUN SUCCESSFUL" << endl;
return 0;
}