CNTEMPTY - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Easy

PREREQUISITES:

Prefix sums, stacks OR sets and binary search

PROBLEM:

You have a string S containing only the characters A and B.
In one move, you can choose an alternating subsequence and delete it.

Across all substrings of S, sum up the minimum number of moves needed to make the substring empty.

EXPLANATION:

As detailed in the easy version’s solution, the answer for a fixed string is obtained by replacing occurrences of A and B by +1 and -1 respectively, then computing the maximum absolute subarray sum of the resulting array.

We now need to sum up this value across all the subarrays of the 1/-1 array.

Having to take the maximum of the maximum subarray sum and (negated) minimum subarray sum makes that seemingly hard, but let’s analyze what this actually means for a single array.

Let a be our transformed array of +1/-1 values, and p be its prefix sum array, so
p_i = a_1 + a_2 + \ldots + a_i.
We’ll treat p as a 0-indexed array of length n+1, with p_0 = 0.
The sum of the subarray [l, r] is then p_r - p_{l-1}.

Let mx = \max(p) and mn = \min(p).
The absolute maximum subarray sum is then exactly (mx - mn), because:

  • It’s obviously impossible to obtain any larger sum from the difference of two values within p, so if we can attain this upper bound it’s clearly optimal.
  • If the minimum index occurs to the left of the maximum, the maximum subarray sum will be between them, and equal (mx - mn).
  • If the minimum index occurs to the right of the maximum, the minimum subarray sum will be (mn - mx) instead; and we negate that to obtain (mx - mn) in absolute value.

So, our problem has really reduced to computing this quantity:

\sum_{l=1}^N \sum_{r=l}^N \left( \max(p_{l-1}, p_l, p_{l+1}, \ldots p_r) - \min(p_{l-1}, p_l, \ldots, p_r) \right)

This is doable by computing the sum of the maximums and the sum of the minimums separately, and subtracting one from the other.

Computing the sum of the maximums across all subarrays of an array is a fairly standard task.
The main observation here is: suppose we fix index i, and let L \lt i be the closest index such that p_L \gt p_i and R\gt i be the closest index such that p_R \geq p_i.
Then, p_i will be the maximum for exactly those subarrays [l, r] such that L \lt l \leq i \leq r \lt R.
Since l and r are essentially independent here, we have (i - L) \cdot (R - i) such subarrays.

So, all we really need to do is find these indices L and R quickly for a given index i.
There are a few different ways to do this:

  1. L is the previous greater element and R is the next greater/equal element.
    Finding these for all elements of an array in linear time is a well-known application of stacks.
  2. Alternately, you can maintain a set of indices and process elements of p in descending order of value. When processing an index, insert it into the set, and then find its predecessor and successor in the set using binary search to obtain L and R respectively.

TIME COMPLEXITY:

\mathcal{O}(N) or \mathcal{O}(N\log N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

void Solve() 
{
    int n; cin >> n;
    string s; cin >> s;
    
    vector <int> p(n + 1, 0);
    for (int i = 1; i <= n; i++){
        p[i] = p[i - 1] + (s[i - 1] == 'A') - (s[i - 1] == 'B');
    }
    
    auto f = [&](vector <int> a){
        // sum of max over all subarrays, excluding length 1 
        
        vector <int> nx(n + 1, n + 1), pv(n + 1, -1);
        stack <pair<int, int>> st;
        for (int i = 0; i <= n; i++){
            while (!st.empty() && st.top().first < a[i]){
                nx[st.top().second] = i;
                st.pop();
            }
            
            st.push({a[i], i});
        }
        
        while (!st.empty()) st.pop();
        
        for (int i = n; i >= 0; i--){
            while (!st.empty() && st.top().first <= a[i]){
                pv[st.top().second] = i;
                st.pop();
            }
            
            st.push({a[i], i});
        }
        
        int ans = 0;
        for (int i = 0; i <= n; i++){
            ans += (i - pv[i]) * (nx[i] - i) * a[i];
            ans -= a[i];
        }
        
        return ans;
    };
    
    int ans = f(p);
    for (auto &x : p){
        x *= -1;
    }
    ans += f(p);
    cout << ans << '\n';
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;

template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;

#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "YES" << endl
#define no cout << "NO" << endl

#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)

template<typename T>
void amin(T &a, T b) {
    a = min(a,b);
}

template<typename T>
void amax(T &a, T b) {
    a = max(a,b);
}

#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif

/*



*/

const int MOD = 1e9 + 7;
const int N = 1e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;

void solve(int test_case){
    ll n; cin >> n;
    string s; cin >> s;
    s = "$" + s;
    vector<ll> a(n+5);

    auto go = [&](){
        vector<ll> p(n+5);
        rep1(i,n) p[i] = p[i-1]+a[i];

        vector<ll> dp(n+5); // dp[i] = if we start at pos = i with x = 0, what is the sum of x over all ranges [i..r]?
        vector<ll> closest(2*n+5,n);
        rep(i,n+1) p[i] += n+1;

        vector<ll> pp(n+5);
        rep1(i,n) pp[i] = pp[i-1]+p[i];

        rev(i,n,1){
            if(a[i] < 0){
                dp[i] = dp[i+1];
            }
            else{
                // closest j to right s.t sum[i..j] = 0
                ll j = closest[p[i-1]];
                ll s = pp[j]-pp[i-1];
                ll add = s-p[i-1]*(j-i+1);
                dp[i] = add+dp[j+1];
            }

            closest[p[i]] = i;
        }

        ll res = accumulate(all(dp),0ll);

        return res;

        /*

        ll res = 0;

        rep1(i,n){
            ll s = 0;
            for(int j = i; j <= n; ++j){
                s += a[j];
                amax(s,0ll);
                res += s;
            }
        }

        return res;

        */
    };

    rep1(i,n){
        if(s[i] == 'A'){
            a[i] = 1;
        }
        else{
            a[i] = -1;
        }
    }

    ll ans = go();
    rep1(i,n) a[i] *= -1;
    ans += go();

    cout << ans << endl;
}

int main()
{
    fastio;

    int t = 1;
    cin >> t;

    rep1(i, t) {
        solve(i);
    }

    return 0;
}
Editorialist's code (PyPy3)
"""
The "sorted list" data-structure, with amortized O(n^(1/3)) cost per insert and pop.

Example:

A = SortedList()
A.insert(30)
A.insert(50)
A.insert(20)
A.insert(30)
A.insert(30)

print(A) # prints [20, 30, 30, 30, 50]

print(A.lower_bound(30), A.upper_bound(30)) # prints 1 4

print(A[-1]) # prints 50
print(A.pop(1)) # prints 30

print(A) # prints [20, 30, 30, 50]
print(A.count(30)) # prints 2

"""

from bisect import bisect_left as lower_bound
from bisect import bisect_right as upper_bound


class FenwickTree:
    def __init__(self, x):
        bit = self.bit = list(x)
        size = self.size = len(bit)
        for i in range(size):
            j = i | (i + 1)
            if j < size:
                bit[j] += bit[i]

    def update(self, idx, x):
        """updates bit[idx] += x"""
        while idx < self.size:
            self.bit[idx] += x
            idx |= idx + 1

    def __call__(self, end):
        """calc sum(bit[:end])"""
        x = 0
        while end:
            x += self.bit[end - 1]
            end &= end - 1
        return x

    def find_kth(self, k):
        """Find largest idx such that sum(bit[:idx]) <= k"""
        idx = -1
        for d in reversed(range(self.size.bit_length())):
            right_idx = idx + (1 << d)
            if right_idx < self.size and self.bit[right_idx] <= k:
                idx = right_idx
                k -= self.bit[idx]
        return idx + 1, k


class SortedList:
    block_size = 700

    def __init__(self, iterable=()):
        iterable = sorted(iterable)
        self.micros = [iterable[i:i + self.block_size - 1] for i in range(0, len(iterable), self.block_size - 1)] or [[]]
        self.macro = [i[0] for i in self.micros[1:]]
        self.micro_size = [len(i) for i in self.micros]
        self.fenwick = FenwickTree(self.micro_size)
        self.size = len(iterable)

    def insert(self, x):
        i = lower_bound(self.macro, x)
        j = upper_bound(self.micros[i], x)
        self.micros[i].insert(j, x)
        self.size += 1
        self.micro_size[i] += 1
        self.fenwick.update(i, 1)
        if len(self.micros[i]) >= self.block_size:
            self.micros[i:i + 1] = self.micros[i][:self.block_size >> 1], self.micros[i][self.block_size >> 1:]
            self.micro_size[i:i + 1] = self.block_size >> 1, self.block_size >> 1
            self.fenwick = FenwickTree(self.micro_size)
            self.macro.insert(i, self.micros[i + 1][0])

    def pop(self, k=-1):
        i, j = self._find_kth(k)
        self.size -= 1
        self.micro_size[i] -= 1
        self.fenwick.update(i, -1)
        return self.micros[i].pop(j)

    def __getitem__(self, k):
        i, j = self._find_kth(k)
        return self.micros[i][j]

    def count(self, x):
        return self.upper_bound(x) - self.lower_bound(x)

    def __contains__(self, x):
        return self.count(x) > 0

    def lower_bound(self, x):
        i = lower_bound(self.macro, x)
        return self.fenwick(i) + lower_bound(self.micros[i], x)

    def upper_bound(self, x):
        i = upper_bound(self.macro, x)
        return self.fenwick(i) + upper_bound(self.micros[i], x)

    def _find_kth(self, k):
        return self.fenwick.find_kth(k + self.size if k < 0 else k)

    def __len__(self):
        return self.size

    def __iter__(self):
        return (x for micro in self.micros for x in micro)

    def __repr__(self):
        return str(list(self))

for _ in range(int(input())):
    n = int(input())
    s = input()
    p = [0]
    for c in s:
        x = 1 if c == 'A' else -1
        p.append(p[-1] + x)
    
    ans = 0
    ind = list(range(0, n+1))
    ind.sort(key = lambda x: -p[x])
    S = SortedList([-1, n+1])
    for i in ind:
        S.insert(i)
        pos = S.lower_bound(i)
        L, R = S[pos-1], S[pos+1]
        ans += p[i] * (i - L - 1) * (R - i)
        ans += p[i] * (R - i - 1)
    
    S = SortedList([-1, n+1])
    for i in reversed(ind):
        S.insert(i)
        pos = S.lower_bound(i)
        L, R = S[pos-1], S[pos+1]
        ans -= p[i] * (i - L - 1) * (R - i)
        ans -= p[i] * (R - i - 1)
    
    print(ans)

This topic was automatically closed after 2 hours. New replies are no longer allowed.

This topic was automatically opened after 0 minutes.

Great trick of overcoming the repeated segments when encountering same numbers. Thanks