BIN_OD - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Sahil Tiwari
Testers: Takuki Kurokawa, Utkarsh Gupta
Editorialist: Nishank Suresh

DIFFICULTY:

1776

PREREQUISITES:

Prefix sums

PROBLEM:

You are given an array A and A queries on it. For each query, you are given two subarrays and an integer k.
Find the number of pairs of elements, one from the first subarray and one from the second, such that their bitwise xor has the k-th bit set.

EXPLANATION:

Let’s look at answering a single query (k, L_1, R_1, L_2, R_2) first: speeding it up to answer multiple queries can come later.

Suppose A_i \oplus A_j has its k-th bit set. This is only possible when:

  • A_i has its k-th bit set and A_j doesn’t; or
  • A_j has its k-th bit set and A_i doesn’t

In particular, if take some A_i from [L_1, R_1] with its k-th bit set, we can pair it with any A_j from [L_2, R_2] whose k-th bit is unset.
Similarly, if take some A_i from [L_1, R_1] with its k-th bit unset, we can pair it with any A_j from [L_2, R_2] whose k-th bit is set.

This gives us a rather simple solution:

  • Let S_1 be the number of elements in subarray [L_1, R_1] that have the k-th bit set
  • Let U_1 be the number of elements in subarray [L_1, R_1] that have the k-th bit unset
  • Let S_2 be the number of elements in subarray [L_2, R_2] that have the k-th bit set
  • Let U_2 be the number of elements in subarray [L_2, R_2] that have the k-th bit unset

Then, the answer to this query is simply S_1\cdot U_2 + S_2\cdot U_1.

Computing S_1, S_2, U_1, U_2 is easy to do by looping across the subarrays, but that’s not fast enough to answer multiple queries: we need something a bit faster.

Using prefix sums

Notice that, if k is fixed, we can treat each element of the array as being either 0 or 1 depending on whether it has the k-th bit set or not.

Then, the above variables simplify quite nicely:

  • S_1 and S_2 are the number of ones in their respective ranges, or more specifically, just the sums of those ranges.
  • U_1 and U_2 are the number of zeros in their respective ranges. Knowing S_1, S_2, and the lengths of the ranges is enough to compute these values (since S_1 + U_1 = R_1-L_1 + 1 and S_2 + U_2 = R_2+L_2-1).

Computing range sums quickly is a well-known application of prefix sums.
We need to maintain separate prefix sums for each k, but there are only 60 possible values of k anyway so this is not an issue.

That is, for each 0 \leq k \lt 60, let pref_{k, i} denote the number of elements in [1, i] that have the k-th bit set.
Then,

  • S_1 = pref_{k, R_1} - pref_{k, L_1-1}
  • S_2 = pref_{k, R_2} - pref_{k, L_2-1}
  • U_1 and U_2 can be computed as noted above.

This allows us to answer each query in \mathcal{O}(1) time.

TIME COMPLEXITY

\mathcal{O}(60\cdot N + Q) per test case.

CODE:

Setter's code (C++)
//	Code by Sahil Tiwari (still_me)

#include<bits/stdc++.h>
#define still_me main
#define endl "\n"
#define int long long int
#define all(a) (a).begin() , (a).end()
#define print(a) for(auto TEMPORARY: a) cout<<TEMPORARY<<" ";cout<<endl;
#define tt int TESTCASE;cin>>TESTCASE;while(TESTCASE--)
#define arrin(a,n) for(int INPUT=0;INPUT<n;INPUT++)cin>>a[INPUT]

using namespace std;
const int mod = 1e9+7;
const int inf = 1e18;

void solve() {
    int n , q;
    cin>>n>>q;
    vector<int> a(n);
    arrin(a , n);
    vector<vector<int>> b(n+1 , vector<int>(61));
    for(int i=0;i<n;i++) {
        for(int j=0;j<61;j++) {
            if(a[i] & (1ll << j))
                b[i+1][j]++;
            b[i+1][j] += b[i][j];
        }
    }
    while(q--) {
        int k , l , r , x , y;
        cin>>k>>l>>r>>x>>y;
        int o1 = b[r][k] - b[l-1][k];
        int o2 = b[y][k] - b[x-1][k];
        int z1 = r-l+1 - o1;
        int z2 = y-x+1 - o2;
        cout<<(o1*z2 + o2*z1)<<endl;
    }

}

signed still_me()
{
    ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);

    tt{
        solve();
    }
    return 0;
}
Tester's code (C++)
//Utkarsh.25dec
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#include <vector>
#include <set>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <queue>
#include <ctime>
#include <cassert>
#include <complex>
#include <string>
#include <cstring>
#include <chrono>
#include <random>
#include <bitset>
#include <array>
#define ll long long int
#define pb push_back
#define mp make_pair
#define mod 1000000007
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
const int N=500023;
bool vis[N];
vector <int> adj[N];
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }

            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
int sumN=0,sumQ=0;
void solve()
{
    int n=readInt(1,100000,' ');
    sumN+=n;
    int q=readInt(1,500000,'\n');
    sumQ+=q;
    assert(sumN<=100000);
    assert(sumQ<=500000);
    int sum[n+1][65];
    memset(sum,0,sizeof(sum));
    long long A[n+1];
    memset(A,0,sizeof(A));
    for(int i=1;i<=n;i++)
    {
        if(i==n)
            A[i]=readInt(0,1LL<<60,'\n');
        else
            A[i]=readInt(0,1LL<<60,' ');
        for(int j=0;j<60;j++)
        {
            sum[i][j]=sum[i-1][j];
            if((A[i]&(1LL<<j))!=0)
                sum[i][j]++;
        }
    }
    while(q--)
    {
        int k=readInt(0,59,' ');
        int l1=readInt(1,n,' ');
        int r1=readInt(l1,n,' ');
        int l2=readInt(r1+1,n,' ');
        int r2=readInt(l2,n,'\n');
        long long left1s=sum[r1][k]-sum[l1-1][k];
        long long left0s=(r1-l1+1)-left1s;
        long long right1s=sum[r2][k]-sum[l2-1][k];
        long long right0s=(r2-l2+1)-right1s;
        cout<<(left1s*right0s)+(left0s*right1s)<<'\n';
    }
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL),cout.tie(NULL);
    int T=readInt(1,50000,'\n');
    while(T--)
        solve();
    assert(getchar()==-1);
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Editorialist's code (Python)
for _ in range(int(input())):
    n, q = map(int, input().split())
    a = list(map(int, input().split()))
    pref = [[0 for i in range(60)] for _ in range(n+1)]
    for i in range(n):
        for k in range(60):
            pref[i+1][k] = pref[i][k] + ((a[i] >> k) & 1)
    for i in range(q):
        k, l1, r1, l2, r2 = map(int, input().split())
        on1, on2 = pref[r1][k] - pref[l1-1][k], pref[r2][k] - pref[l2-1][k]
        off1, off2 = r1-l1+1 - on1, r2-l2+1 - on2
        print(on1*off2 + on2*off1)
2 Likes

I can’t understand what is wrong in my code ?? plz help
https://www.codechef.com/viewsolution/80193617

2 Likes

1 << j will overflow for j \gt 31, since it computes in int. Use 1LL << j instead.

Unfortunately, that appears to be your only mistake.

2 Likes

https://www.codechef.com/viewsolution/80275344

Runtime error: RE (SIGSEGV) for some of the tests.
Anyone knows why?

@zoharbarak
It is because you have defined a as a vector of int, and it should be long long .

Now what happens is that cin expects an int, but in the input buffer is a big number. So this creates some anomalies and in the end, cin doesn’t behave well.

1
2 1
1000000000000000 1000000000000000
1 1 1 2 2

If you try this custom test case in your original code, you will get a runtime error on CodeChef ide.

Also, you should look for overflow in your final calculation of res
Modified Accepted Solution: CodeChef: Practical coding for everyone

1 Like

Thanks :slight_smile:

https://www.codechef.com/viewsolution/80284830

I’m not sure why this gets TLE (I just loop n & q)

You are not inputting anything inside the loop for “q”…So the 1st value (i.e. k) becomes the n of the next test case, and further operations are performed according to that, which can cause TLE.

There can be upto 5\cdot 10^4 testcases, and you’re creating an array of size 10^5 \times 61 for each one. That’s over 10^{11} operations just to allocate the memory, it’s no surprise that you get TLE.

1 Like

1<<j will overflow
use
1LL<<j instead

1 Like

Hello All,
Can anyone please tell me what’s wrong with my code, it was failing one test case.
My Code
Thanks.

@celestialidiot
Using k = log2(mx) was causing errors.
Use k = 59.
Modified Accepted Solution
https://www.codechef.com/viewsolution/80307160

Thanks

Take my whole day to code. learn a new concept of prefix sum on bit. I was thinking that this don’t exists . There were no blog post that I find related to this .
:face_holding_back_tears: If anyone Have set of question to practice on prefix on bit . Please do reply or mail it (yadav11adu@gmail.com).

#include <bits/stdc++.h>
using namespace std;

#define int     long long int

// count the xor value whose kth bit is set
void solve(){
    int n,q;
    cin>>n>>q;
    vector<int> v(n);
    vector<vector<int>> prefix(n,vector<int> (62));

    for(int i=0;i<n;i++){
        cin>>v[i];

        vector<int> temp(60);
        for(int j=0;j<=60;j++){
            if((v[i]>>j)&1) temp[j] = 1;
            else temp[j] = 0;

            if(i == 0) prefix[i][j] = temp[j];
            else prefix[i][j] = prefix[i-1][j] + temp[j];
            // cout<<prefix[i][j]<<" ";
        }
        // cout<<endl;

    }

    while(q--){
        int k,l1,r1,l2,r2;
        cin>>k>>l1>>r1>>l2>>r2;

        l1--,r1--,l2--,r2--;

        int pr1;
        if(l1 - 1 < 0) pr1 = 0;
        else pr1 = prefix[l1-1][k];

        int pr2;
        if(l2 - 1 < 0) pr2 = 0;
        else pr2 = prefix[l2-1][k];

        int FirstSet = prefix[r1][k] - pr1;
        int FirstUnSet = (r1 - l1 + 1) - FirstSet;

        int SecondSet = prefix[r2][k] - pr2;
        int SecondUnSet = (r2 - l2 + 1) - SecondSet;

        // cout<<FirstSet<<" "<<FirstUnSet<<" "<<SecondSet<<" "<<SecondUnSet<<endl;
        cout<<FirstSet*SecondUnSet + FirstUnSet*SecondSet<<endl;
    }
}

signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0);

    int t=1;
    cin>>t;
    while(t--) solve();
    return 0;
}

TomTomAndJerryGIF

1 Like

https://www.codechef.com/viewsolution/80299712
Why do this code give me TLE?

The issue there is the lines

counts1[j]=counts1[j]+[counts1[j][-1]+((l[i])%2)]
counts0[j]=counts0[j]+[counts0[j][-1]+((((l[i])%2)+1)%2)]

If A is an array of length N, doing A = A + [x] in Python takes \mathcal{O}(N) time since it creates a copy of A, appends to it, then assigns the new list to A.
Because of this, your code is actually \mathcal{O}(60N^2),

Since you want to append to the array, just use Python’s inbuilt append function instead, which works in \mathcal{O}(1): this change alone makes your code fast enough, see submission.

You aren’t going to find a blog post on it because it’s not actually anything special or a ‘technique’.
If you can find the prefix sums of one array, you can obviously do it for 2 arrays, 3, arrays, \ldots, 60 arrays, right? That’s essentially what you’re doing here: applying prefix sums on 60 different arrays.

1 Like

can you help me understand how you have used prefix sum here or maybe provide some links to understand how to use prefix sum.

I can’t understand why we are taking prefix[n][62].

prefix[n][62]
Here n represent the number of element in array .
62 → represent number of bits require to represent a single number of type long long int.
we have taken bit prefix sum of n element in array . It is small part of bit manipulation

https://www.youtube.com/watch?v=L_fIn5TM3mM&list=PL-Jc9J83PIiFJRioti3ZV7QabwoJK6eKe&index=27&ab_channel=Pepcoding

Here is the code which you would able to understand better

#include “bits/stdc++.h”
using namespace std;


#define int        long long int
#define now(x)     cout<<#x<<" : "<<x<<endl;

bool arrSame(vector<int>& bits,vector<int>& ors){
    for(int i=0;i<bits.size();i++){
        if((bits[i] > 0 and ors[i] == 0) or 
            (bits[i] == 0 and ors[i] > 0))
            return false;
    }
    return true;
}

bool orCalculate(vector<int> v,vector<int> ors,int mid){
    vector<int> bits(32,0);
    for(int i=0;i<mid;i++) for(int j=30;j>=0;j--) 
        if(v[i] & (1ll<<j)) bits[j]++,ors[j]--;

    if(arrSame(bits,ors)){
        // cout<<mid<<endl;
        return true;
    } 

    for(int i=mid;i<v.size();i++){

        for(int j=30;j>=0;j--){
            if(v[i - mid] & (1ll<<j)) bits[j]--,ors[j]++;
            if(v[i] & (1ll<<j)) bits[j]++,ors[j]--;
        }
        
        if(arrSame(bits,ors)){
            // cout<<mid<<endl;
            return true;  
        } 
    }   

    return false;
}

void solve() {
    int n;
    cin >> n;
    vector<int> v(n);
    for (int i = 0; i < v.size(); i++) {cin >> v[i];}

    vector<int> bits(32,0);
    for(int i=0;i<v.size();i++) for(int j=30;j>=0;j--) 
        if(v[i] & (1ll << j)) bits[j]++;

    int low = 1, high = n+1, ans = -1;
    while (low <= high) {
        int mid = (low + high) >> 1;
        if (orCalculate(v, bits, mid)) {
            ans = max(ans, mid);
            low = mid + 1;
        } else {
            high = mid - 1;
        }
    }

    // cout<<"ASDF : ";
    cout << ans << endl;
}

signed main() {
    ios_base::sync_with_stdio(false); cin.tie(NULL);
    int t = 1;
    cin >> t;
    while (t--) solve();
    return 0;
}

Some basic code for better understanding

***Count the set bits in array***

vector<int> set(32,0);

for(int i=0;i<n;i++)

for(int j=30;j>=0;j--)

if(v[i] & (1ll << j)) set[j]++;

***Basic Manipulation***

int onmask = (1<<i);

int offmask = ~(1<<j);

int togglemask = (1<<k);

int checkmask = (1<<m);

cout<<( n | onmask )<<endl;

cout<<( n & offmask )<<endl;

cout<<( n ^ togglemask )<<endl;

cout<<( (n & checkmask) == 0 ? "false" : "true" )<<endl;

// check weather the bits are set or not

for(int j = 0;j<32;j++){

int mask = (1 << j);

bits[j] += ((mask & v[i]) == mask);

}

***Power of Two***

1<<n pow(2,n) ***,*** ***double*** fmod(n,mod)

***How many BST are possible with n nodes?***

(2n)! / (n+1)! n!

***How do I pass multiple ints into a vector at once?***

`vector<int> array;` ``

`array.insert(array.end(), { 1, 2, 3, 4, 5, 6 });`

``

`Concatenating two vectors`

`vector1.insert( vector1.end(), vector2.begin(), vector2.end() );`

``

Power Setcheck weather bit is set with respect to no


``

Check K-th Bit Set
return (n&(1<<k));


return (1&(n>>k));


Set K-th Bitreturn (N | (1 << K));


***Storing in desc order in map***

`map<` `int` `, string, greater<` `int` `> > mymap;`

`` `multimap<` `int` `, string, greater<` `int` `> > mymap;`

`` `set<` `int` `, greater<` `int` `> > s1;`

``

`Priority Queue`

``

`` `// Creates a max heap`

`    ` `priority_queue <` `int` `> pq;`

`// Creates a min heap`

`    ` `priority_queue <` `int` `, vector<` `int` `>, greater<` `int` `> > pq;`

``

``

`Finding No of Digit in Integer`

`    ` `// Find total number of digits - 1`

`    ` `int` `digits = (` `int` `)` `log10` `(n);`

``

**Negative No Modo**

((a%n) + n)%n)

**Binary Exponentiation**

int power(int a , int b) {
if(b == 0)
return 1;
int res = power(a , b>>1);
if(b & 1)
return (res * res % mod) * a % mod;
return res * res % mod;
}


int binExpIter(int a,int b){
int ans = 1;
while(b){
if(b&1) ans = (ans * a) %M;
a = (a*a) % M;
b >>= 1;
}
return ans;
}


***Negative Mod***

#define ma_mod(a,n) ((a%n)+a)%n


**Number of Subsequences** nC0 + nC1 + nC2 + … nCn = 2n

***Decimal to Binary***

`int` `decToBinary(` `int` `n)` {

`   ` `for` `(` `int` `i = 31; i >= 0; i--) {`

`       ` `int` `k = n >> i;`

`       ` `if` `(k & 1)`

`           ` `cout << ` `"1"` `;`

`       ` `else`

`           ` `cout << ` `"0"` `;`

`   ` `}`

`}`

***Substring of s String***

`for` `(` `int` `i = 0; i < str.length(); i++) {`

`    string subStr;`

`    ` `// Second loop is generating sub-string`

`    ` `for` `(` `int` `j = i; j < str.length(); j++) {`

`        ` `subStr += str[j];`

`        ` `cout << subStr << endl;`

`    ` `}`

`}`
1 Like

Sir I have been Observing you skill , while few contest . Please do suggest some topic that I must do I order to improve coding skill while coding