MAXIMISESUM - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: satyam_343
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

1715

PREREQUISITES:

None

PROBLEM:

You have an array A.
In one move, you can pick indices 1 \leq i \lt j \leq N and set A_k:=\min(A_i, A_j) for each i \lt k \lt j.
Find the maximum possible sum of the final array obtained after performing some operations.

EXPLANATION:

Let’s analyze what the final array might look like.
Intuitively, if we have indices i \lt k \lt j such that A_k \lt \min(A_i, A_j), then we should be able to bring A_k up to \min(A_i, A_j) by performing the operation.
Of course, this operation also affects other elements inbetween (and might reduce some of them, which isn’t what we want).

Let’s call a move (i, j) a good move if there’s no index i \lt k \lt j such that A_k \gt \min(A_i, A_j).
That is, a good move is one that only increases the elements it affects.
It’d be nice if we could use only good moves - in fact ,we can!

Claim: The final array will be such that for any three indices i \lt k \lt j, we’ll have A_k \geq \min(A_i, A_j); and further, this can be achieved by using only good moves.

Proof

Suppose there are indices i \lt k \lt j such that A_k \lt \min(A_i, A_j).
Let x be the index of the maximum element of the segment [i, k-1]. We definitely have A_x \geq A_i (if there are multiple, choose the rightmost).
Similarly, let y be the index of the leftmost maximum of segment [k+1, j]. Once again, we know A_y \geq A_j.

Without loss of generality, let’s say A_x \geq A_y, i.e, A_y = \min(A_x, A_y).
Let z be the largest index in [x, k-1] such that A_z \geq A_y.
Notice that the operation (z ,y) is a good operation, because we’ve chosen our indices in such a way that A_y is the second maximum in range [z, y].
Further, this range includes index k as well.

So, we can perform good operation (z ,y), which increases A_k to reach A_y.
Since A_y \geq A_j, we have A_k \geq\min(A_i, A_j), as we wanted!

This way, we can keep on performing good operations to increase elements till no further is possible.
It’s not hard to see that the process will terminate after finitely many steps.


Now that we know this, let’s see what it actually means for A.
Let M be the index of the maximum element of A.
Then, the above claim tells us that the final array must have:

  • A_1 \leq A_2 \leq A_3 \leq\ldots\leq A_M
  • A_M \geq A_{M+1}\geq\ldots\geq A_N

That is, the array will look like a pyramid.

So, we just need to find out what the prefix of the array till M will look like in the end.
It’s not too hard to see that:

  • Let M_1 be the maximum element of the range [1, M-1] (if there are multiple occurrences, choose the leftmost).
    Then, we can set A_{M_1} = A_{M_1+1} = A_{M_1+1} = \ldots = A_{M-1} all to this maximum, by performing the operation (M_1, M).
  • Again, let M_2 be the index of the leftmost maximum element of [1, M_1-1]. Everything from M_2 till M_1-1 can be set to A_{M_2}.
  • Repeat this process for [1, M_2-1], and so on till all the elements are set.

Of course, we can’t implement this process in \mathcal{O}(N^2) time, that’d be too slow.
To speed it up, note that the indices M_1, M_2, M_3, \ldots we found are in fact exactly the prefix maximums of the array A.
That is whenever there’s a new prefix maximum, that’s one of the M_i.
This way, all the M_i can be found in \mathcal{O}(N) time, after which finding the final values of all the elements is easy.

The suffix after M can be solved similarly by finding suffix maximums.
Once the entire array is known, just output its sum.

TIME COMPLEXITY

\mathcal{O}(N) per testcase.

CODE:

Author's code (C++)
#pragma GCC optimod_intze("O3,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;   
using namespace std;
#define ll long long
const ll INF_MUL=1e15;
const ll INF_ADD=1e18;
#define pb push_back               
#define mp make_pair          
#define nline "\n"                           
#define f first                                          
#define s second                                             
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()     
#define vl vector<ll>           
#define vvl vector<vector<ll>>    
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);  
#endif     
void _print(ll x){cerr<<x;}  
void _print(char x){cerr<<x;}   
void _print(string x){cerr<<x;}    
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());   
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------  
const ll MOD=998244353;
const ll MAX=500500;
void solve(){  
    ll n; cin>>n;
    ll ans=0;
    vector<ll> a(n+5);
    for(ll i=1;i<=n;i++){  
        cin>>a[i];
    } 
    vector<ll> pref(n+5,0),suff(n+5,0);
    for(ll i=1;i<=n;i++){  
        pref[i]=max(pref[i-1],a[i]);
    }
    for(ll i=n;i>=1;i--){
        suff[i]=max(suff[i+1],a[i]);
    }
    for(ll i=2;i<n;i++){
        a[i]=max(a[i],min(pref[i-1],suff[i+1]));
    }
    for(ll i=1;i<=n;i++){
        ans+=a[i];
    }
    cout<<ans<<nline;
    return;    
}                                          
int main()                                                                             
{     
    ios_base::sync_with_stdio(false);                         
    cin.tie(NULL);                              
    #ifndef ONLINE_JUDGE                 
    freopen("input.txt", "r", stdin);                                           
    freopen("output.txt", "w", stdout);      
    freopen("error.txt", "w", stderr);                        
    #endif     
    ll test_cases=1;               
    cin>>test_cases;
    while(test_cases--){
        solve();
    }
    cout<<fixed<<setprecision(10);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
}  
Tester's code (C++)
#ifndef LOCAL
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,avx2,sse,sse2,sse3,sse4,popcnt,fma")
#endif

#include <bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...) "11-111"
#endif

struct input_checker {
	string buffer;
	int pos;

	const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
	const string number = "0123456789";
	const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
	const string lower = "abcdefghijklmnopqrstuvwxyz";

	input_checker() {
		pos = 0;
		while (true) {
			int c = cin.get();
			if (c == -1) {
				break;
			}
			buffer.push_back((char) c);
		}
	}

	int nextDelimiter() {
		int now = pos;
		while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
			now++;
		}
		return now;
	}

	string readOne() {
		assert(pos < (int) buffer.size());
		int nxt = nextDelimiter();
		string res;
		while (pos < nxt) {
			res += buffer[pos];
			pos++;
		}
		return res;
	}

	string readString(int minl, int maxl, const string &pattern = "") {
		assert(minl <= maxl);
		string res = readOne();
		assert(minl <= (int) res.size());
		assert((int) res.size() <= maxl);
		for (int i = 0; i < (int) res.size(); i++) {
			assert(pattern.empty() || pattern.find(res[i]) != string::npos);
		}
		return res;
	}

	int readInt(int minv, int maxv) {
		assert(minv <= maxv);
		int res = stoi(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	long long readLong(long long minv, long long maxv) {
		assert(minv <= maxv);
		long long res = stoll(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	auto readInts(int n, int minv, int maxv) {
		assert(n >= 0);
		vector<int> v(n);
		for (int i = 0; i < n; ++i) {
			v[i] = readInt(minv, maxv);
			if (i+1 < n) readSpace();
		}
		return v;
	}

	auto readLongs(int n, long long minv, long long maxv) {
		assert(n >= 0);
		vector<long long> v(n);
		for (int i = 0; i < n; ++i) {
			v[i] = readLong(minv, maxv);
			if (i+1 < n) readSpace();
		}
		return v;
	}

	void readSpace() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == ' ');
		pos++;
	}

	void readEoln() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == '\n');
		pos++;
	}

	void readEof() {
		assert((int) buffer.size() == pos);
	}
};

int32_t main() {
    ios_base::sync_with_stdio(0);   cin.tie(0);

    input_checker input;
    int T = input.readInt(1, (int)1e5); input.readEoln();
    int sum_N = 0;
    while(T-- > 0) {
        int n = input.readInt(1, (int)1e5);    input.readEoln();
        sum_N += n;
        vector<int> a = input.readInts(n, 1, (int)1e9);  input.readEoln();

        int pos = max_element(a.begin(), a.end()) - a.begin();
        int mx = 0;
        for(int i = 0 ; i < pos ; i++) {
            mx = max(mx, a[i]);
            a[i] = mx;
        }
        mx = 0;
        for(int i = n - 1 ; i > pos ; i--) {
            mx = max(mx, a[i]);
            a[i] = mx;
        }

        cout << accumulate(a.begin(), a.end(), 0ll) << '\n';
    }
    input.readEof();
    assert(sum_N <= (int)5e5);

    return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    m = max(a)
    ans, mx = 0, 0
    lo, hi = 0, 0
    for i in range(n):
        x = a[i]
        if x == m:
            lo = i
            break
        mx = max(mx, x)
        ans += mx
    mx = 0
    for i in reversed(range(n)):
        x = a[i]
        if x == m:
            hi = i
            break
        mx = max(mx, x)
        ans += mx
    ans += m*(hi-lo+1)
    print(ans)
1 Like

I also implemented the same idea in C++, here’s the code.

#include <iostream>
#include <vector>
using namespace std;

int minimum(int& a, int& b){
    if(a < b) return a;
    else return b;
}

int main() {
	// your code goes here
	int T; cin >> T;
	
	while(T--){
	    int N; cin >> N;
        vector<int> A;
        int max = -1; int max_pos;
        
        for(int i = 0; i < N; i++){
            int Ai; cin >> Ai;
            if(Ai >= max){
                max = Ai; max_pos = i;
            }
            A.push_back(Ai);
        }
        
        int i = 0; int j = i + 1;
        while(i != max_pos){
            if(A[j] >= A[i]){
                i = j; j = i + 1;
            }
            else{
                A[j++] = A[i];
            }
        }
        
        i = N - 1; j = i - 1;
	    while(i != max_pos){
	        if(A[j] >= A[i]){
	            i = j; j = i - 1;
	        }
	        else{
	            A[j--] = A[i];
	        }
	    }
	    
	    
	    long sum = 0;
	    for(int i = 0; i < N; i++){
	        sum += A[i];
	    }
	    
	    cout << sum << "\n";
	}
	return 0;
}

One mistake I made during the contest was taking sum as int instead of long.
Hope it helps

1 Like

Very Helpful !

for detailed explanation visit link below
Easy Solution than Editor

for _ in range(int(input())):
    n = int(input())
    arr = list(map(int, input().split()))
    
    l = [0] * n
    r = [0] * n
    
    l[0] = arr[0]
    r[-1] = arr[-1]
    for i in range(1, n):
        l[i] = max(l[i - 1], arr[i])
        r[n - i - 1] = max(r[n - i], arr[n - i - 1])
    
    s = 0
    for i in range(n):
        s += min(l[i], r[i])
    print(s)
1 Like

@yash_visavadia
WOW! Really easy to understand.

1 Like

Anyone thought of using local maxima? comment ur code

1 Like

What’s wrong in my logic? anyone

for _ in range(int(input())):
    n=int(input())

    arr=[-float('inf')]+ list(map(int,input().split())) +[-float('inf')]

    final=[arr[0]]
    for i in range(1,len(arr)):
        if arr[i]!=final[-1]:final.append(arr[i])
    
    #lets find the local maximum
    lmax=[]
    for i in range(1,len(final)-1):
        if final[i-1]<final[i]>final[i+1]:
            lmax.append(i)
    c=0
    for i in range(len(lmax)-1):
        l=lmax[i]
        r=lmax[i+1]
        if arr[l]<arr[r]:
            for j in range(l+1,r):
                if arr[j]>arr[l]:
                    break
                c+=arr[l]-arr[j]
        else:
            for j in range(l+1,r):
                if arr[j]<arr[r]:
                    c+=arr[r]-arr[j]
    print(sum(arr[1:-1])+c)



@apoorv_me @satyam_343 @iceknight1093 looks like test cases doesn’t cover all cases
for example my code got AC, but fails for below test case
Input

1
5
5 5 7 5 7

Expected output:

31

My code output:

29

No extra space , easier to understand O(N) TC

for _ in range(int(input())):
    n=int(input())
    l=list(map(int,input().split()))
    if n<=2:
        print(sum(l))
    elif len(set(l))==1:
        print(sum(l))
    else:
        z=l.index(max(l))
        m=l[0]
        # q=l[n-1]
        for i in range(1,z):
            m=max(m,l[i])
            # q=max(q,l[n-1-i])
            l[i]=m
            # l[n-1-i]=q
        m=l[n-1]
        for i in range(n-2,z,-1):
            m=max(m,l[i])
            l[i]=m
        # print(l)
        print(sum(l))

Problem is very similar to Trapping Rain Water
TC O(n)
No Extra Space

#include<bits/stdc++.h>
#define ll long long
using namespace std;

int main(){
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int t; cin>>t;
while(t--){
	int n; cin>>n;
	vector<ll> v(n);
	for(auto &i:v) cin>>i;
	ll ans=0;
	int l=0,r=n-1,lm=0,rm=0;
	while(l<=r){
		if(v[l]<=v[r]){
			if(v[l]>=lm) lm=v[l];
			ans+=lm;
			l++;
		}else{
			if(v[r]>=rm) rm=v[r];
			ans+=rm;
			r--;
		}
	}
	cout<<ans<<endl;
}
return 0;
}
1 Like

include <bits/stdc++.h>
using namespace std;

int main() {
// your code goes here
int t;
cin>>t;

while(t--)
{
    int n;
    cin>>n;
    int arr[n];

	for(int i=0;i<n;i++)
	{
	    cin>>arr[i];
	}
	long ans=0;
	if(arr[0]<=arr[n-1])
	{
	    
	    stack<int>st;
	    st.push(0);
	    for(int i=0;i<n;i++)
	    {
	        if(arr[st.top()]<arr[i])
	        {
	            ans+=(i-st.top())*arr[st.top()];
	            st.pop();
	            st.push(i);
	        }
	        else if(i==n-1 and arr[st.top()]==arr[i])
	        {
	            ans+=(i-st.top())*arr[st.top()];
	        }
	        else
	        {
	            continue;
	        }
	    }
	    
	    
	    
	    ans+=arr[n-1];
	}
	else
	{
	    stack<int>st;
	    st.push(n-1);
	    for(int i=n-1;i>=0;i--)
	    {
	        if(arr[st.top()]<arr[i])
	        {
	            ans+=(st.top()-i)*arr[st.top()];
	            st.pop();
	            st.push(i);
	        }
	        else if(i==0 and arr[st.top()]==arr[i])
	        {
	            ans+=(st.top()-i)*arr[st.top()];
	        }
	        else
	        {
	            continue;
	        }
	    }
	    
	    
	    ans+=arr[0];
	    
	}
	cout<<ans<<endl;
}

}
Hey can you plz tell what is the problem in this code