SUB_XOR - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4
Practice

Setter: Aryan Raj
Testers: Tejas Pandey and Abhinav sharma
Editorialist: Taranpreet Singh

DIFFICULTY

Easy

PREREQUISITES

Bitwise operations, prefix sums, Contribution trick

PROBLEM

Given a binary string S of length N, find the beauty of string S, which is defined as the bitwise XOR of decimal representations of all substrings of S.

QUICK EXPLANATION

  • Consider each occurrence of 1 separately, and consider its contribution to XOR.
  • For a 1 at position p (0-based) in S, it will be included in (1+p)*(|S|-p) substrings.
  • This 1 will appear at 0-th position in (1+p) substrings, at 1-st position in (1+p) substrings, \ldots, at (|S|-p+1)-th position in (1+p) substrings.
  • We can maintain the total number of substrings containing ON bit at position x for each x using the difference array, and compute its prefix sum.
  • In the final XOR, only those bits would be ON, which have an odd number of substrings having ON bit at that position.

EXPLANATION

Since XOR operation is applied on bits and the string is also in binary, we shall solve the whole problem in binary to find the binary representation of the answer and then finally compute its decimal value modulo 998244353.

For some 0 \leq p \leq N-1, p-th bit in final XOR would be ON if and only if there are an odd number of substrings having p-th bit ON. So, if we can compute for each p, the number of substrings of S having p-th bit ON, we get the binary representation of the answer.

Reduced problem

Now, the problem is to count the number of substrings of S, which have p-th bit on for all 0 \leq p \lt |S|. Let’s denote this count as cnt_p.

Now, let’s consider each ON bit in S, and find its contribution to cnt_p. Let’s assume we have position c with S_c = 1. Firstly, position c is included in exactly (1+c)*(|S|-c) substrings of S (We have (1+c) choices for left end, and (|S|-c) choices for right end).

For example, considering string 101000 and c = 2, there are 3 substrings where S_2 appears at lowest bit (right end at 2), 3-substrings with S_2 appearing at second lowest bit (right end at 3) and so on.

We can see that for position c and right end of substring r, it contributes (1+c) substrings to cnt_{r-c}. The right end r can take values c \leq r \lt |S|. Hence, if c-th character in S is ON, it shall contribute (1+c) substrings to all cnt_{r-c} for c \leq r \lt N.

But naively updating each cnt_{r-c} for all pairs (r, c) would take |S|^2 time, which would time out, we need something faster.

Standard problem

We need to increase cnt_{x} for all 0 \leq x \lt |S|-c for a some set of positions, and then compute the final array cnt. We need to perform a range increment operation here.

We can use difference arrays, supporting this operation in O(1), or segment tree supporting this in O(log_2(N)).

Taking the prefix/suffix sum (depending upon implementation) will recover the original cnt array, using which, the final answer can be computed.

If in doubt, refer to my solution with comments mentioning each step.

TIME COMPLEXITY

The time complexity is O(|S|) per test case.

SOLUTIONS

Setter's Solution
// author: Aryan Raj
#include "bits/stdc++.h"
using namespace std;
#define int   long long int
#define mod   998244353

signed main()
{
#ifndef ONLINE_JUDGE
  freopen("input.txt", "r", stdin);
  freopen("output.txt", "w", stdout);
#endif
  ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
  int t, n;
  cin >> t;
  while (t--)
  {
    string s;
    cin >> n >> s;
    vector<int> cnt(n);
    for (int i = 0; i < n; ++i)
      if (s[i] == '1')cnt[n - i - 1] += (i + 1);

    for (int j = n - 2; j >= 0; --j)
      cnt[j] += cnt[j + 1];

    int ans = 0, cur = 1;
    for (int i = 0; i < n; ++i)
    {
      if (cnt[i] % 2)
        ans = (ans + cur) % mod;
      cur = (cur * 2) % mod;
    }
    cout << ans << "\n";
  }
  return 0;
}
Tester's Solution 1
#include <bits/stdc++.h>
#pragma GCC optimize("Ofast")
#pragma GCC target("avx,avx2,fma")
using namespace std;


/*
------------------------Input Checker----------------------------------
*/

long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }

            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}


/*
------------------------Main code starts here----------------------------------
*/

const int MAX_T = 100;
const int MAX_N = 100000;
const int MAX_SUM_N = 200000;
const int mod = 998244353;

#define ll long long int
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)

long long int sum_len=0;

ll mpow(ll a, ll b) {
    ll res = 1;
    while(b) {
        if(b&1) res *= a, res %= mod;
        a *= a;
        a %= mod;
        b >>= 1;
    }
    return res;
}

void solve()
{
    int n = readIntLn(1, MAX_N);
    sum_len += n;
    assert(sum_len <= MAX_SUM_N);
    string s = readStringLn(n, n);
    ll mv = mpow(2, n - 1), inv = mpow(2, mod - 2);
    ll cnt = 0, ans = 0;
    for(int i = 0; i < n; i+=2) {
        if(s[i] - '0') cnt ^= 1;
        if(cnt) {
            ans += mv;
            ans %= mod;
            mv *= inv;
            mv %= mod;
            if(i + 1 < n) {
                ans += mv;
                ans %= mod;
                mv *= inv;
                mv %= mod;
            }
        } else {
            mv *= inv;
            mv %= mod;
            mv *= inv;
            mv %= mod;
        }
    }
    cout << ans << "\n";
}

signed main()
{
    //fast;
    #ifndef ONLINE_JUDGE
    //freopen("input.txt", "r", stdin);
    //freopen("output.txt", "w", stdout);
    #endif


    int t = readIntLn(1, MAX_T);

    for(int i=1;i<=t;i++)
    {
        solve();
    }

    assert(getchar() == -1);
}
Tester's Solution 2
#include <bits/stdc++.h>
using namespace std;
 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/
 
const int MAX_T = 1e5;
const int MAX_N = 1e5;
const int MAX_SUM_LEN = 1e5;
 
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define ff first
#define ss second
#define mp make_pair
#define ll long long
#define rep(i,n) for(int i=0;i<n;i++)
#define rev(i,n) for(int i=n;i>=0;i--)
#define rep_a(i,a,n) for(int i=a;i<n;i++)
 
int sum_n = 0, sum_m = 0;
int max_n = 0, max_m = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;
ll mod = 998244353;

ll po(ll x, ll n){ 
    ll ans=1;
    while(n>0){ if(n&1) ans=(ans*x)%mod; x=(x*x)%mod; n/=2;}
    return ans;
}



void solve()
{   
    
    int n = readIntLn(1,1e5);
    sum_n+=n;
    max_n = max(max_n, n);

    string s = readStringLn(n,n);
    for(auto h:s) assert(h=='0' || h=='1');

    ll df[n] = {0};

    rep(i,n){
        if(s[i]=='1'){
            df[n-1-i]+=i+1;
        }
    }

    rev(i,n-2) df[i]+=df[i+1];

    ll ans = 0;
    rep(i,n){
        if(df[i]&1) ans+=po(2,i);
    }

    ans%=mod;
    cout<<ans<<'\n';
}
 
signed main()
{

    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r" , stdin);
    freopen("output.txt", "w" , stdout);
    #endif
    fast;
    
    int t = 1;
    
    t = readIntLn(1,100);
    
    for(int i=1;i<=t;i++)
    {    
       solve();
    }
   
    //assert(getchar() == -1);
    assert(sum_n<=2e5);
 
    cerr<<"SUCCESS\n";
    cerr<<"Tests : " << t << '\n';
    cerr<<"Sum of lengths : " << sum_n <<'\n';
    cerr<<"Maximum length : " << max_n <<'\n';
    // cerr<<"Total operations : " << total_ops << '\n';
    //cerr<<"Answered yes : " << yess << '\n';
    //cerr<<"Answered no : " << nos << '\n';
}
Editorialist's Solution (with comments)
import java.util.*;
import java.io.*;
class SUB_XOR{
    //SOLUTION BEGIN
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        int N = ni();
        char[] C = n().toCharArray();
        long[] cnt = new long[N];
        for(int i = 0; i< N; i++)
            if(C[i] == '1')
                cnt[N-i-1] += (1+i);//Adding contribution of on bits
        
        for(int i = N-2; i>= 0; i--)cnt[i] += cnt[i+1]; // Taking suffix sum to recover cnt
        
        //Converting cnt to decimal number
        long ans = 0, f = 1, MOD = 998244353;
        for(int i = 0; i< N; i++){
            cnt[i] %= 2; // Only the parity of count matters
            ans += f*cnt[i]%MOD;
            if(ans >= MOD)ans -= MOD;
            f = (f*2)%MOD;
        }
        pn(ans);
    }
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = true;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
        new SUB_XOR().run();
    }
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
                }
            }
            return st.nextToken();
        }

        String nextLine() throws Exception{
            String str = "";
            try{   
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            }  
            return str;
        }
    }
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

1 Like

why my ans is Wrong ?? kindly check my approach :
My approach was based on idea that
from most significant bit to least significant bit no. of times xor performed by each bit is 1 to n respectively with itself and all bits lower significant than it.
so i started from MSB to LSB if bit is 1 its contribution is added in variable called “i1” and if i1 is odd its position value is added to final ans.

#include <bits/stdc++.h>
using namespace std;

void solve(){long long i,n,j,ans,i1,m=998244353;n=0;j=0;i=0;i1=0;
    cin>>n;//char ch;
   string s;cin>>s;
  // for(i=0;i<n;i++){cin>>ch;s.push_back(ch);}
   //vector<int> a;
   ans=0;
   if(s[0]=='1')i1=1;else i1=0;i=pow(2,n-1);
   ans=i*i1;ans=ans%m;
   for(j=1;j<n;j++){if(s[j]=='1')i1+=j+1;
       i1=i1%2;
       if(i1){i=pow(2,n-1-j);i=i%m;//cout<<n-1-j<<" ";
       ans=ans+i;
       ans=ans%m;
       }
   }
// a.clear();
      ans=ans%m;
       cout<<ans<<"\n";

}

int main()
{int t;cin>>t;while(t--){solve();}
    
    return 0;
}

One observation that can be made from the xor of all substrings is that if string s is equal to
s = { a0,a1,a2,a3,a4,a5,a6,…,a(n-1)}
the final decimal number’s binary representation is :
num = {a0,a0,(a0^a2),(a0^a2),(a0^a2^a4),(a0^a2^a4),(a0^a2^a4^a6),(a0^a2^a4^a6),…}
one can then easily iterate from n-1 to 0 th index and convert it to decimal representation accordingly.

1 Like

Hi,
Can you please explain in detail on how do you get this formula ?
“num = {a0,a0,(a0^a2),(a0^a2),(a0^a2^a4),(a0^a2^a4),(a0^a2^a4^a6),(a0^a2^a4^a6),…}”

1 Like

Startin from simple size 2 i kept on increasing the size and noticed this pattern.
Eg. S={a0,a1,a2} therfore all numbers are represented in binary form by a0,a1,a2 , a0a1, a1a2, a1a2a3 xor all this to find final string as a0,a0,a0^a2.
Likewise keep on increasing size and you will get the above mentioned pattern.

1 Like

Sorry,but pls explain this !

1 Like

They asked us to take all possible substrings of the string and xor all of their decimal representation.
Eg. 011 is the string hence we can say that there are substrings of size 1,2 and finally whole substring of size 3.
Size1 : 0,1,1
Size2: 01,11
Size3: 011
Now if i wish to xor there decimal representation i can perform the same in binary form as well.
See 11 is 3 in decimal form,01 is 1 in decimal form so we see 3^1 is equal to 2 which is 10 in binary form which i could do in the binary form as well.
Now to xor 2 and 1 sizes numbers with the 3 size i will simply add the required leading zeros
So
0 0(0)
0 0(1)
0 0(1)
0 (01)
0 (11)
0 11


0 0 1
Just replace the mentioned bits with a0,a1,a2 and u can see that pattern

1 Like

I don’t mean to impose but I noticed the same pattern so maybe I can make it clear too.

For any position, the bits on index left never reach that position in any of the substrings.(The first observation)

Now if you look at index 0, only original element at index 0 can affect it and that is only once.(original sub string)

When you go to index 1, the element at index 0 can get to index 1 position once(sub string from 0 to second last) And the original element at index 1 can reach index 1 position in 2 ways(original sub string and a sub string from index 1 to last)

Similarly when you check for index 2, you will notice element at index 0 affecting it once, element at index 1 affecting it twice, element at index 2 affecting it thrice.

The pattern is Index 0–>affects all indexes once.
Index 1–>affects all indexes twice(other than index 0)
Index 2–>affects all indexes thrice(other than 0,1)
This shows all elements at even indexes can affect themselves and those on left an odd number of times.
Those at odd indexes will make themselves 0(no matter if they are 0 or 1) and have no effect on higher indexes too.

So, checking only even indexes, and assigning same value as preceding index to the odd index positions led to the formula.

1 Like

**I made solution using video soluction provided by codeshef. **
Can soneone check why my code fails?

#include
using namespace std;
#include

int main() {
// your code goes here
int T;
cin>>T;
while(T–){
long int N;
cin>>N;
char s[N];
for(int i=0; i<N; i++){
cin>>s[i];
}

   int arr[N];
   for(int i=0; i<N; i++){
       arr[i]=0;
   }
   for(int j=0; j<N; j++){
       if(s[j]=='1'){
           for(int i=0; i<=N-j-1; i++){
               arr[i]+=j+1;
           }
       }
   }
   
   for(int i=0; i<N; i++){
       if(arr[i]%2==0){
           arr[i]=0;
       }
       else{
           arr[i]=1;
       }
   }
   
  long int res=0;
   for(int i=0; i<N; i++){
       res=res+pow(2, i)*arr[i];
   }
   cout<<res%998244353<<endl;
}
return 0;

}

Well, it’s wrong because this takes O(n^2) time. You need it in linear time at least afaik.

Why does this give wrong answer.
Instead of suffix I have taken prefix

#include <bits/stdc++.h>
using namespace std;

#define vi vector<int>
#define ll long long
#define vll vector<long long>
#define mod 998244353;
int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    ll t;
    cin>>t;
    while(t--){
        ll n;
        cin>>n;
        string str;
        cin>>str;
        vll prefix(n,0);
        for(ll i=0;i<n;i++){
            if(str[i]=='1'){
                if(i==0)
                    prefix[0]=i+1;
                else
                    prefix[i]=prefix[i-1]+i+1;
            }
            else{
                prefix[i]=prefix[i-1];
            }
        }
        ll xor_ans=0,expo=1;
        for(int i=n-1;i>=0;i--){
            if(prefix[i]%2){
                xor_ans = (xor_ans+expo) % mod;
            }
            expo=(expo*2)%mod;
            
        }
        cout<<xor_ans<<"\n";
    }
}

There is an an overflow in pow function as there are no pre defined modulo while calculation large powers. Better use a self defined power function(preferable binary exponentiation to avoid TLE) or just make an array to calculate and store powers of 2 from 0 to n.

1 Like

The reason your approach fails is because the in the binary representation of the input string (lets call it s), s[0] is the most significant bit, and s[s.length() - 1] is the least. but judging from your code, you seem to have guessed it other way around.

Here, I have made some changes in your code and got it accepted. Note that I have reversed the string s so that there is minimal change required in your code.

#include <bits/stdc++.h>
using namespace std;

#define vi vector<int>
#define ll long long
#define vll vector<long long>
#define mod 998244353;
int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    ll t;
    cin>>t;
    while(t--){
        ll n;
        cin>>n;
        string str;
        cin>>str;
        reverse(str.begin(), str.end());
        vll prefix(n+1,0);

        for(ll i=n-1;i>=0;i--){
            if(str[i]=='1'){
                if(i==n-1)
                    prefix[i]=(n-i);
                else
                    prefix[i]=prefix[i+1]+(n-i);
            }
            else{
                prefix[i]=prefix[i+1];
            }
            
        }
        ll xor_ans=0,expo=1;
        for(int i=0;i<n;i++){
            if(prefix[i]%2){
                xor_ans = (xor_ans+expo) % mod;
            }
            expo=(expo*2)%mod;
            
        }
        cout<<xor_ans<<"\n";
    }
}

I made a mistake on my implementation. It was not really an overflow bug, but similar: compute exact bit count vs in modular arithmetics where you may add/discount an odd number of times an odd number to the exact value …

See Tech Vineyard: Codechef SUB_XOR