PARTITION - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: wuhudsm
Tester: satyam_343
Editorialist: iceknight1093

DIFFICULTY:

3000

PREREQUISITES:

Dynamic programming, stacks

PROBLEM:

You have an array A.
The score of a partition of A into k subarrays B_1, B_2, \ldots, B_k is

\sum_{i=1}^k \left(1 - \max(B_i)\right)

Find the maximum score across all partitions of A into subarrays.

EXPLANATION:

Given that we’re splitting into subarrays, there’s a natural dynamic programming setup.
Let dp_i be the maximum score obtainable by partitioning the first i elements of A, with dp_0 = 0 being the base case for the empty subarray.

Then, by considering all possible split points, we have

dp_i = 1 + \max_{0 \leq j\lt i} \left(dp_j - \max(A[j+1:i])\right)

This is easy to compute in \mathcal{O}(N^2), but we need to optimize it further.

Let’s look at the structure of the subarrays we consider.
Suppose we fix an index i. Then, as we move the left endpoint j from i to 1, the value of \max(A[j:i]) will only increase — after all, we’re only adding elements.

So, we have several distinct maximums to consider, and for each maximum, a range of elements where it’s the maximum.
Note that for a fixed maximum M, with corresponding range [L, R], we only really care about the maximum dp_j value for L-1 \leq j \leq R-1.

Let’s keep the following information:

  • The current list of maximums, in descending order from left to right.
    Notice that this can be represented using a stack.
  • The maximum dp value corresponding to the range of each maximum.
  • The set of (max(dp) - M) values across all maximums.
    Let this set be S.

Then, when moving from index i-1 to index i:

  • Push A_i onto the stack, as a new maximum.
    The singular dp value corresponding to it is dp_{i-1}.
  • Then, while the top element of the stack is not less then the previous element, their ranges can be combined into a single one (with A_i being the maximum value of the new range).
    When combining segments, the new maximum dp value of the segment equals the larger of the previous ones.
  • Finally, we keep the set S updated.
    Each time, two elements are removed from the stack and one is added to it.
    This corresponds to deleting two elements from S and inserting one; which can be done quickly using a multiset.
  • Once the above process is finished, we simply have dp_i = 1 + \max(S).

It’s easy to see that \mathcal{O}(N) push/pop operations are done on the stack in total; and each corresponds to one multiset operation for \mathcal{O}(N\log N) in total.

There are also ways to implement this using other data structures, such as a segment tree. A couple of these implementations are linked below; in the author’s and tester’s code.
The base idea of maintaining a stack of maximums remains the same, however.

TIME COMPLEXITY

\mathcal{O}(N \log N) per testcase.

CODE:

Author's code (C++)
#include <map>
#include <set>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
typedef double db; 
typedef long long ll;
typedef unsigned long long ull;
const int N=1000010;
const int LOGN=28;
const ll  TMD=0;
const ll  INF=2147483647LL*2147483647LL;
int T,n;
int a[N];
ll  dp[N];

struct Data
{
	ll num,val;
	
	Data(ll num,ll val):num(num),val(val) {}
	
	friend bool operator<(Data x,Data y)
	{
		if(x.val!=y.val) return x.val>y.val;
		return x.num<y.num;
	}
};

struct nod
{
	int l,r;
	ll  mx;
	nod *lc,*rc;
};

struct Segtree
{
	nod *root;
	
	Segtree()
	{
		build(&root,1,n);
	}
	
	void newnod(nod **p,int L,int R)
	{
		*p=new nod;
		(*p)->l=L;(*p)->r=R;
		(*p)->mx=-INF;
		(*p)->lc=(*p)->rc=NULL;
	}
	
	void build(nod **p,int L,int R)
	{
		newnod(p,L,R);
		if(L==R) return ;
		int M=(L+R)>>1;
		build(&(*p)->lc,L,M);
		build(&(*p)->rc,M+1,R);
	}
	
	void insert(int pos,ll val)
	{
		_insert(root,pos,val);
	}
	
	void _insert(nod *p,int pos,ll val)
	{
		if(p->l==p->r)
		{
			p->mx=val;
			return ;
		}
		int M=(p->l+p->r)>>1;
		if(pos<=M) _insert(p->lc,pos,val);
		else       _insert(p->rc,pos,val);
		p->mx=max(p->lc->mx,p->rc->mx);
	}
	
	ll getmax(int L,int R)
	{
		if(L>R) return 0;
		return _getmax(root,L,R);
	}
	
	ll _getmax(nod *p,int L,int R)
	{
		if(p->l==L&&p->r==R) return p->mx;
		int M=(p->l+p->r)>>1;
		if(R<=M)     return _getmax(p->lc,L,R);
		else if(L>M) return _getmax(p->rc,L,R);
		else         return max(_getmax(p->lc,L,M),_getmax(p->rc,M+1,R));
	}
};

int main()
{
	scanf("%d",&T);
	while(T--)
	{
    	scanf("%d",&n);
    	for(int i=1;i<=n;i++) scanf("%d",&a[i]);
    	Segtree ST;
    	set<Data>  S;
    	stack<int> stk; 
    	for(int i=1;i<=n;i++)
    	{
	    	while((!stk.empty())&&a[stk.top()]<=a[i])
	    	{
	    		S.erase(Data(stk.top(),dp[stk.top()]));
	    		stk.pop();
	    	}
	    	if(!stk.empty()) dp[i]=max(S.begin()->val,ST.getmax(stk.top(),i-1)-a[i]+1);	
	    	else dp[i]=max((ll)-a[i]+1,ST.getmax(1,i-1)-a[i]+1);
	    	ST.insert(i,dp[i]);
	    	stk.push(i);
	    	S.insert(Data(i,dp[i]));
	    }
	    printf("%lld\n",dp[n]);
	}
	
	return 0;
}
Tester's code (C++)
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast,unroll-loops")
 
 
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;   
using namespace std;  
#define ll long long
const ll INF_MUL=1e13;
const ll INF_ADD=2e18; 
#define pb push_back                   
#define mp make_pair          
#define nline "\n"                           
#define f first                                          
#define s second                                             
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()     
#define vl vector<ll>             
#define vvl vector<vector<ll>>    
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);    
#endif       
void _print(ll x){cerr<<x;}  
void _print(char x){cerr<<x;}     
void _print(string x){cerr<<x;}    
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());   
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=1e9+7;     
const ll MAX=500500; 
class ST{
public:
    vector<ll> segs;
    ll size=0;                       
    ll ID=-INF_ADD;
 
    ST(ll sz) {
        segs.assign(2*sz,ID);
        size=sz;  
    }   
   
    ll comb(ll a,ll b) {
        return max(a,b);  
    }
 
    void upd(ll idx, ll val) {
        segs[idx+=size]=val;
        for(idx/=2;idx;idx/=2){
            segs[idx]=comb(segs[2*idx],segs[2*idx+1]);
        }
    }
 
    ll query(ll l,ll r) {
        ll lans=ID,rans=ID;
        for(l+=size,r+=size+1;l<r;l/=2,r/=2) {
            if(l&1) {
                lans=comb(lans,segs[l++]);
            }
            if(r&1){
                rans=comb(segs[--r],rans);
            }  
        }  
        return comb(lans,rans);
    }
};
void solve(){     
    ll n; cin>>n;
    vector<ll> dp(n+5),a(n+5,INF_ADD),track={0};
    ST get_dp(n+5),active(n+5);  
    get_dp.upd(0,0); 
    for(ll i=1;i<=n;i++){
        cin>>a[i]; 
        while(a[track.back()]<a[i]){
            active.upd(track.back(),-INF_ADD); 
            track.pop_back();
        }
        dp[i]=max(active.query(1,n),1-a[i]+get_dp.query(track.back(),n));
        active.upd(i,dp[i]);
        get_dp.upd(i,dp[i]); 
        track.push_back(i);
    }
    cout<<dp[n]<<nline;
    return;                                
}                                                  
int main()                                                                                                 
{            
    ios_base::sync_with_stdio(false);                             
    cin.tie(NULL);        
    #ifndef ONLINE_JUDGE                     
    freopen("input.txt", "r", stdin);                                                    
    freopen("output.txt", "w", stdout);  
    freopen("error.txt", "w", stderr);                          
    #endif                          
    ll test_cases=1;               
    cin>>test_cases;
    while(test_cases--){
        solve();
    } 
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
}   
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

	int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        ll ans = 0;
        vector<array<ll, 2>> st;
        map<ll, int> vals;
        for (int i = 0; i < n; ++i) {
            int x; cin >> x;
            st.push_back({x, ans});
            ++vals[ans - x];
            int sz = st.size();
            while (sz > 1) {
                if (st[sz-1][0] >= st[sz-2][0]) {
                    auto [mx1, val1] = st.back(); st.pop_back();
                    auto [mx2, val2] = st.back(); st.pop_back();
                    st.push_back({mx1, max(val1, val2)});
                    
                    --vals[val1 - mx1]; if (vals[val1 - mx1] == 0) vals.erase(val1 - mx1);
                    --vals[val2 - mx2]; if (vals[val2 - mx2] == 0) vals.erase(val2 - mx2);
                    ++vals[max(val1, val2) - mx1];
                    --sz;
                }
                else break;
            }
            ans = 1 + (*vals.rbegin()).first;
        }
        cout << ans << '\n';
    }
}
3 Likes

O(N) solution using monotonic stack. CodeChef: Practical coding for everyone

Is it possible to improve the readability of codes?
They are just like the solutions for submission, not for editorial