BSTRING - Editorial


Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: isheoran
Tester: yash_daga
Editorialist: iceknight1093






A binary string is said to be good if it has an equal number of 01 and 10 substrings.

Given a binary string S, count the number of its good subsequences.


First, we need a simpler characterization as to exactly when a binary string is good.

Let A be a binary string.
Suppose its first character is 0.
Then, A looks like a bunch of zeros, followed by several ones, followed by several zeros, and so on.
Since we only care about adjacent characters when they aren’t equal, we can compress the string, i.e, A = 010101\ldots.

In this form, it’s not hard to see that A has an equal number of 10 and 01 substrings only when it ends with a 0.
This condition is both necessary and sufficient; more importantly, it’s quite a simple one, since it depends only on the first and last characters.

This is the characterization we’ll use: a binary string A is good if and only if A_1 = A_N.

Given S, our aim is now to count the number of subsequences of S that start and end with the same character.
First off, there are N subsequences of length 1 that always fit this criterion, so let’s keep them aside.

Suppose we have S_i = S_j, where i \lt j.
Let’s count the number of subsequences with these two as the endpoints.

  • We can’t pick any elements before i or after j.
  • Between i and j, we can freely pick any elements.
  • There are (j - i - 1) elements between them, so there are 2^{j-i-1} subsequences in total.

This gives us a (slow) solution in \mathcal{O}(N^2) by fixing each pair of (i, j). All that remains is to optimize it.

Let’s fix the value of j.
Then, we want to compute the sum of 2^{j-i-1} across all i \lt j such that S_i = S_j.
Notice that 2^{j-i-1} = 2^{j-1}\cdot(2^{-i}), so all we really want is the sum of 2^{-i} across all i \lt j such that S_i = S_j: we can later multiply this sum by 2^{j-1} to account for all subsequences ending at index j.

Maintaining this sum is quite easy!
Computing 2^{-i} for a given i is a standard application of binary exponentiation and modular inverses.
After that, the only thing that needs to be maintained is two separate sums: the sum of 2^{-i} for all indices i such that S_i = 0, and the same sum for all indices such that S_i = 1.

Then, for a fixed j, the required contribution can be computed in \mathcal{O}(1); after which the corresponding sum can be increased by 2^{-j}.


\mathcal{O}(N) or \mathcal{O}(N\log{MOD}) per testcase.


Setter's code (C++)
#include <iostream>
using namespace std;
#define  enl          '\n'
#define  int          long long

const int mod = 1e9+7;

int binpow(int a,int b) {
    if(b<0) return 0;
    int res = 1;
    while(b > 0) {
        if(b & 1) res = res*a%mod;
    return res;

void solve() {
    int n;
    string s;
    int pre0 = 0, pre1 = 0;
    int ans = s.size();

    int inv2 = binpow(2,mod-2);
    int inv2Pow = inv2;
    int pow2 = 1;

    for(auto u:s) {
        inv2Pow = inv2Pow*inv2%mod;
        pow2 = pow2*2%mod;

        if(u == '1') {
            ans = (ans + pre1*pow2)%mod;
            pre1 = (pre1 + inv2Pow)%mod;
        else {
            ans = (ans + pre0*pow2)%mod;
            pre0 = (pre0 + inv2Pow)%mod;


signed main() {
    int testcases = 1;
    while(testcases--) solve();
    return 0;
Tester's code (C++)
//clear adj and visited vector declared globally after each test case
//check for long long overflow   
//Mod wale question mein last mein if dalo ie. Ans<0 then ans+=mod;
//Incase of close mle change language to c++17 or c++14  
//Check ans for n=1 
// #pragma GCC target ("avx2")    
// #pragma GCC optimize ("O3")  
// #pragma GCC optimize ("unroll-loops")
#include <bits/stdc++.h>                   
#include <ext/pb_ds/assoc_container.hpp>  
#define int long long     
#define IOS std::ios::sync_with_stdio(false); cin.tie(NULL);cout.tie(NULL);cout.precision(dbl::max_digits10);
#define pb push_back 
#define mod 1000000007ll //998244353ll
#define lld long double
#define mii map<int, int> 
#define pii pair<int, int>
#define ll long long 
#define ff first
#define ss second 
#define all(x) (x).begin(), (x).end()
#define rep(i,x,y) for(int i=x; i<y; i++)    
#define fill(a,b) memset(a, b, sizeof(a))
#define vi vector<int>
#define setbits(x) __builtin_popcountll(x)
#define print2d(dp,n,m) for(int i=0;i<=n;i++){for(int j=0;j<=m;j++)cout<<dp[i][j]<<" ";cout<<"\n";}
typedef std::numeric_limits< double > dbl;
using namespace __gnu_pbds;
using namespace std;
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> indexed_set;
//member functions :
//1. order_of_key(k) : number of elements strictly lesser than k
//2. find_by_order(k) : k-th element in the set
const long long N=200005, INF=2000000000000000000;
const int inf=2e9 + 5;
lld pi=3.1415926535897932;
int lcm(int a, int b)
    int g=__gcd(a, b);
    return a/g*b;
int power(int a, int b, int p)
        return 0;
        int res=1;
        return res;
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

int getRand(int l, int r)
    uniform_int_distribution<int> uid(l, r);
    return uid(rng);

int32_t main()
    int t;
        int n;
        string s;
        int co[2][2], c[2];
        fill(co, 0);
        fill(c, 0);
        int ans=0;
                ans+=(co[1][0] + co[1][1] + c[1] + 1);
                co[1][1]=(co[1][1]*2 + co[1][0])%mod;
                co[0][1]=(co[0][1]*2 + co[0][0] + c[0])%mod;
                c[1]=(c[1]*2 + 1)%mod;
                ans+=(co[0][1] + co[0][0] + c[0] + 1);
                co[0][0]=(co[0][0]*2 + co[0][1])%mod;
                co[1][0]=(co[1][0]*2 + co[1][1] + c[1])%mod;
                c[0]=(c[0]*2 + 1)%mod;
Editorialist's code (Python)
mod = 10**9 + 7
inv2 = pow(2, mod-2, mod)

import sys
input = sys.stdin.readline

for _ in range(int(input())):
    n = int(input())
    s = input()
    zsum, osum = 0, 0
    pw, invpw = 1, 1
    ans = 0
    for i in range(n):
        if s[i] == '0':
            ans = (ans + pw * zsum)%mod
            zsum = (zsum + invpw)%mod
            ans = (ans + pw * osum)%mod
            osum = (osum + invpw)%mod
        pw = (pw * 2) % mod
        invpw = (invpw * inv2) % mod
    ans *= inv2
    print((ans + n) % mod)

Can this problem be solved with Dynamic Programming approach?