BIN_OD - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Sahil Tiwari
Testers: Takuki Kurokawa, Utkarsh Gupta
Editorialist: Nishank Suresh

DIFFICULTY:

1776

PREREQUISITES:

Prefix sums

PROBLEM:

You are given an array A and A queries on it. For each query, you are given two subarrays and an integer k.
Find the number of pairs of elements, one from the first subarray and one from the second, such that their bitwise xor has the k-th bit set.

EXPLANATION:

Let’s look at answering a single query (k, L_1, R_1, L_2, R_2) first: speeding it up to answer multiple queries can come later.

Suppose A_i \oplus A_j has its k-th bit set. This is only possible when:

  • A_i has its k-th bit set and A_j doesn’t; or
  • A_j has its k-th bit set and A_i doesn’t

In particular, if take some A_i from [L_1, R_1] with its k-th bit set, we can pair it with any A_j from [L_2, R_2] whose k-th bit is unset.
Similarly, if take some A_i from [L_1, R_1] with its k-th bit unset, we can pair it with any A_j from [L_2, R_2] whose k-th bit is set.

This gives us a rather simple solution:

  • Let S_1 be the number of elements in subarray [L_1, R_1] that have the k-th bit set
  • Let U_1 be the number of elements in subarray [L_1, R_1] that have the k-th bit unset
  • Let S_2 be the number of elements in subarray [L_2, R_2] that have the k-th bit set
  • Let U_2 be the number of elements in subarray [L_2, R_2] that have the k-th bit unset

Then, the answer to this query is simply S_1\cdot U_2 + S_2\cdot U_1.

Computing S_1, S_2, U_1, U_2 is easy to do by looping across the subarrays, but that’s not fast enough to answer multiple queries: we need something a bit faster.

Using prefix sums

Notice that, if k is fixed, we can treat each element of the array as being either 0 or 1 depending on whether it has the k-th bit set or not.

Then, the above variables simplify quite nicely:

  • S_1 and S_2 are the number of ones in their respective ranges, or more specifically, just the sums of those ranges.
  • U_1 and U_2 are the number of zeros in their respective ranges. Knowing S_1, S_2, and the lengths of the ranges is enough to compute these values (since S_1 + U_1 = R_1-L_1 + 1 and S_2 + U_2 = R_2+L_2-1).

Computing range sums quickly is a well-known application of prefix sums.
We need to maintain separate prefix sums for each k, but there are only 60 possible values of k anyway so this is not an issue.

That is, for each 0 \leq k \lt 60, let pref_{k, i} denote the number of elements in [1, i] that have the k-th bit set.
Then,

  • S_1 = pref_{k, R_1} - pref_{k, L_1-1}
  • S_2 = pref_{k, R_2} - pref_{k, L_2-1}
  • U_1 and U_2 can be computed as noted above.

This allows us to answer each query in \mathcal{O}(1) time.

TIME COMPLEXITY

\mathcal{O}(60\cdot N + Q) per test case.

CODE:

Setter's code (C++)
//	Code by Sahil Tiwari (still_me)

#include<bits/stdc++.h>
#define still_me main
#define endl "\n"
#define int long long int
#define all(a) (a).begin() , (a).end()
#define print(a) for(auto TEMPORARY: a) cout<<TEMPORARY<<" ";cout<<endl;
#define tt int TESTCASE;cin>>TESTCASE;while(TESTCASE--)
#define arrin(a,n) for(int INPUT=0;INPUT<n;INPUT++)cin>>a[INPUT]

using namespace std;
const int mod = 1e9+7;
const int inf = 1e18;

void solve() {
    int n , q;
    cin>>n>>q;
    vector<int> a(n);
    arrin(a , n);
    vector<vector<int>> b(n+1 , vector<int>(61));
    for(int i=0;i<n;i++) {
        for(int j=0;j<61;j++) {
            if(a[i] & (1ll << j))
                b[i+1][j]++;
            b[i+1][j] += b[i][j];
        }
    }
    while(q--) {
        int k , l , r , x , y;
        cin>>k>>l>>r>>x>>y;
        int o1 = b[r][k] - b[l-1][k];
        int o2 = b[y][k] - b[x-1][k];
        int z1 = r-l+1 - o1;
        int z2 = y-x+1 - o2;
        cout<<(o1*z2 + o2*z1)<<endl;
    }

}

signed still_me()
{
    ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);

    tt{
        solve();
    }
    return 0;
}
Tester's code (C++)
//Utkarsh.25dec
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#include <vector>
#include <set>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <queue>
#include <ctime>
#include <cassert>
#include <complex>
#include <string>
#include <cstring>
#include <chrono>
#include <random>
#include <bitset>
#include <array>
#define ll long long int
#define pb push_back
#define mp make_pair
#define mod 1000000007
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
const int N=500023;
bool vis[N];
vector <int> adj[N];
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }

            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
int sumN=0,sumQ=0;
void solve()
{
    int n=readInt(1,100000,' ');
    sumN+=n;
    int q=readInt(1,500000,'\n');
    sumQ+=q;
    assert(sumN<=100000);
    assert(sumQ<=500000);
    int sum[n+1][65];
    memset(sum,0,sizeof(sum));
    long long A[n+1];
    memset(A,0,sizeof(A));
    for(int i=1;i<=n;i++)
    {
        if(i==n)
            A[i]=readInt(0,1LL<<60,'\n');
        else
            A[i]=readInt(0,1LL<<60,' ');
        for(int j=0;j<60;j++)
        {
            sum[i][j]=sum[i-1][j];
            if((A[i]&(1LL<<j))!=0)
                sum[i][j]++;
        }
    }
    while(q--)
    {
        int k=readInt(0,59,' ');
        int l1=readInt(1,n,' ');
        int r1=readInt(l1,n,' ');
        int l2=readInt(r1+1,n,' ');
        int r2=readInt(l2,n,'\n');
        long long left1s=sum[r1][k]-sum[l1-1][k];
        long long left0s=(r1-l1+1)-left1s;
        long long right1s=sum[r2][k]-sum[l2-1][k];
        long long right0s=(r2-l2+1)-right1s;
        cout<<(left1s*right0s)+(left0s*right1s)<<'\n';
    }
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL),cout.tie(NULL);
    int T=readInt(1,50000,'\n');
    while(T--)
        solve();
    assert(getchar()==-1);
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Editorialist's code (Python)
for _ in range(int(input())):
    n, q = map(int, input().split())
    a = list(map(int, input().split()))
    pref = [[0 for i in range(60)] for _ in range(n+1)]
    for i in range(n):
        for k in range(60):
            pref[i+1][k] = pref[i][k] + ((a[i] >> k) & 1)
    for i in range(q):
        k, l1, r1, l2, r2 = map(int, input().split())
        on1, on2 = pref[r1][k] - pref[l1-1][k], pref[r2][k] - pref[l2-1][k]
        off1, off2 = r1-l1+1 - on1, r2-l2+1 - on2
        print(on1*off2 + on2*off1)
2 Likes

I can’t understand what is wrong in my code ?? plz help
https://www.codechef.com/viewsolution/80193617

2 Likes

1 << j will overflow for j \gt 31, since it computes in int. Use 1LL << j instead.

Unfortunately, that appears to be your only mistake.

2 Likes

https://www.codechef.com/viewsolution/80275344

Runtime error: RE (SIGSEGV) for some of the tests.
Anyone knows why?

@zoharbarak
It is because you have defined a as a vector of int, and it should be long long .

Now what happens is that cin expects an int, but in the input buffer is a big number. So this creates some anomalies and in the end, cin doesn’t behave well.

1
2 1
1000000000000000 1000000000000000
1 1 1 2 2

If you try this custom test case in your original code, you will get a runtime error on CodeChef ide.

Also, you should look for overflow in your final calculation of res
Modified Accepted Solution: CodeChef | Competitive Programming | Participate & Learn

1 Like

Thanks :slight_smile:

https://www.codechef.com/viewsolution/80284830

I’m not sure why this gets TLE (I just loop n & q)

You are not inputting anything inside the loop for “q”…So the 1st value (i.e. k) becomes the n of the next test case, and further operations are performed according to that, which can cause TLE.

There can be upto 5\cdot 10^4 testcases, and you’re creating an array of size 10^5 \times 61 for each one. That’s over 10^{11} operations just to allocate the memory, it’s no surprise that you get TLE.

1 Like

1<<j will overflow
use
1LL<<j instead

1 Like

Hello All,
Can anyone please tell me what’s wrong with my code, it was failing one test case.
My Code
Thanks.

@celestialidiot
Using k = log2(mx) was causing errors.
Use k = 59.
Modified Accepted Solution
https://www.codechef.com/viewsolution/80307160

Thanks

Take my whole day to code. learn a new concept of prefix sum on bit. I was thinking that this don’t exists . There were no blog post that I find related to this .
🥹 If anyone Have set of question to practice on prefix on bit . Please do reply or mail it (yadav11adu@gmail.com).

#include <bits/stdc++.h>
using namespace std;

#define int     long long int

// count the xor value whose kth bit is set
void solve(){
    int n,q;
    cin>>n>>q;
    vector<int> v(n);
    vector<vector<int>> prefix(n,vector<int> (62));

    for(int i=0;i<n;i++){
        cin>>v[i];

        vector<int> temp(60);
        for(int j=0;j<=60;j++){
            if((v[i]>>j)&1) temp[j] = 1;
            else temp[j] = 0;

            if(i == 0) prefix[i][j] = temp[j];
            else prefix[i][j] = prefix[i-1][j] + temp[j];
            // cout<<prefix[i][j]<<" ";
        }
        // cout<<endl;

    }

    while(q--){
        int k,l1,r1,l2,r2;
        cin>>k>>l1>>r1>>l2>>r2;

        l1--,r1--,l2--,r2--;

        int pr1;
        if(l1 - 1 < 0) pr1 = 0;
        else pr1 = prefix[l1-1][k];

        int pr2;
        if(l2 - 1 < 0) pr2 = 0;
        else pr2 = prefix[l2-1][k];

        int FirstSet = prefix[r1][k] - pr1;
        int FirstUnSet = (r1 - l1 + 1) - FirstSet;

        int SecondSet = prefix[r2][k] - pr2;
        int SecondUnSet = (r2 - l2 + 1) - SecondSet;

        // cout<<FirstSet<<" "<<FirstUnSet<<" "<<SecondSet<<" "<<SecondUnSet<<endl;
        cout<<FirstSet*SecondUnSet + FirstUnSet*SecondSet<<endl;
    }
}

signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0);

    int t=1;
    cin>>t;
    while(t--) solve();
    return 0;
}

TomTomAndJerryGIF

https://www.codechef.com/viewsolution/80299712
Why do this code give me TLE?

The issue there is the lines

counts1[j]=counts1[j]+[counts1[j][-1]+((l[i])%2)]
counts0[j]=counts0[j]+[counts0[j][-1]+((((l[i])%2)+1)%2)]

If A is an array of length N, doing A = A + [x] in Python takes \mathcal{O}(N) time since it creates a copy of A, appends to it, then assigns the new list to A.
Because of this, your code is actually \mathcal{O}(60N^2),

Since you want to append to the array, just use Python’s inbuilt append function instead, which works in \mathcal{O}(1): this change alone makes your code fast enough, see submission.

You aren’t going to find a blog post on it because it’s not actually anything special or a ‘technique’.
If you can find the prefix sums of one array, you can obviously do it for 2 arrays, 3, arrays, \ldots, 60 arrays, right? That’s essentially what you’re doing here: applying prefix sums on 60 different arrays.

1 Like