LNDS - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: satyam_343
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

2839

PREREQUISITES:

Dynamic programming or divide and conquer, prefix sums

PROBLEM:

Given a binary string S, let f(i, j) denote the length of the longest non-decreasing subsequence of the substring S[i\ldots j].
Find

\sum_{i=1}^N\sum_{j=i}^N f(i, j)

EXPLANATION:

There are a couple of different solutions to this problem, I’ll present two of them below.

Solution 1

\mathcal{O}(N) Dynamic Programming

Let’s define \text{dp}[R] = \sum_{L=1}^R f(L, R).
That is, \text{dp}[R] denotes the sum of lengths of longest non-decreasing subsequences of all subarrays ending at R.
If we can compute \text{dp}[R] for all R, the final answer is just their sum.

First off, if S_R = 1, then \text{dp}[R] = \text{dp}[R-1] + R.
This is because, if S_R = 1, it’s always optimal to use it to extend the LNDS.
Now we only need to deal with S_R = 0.

There are two types of subarrays: those whose LNDS includes the newly added zero, and those that don’t.
Let’s try to find them.

Consider a subarray [L, R] whose LNDS doesn’t include S_R - meaning its LNDS ends with a 1.
Suppose the leftmost 1 in its LNDS is at index k, where L \leq k \lt R.
Then, surely the range [k\ldots R] must contain strictly more ones than it does zeros; otherwise we could just as well use the zeros (and hence S_R).
Conversely, if for some index L, there’s no L \leq k such that [k\ldots R] contains strictly more ones than zeros; then it will surely use only zeros in its LNDS.

So, let’s find the rightmost index k such that [k\ldots R] contains more ones than zeros.
Then, note that:

  • For any L \geq k+1, f(L, R) will equal the number of zeros in its subarray, as our observations above show.
    Computing the sum of these over all L in the range can be done in \mathcal{O}(1) with the help of suffix sums.
  • Note that because of our choice of k, S[k+1\ldots R] in fact contains an equal number of zeros and ones; and further, the best NDS we can find is to just take all the ones.
    So, for any L \leq k, the answer is simply to take its LNDS till k (the information of which is contained in \text{dp}[k]), then add the number of ones in [k+1, R].
    Once again, this can be computed in \mathcal{O}(1) time from \text{dp}[k] and a bit of math.

So, as long as we find k quickly, we’re good.
To do this, note that we can instead treat all the 0's to be -1's instead; and then we’re looking for the smallest subarray ending at R with sum 1.
Now prefix sums can help you do this in \mathcal{O}(\log N), or even \mathcal{O}(1) time if you use the fact that all prefix sums are in the range [-N, N].

This solves the problem in \mathcal{O}(N) or \mathcal{O}(N\log N) time, depending on implementation.

Solution 2

\mathcal{O}(N\log^2 N) Divide & Conquer

Let f(L, R) denote the answer for all subarrays contained in S[L\ldots R].
Let M = \frac{L+R}{2} denote the midpoint of the range.
Recursively compute f(L, M) and f(M+1, R).
This leaves us with only subarrays crossing the middle, i.e, something of the form [i\ldots M] joined to [M+1 \ldots j].

Let \text{left}[i] denote the subarray S[i\ldots M], and f(i, M) be the length of its answer.
Similarly, let \text{right}[j] denote the subarray S[M+1\ldots j] and f(M+1, j) denote its answer.
Observe that, for any pair i and j,

  • f(i, j) = f(i, M) + \text{ones}(M+1, j); or
  • f(i, j) = \text{zeros}(i, M) + f(M+1, j)
    Here, \text{ones}(M+1, j) denotes the number of 1's in S[M+1\ldots j].
    \text{zeros}(i, M) is defined similarly.

That is, the LNDS for S[i\ldots j] can be found by either extending an LNDS for the left side to the right by ones, or extending an LNDS for the right side to the left by zeros.
The proof is simple - the LNDS looks like several zeros followed by several ones, so consider cases for where the first 1 lies.

Let’s attempt to use this information.

f(i, M) + \text{ones}(M+1, j) \geq \text{zeros}(i, M) + f(M+1, j) \\ \implies f(i, M) - \text{zeros}(i, M) \geq f(M+1, j) - \text{ones}(M+1, j)

So, for a fixed i on the left, the set of j for which we choose the first option is some contiguous range; assuming the indices on the right are sorted by their f(M+1, j) - \text{ones}(M+1, j) values.
Once this range is found,

  • For everything in the range, the answer is f(i, M) + \text{ones}(M+1, j).
    Summing this across all j in the range is simple, just keep a prefix sum of the \text{ones}(M+1, j) values.
  • For everything outside the range, the answer is f(M+1, j) + \text{zeros}(i, M).
    Again, summing this up is simple via prefix sums of the f(M+1, j) values.

Note that this does require us to know the f(i, M) and f(M+1, j) values.
However, those can be easily computed since one endpoint is fixed.
For example,

  • If S_j = 1, then f(M+1, j) = f(M+1, j-1) + 1.
    If there’s a 1 at the end, it’s optimal to use it.
  • If S_j = 0, then f(M+1, j) = \max(f(M+1, j-1), \text{zeros}(M+1, j))
    If there’s a zero at the end, we either use all zeros, or we don’t use it at all.

So, for a single i, we require one binary search to find the appropriate range, taking \mathcal{O}(\log N) time.
Each index is processed \mathcal{O}(\log N) times throughout the entire computation tree, making the overall complexity \mathcal{O}(N\log^2 N).

We also sort the right side each time.
That’s \mathcal{O}(N\log N) for the entire array throughout a single level of the recursion tree, and there are \mathcal{O}(\log N) levels making the whole thing \mathcal{O}(N\log^2 N) as well.

Just as in the first solution, one log factor can be shaved off by utilizing the fact that we’re dealing with small values (so sorting can be done in linear time and the binary search can be replaced by a direct index lookup).

TIME COMPLEXITY

\mathcal{O}(N) per testcase.

CODE:

Author's code (C++, solution 1)
#pragma GCC optimod_intze("O3,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;   
using namespace std;
#define ll long long
const ll INF_MUL=1e15;
const ll INF_ADD=1e18;
#define pb push_back               
#define mp make_pair          
#define nline "\n"                           
#define f first                                          
#define s second                                             
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()     
#define vl vector<ll>           
#define vvl vector<vector<ll>>    
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);  
#endif     
void _print(ll x){cerr<<x;}  
void _print(char x){cerr<<x;}   
void _print(string x){cerr<<x;}    
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());   
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------  
const ll MOD=998244353;
const ll MAX=500500;
void solve(){  
    ll n; cin>>n;  
    vector<ll> pref(n+5,0);
    string s; cin>>s; s=" "+s;
    for(ll i=1;i<=n;i++){
        pref[i]=pref[i-1]+(s[i]-'0'); 
    }    
    auto getv=[&](ll l,ll r,ll val){
        if(l>r){  
            return 0ll;  
        }    
        ll len=r-l+1,now=pref[r]-pref[l-1];
        if(val==0){
            now=len-now;
        }  
        return now;  
    };    
    ll ans=0;    
    vector<ll> consider(n+5,0);  
    vector<array<ll,3>> track;
    track.push_back({INF_ADD,n+1,n+1});
    ll sum=0;
    for(ll i=n;i>=1;i--){
        consider[i]=getv(1,i,0)+getv(i,n,1);
        array<ll,3> now={consider[i],i,i};
        sum+=consider[i];
        while(consider[i]>track.back()[0]){
            auto it=track.back();
            sum+=(now[0]-it[0])*(it[2]-it[1]+1);
            now[2]=it[2];
            track.pop_back();
        }
        track.push_back(now);
        ans+=sum;
        ans-=getv(1,i-1,0)*(n-i+1);
        ans-=getv(i+1,n,1)*i;
    }
    cout<<ans<<nline;
    return;
}                                          
int main()                                                                           
{   
    ios_base::sync_with_stdio(false);                         
    cin.tie(NULL);                              
    #ifndef ONLINE_JUDGE                 
    freopen("input.txt", "r", stdin);                                           
    freopen("output.txt", "w", stdout);      
    freopen("error.txt", "w", stderr);                        
    #endif     
    ll test_cases=1;               
    cin>>test_cases;
    while(test_cases--){
        solve();
    }
    cout<<fixed<<setprecision(10);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
}  
Tester's code (C++, solution 1)
#ifndef LOCAL
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,avx2,sse,sse2,sse3,sse4,popcnt,fma")
#endif

#include <bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...) "11-111"
#endif

struct input_checker {
	string buffer;
	int pos;

	const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
	const string number = "0123456789";
	const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
	const string lower = "abcdefghijklmnopqrstuvwxyz";

	input_checker() {
		pos = 0;
		while (true) {
			int c = cin.get();
			if (c == -1) {
				break;
			}
			buffer.push_back((char) c);
		}
	}

	int nextDelimiter() {
		int now = pos;
		while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
			now++;
		}
		return now;
	}

	string readOne() {
		assert(pos < (int) buffer.size());
		int nxt = nextDelimiter();
		string res;
		while (pos < nxt) {
			res += buffer[pos];
			pos++;
		}
		return res;
	}

	string readString(int minl, int maxl, const string &pattern = "") {
		assert(minl <= maxl);
		string res = readOne();
		assert(minl <= (int) res.size());
		assert((int) res.size() <= maxl);
		for (int i = 0; i < (int) res.size(); i++) {
			assert(pattern.empty() || pattern.find(res[i]) != string::npos);
		}
		return res;
	}

	int readInt(int minv, int maxv) {
		assert(minv <= maxv);
		int res = stoi(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	long long readLong(long long minv, long long maxv) {
		assert(minv <= maxv);
		long long res = stoll(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	auto readInts(int n, int minv, int maxv) {
		assert(n >= 0);
		vector<int> v(n);
		for (int i = 0; i < n; ++i) {
			v[i] = readInt(minv, maxv);
			if (i+1 < n) readSpace();
		}
		return v;
	}

	auto readLongs(int n, long long minv, long long maxv) {
		assert(n >= 0);
		vector<long long> v(n);
		for (int i = 0; i < n; ++i) {
			v[i] = readLong(minv, maxv);
			if (i+1 < n) readSpace();
		}
		return v;
	}

	void readSpace() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == ' ');
		pos++;
	}

	void readEoln() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == '\n');
		pos++;
	}

	void readEof() {
		assert((int) buffer.size() == pos);
	}
};

constexpr int N = 1000000;

int32_t main() {
    ios_base::sync_with_stdio(0);   cin.tie(0);

    input_checker input;
    int T = input.readInt(1, N); input.readEoln();
    int sum_testcase = 0;
    while(T-- > 0) {
        int n = input.readInt(1, N); input.readEoln();
        string s = input.readString(n, n, "01");    input.readEoln();
        
        sum_testcase += n;
        vector<int> a(n);
        for(int i = 0 ; i < n ; i++) {
            a[i] = s[i] - '0';
        }

        vector<long long> sf(n + 2), sum(n + 2), ans(n + 2), pf(n + 2);
        
        vector<int> index(2 * n + 5);
        // map<int, int> index;
        int sd = n + 2;
        for(int i = n - 1 ; i >= 0 ; i--) {
            sf[i + 1] = sf[i + 2] + 1 - a[i];
            sum[i + 1] = sum[i + 2] + sf[i + 1];
        }

        for(int i = 0 ; i < n ; i++) {
            pf[i + 1] = pf[i] + a[i];
            if(a[i] == 0) {
                int ind = index[sd + i + 1 - 2 * pf[i + 1]];
                ans[i + 1] += ans[ind] + ind * 1ll * (pf[i + 1] - pf[ind]);
                ans[i + 1] += sum[ind + 1] - sum[i + 1] + 1;
                ans[i + 1] -= sf[i + 2] * 1ll * (i - ind);
            } else {
                ans[i + 1] += ans[i] + i + 1;
            }
            index[sd + i - 2 * pf[i]] = i;
        }

        cout << accumulate(ans.begin(), ans.end(), 0ll) << '\n';
    }
    assert(sum_testcase <= N);
    input.readEof();

    return 0;
}
Editorialist's code (C++, solution 2)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector<int> a(n);
        for (int &x : a) cin >> x;
        ll ans = 0;
        auto calc = [&] (const auto &self, int L, int R) -> void {
            if (L+1 == R) {
                ++ans;
                return;
            }
            int mid = (L + R)/2;
            self(self, L, mid);
            self(self, mid, R);

            int curans = 0, zeros = 0, ones = 0;
            vector<array<ll, 3>> vals;
            for (int i = mid; i < R; ++i) {
                if (a[i] == 1) {
                    ++ones;
                    ++curans;
                }
                else {
                    ++zeros;
                    curans = max(zeros, curans);
                }
                vals.push_back({curans - ones, ones, curans});
            }
            sort(begin(vals), end(vals));
            for (int i = 1; i < R - mid; ++i) vals[i][1] += vals[i-1][1];
            for (int i = R - mid - 2; i >= 0; --i) vals[i][2] += vals[i+1][2];

            curans = zeros = ones = 0;
            for (int i = mid-1; i >= L; --i) {
                if (a[i] == 0) {
                    ++zeros;
                    ++curans;
                }
                else {
                    ++ones;
                    curans = max(curans, ones);
                }

                auto it = lower_bound(begin(vals), end(vals), array{curans - zeros + 1LL, -1LL, -1LL});
                int lt = it - begin(vals), rt = end(vals) - it;
                if (lt) ans += 1LL*lt*curans + vals[lt-1][1];
                if (rt) ans += it->at(2) + 1LL*rt*zeros;
            }
        };
        calc(calc, 0, n);
        cout << ans << '\n';
    }
}