NUMMAT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: mexomerf
Tester: mexomerf
Editorialist: iceknight1093

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Segment trees and lazy propagation

PROBLEM:

You’re given a binary string A. When at index i, Chef will do the following:

  • If i = N or A_{i+1} = 0, stop.
  • Otherwise, jump to the closest index j \gt i such that A_j = 0. Add to his score the number of ones on the way.
    Then, repeat the process.

Process Q queries and updates:

  1. Given L and R, flip the values of A_i for all L \leq i \leq R.
  2. Given x, find the number of coins Chef will collect if he starts at x.

EXPLANATION:

When starting from an index x, observe that Chef will essentially keep moving right till he reaches either index N, or a pair of adjacent zeros.

Once the movement’s endpoint is known, the number of coins collected becomes a simple range sum.
So, we need a data structure that can handle the following:

  1. Compute the sum of a range.
  2. Find the first occurrence of two adjacent zeros after a specified index.
  3. Process range flips along with the above two.

One data structure that can handle such updates and queries is a segment tree, combined with lazy propagation.
In particular, in each node of the segtree, store the following:

  1. The number of zeros and number of ones.
  2. The leftmost and rightmost characters.
  3. The leftmost occurrence of substrings \texttt{00} and \texttt{11}.
  4. A flag denoting whether the range is flipped or not.

Merging two nodes is straightforward: frequencies are added, left/right characters are taken from the left/right node respectively, and leftmost occurrences of \texttt{00} and \texttt{11} can obtained by taking from the left if possible; otherwise trying to merge across the middle; and lastly from the right.

Finding the closest occurrence of \texttt{00} after an index x is now easily done in \mathcal{O}(\log^2 N) with binary search, or \mathcal{O}(\log N) if the search is baked into the segment tree descent.
Range sums are standard lazy segment tree affair after that (and can also be computed during the descent, without requiring a separate second query).

If you’re stuck on implementation detail, looking at the code attached below might help.

TIME COMPLEXITY:

\mathcal{O}((N+Q)\log N) per testcase.

CODE:

Author's code (C++)
// library link: https://github.com/manan-grover/My-CP-Library/blob/main/library.cpp
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
#define asc(i,a,n) for(I i=a;i<n;i++)
#define dsc(i,a,n) for(I i=n-1;i>=a;i--)
#define forw(it,x) for(A it=(x).begin();it!=(x).end();it++)
#define bacw(it,x) for(A it=(x).rbegin();it!=(x).rend();it++)
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define lb(x) lower_bound(x)
#define ub(x) upper_bound(x)
#define fbo(x) find_by_order(x)
#define ook(x) order_of_key(x)
#define all(x) (x).begin(),(x).end()
#define sz(x) (I)((x).size())
#define clr(x) (x).clear()
#define U unsigned
#define I long long int
#define S string
#define C char
#define D long double
#define A auto
#define B bool
#define CM(x) complex<x>
#define V(x) vector<x>
#define P(x,y) pair<x,y>
#define OS(x) set<x>
#define US(x) unordered_set<x>
#define OMS(x) multiset<x>
#define UMS(x) unordered_multiset<x>
#define OM(x,y) map<x,y>
#define UM(x,y) unordered_map<x,y>
#define OMM(x,y) multimap<x,y>
#define UMM(x,y) unordered_multimap<x,y>
#define BS(x) bitset<x>
#define L(x) list<x>
#define Q(x) queue<x>
#define PBS(x) tree<x,null_type,less<I>,rb_tree_tag,tree_order_statistics_node_update>
#define PBM(x,y) tree<x,y,less<I>,rb_tree_tag,tree_order_statistics_node_update>
#define pi (D)acos(-1)
#define md 1000000007
#define rnd randGen(rng)
class seg{
public:
  struct node{
    //........start
    I ones;
    I nxz;
    I nxo;
    B ls;
    B fr;
    B lazy;
    //........end
    I lft,rgt;
  };
  I m;
  V(node) segarr;
  node merge(node a,node b){
    node ans;
    ans.lft=min(a.lft,b.lft);
    ans.rgt=max(a.rgt,b.rgt);
    //..............start
    ans.ones=a.ones+b.ones;
    ans.lazy=0;
    ans.nxz=min(a.nxz,b.nxz);
    ans.ls=b.ls;
    ans.fr=a.fr;
    if(!a.ls && !b.fr){
      ans.nxz=min(ans.nxz,a.rgt);
    }
    ans.nxo=min(a.nxo,b.nxo);
    if(a.ls && b.fr){
      ans.nxo=min(ans.nxo,a.rgt);
    }
    //..............end
    return ans;
  }
  void make(node &temp,B a){
    temp.ones=a;
    temp.nxz=m-1;
    temp.nxo=m-1;
    temp.ls=a;
    temp.fr=a;
    temp.lazy=0;
  }
  seg(I n,S s){
    m=pow(2,ceil(log2(n)));
    node temp;
    asc(i,0,2*m-1){
      segarr.pb(temp);
    }
    asc(i,0,m){
      if(i<n){
        make(segarr[i+m-1],s[i]-'0');
      }else{
        make(segarr[i+m-1],0);
      }
      segarr[i+m-1].lft=i;
      segarr[i+m-1].rgt=i;
    }
    dsc(i,0,m-1){
      segarr[i]=merge(segarr[2*i+1],segarr[2*i+2]);
    }
  }
  void push(I &x){
    if(segarr[x].lazy){
      segarr[x].ones=segarr[x].rgt-segarr[x].lft+1-segarr[x].ones;
      segarr[x].ls^=1;
      segarr[x].fr^=1;
      swap(segarr[x].nxz,segarr[x].nxo);
    }
    if(2*x+1<sz(segarr)){
      segarr[2*x+1].lazy^=segarr[x].lazy;
      segarr[2*x+2].lazy^=segarr[x].lazy;
    }
    segarr[x].lazy=0;
  }
  node query_help(I &l,I &r,I x){
    push(x);
    if(segarr[x].lft>=l && segarr[x].rgt<=r){
      return segarr[x];
    }
    if(l>segarr[2*x+1].rgt){
      return query_help(l,r,2*x+2);
    }
    if(r<segarr[2*x+2].lft){
      return query_help(l,r,2*x+1);
    }
    return merge(query_help(l,r,2*x+1),query_help(l,r,2*x+2));
  }
  I query(I l,I r){
    I rr=query_help(l,r,0).nxz;
    return query_help(l,rr,0).ones;
  }
  void update_range_help(I &l,I &r,I x){
    if(segarr[x].lft>=l && segarr[x].rgt<=r){
      segarr[x].lazy^=1;
      push(x);
      return;
    }
    push(x);
    if(segarr[x].lft>r || segarr[x].rgt<l){
      return;
    }
    update_range_help(l,r,2*x+1);
    update_range_help(l,r,2*x+2);
    segarr[x]=merge(segarr[2*x+1],segarr[2*x+2]);
  }
  void update_range(I &l,I &r){
    update_range_help(l,r,0);
  }
};
int main(){
  mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
  uniform_int_distribution<I> randGen;
  ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
  #ifndef ONLINE_JUDGE
  freopen("input.txt", "r", stdin);
  freopen("output.txt", "w", stdout);
  #endif
  I t;
  cin>>t;
  I tt=t;
  while(t--){
    I n,q;
    cin>>n>>q;
    S s;
    cin>>s;
    s+='0';
    n++;
    seg sg(n,s);
    while(q--){
      I tp;
      cin>>tp;
      if(tp==1){
        I l,r;
        cin>>l>>r;
        l--;
        r--;
        sg.update_range(l,r);
      }else{
        I x;
        cin>>x;
        x--;
        cout<<sg.query(x,n-1)<<"\n";
      }
    }
  }
  return 0;
}
Tester's code (C++)
//clear adj and visited vector declared globally after each test case
//check for long long overflow   
//Mod wale question mein last mein if dalo ie. Ans<0 then ans+=mod;
//Incase of close mle change language to c++17 or c++14  
//Check ans for n=1 
// #pragma GCC target ("avx2")    
// #pragma GCC optimization ("O3") 
// #pragma GCC optimization ("unroll-loops")
#include <bits/stdc++.h>                   
#include <ext/pb_ds/assoc_container.hpp>    
#define IOS std::ios::sync_with_stdio(false); cin.tie(NULL);cout.tie(NULL);cout.precision(dbl::max_digits10);
#define pb push_back 
#define mod 1000000007ll //998244353ll
#define lld long double
#define mii map<int, int> 
#define pii pair<int, int>
#define ll long long 
#define ff first
#define ss second 
#define all(x) (x).begin(), (x).end()
#define rep(i,x,y) for(int i=x; i<y; i++)    
#define fill(a,b) memset(a, b, sizeof(a))
#define vi vector<int>
#define setbits(x) __builtin_popcountll(x)
#define print2d(dp,n,m) for(int i=0;i<=n;i++){for(int j=0;j<=m;j++)cout<<dp[i][j]<<" ";cout<<"\n";}
typedef std::numeric_limits< double > dbl;
using namespace __gnu_pbds;
using namespace std;
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> indexed_set;
//member functions :
//1. order_of_key(k) : number of elements strictly lesser than k
//2. find_by_order(k) : k-th element in the set
const long long N=200005, INF=2000000000000000000;
const int inf=2e9 + 5;
lld pi=3.1415926535897932;
int lcm(int a, int b)
{
    int g=__gcd(a, b);
    return a/g*b;
}
int power(int a, int b, int p)
    {
        if(a==0)
        return 0;
        int res=1;
        a%=p;
        while(b>0)
        {
            if(b&1)
            res=(1ll*res*a)%p;
            b>>=1;
            a=(1ll*a*a)%p;
        }
        return res;
    }
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

int getRand(int l, int r)
{
    uniform_int_distribution<int> uid(l, r);
    return uid(rng);
}

struct node1 {
    int left, right, first_00, first_11, count_0, count_1;
};

node1 zero = {-1, -1, -1, -1, 0, 0};
node1 st[4*N];
int lazy[4*N];
string s;

void flip(int node)
{
    st[node].left^=1;
    st[node].right^=1;
    swap(st[node].first_00, st[node].first_11);
    swap(st[node].count_0, st[node].count_1);
}
node1 combine(node1 a, node1 b, int r1)
{
    if(a.left==-1)
    return b;
    if(b.left==-1)
    return a;
    node1 c = {a.left, b.right, min(a.first_00, b.first_00), min(a.first_11, b.first_11), a.count_0+b.count_0, a.count_1+b.count_1};
    if(a.right==b.left)
    {
        if(a.right==0 && c.first_00>r1)
        c.first_00=r1;
        if(a.right==1 && c.first_11>r1)
        c.first_11=r1;
    }
    return c;
}
void propogate(int node, int l, int r)
{
    if(l!=r)
    {
        lazy[node*2]^=lazy[node];
        lazy[node*2+1]^=lazy[node];
    }
    flip(node);
    lazy[node]=0;
}
void build(int node, int l, int r)
{
    lazy[node]=0;
    if(l==r)
    {
        st[node]={s[l]-'0', s[l]-'0', inf, inf, s[l]=='0', s[l]=='1'};
        return;
    }
    int mid=(l+r)/2;
    build(node*2, l, mid);
    build(node*2+1, mid+1, r);
    st[node]=combine(st[node*2], st[node*2+1], mid);
    return;
}
void update(int node, int l, int r, int x, int y)
{
    if(lazy[node])
    propogate(node, l, r);
    if(y<x||x>r||y<l)
    return;
    if(l>=x&&r<=y)
    {
        flip(node);
        if(l!=r)
        {
            lazy[node*2]^=1;
            lazy[node*2+1]^=1;
        }
        return;
    }
    int mid=(l+r)/2;
    update(node*2, l, mid, x, y);
    update(node*2+1, mid+1, r, x, y);
    st[node]=combine(st[node*2], st[node*2+1], mid);
    return;
}
node1 query(int node, int l, int r, int x, int y)
{
    if(lazy[node])
    propogate(node, l, r);
    if(y<x||y<l||x>r)
    return zero;
    if(l>=x&&r<=y)
    return st[node];
    int mid=(l+r)/2;
    return combine(query(node*2, l, mid, x, y), query(node*2+1, mid+1, r, x, y), mid);
}

int main()
{   
  #ifndef ONLINE_JUDGE
  freopen("input.txt", "r", stdin);
  freopen("output.txt", "w", stdout);
  #endif
    int t;
    cin>>t;
    while(t--)
    {
        int n, q;
        cin>>n>>q;
        cin>>s;
        build(1, 0, n-1);
        while(q--)
        {
            int typ;
            cin>>typ;
            if(typ==1)
            {
                int l, r;
                cin>>l>>r;
                l--, r--;
                update(1, 0, n-1, l, r);
            }
            else
            {
                int x;
                cin>>x;
                x--;
                int last=query(1, 0, n-1, x, n-1).first_00;
                cout<<query(1, 0, n-1, x+1, min(last, n-1)).count_1<<"\n";
            }
        }
    }
}