PROBLEM LINK:
Author: Akib Tonnoy, Alex Danilyuk
Preparer: Yahor Dubovik
Tester: Harris Leung
Editorialist: Ashley Khoo
DIFFICULTY:
3013
Prerequisites:
Data Structures, Divide and Conquer
Problem:
You are given an array A consisting of N integers.
For each K from 0 to N find the maximum prefix sum of A after changing the sign of at most K elements of the array.
Explanation:
For a fixed K, the best prefix sum of the range [1,P] is c(K,M) = S_i - 2(\text{the sum of } K \text{ smallest negative elements of the range } [1,P]), where S_i. This is true since the cost we get from flipping an element A_i is -2A_i.
For a fixed K, we have to find the value of P where c(K,P) is maximised. Call this value P_K. If there are multiple P where c(K,P), then we will choose the minimum P to be P_K (this is just an arbitrary choice and it does not really matter).
Claim: P_K \leq P_{K+1}
Proof
Usually, the technique to prove such inequalities is contradiction. Let us assume that P_K > P_{K+1}.
Let L_K and R_K denote the sum of the smallest negative elements in the ranges [1,P_{K+1}] and [P_{K+1}+1,P_K] respectively. Note that L and R are both convex.
Suppose that in the optimal solution of c(K,P_K), we flipped K_1 elements in [1,P_{K+1}] and K_2 elements in [P_{K+1}+1,P_K]. By definition of P_K, c(K,P_{K+1}) < c(K,P_K). That is, c(K,P_{K+1}) = S_{P_{K+1}}-L_{K} < S_{P_K} - L_{K_1}-R_{K_2} = c(K,P_K).
Now, c(K+1,P_K) \geq S_{P_K} - L_{K_1+1}-R_{K_2} since we can achieve this cost by flipping K_1+1 elements in [1,P_{K+1}] and K_2 elements in [P_{K+1}+1,P_K]. Since L is convex, L_{K} - L _ {K+1} \leq L_{K_1} - L_{K_1+1}.
Combining all inequalities, we get c(K+1,P_{K+1}) = S_{P_{K+1}}-L_{K+1} < S_{P_K} - L_{K_1 +1}-R_{K_2} \leq c(K+1,P_K). c(K+1,P_{K+1}) < c(K+1,P_K) contradicts our assumption of the optimality of P_{K+1}.
Therefore, it must be true that P_K \leq P_{K+1}. \blacksquare
Because of this property, one can use a divide and conquer approach. Specifically, we can make a function like solve(l,r,optl,optr). If we want to find the answer for m = \lfloor \frac{l+r}{2} \rfloor, we will know that optl \leq P_m \leq optr.
Now, we will be able to solve the problem if we are able to find the value of c(M,K) for arbitrary M and K quickly. This problem boils down to finding the K smallest elements of the range [1,M] which can be solved in O(\log N) using wavelet trees. However, there is an easier (in the humble opinion of the editorialist) way to accomplish this.
Consider a data structure that stores a (multi)set S and an integer K that can handle the following operations all in logarithmic time:
- increment/decrement K by 1
- insert/delete element from S
- find the sum of the minimum \min(K,|S|) elements of S
This data structure can be accomplished by maintaining 2 (multi)sets A and B where \max(A) \leq \min(B) and either |A|=K or |A|< K and B=\varnothing. This way, the answer to the queries is the sum of elements in A. It is easy to maintain sets A and B.
Time Complexity
O(N \log^2 N) per test case.
Code:
Preparer's Code
#include <bits/stdc++.h>
#define f first
#define s second
#define pb push_back
#define mp make_pair
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
typedef pair<long long, long long> pll;
const int N = 300500, inf = 1e9, mod = 998244353;
int n, a[N], p[N], pp[N];
bool was[N];
ll S;
ll ans[N], tup[N];
pll t[N << 2];
void upd(int v, int tl, int tr, int pos, bool flag)
{
if (tl == tr)
{
if (flag)
t[v] = {max(0, -a[p[pos]]), 1};
else
t[v] = {0, 0};
return;
}
int tm = (tl + tr) >> 1;
if (pos <= tm)
upd(v << 1, tl, tm, pos, flag);
else
upd(v << 1 | 1, tm + 1, tr, pos, flag);
t[v].f = t[v << 1].f + t[v << 1 | 1].f;
t[v].s = t[v << 1].s + t[v << 1 | 1].s;
}
ll get(int v, int tl, int tr, int &k)
{
if (k == 0)
return 0;
if (t[v].s <= k)
{
k -= t[v].s;
return t[v].f;
}
int tm = (tl + tr) >> 1;
return get(v << 1, tl, tm, k) + get(v << 1 | 1, tm + 1, tr, k);
}
void upd(int i)
{
if (was[i])
{
S -= a[i];
upd(1, 0, n - 1, pp[i], 0);
}
else
{
S += a[i];
upd(1, 0, n - 1, pp[i], 1);
}
was[i] ^= 1;
}
ll get(int k)
{
return S + 2 * get(1, 0, n - 1, k);
}
void solve(int l, int r, int bl, int br)
{
if (l > r)
return;
int m = (l + r) >> 1;
pll res = {get(m), bl};
for (int i = bl; i < br; i++)
{
upd(i);
res = max(res, {get(m), i + 1});
}
for (int i = bl; i < br; i++)
upd(i);
ans[m] = res.f;
int bm = res.s;
solve(l, m - 1, bl, bm);
for (int i = bl; i < bm; i++)
upd(i);
solve(m + 1, r, bm, br);
for (int i = bl; i < bm; i++)
upd(i);
}
void solve()
{
cin >> n;
for (int i = 0; i < n; i++)
cin >> a[i];
for (int i = 0; i < n; i++)
p[i] = i;
sort(p, p + n, [](int i, int j)
{ return a[i] < a[j]; });
for (int i = 0; i < n; i++)
pp[p[i]] = i;
solve(0, n, 0, n);
for (int i = 0; i <= n; i++)
{
if (i)
cout << " ";
cout << ans[i];
}
cout << endl;
}
int main()
{
ios_base::sync_with_stdio(false);
int t = 1;
cin >> t;
for (int i = 1; i <= t; i++)
{
solve();
}
}
Tester's Solution
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define fi first
#define se second
const ll mod=998244353;
const int N=1e5+1;
int n;
ll a[N];
const int ts=262144;
ll nx[ts];
pair<ll,int>fn[ts];
ll cnt[ts],s[ts];
int getr(int id,int l,int r){
if(l==r) return l;
int mid=(l+r)/2;
if(cnt[id*2+1]>0) return getr(id*2+1,mid+1,r);
else return getr(id*2,l,mid);
}
pair<ll,int> qmn(int id,int l,int r,int ql,int qr){
if(l>qr || r<ql) return {(ll)2e9,0};
if(ql<=l && r<=qr) return fn[id];
int mid=(l+r)/2;
return min(qmn(id*2,l,mid,ql,qr),qmn(id*2+1,mid+1,r,ql,qr));
}
void pull(int id){
nx[id]=max(nx[id*2],nx[id*2+1]);
fn[id]=min(fn[id*2],fn[id*2+1]);
cnt[id]=cnt[id*2]+cnt[id*2+1];
s[id]=s[id*2]+s[id*2+1];
}
void upd(int id,int l,int r,int p,int v){
if(l==r){
if(v==0){
nx[id]=-2e9;fn[id]={a[l],l};
cnt[id]=0;s[id]=a[l];
}
else{
nx[id]=a[l];fn[id]={2e9,0};
cnt[id]=1;s[id]=max(a[l],-a[l]);
}
return;
}
int mid=(l+r)/2;
if(p<=mid) upd(id*2,l,mid,p,v);
else upd(id*2+1,mid+1,r,p,v);
pull(id);
}
void pop(int id,int l,int r){
if(l==r){
nx[id]=-2e9;fn[id]={a[l],l};
cnt[id]=0;s[id]=a[l];
return;
}
int mid=(l+r)/2;
if(nx[id*2]>nx[id*2+1]) pop(id*2,l,mid);
else pop(id*2+1,mid+1,r);
pull(id);
}
void build(int id,int l,int r){
if(l==r){
nx[id]=-2e9;fn[id]={a[l],l};
cnt[id]=0;s[id]=a[l];
return;
}
int mid=(l+r)/2;
build(id*2,l,mid);
build(id*2+1,mid+1,r);
pull(id);
}
ll qry(int id,int l,int r,int ql,int qr){
if(l>qr || r<ql) return 0;
if(ql<=l && r<=qr) return s[id];
int mid=(l+r)/2;
return qry(id*2,l,mid,ql,qr)+qry(id*2+1,mid+1,r,ql,qr);
}
ll ans[N];
void solve(int l,int r,int gl,int gr){
if(l>r) return;
int mid=(l+r)/2;
int st=max(gl,mid);
while(true){
int x=getr(1,0,n);
if(x<=st) break;
else upd(1,0,n,x,0);
}
while(cnt[1]>mid){
pop(1,0,n);
}
while(cnt[1]<mid){
int x=qmn(1,0,n,0,st).se;
upd(1,0,n,x,1);
}
ll best=qry(1,0,n,0,st);int pos=st;
for(int i=st+1; i<=gr ;i++){
upd(1,0,n,i,1);
pop(1,0,n);
ll cur=qry(1,0,n,0,i);
if(cur>best){
best=cur;pos=i;
}
}
//cout << "!! " << mid << ' ' << best << endl;
ans[mid]=best;
solve(l,mid-1,gl,pos);
solve(mid+1,r,pos,gr);
}
void solve(){
cin >> n;
for(int i=1; i<=n ;i++){
cin >> a[i];
}
build(1,0,n);
solve(0,n,0,n);
for(int i=0; i<=n ;i++){
cout << ans[i] << ' ';
}
cout << '\n';
}
int main(){
ios::sync_with_stdio(false);cin.tie(0);
int t;cin >> t;while(t--) solve();
}
Editorialist's Solution
//もう布団の中から出たくない
//布団の外は寒すぎるから
//布団の中から出たくない
//布団の中はあたたかすぎるから
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define ll long long
#define ii pair<ll,ll>
#define iii pair<ii,ll>
#define fi first
#define se second
#define endl '\n'
#define debug(x) cout << #x << ": " << x << endl
#define pub push_back
#define pob pop_back
#define puf push_front
#define pof pop_front
#define lb lower_bound
#define ub upper_bound
#define rep(x,start,end) for(int x=(start)-((start)>(end));x!=(end)-((start)>(end));((start)<(end)?x++:x--))
#define all(x) (x).begin(),(x).end()
#define sz(x) (int)(x).size()
mt19937 rng(chrono::system_clock::now().time_since_epoch().count());
struct KHEAP{
int sum=0,lim;
multiset<int> big;
multiset<int,greater<int> > small;
void proc(){
while (sz(big)>lim){
sum-=*big.begin();
small.insert(*big.begin());
big.erase(big.begin());
}
while (sz(big)<lim && !small.empty()){
sum+=*small.begin();
big.insert(*small.begin());
small.erase(small.begin());
}
}
void change(int _lim){
lim=_lim;
proc();
}
void add(int i){
big.insert(i),sum+=i;
proc();
}
void del(int i){
if (big.find(i)!=big.end()) big.erase(big.find(i)),sum-=i;
else small.erase(small.find(i));
proc();
}
} kheap;
int n;
int arr[100005];
int pref[100005];
int ans[100005];
void dnc(int l,int r,int optl,int optr){
int m=l+r>>1;
int best=-1e9;
int optm=-1;
kheap.change(m);
rep(x,optl,optr+1){
if (x!=optl && arr[x]<0) kheap.add(-arr[x]);
int curr=pref[x]+2*kheap.sum;
if (best<curr){
best=curr;
optm=x;
}
}
ans[m]=best;
rep(x,optm+1,optr+1) if (arr[x]<0) kheap.del(-arr[x]);
if (m!=r) dnc(m+1,r,optm,optr);
rep(x,optl+1,optm+1) if (arr[x]<0) kheap.del(-arr[x]);
if (l!=m) dnc(l,m-1,optl,optm);
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin.exceptions(ios::badbit | ios::failbit);
int TC;
cin>>TC;
while (TC--){
cin>>n;
rep(x,1,n+1) cin>>arr[x];
rep(x,1,n+1) pref[x]=pref[x-1]+arr[x];
dnc(0,n,0,n);
rep(x,0,n+1) cout<<ans[x]<<" "; cout<<endl;
}
}