How to solve Xor sum problem from Hackerearth?

I am trying to solve Xor sum problem from Hackerearth Practice.

I am not getting the idea on how to apply the merge operation in the construction of segment tree for this problem. Any hints with code is appreciated.

Thank you :slight_smile:

You need the count of each bit, Let a_i denote the number of times bit i is set, and b_i denote the number of times it is not set in the given range. The answer is \sum (\binom{a_i}{1}\binom{b_i}{2} + \binom{a_i}{3}\binom{b_i}{0} )\times 2^i

1 Like

I got it, so for a triple xor to be 1 you need to have 1 set bit and 2 non-set bits or 3 set bits. Number of ways of doing that is what given right?

Thank you :slightly_smiling_face: )

I guess segment tree is used to get those number of set and non-set bits in a range.
Implemeting it using fenwick tree would be quite easier

#include<bits/stdc++.h>
using namespace std;
#define ll long long int
#define MOD 1000000007
ll n,q,a[100005],seg[400005][45],l,r,po[100005];
/*void build(ll node,ll st,ll en)
{
 if(st==en)
 {
     for(ll i=0;i<=42;i++)
     {
         if(a[st]&(1<<i))
         seg[node][i]=1;
     }
 }
 else
 {
     ll mid=(st+en)/2;
     build(2*node,st,mid);
     build(2*node+1,mid+1,en);
     for(ll i=0;i<=42;i++)
     {
         seg[node][i]=seg[2*node]+seg[2*node+1];
     }
 }
}
ll qry(ll node,ll st,ll en,ll l,ll r,ll idx)
{
 if(st>en||st>r||en<l||l>r)
 return 0;
 else if(st>=l&&en<=r)
 return seg[node][idx];
 else
 {
     ll mid=(st+en)/2;
     return qry(2*node,st,mid,l,r,idx)+qry(2*node+1,mid+1,en,l,r,idx);
 }
}*/
void create()
{
 ll i,j;
 for(i=0;i<=42;i++)
 {
     for(j=1;j<=n;j++)
     {
         seg[j][i]=seg[j-1][i];
         if(a[j]&(1LL<<i))
         seg[j][i]++;
     }
 }
}
int main()
{
 //ios::sync_with_stdio(0);
 //cin.tie(0);
 freopen("in05.txt","r",stdin);
 freopen("out05.txt","w",stdout);
 ll i,j,k;
 cin>>n;
 po[0]=1;
 for(i=1;i<=100002;i++)
 {
     po[i]=2*po[i-1];
     po[i]%=MOD;
 }
 for(i=1;i<=n;i++)
 {
     cin>>a[i];
 }
 //build(1,1,n);
 create();
 cin>>q>>j;
 while(q--)
 {
     cin>>l>>r;
     ll cnt1,cnt0,ans=0,ans1=0;
     for(i=0;i<=42;i++)
     {
         cnt1=seg[r][i]-seg[l-1][i];
         cnt0=r-l+1-cnt1;
         ans=cnt1*(cnt0*(cnt0-1))/2;
         ans+=(cnt1*(cnt1-1)*(cnt1-2))/6;
         ans%=MOD;
         ans=ans*po[i];
         ans%=MOD;
         ans1+=ans;
         ans1%=MOD;
     }
     cout<<ans1<<"\n";
 }
 return 0;
}

Explanation

- cnt1 stores no. of 1s occurring in i-th place of binary representation of numbers in the array from range L to R
- cnt1=seg[r][i]-seg[l-1][i];
- cnt0 stores no. of 0s occurring in i-th place of binary representation of numbers in the array from range L to R
- cnt0=r-l+1-cnt1;

- (cnt0*(cnt0-1))/2 (nC2) calculates no. of distinct pairs of 0s that can be formed from 0s present in i-th place of binary representation of numbers in the array from range L to R
- multiply no. of 1s with the no of 0-pairs calculated
- it will result in no of trios formed whose xor will give 1 (because 1 xor 0 xor 0 is 1)
   ans=cnt1*(cnt0*(cnt0-1))/2;

- (cnt1*(cnt1-1)*(cnt1-2))/6 (nC3) calculates no. of distinct trios of 1s that can be formed from 0s present in i-th place of binary representation of numbers in the array from range L to R
- it will also result in no of trios formed whose xor will give 1 (because 1 xor 1 xor 1 is 1)
ans += (cnt1*(cnt1-1)*(cnt1-2))/6;

ans%=MOD;
ans=ans*po[i];
ans%=MOD;
ans1+=ans;
ans1%=MOD;