SUBSEQI - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh_07
Tester: mridulahi
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

None

PROBLEM:

For an array B of length M, consider the following process:

  • For i = M-1, M-2, \ldots, 2, 1 in order, choose j\gt i and set B_j := \max(0, B_j - B_i).
    That is, choose some index j\gt i and subtract B_i from it, not going below 0.

f(B) is the maximum possible value of the sum of B after this process.
Youā€™re given an array A. Find f(A), then process Q point updates to it, after each of which f(A) must be recomputed.

EXPLANATION:

Our first order of business is computing the answer for a single array, i.e, finding f(A) for a fixed A.
Dealing with updates will come later.

Intuitively, it seems nice to try and subtract everything from a single index - once it reaches zero, we arenā€™t even subtracting anything from it anymore which is great.
This is indeed true: if thereā€™s a way to ensure that the final array contains a 0, itā€™s optimal to try and do so.

Proof

If A doesnā€™t contain any zeros in the end, that means every subtraction operation operation was done in full.
That is, the total loss in the sum of the array is A_1 + A_2 + \ldots + A_{N-1}.

This is, quite literally, the most we can lose by performing operations.
If itā€™s possible to form a zero in the array, we have the chance of ā€œavoidingā€ some loss by subtracting a larger number from a smaller one, which is better than nothing.

Also, note that itā€™s never optimal to make two different indices reach 0: if i \lt j and both A_i and A_j reach zero, we couldā€™ve instead used all the operations that decreased A_i on A_j instead; that way A_j would still be zero but A_i wouldnā€™t change (which increases the sum of the final array).

Now, when is it possible at all to make A_i = 0?
Because of the order of subtractions, this can only happen when the sum of everything before it, exceeds it. That is, A_1 + A_2 + \ldots + A_{i-1} \geq A_i.

Conversely, this means that the only time when itā€™s impossible to create a 0 is when A_1 + \ldots + A_{i-1} \lt A_i for every i.
However, this restricts the size of A a lot: in fact, A canā€™t be of length \gt 30.

Why?

If that condition holds, weā€™ll have:

  • A_1 \lt A_2
  • A_1 + A_2 \lt A_3 \implies 2A_1 \lt A_3
  • A_1 + A_2 + A_3 \lt A_4 \implies 2\cdot (A_1 + A_2) \lt A_4 \implies 4A_1 \lt A_4
    \vdots
  • 2^{k-2} A_1 \lt A_k

If N\gt 30, A_N would then exceed 10^9 which isnā€™t allowed.

In particular, this observation can be applied to the last 31 elements of A: just among them, there is guaranteed to be one that can be made 0!

In fact, this observation can be extended a bit: making A_i zero is optimal only if it isnā€™t possible to form a zero after it; since making A_i zero means that the subarray [A_i, A_{i+1}, \ldots, A_N] is itself an independent instance of this problem (and as noted earlier, will contain a zero if it is able to).

This gives us a method to find f(A).
Letā€™s fix i, the index that will be made 0. As noted above, i can always be found among the last 31 indices.
Then,

  • Check if A_1 + \ldots + A_{i-1}, i.e, if itā€™s possible to make A_i zero at all.
  • If it is, then we can assume that nothing after i will become zero (which is true in the optimal solution), and so our total ā€œlossā€ is A_i + (A_i + A_{i+1} + \ldots + A_{N-1}), because:
    • We lose A_i by making it zero with everything before it.
    • We lose A_i + A_{i+1} + \ldots + A_{N-1} from the part after i since there are no zeros there (and hence subtractions happen in full).
  • In particular, note that if i = N the ā€œlossā€ is exactly A_N.

f(A) then equals the sum of A minus the minimum loss across all these indices.

Donā€™t forget to also consider the case of when no index can be made 0 (for small N), in which case the loss is just A_1 + A_2 + \ldots + A_{N-1}.


Now that we know how to compute f(A) for a fixed A, you may note that dealing with point updates is surprisingly simple.
The only things that really matter are the last 30 or so elements of A, and the sum of A.

So, it suffices to directly keep the updated array A and recalculate the answer entirely after each update.

TIME COMPLEXITY:

\mathcal{O}(N + Q\log\max A) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
#define f first
#define s second

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

const int N = 2e5 + 69;
int seg[4 * N], lazy[4 * N];
int b[N];

void Build(int l, int r, int pos){
    lazy[pos] = 0;
    if (l == r){
        seg[pos] = b[l];
        return;
    }

    int mid = (l + r) / 2;
    Build(l, mid, pos * 2);
    Build(mid + 1, r, pos * 2 + 1);

    seg[pos] = min(seg[pos * 2], seg[pos * 2 + 1]);
}

void updlz(int l, int r, int pos){
    seg[pos] += lazy[pos];
    if (l != r){
        lazy[pos * 2] += lazy[pos];
        lazy[pos * 2 + 1] += lazy[pos];
    }

    lazy[pos] = 0;
}

void upd(int l, int r, int pos, int ql, int qr, int v){
    if (lazy[pos] != 0) updlz(l, r, pos);
    if (l >= ql && r <= qr){
        seg[pos] += v;
        if (l != r){
            lazy[pos * 2] += v;
            lazy[pos * 2 + 1] += v;
        }
    } else if (l > qr || r < ql){
        return;
    } else {
        int mid = (l + r) / 2;
        upd(l, mid, pos * 2, ql, qr, v);
        upd(mid + 1, r, pos * 2 + 1, ql, qr, v);
        seg[pos] = min(seg[pos * 2], seg[pos * 2 + 1]);
    }
}

void Solve() 
{
    int n, q; cin >> n >> q;

    vector <int> a(n);
    int sum = 0;
    for (auto &x : a) cin >> x, sum += x;

    // min loss is p[n - 1] or a_i + p[n - 1] - p[i - 1] 

    vector <int> p(n + 1, 0);
    for (int i = 1; i <= n; i++) p[i] = p[i - 1] + a[i - 1];

    int ok = 0;
    
    for (int i = 1; i <= n; i++){
        b[i] = a[i - 1] + p[n - 1] - p[i - 1];
        if (i != n) ok += a[i - 1];
    }

    Build(1, n, 1);

    auto res = [&](){
        cout << sum - min(ok, seg[1]) << "\n";
    };

    auto update = [&](int i, int x){
        // added x to a_i 
        if (i == n){
            upd(1, n, 1, i, i, x);
            return;
        }
        upd(1, n, 1, i, i, 2 * x);
        upd(1, n, 1, 1, i - 1, x);
    };

    res(); 

    while (q--){
        int i, v; cin >> i >> v;

        update(i, -a[i - 1]);
        sum -= a[i - 1];
        sum += v;
        if (i != n) ok -= a[i - 1];
        a[i - 1] = v;
        update(i, v);
        if (i != n) ok += a[i - 1];

        res();
    }
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++)
#include<bits/stdc++.h>

using namespace std;

#define all(x) begin(x), end(x)
#define sz(x) static_cast<int>((x).size())
#define int long long

int get (int n, int s, int a[]) {
        int ans = a[n - 1];   
        ans = max(ans, s - a[n - 1]); 
        for (int i = n - 2; i >= max(0ll, n - 28); i--) {
                ans = max(ans, s - 2 * a[i]);
                s -= a[i];
        }    
        return ans;
}

signed main() {

        ios::sync_with_stdio(0);
        cin.tie(0);
        cout.tie(0);

        int t;
        cin >> t;

        while (t--) {

                int n, q;
                cin >> n >> q;
                int a[n];
                for (auto &x : a) cin >> x;
                int s = accumulate(a, a + n, 0ll);

                cout << get(n, s, a) << "\n";

                while (q--) {
                        int i, j;
                        cin >> i >> j;
                        i--;
                        s += j - a[i];
                        a[i] = j;
                        cout << get(n, s, a) << "\n";
                }

        }

        
}
Editorialist's code (Python)
for _ in range(int(input())):
    n, q = map(int, input().split())
    a = list(map(int, input().split()))
    totsum = sum(a)
    def calc():
        best = totsum - a[-1]
        if best >= a[-1]: best = min(best, a[-1])
        curpref = totsum - a[-1]
        cursuf = 0
        for i in range(1, 32):
            if i >= n: break
            curpref -= a[-1-i]
            cursuf += a[-1-i]
            if curpref >= a[-1-i]:
                best = min(best, a[-1-i] + cursuf)
        return best
    print(totsum - calc())
    for i in range(q):
        x, y = map(int, input().split())
        totsum -= a[x-1]
        a[x-1] = y
        totsum += y
        print(totsum - calc())

1 Like

Can someone explain me the question.

what do you mean by optimal solution here?
@iceknight1093

ā€œOptimal solutionā€ here means any sequence of moves that leads to maximum f(A).

1 Like

Can someone please tell what is wrong in this code?
It is failing for one test case.

#include <bits/stdc++.h>
using namespace std;

int main() {
	// your code goes here
	int t;
	cin>>t;
	while(t--){
	    int n,q;
	    cin>>n>>q;
	    vector<int>arr(n);
	    int sum=0;
	    for(int i=0;i<n;i++){
	        cin>>arr[i];
	        sum+=arr[i];
	    }
	    auto solve = [&](){
	        int ans=sum-arr[n-1]; //max loss
            int pre=ans;
            int suff=0;
            if(ans>=arr[n-1]){
                ans=min(ans,arr[n-1]);
            }
            for(int i=1;i<32;i++){
                if(i>=n){
                    break;
                }
                pre-=arr[n-i-1];
                suff+=arr[n-i-1];
                if(pre>=arr[n-i-1]){
                    ans=min(ans,suff+arr[n-i-1]);
                }
            }
            return ans;
	    };
	    cout<<sum-solve()<<endl;
	    for(int i=0;i<q;i++){
	        int q1,q2;
	        cin>>q1>>q2;
	        sum-=arr[q1-1];
	        arr[q1-1]=q2;
	        sum+=q2;
	        cout<<sum-solve()<<endl;
	    }
	}
}

I found if I am using :
#define int long long int
and instead of int main():
signed main()
then it is getting accepted.
Why this is happening?

A_i \leq 10^9 and N \leq 2\cdot 10^5 means that the sum of A can be quite large, upto 2\cdot 10^{14}.
The limit of int is 2^{31}-1 which is a bit more than 2\cdot 10^9, so the issue is quite simply overflow.

long long can hold upto 2^{63}-1 which is about 9\cdot 10^{18}, so large enough.