MAXPREFFLIP - Editorial

PROBLEM LINK:

Practice
Contest

Author: Akib Tonnoy, Alex Danilyuk
Preparer: Yahor Dubovik
Tester: Harris Leung
Editorialist: Ashley Khoo

DIFFICULTY:

3013

Prerequisites:

Data Structures, Divide and Conquer

Problem:

You are given an array A consisting of N integers.

For each K from 0 to N find the maximum prefix sum of A after changing the sign of at most K elements of the array.

Explanation:

For a fixed K, the best prefix sum of the range [1,P] is c(K,M) = S_i - 2(\text{the sum of } K \text{ smallest negative elements of the range } [1,P]), where S_i. This is true since the cost we get from flipping an element A_i is -2A_i.

For a fixed K, we have to find the value of P where c(K,P) is maximised. Call this value P_K. If there are multiple P where c(K,P), then we will choose the minimum P to be P_K (this is just an arbitrary choice and it does not really matter).

Claim: P_K \leq P_{K+1}

Proof

Usually, the technique to prove such inequalities is contradiction. Let us assume that P_K > P_{K+1}.

Let L_K and R_K denote the sum of the smallest negative elements in the ranges [1,P_{K+1}] and [P_{K+1}+1,P_K] respectively. Note that L and R are both convex.

Suppose that in the optimal solution of c(K,P_K), we flipped K_1 elements in [1,P_{K+1}] and K_2 elements in [P_{K+1}+1,P_K]. By definition of P_K, c(K,P_{K+1}) < c(K,P_K). That is, c(K,P_{K+1}) = S_{P_{K+1}}-L_{K} < S_{P_K} - L_{K_1}-R_{K_2} = c(K,P_K).

Now, c(K+1,P_K) \geq S_{P_K} - L_{K_1+1}-R_{K_2} since we can achieve this cost by flipping K_1+1 elements in [1,P_{K+1}] and K_2 elements in [P_{K+1}+1,P_K]. Since L is convex, L_{K} - L _ {K+1} \leq L_{K_1} - L_{K_1+1}.

Combining all inequalities, we get c(K+1,P_{K+1}) = S_{P_{K+1}}-L_{K+1} < S_{P_K} - L_{K_1 +1}-R_{K_2} \leq c(K+1,P_K). c(K+1,P_{K+1}) < c(K+1,P_K) contradicts our assumption of the optimality of P_{K+1}.

Therefore, it must be true that P_K \leq P_{K+1}. \blacksquare

Because of this property, one can use a divide and conquer approach. Specifically, we can make a function like solve(l,r,optl,optr). If we want to find the answer for m = \lfloor \frac{l+r}{2} \rfloor, we will know that optl \leq P_m \leq optr.

Now, we will be able to solve the problem if we are able to find the value of c(M,K) for arbitrary M and K quickly. This problem boils down to finding the K smallest elements of the range [1,M] which can be solved in O(\log N) using wavelet trees. However, there is an easier (in the humble opinion of the editorialist) way to accomplish this.

Consider a data structure that stores a (multi)set S and an integer K that can handle the following operations all in logarithmic time:

  • increment/decrement K by 1
  • insert/delete element from S
  • find the sum of the minimum \min(K,|S|) elements of S

This data structure can be accomplished by maintaining 2 (multi)sets A and B where \max(A) \leq \min(B) and either |A|=K or |A|< K and B=\varnothing. This way, the answer to the queries is the sum of elements in A. It is easy to maintain sets A and B.

Time Complexity

O(N \log^2 N) per test case.

Code:

Preparer's Code
#include <bits/stdc++.h>

#define f first
#define s second
#define pb push_back
#define mp make_pair

using namespace std;

typedef long long ll;
typedef pair<int, int> pii;
typedef pair<long long, long long> pll;

const int N = 300500, inf = 1e9, mod = 998244353;

int n, a[N], p[N], pp[N];
bool was[N];

ll S;
ll ans[N], tup[N];

pll t[N << 2];

void upd(int v, int tl, int tr, int pos, bool flag)
{
    if (tl == tr)
    {
        if (flag)
            t[v] = {max(0, -a[p[pos]]), 1};
        else
            t[v] = {0, 0};
        return;
    }
    int tm = (tl + tr) >> 1;
    if (pos <= tm)
        upd(v << 1, tl, tm, pos, flag);
    else
        upd(v << 1 | 1, tm + 1, tr, pos, flag);
    t[v].f = t[v << 1].f + t[v << 1 | 1].f;
    t[v].s = t[v << 1].s + t[v << 1 | 1].s;
}
ll get(int v, int tl, int tr, int &k)
{
    if (k == 0)
        return 0;
    if (t[v].s <= k)
    {
        k -= t[v].s;
        return t[v].f;
    }
    int tm = (tl + tr) >> 1;
    return get(v << 1, tl, tm, k) + get(v << 1 | 1, tm + 1, tr, k);
}
void upd(int i)
{
    if (was[i])
    {
        S -= a[i];
        upd(1, 0, n - 1, pp[i], 0);
    }
    else
    {
        S += a[i];
        upd(1, 0, n - 1, pp[i], 1);
    }
    was[i] ^= 1;
}
ll get(int k)
{
    return S + 2 * get(1, 0, n - 1, k);
}
void solve(int l, int r, int bl, int br)
{
    if (l > r)
        return;

    int m = (l + r) >> 1;
    pll res = {get(m), bl};
    for (int i = bl; i < br; i++)
    {
        upd(i);
        res = max(res, {get(m), i + 1});
    }
    for (int i = bl; i < br; i++)
        upd(i);

    ans[m] = res.f;

    int bm = res.s;

    solve(l, m - 1, bl, bm);

    for (int i = bl; i < bm; i++)
        upd(i);
    solve(m + 1, r, bm, br);
    for (int i = bl; i < bm; i++)
        upd(i);
}

void solve()
{

    cin >> n;
    for (int i = 0; i < n; i++)
        cin >> a[i];


    for (int i = 0; i < n; i++)
        p[i] = i;
    sort(p, p + n, [](int i, int j)
         { return a[i] < a[j]; });
    for (int i = 0; i < n; i++)
        pp[p[i]] = i;

    solve(0, n, 0, n);

    for (int i = 0; i <= n; i++)
    {
        if (i)
            cout << " ";
        cout << ans[i];
    }
    cout << endl;
}
int main()
{
    ios_base::sync_with_stdio(false);
    int t = 1;
    cin >> t;
    for (int i = 1; i <= t; i++)
    {
        solve();
    }
}
Tester's Solution
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define fi first
#define se second
const ll mod=998244353;
const int N=1e5+1;
int n;
ll a[N];

const int ts=262144;
ll nx[ts];
pair<ll,int>fn[ts];
ll cnt[ts],s[ts];
int getr(int id,int l,int r){
	if(l==r) return l;
	int mid=(l+r)/2;
	if(cnt[id*2+1]>0) return getr(id*2+1,mid+1,r);
	else return getr(id*2,l,mid);
}
pair<ll,int> qmn(int id,int l,int r,int ql,int qr){
	if(l>qr || r<ql) return {(ll)2e9,0};
	if(ql<=l && r<=qr) return fn[id];
	int mid=(l+r)/2;
	return min(qmn(id*2,l,mid,ql,qr),qmn(id*2+1,mid+1,r,ql,qr));
}
void pull(int id){
	nx[id]=max(nx[id*2],nx[id*2+1]);
	fn[id]=min(fn[id*2],fn[id*2+1]);
	cnt[id]=cnt[id*2]+cnt[id*2+1];
	s[id]=s[id*2]+s[id*2+1];
}
void upd(int id,int l,int r,int p,int v){
	if(l==r){
		if(v==0){
			nx[id]=-2e9;fn[id]={a[l],l};
			cnt[id]=0;s[id]=a[l];
		}
		else{
			nx[id]=a[l];fn[id]={2e9,0};
			cnt[id]=1;s[id]=max(a[l],-a[l]);
		}
		return;
	}
	int mid=(l+r)/2;
	if(p<=mid) upd(id*2,l,mid,p,v);
	else upd(id*2+1,mid+1,r,p,v);
	pull(id);
}
void pop(int id,int l,int r){
	if(l==r){
		nx[id]=-2e9;fn[id]={a[l],l};
		cnt[id]=0;s[id]=a[l];
		return;
	}
	int mid=(l+r)/2;
	if(nx[id*2]>nx[id*2+1]) pop(id*2,l,mid);
	else pop(id*2+1,mid+1,r);
	pull(id);
}
void build(int id,int l,int r){
	if(l==r){
		nx[id]=-2e9;fn[id]={a[l],l};
		cnt[id]=0;s[id]=a[l];
		return;
	}
	int mid=(l+r)/2;
	build(id*2,l,mid);
	build(id*2+1,mid+1,r);
	pull(id);
}
ll qry(int id,int l,int r,int ql,int qr){
	if(l>qr || r<ql) return 0;
	if(ql<=l && r<=qr) return s[id];
	int mid=(l+r)/2;
	return qry(id*2,l,mid,ql,qr)+qry(id*2+1,mid+1,r,ql,qr);
}
ll ans[N];
void solve(int l,int r,int gl,int gr){
	if(l>r) return;
	int mid=(l+r)/2;
	int st=max(gl,mid);
	while(true){
		int x=getr(1,0,n);
		if(x<=st) break;
		else upd(1,0,n,x,0);
	}
	while(cnt[1]>mid){
		pop(1,0,n);
	}
	while(cnt[1]<mid){
		int x=qmn(1,0,n,0,st).se;
		upd(1,0,n,x,1);
	}
	ll best=qry(1,0,n,0,st);int pos=st;
	for(int i=st+1; i<=gr ;i++){
		upd(1,0,n,i,1);
		pop(1,0,n);
		ll cur=qry(1,0,n,0,i);
		if(cur>best){
			best=cur;pos=i;
		}
	}
	//cout << "!! " << mid << ' ' << best << endl;
	ans[mid]=best;
	solve(l,mid-1,gl,pos);
	solve(mid+1,r,pos,gr);
}
void solve(){
	cin >> n;
	for(int i=1; i<=n ;i++){
		cin >> a[i];
	}
	build(1,0,n);
	solve(0,n,0,n);
	for(int i=0; i<=n ;i++){
		cout << ans[i] << ' ';
	}
	cout << '\n';
}
int main(){
	ios::sync_with_stdio(false);cin.tie(0);
	int t;cin >> t;while(t--) solve();
}
Editorialist's Solution
//もう布団の中から出たくない
//布団の外は寒すぎるから
//布団の中から出たくない
//布団の中はあたたかすぎるから

#include <bits/stdc++.h>
using namespace std;

#define int long long
#define ll long long
#define ii pair<ll,ll>
#define iii pair<ii,ll>
#define fi first
#define se second
#define endl '\n'
#define debug(x) cout << #x << ": " << x << endl

#define pub push_back
#define pob pop_back
#define puf push_front
#define pof pop_front
#define lb lower_bound
#define ub upper_bound

#define rep(x,start,end) for(int x=(start)-((start)>(end));x!=(end)-((start)>(end));((start)<(end)?x++:x--))
#define all(x) (x).begin(),(x).end()
#define sz(x) (int)(x).size()

mt19937 rng(chrono::system_clock::now().time_since_epoch().count());

struct KHEAP{
	int sum=0,lim;
	multiset<int> big;
	multiset<int,greater<int> > small;
	
	void proc(){
		while (sz(big)>lim){
			sum-=*big.begin();
			small.insert(*big.begin());
			big.erase(big.begin());
		}
		while (sz(big)<lim && !small.empty()){
			sum+=*small.begin();
			big.insert(*small.begin());
			small.erase(small.begin());
		}
	}
	
	void change(int _lim){
		lim=_lim;
		proc();
	}
	
	void add(int i){
		big.insert(i),sum+=i;
		proc();
	}
	
	void del(int i){
		if (big.find(i)!=big.end()) big.erase(big.find(i)),sum-=i;
		else small.erase(small.find(i));
		proc();
	}
} kheap;

int n;
int arr[100005];
int pref[100005];
int ans[100005];

void dnc(int l,int r,int optl,int optr){
	int m=l+r>>1;
	int best=-1e9;
	int optm=-1;
	
	kheap.change(m);
	
	rep(x,optl,optr+1){
		if (x!=optl && arr[x]<0) kheap.add(-arr[x]);
		int curr=pref[x]+2*kheap.sum;
		if (best<curr){
			best=curr;
			optm=x;
		}
	}
	
	ans[m]=best;
	
	rep(x,optm+1,optr+1) if (arr[x]<0) kheap.del(-arr[x]);
	if (m!=r) dnc(m+1,r,optm,optr);
	rep(x,optl+1,optm+1) if (arr[x]<0) kheap.del(-arr[x]);
	if (l!=m) dnc(l,m-1,optl,optm);
}

signed main(){
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
	cin.exceptions(ios::badbit | ios::failbit);
	
	int TC;
	cin>>TC;
	while (TC--){
		cin>>n;
		rep(x,1,n+1) cin>>arr[x];
		rep(x,1,n+1) pref[x]=pref[x-1]+arr[x];
		
		dnc(0,n,0,n);
		
		rep(x,0,n+1) cout<<ans[x]<<" "; cout<<endl;
	}
}