SATSUB - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: harsh_h
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Segment trees

PROBLEM:

An array is said to be saturated if its sum equals its maximum subarray sum.

You’re given an array A. Answer Q queries on it: given L and R, count the number of subarrays of [A_L, A_{L+1}, \ldots, A_R] that are saturated.

EXPLANATION:

First, let’s characterize saturated arrays.

Consider some array B.
If some prefix of B has negative sum, cutting out this prefix will leave us with a suffix that has sum strictly larger than that of B - so B cannot be saturated.
Similarly, if B has a suffix with negative sum, it cannot be saturated.

These conditions are not just necessary: they’re also sufficient.
That is,

B is saturated if and only if every prefix and every suffix of B have non-negative sum.

Necessity has already been shown; as for being sufficient, note that if every prefix and suffix have sum \geq 0, any subarray [B_L, \ldots, B_R] will have its sum not decreased by including in the prefix till L-1 and suffix from R+1, so the entire array’s sum is not less than that of B[L\ldots R].
This was true for an arbitrary subarray, meaning every subarray sum doesn’t exceed the array’s sum, proving our claim.


Let’s use this characterization to answer queries.

We’ll process queries offline using a sweepline.
Iterate over L from N down to 1, and define c_i to be the number of subarrays that start \geq L, end at i, and are saturated.
If we’re able to keep the array c updated as we move L downwards, answering a query [L, R] becomes simply finding the sum

c_L + c_{L+1} + \ldots + c_R

Let’s see what changes when we move from L+1 to L.

Only subarrays starting at L matter.
So, for each index i \geq L such that A[L\ldots i] is saturated, we want to add 1 to c_i.

Now, let’s use the characterization of saturated subarrays.
A[L\ldots, i] is saturated if and only if:

  1. A_L + A_{L+1} + \ldots + A_j \geq 0 for every L \leq j \leq i, and
  2. A_j + A_{j+1} + \ldots + A_i \geq 0 for every L \leq j \leq i

Let p_i = A_1 + A_2 + \ldots + A_i denote the prefix sum array of A.
The above conditions can be rewritten in terms of p, as:

  1. p_j \geq p_{L-1} for each L \leq j \leq i, and
  2. p_{j-1} \leq p_i for each L \leq j \leq i

The first condition is in fact fairly easy to deal with: let x \geq L be the first index such that p_x \lt p_{L-1}, that is, the first prefix sum to the right that’s smaller than p_{L-1}.
Then, we must have i \lt x, since including index x in the subarray would violate the condition.
This gives us a range of right endpoints to work with.

The second condition is a bit harder to deal with, since it depends on i rather than L.
However, observe that a similar situation applies here: each right endpoint i will be “active” for some range of L, specifically till it reaches the first prefix sum to its left that’s strictly larger than it.

Let d_i be a boolean array denoting whether index i is “active” as a valid right endpoint or not.
Initially, d_i = 1 for all i.

Now, when we’re at index L,

  1. Deactivate all i such that L-1 is the end of their range, i.e, all i such that L-1 is the nearest larger prefix sum to its left.
    This can be done by setting d_i = 0 for these indices.
  2. Let x \geq i be the nearest index such that p_x \lt p_{L-1}.
  3. For each i = L, L+1, L+2, \ldots, x-1, add d_i to c_i.

This way, the c_i values of each element in the range increase by 1 if i is active, and don’t change otherwise.


Now, we need to be able to deal with these operations quickly enough.

One way is to use a segment tree with lazy propagation.
In each node of the tree, store two values: the sum of c_i values for the range it represents, and the sum of d_i values for the range it represents.

Operating on a range is now simply increasing the sum of c_i by the sum of d_i, and updates can be accumulated in a node by counting how many times they need to be performed.
Apart from that, we have point updates to the d_i, but those are easy to deal with in the same tree.

Finally, we need to be able to quickly find the ranges we’re querying/updating.
This is a standard problem — for each p_i, we want to know the nearest prefix sum to its left/right that’s larger/smaller than it respectively, which can be found for all indices in linear time with a stack.

TIME COMPLEXITY:

\mathcal{O}(N + Q\log N) per testcase.

CODE:

Tester's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;

template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;

#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl

#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)

template<typename T>
void amin(T &a, T b) {
    a = min(a,b);
}

template<typename T>
void amax(T &a, T b) {
    a = max(a,b);
}

#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif

/*



*/

const int MOD = 1e9 + 7;
const int N = 1e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;

template<typename T>
struct lazysegtree {
    /*=======================================================*/

    struct data {
        ll sum,active;
    };

    struct lazy {
        ll a;
    };

    data d_neutral = {0,0};
    lazy l_neutral = {0};

    void merge(data &curr, data &left, data &right) {
        curr.sum = left.sum+right.sum;
        curr.active = left.active+right.active;
    }

    void create(int x, int lx, int rx, T v) {
        tr[x].sum = 0;
        tr[x].active = 1;
    }

    void modify(int x, int lx, int rx, T v) {
        if(v.ff == 1){
            lz[x].a = v.ss;
        }
        else{
            tr[x].active = 0;
        }
    }

    void propagate(int x, int lx, int rx) {
        ll v = lz[x].a;
        if(!v) return;
        tr[x].sum += v*tr[x].active;

        if(rx-lx > 1){
            lz[x<<1].a += v;
            lz[x<<1|1].a += v;
        }

        lz[x] = l_neutral;
    }

    /*=======================================================*/

    int siz = 1;
    vector<data> tr;
    vector<lazy> lz;

    lazysegtree() {

    }

    lazysegtree(int n) {
        while (siz < n) siz *= 2;
        tr.assign(2 * siz, d_neutral);
        lz.assign(2 * siz, l_neutral);
    }

    void build(int n, int x, int lx, int rx) {
        if (rx - lx == 1) {
            if (lx < n) {
                create(x, lx, rx, {0,0});
            }

            return;
        }

        int mid = (lx + rx) >> 1;

        build(n, x<<1, lx, mid);
        build(n, x<<1|1, mid, rx);

        merge(tr[x], tr[x<<1], tr[x<<1|1]);
    }

    void build(int n) {
        build(n, 1, 0, siz);
    }

    void rupd(int l, int r, T v, int x, int lx, int rx) {
        propagate(x, lx, rx);

        if (lx >= r or rx <= l) return;
        if (lx >= l and rx <= r) {
            modify(x, lx, rx, v);
            propagate(x, lx, rx);
            return;
        }

        int mid = (lx + rx) >> 1;

        rupd(l, r, v, x<<1, lx, mid);
        rupd(l, r, v, x<<1|1, mid, rx);

        merge(tr[x], tr[x<<1], tr[x<<1|1]);
    }

    void rupd(int l, int r, T v) {
        rupd(l, r + 1, v, 1, 0, siz);
    }

    data query(int l, int r, int x, int lx, int rx) {
        propagate(x, lx, rx);

        if (lx >= r or rx <= l) return d_neutral;
        if (lx >= l and rx <= r) return tr[x];

        int mid = (lx + rx) >> 1;

        data curr;
        data left = query(l, r, x<<1, lx, mid);
        data right = query(l, r, x<<1|1, mid, rx);

        merge(curr, left, right);
        return curr;
    }

    data query(int l, int r) {
        return query(l, r + 1, 1, 0, siz);
    }
};

void solve(int test_case)
{
    ll n,q; cin >> n >> q;
    vector<ll> a(n+5);
    rep1(i,n) cin >> a[i];
    vector<ll> p(n+5);
    rep1(i,n) p[i] = p[i-1]+a[i];

    vector<ll> rx(n+5,n), lx(n+5,1);

    {
        stack<ll> stk;
        rep(i,n+1){
            while(!stk.empty() and p[i] < p[stk.top()]){
                rx[stk.top()] = i-1;
                stk.pop();
            }
            stk.push(i);
        }
    }

    {
        stack<ll> stk;
        rev(i,n,0){
            while(!stk.empty() and p[i] > p[stk.top()]){
                lx[stk.top()] = i+1;
                stk.pop();
            }
            stk.push(i);
        }
    }

    rep1(i,n) lx[i]--;
    vector<pll> queries[n+5];

    rep1(id,q){
        ll l,r; cin >> l >> r;
        queries[r].pb({l,id});

        /*

        ll ans = 0;
        for(int i = l-1; i < r; ++i){
            for(int j = i+1; j <= r; ++j){
                if(rx[i] >= j and lx[j] <= i){
                    ans++;
                }
            }
        }
        cout << ans << endl;
        
        */
    }

    vector<ll> leave[n+5];
    rep(i,n) leave[rx[i]+1].pb(i);

    lazysegtree<pll> st(n+5);
    st.build(n+1);
    vector<ll> ans(q+5);

    rep1(r,n){
        trav(l,leave[r]){
            st.rupd(l,l,{2,1});
        }    

        st.rupd(lx[r],r-1,{1,1});

        for(auto [l,id] : queries[r]){
            ans[id] = st.query(l-1,r-1).sum;
        }
    }

    rep1(i,q) cout << ans[i] << endl;
}

int main()
{
    fastio;

    int t = 1;
    cin >> t;

    rep1(i, t) {
        solve(i);
    }

    return 0;
}
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

struct Node {
    using T = array<ll, 2>;
    T unit = {0, 0};
    T f(T a, T b) {
        return {a[0] + b[0], a[1] + b[1]};
    }
 
    Node *l = 0, *r = 0;
    int lo, hi;
    ll mul = 0;
    T val = unit;
    Node(int _lo,int _hi):lo(_lo),hi(_hi){}
    T query(int L, int R) {
        if (R <= lo || hi <= L) return unit;
        if (L <= lo && hi <= R) return val;
        push();
        return f(l->query(L, R), r->query(L, R));
    }
    void add(int L, int R, ll x) {
        if (R <= lo || hi <= L) return;
        if (L <= lo && hi <= R) {
            mul += x;
            val[0] += x * val[1];
        }
        else {
            push(), l->add(L, R, x), r->add(L, R, x);
            val = f(l->val, r->val);
        }
    }
    void upd(int pos, int x) {
        if (lo > pos or hi <= pos) return;
        if (lo+1 == hi) {
            val[1] = x;
            return;
        }
        push();
        l->upd(pos, x), r->upd(pos, x);
        val = f(l->val, r->val);
    }
    void push() {
        if (!l) {
            int mid = lo + (hi - lo)/2;
            l = new Node(lo, mid); r = new Node(mid, hi);
        }
        if (mul)
            l->add(lo,hi,mul), r->add(lo,hi,mul), mul = 0;
    }
};


int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n, q; cin >> n >> q;

        vector<ll> a(n+1), p(n+1);
        for (int i = 1; i <= n; ++i) {
            cin >> a[i];
            p[i] = a[i] + p[i-1];
        }

        vector<int> lt(n+1, -1), rt(n+1, n+1);
        stack<int> s;
        for (int i = 0; i <= n; ++i) {
            while (!s.empty()) {
                int u = s.top();
                if (p[i] >= p[u]) s.pop();
                else break;
            }
            if (!s.empty()) lt[i] = s.top();
            s.push(i);
        }
        s = stack<int>();
        for (int i = n; i >= 0; --i) {
            while (!s.empty()) {
                int u = s.top();
                if (p[i] <= p[u]) s.pop();
                else break;
            }
            if (!s.empty()) rt[i] = s.top();
            s.push(i);
        }
        vector<vector<int>> deactivate(n+1);
        for (int i = 1; i <= n; ++i) if (lt[i] >= 0)
            deactivate[lt[i]].push_back(i);

        vector<vector<array<ll, 2>>> queries(n+1);
        vector<ll> ans(q);
        for (int i = 0; i < q; ++i) {
            int L, R; cin >> L >> R;
            queries[L].push_back({R, i});
        }
        
        Node *seg = new Node(0, n+1);
        for (int i = n; i >= 1; --i) {
            int r = rt[i-1];
            
            seg->upd(i, 1);
            for (int j : deactivate[i-1]) {
                seg->upd(j, 0);
            }
            seg->add(i, r, 1);
            
            for (auto [R, id] : queries[i])
                ans[id] = seg->query(i, R+1)[0];
        }
        
        for (auto x : ans) cout << x << '\n';
    }
}
1 Like