PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author:
Tester: sushil2006
Editorialist: iceknight1093
DIFFICULTY:
Easy-Medium
PREREQUISITES:
Dynamic programming, Prefix sums
PROBLEM:
For a binary string S, define f(S) = \max(c_0, c_1), where c_0 and c_1 denote the number of zeros and ones in S, respectively.
Given a binary string S, compute \sum_{L=1}^N\sum_{R=L}^N f(S[L\ldots R]).
EXPLANATION:
Let’s fix the right endpoint R of the substring, and try to compute the sum of answers across all L \leq R.
Observe that as L starts at R and moves leftwards, the most frequent character will be S_R for a while - till there reaches a point where there are an equal number of zeros and ones.
Let i \lt R be the rightmost index such that S[i\ldots R] contains an equal number of zeros and ones.
Then, note that for any L \lt i, we have f(S[L\ldots R]) = f([L\ldots (i - 1)]) + \frac{(R - i + 1)}{2}
This is because there are an equal number of zeros and ones in [i, R], the maximum frequency is determined purely by the part to the left of i - and no matter what the maximum is, [i, R] will add half its length to that count.
This hints at a dynamic programming approach.
Let dp_R denote the sum of answers of subarrays ending at R.
Then,
- Find the rightmost index i \lt R such that S[i\ldots R] contains an equal number of zeros and ones.
- For all L \lt i, f(S[L\ldots R]) = f([L\ldots (i - 1)]) + \frac{(R - i + 1)}{2}.
This means we can add dp_i + (i-1)\cdot \frac{(R - i + 1)}{2} to dp_R to account for this. - For all i \leq L \leq R, the most frequent element is always S_R.
So, all we really want to do is sum up the number of occurrences of S_R in these subarrays quickly.
A direct implementation of the above is still quadratic time - specifically the first and third steps.
To optimize the first step of finding i, we use a common trick when dealing with binary strings: replace every occurrence of 0 by -1 instead, and look at subarray sums.
Note that some substring has an equal number of zeros and ones if and only if its sum after the transformation is 0.
Let A be the transformed array (i.e. A_i = -1 if S_i = 0, and A_i = 1 if S_i = 1).
When R is fixed, we’re looking for the largest index i \lt R such that A_i + A_{i+1} + \cdots + A_{R} = 0.
Since we’re working with subarray sums, it’s natural to think in terms of prefix sums instead.
That is, let pA denote the prefix sum array of A, so that pA_i = A_1 + A_2 + \cdots + A_i.
Then, the sum of A[i\ldots R] equals 0 if and only if pA_R - pA_{i-1} = 0, meaning pA_R = pA_{i-1}.
This immediately tells us how to find i: simply find the previous index with the same prefix sum as R.
This is quite easy to do by say, storing all indices corresponding to each prefix sum in a sorted list.
Now that i is known, we need to optimize the third part of the algorithm, i.e., finding the sum of the number of occurrences of S_R across all the substrings S[L\ldots R] for i \leq L \leq R.
This can also be done with the help of prefix sums.
Note that the number of ones in S[L\ldots R] simply equals S_L + S_{L+1} + \ldots + S_R.
So, if pS denotes the prefix sum array of S, the value we want to compute is exactly
The first term is easily found in constant time, the second can also be found in constant time by building prefix sums over pS.
Note that this handles the case where S_R = 1. If it’s 0, you’ll have a similar expression (but not exactly the same).
TIME COMPLEXITY:
\mathcal{O}(N) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define ll long long
#define INF (ll)1e18
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
const int mod = 1e18;
void Solve(int n, string s)
{
vector <int> dp(n + 1, 0);
vector <int> last(2 * n + 1, 0);
vector <int> p(n + 1, 0);
vector <int> po(n + 1, 0), pz(n + 1, 0);
for (int i = 1; i <= n; i++){
p[i] = s[i] == '1';
p[i] += p[i - 1];
po[i] = po[i - 1] + p[i];
if (po[i] >= mod) po[i] -= mod;
pz[i] = pz[i - 1] + (i - p[i]);
if (pz[i] >= mod) pz[i] -= mod;
}
int bal = n;
for (int i = 1; i <= n; i++){
if (s[i] == '1') bal++;
else bal--;
int j = last[bal];
if (s[i] == '1'){
dp[i] = 1LL * p[i] * (i - j) % mod;
int x = po[i - 1];
if (j != 0) x -= po[j - 1];
if (x < 0) x += mod;
dp[i] -= x;
if (dp[i] < 0) dp[i] += mod;
} else {
dp[i] = 1LL * (i - p[i]) * (i - j) % mod;
int x = pz[i - 1];
if (j != 0) x -= pz[j - 1];
if (x < 0) x += mod;
dp[i] -= x;
if (dp[i] < 0) dp[i] += mod;
}
dp[i] += dp[j];
if (dp[i] >= mod) dp[i] -= mod;
dp[i] += 1LL * j * (p[i] - p[j]) % mod;
if (dp[i] >= mod) dp[i] -= mod;
last[bal] = i;
// cout << dp[i] << " \n"[i == n];
}
ll ans = 0;
for (auto x : dp){
ans += x;
}
ans %= mod;
// assert(ans == 8);
cout << ans << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 0; i < t; i++)
{
//cout << "Case #" << i << ": ";
int n; string s;
cin >> n >> s;
s = "0" + s;
Solve(n, s);
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "YES" << endl
#define no cout << "NO" << endl
#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)
template<typename T>
void amin(T &a, T b) {
a = min(a,b);
}
template<typename T>
void amax(T &a, T b) {
a = max(a,b);
}
#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif
/*
*/
const int MOD = 1e9 + 7;
const int N = 1e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;
template<typename T>
struct fenwick {
int n;
vector<T> tr;
int LOG = 0;
fenwick() {
}
fenwick(int n_) {
n = n_;
tr = vector<T>(n + 1);
while((1<<LOG) <= n) LOG++;
}
int lsb(int x) {
return x & -x;
}
void pupd(int i, T v) {
for(; i <= n; i += lsb(i)){
tr[i] += v;
}
}
T sum(int i) {
T res = 0;
for(; i; i ^= lsb(i)){
res += tr[i];
}
return res;
}
T query(int l, int r) {
if (l > r) return 0;
T res = sum(r) - sum(l - 1);
return res;
}
int lower_bound(T s){
// first pos with sum >= s
if(sum(n) < s) return n+1;
int i = 0;
rev(bit,LOG-1,0){
int j = i+(1<<bit);
if(j > n) conts;
if(tr[j] < s){
s -= tr[j];
i = j;
}
}
return i+1;
}
int upper_bound(T s){
return lower_bound(s+1);
}
};
void solve(int test_case){
ll n; cin >> n;
string s; cin >> s;
s = "$" + s;
vector<ll> p(n+5);
rep1(i,n) p[i] = p[i-1]+(s[i]=='1');
ll ans = 0;
auto go = [&](ll c){
fenwick<ll> fenw_cnt(3*n+5), fenw_sum(3*n+5);
rep1(i,n){
ll pos = 2*p[i-1]-(i-1)+n+1;
fenw_cnt.pupd(pos,1);
fenw_sum.pupd(pos,p[i-1]);
ll j = 2*p[i]-i+n+1;
ll cnt = fenw_cnt.sum(j-c), sum = fenw_sum.sum(j-c);
ans += cnt*p[i]-sum;
}
};
go(0);
rep1(i,n) p[i] = i-p[i];
go(1);
cout << ans << endl;
}
int main()
{
fastio;
int t = 1;
cin >> t;
rep1(i, t) {
solve(i);
}
return 0;
}
Editorialist's code (PyPy3)
from collections import defaultdict
for _ in range(int(input())):
n = int(input())
s = 'a' + input()
pref = [0]
for i in range(1, n+1):
pref.append(pref[-1] + (s[i] == '1'))
prefpref = [0]
for i in range(1, n+1):
prefpref.append(prefpref[-1] + pref[i])
dp = [0]*(n+1)
last = defaultdict(lambda: 0)
dif = 0
for i in range(1, n+1):
if s[i] == '0': dif -= 1
else: dif += 1
j = last[dif]
sz = i - j
dp[i] = dp[j] + j * (sz//2)
if s[i] == '1': dp[i] += sz*pref[i] - prefpref[i-1] + (prefpref[j-1] if j >= 1 else 0)
else: dp[i] += (i-j)*(i-j+1)//2 - (sz*pref[i] - prefpref[i-1] + (prefpref[j-1] if j >= 1 else 0))
last[dif] = i
print(sum(dp))