XORDCTN - Editorial

PROBLEM LINK:

Practice
Contest

Author: Sahil Tiwari
Tester: Jay Sharma
Editorialist: Sarvesh Kesharwani

DIFFICULTY:

EASY

PREREQUISITES:

MSB(Most Significant Bit), Prefix sum.

PROBLEM:

Mycroft and Sherlock are playing the game of XOR-duction.

In this game Mycroft gives Sherlock an array A of size N , whose elements are denoted by a1 , a2 , . . . , aN

Mycroft then asks Sherlock Q queries:

  • For each query there are three numbers − l,r and x and sherlock has to find the number of elements in A such that ai ⊕ x is greater than x and l≤i≤r. Here ⊕ denotes bitwise XOR operation.

EXPLANATION:

We have to find a_i \oplusx > x, l<=i<=r.
The most significant bit (MSB) is the leftmost bit in a binary number.
e.g. 17 = 10001, 5 is the MSB of 17 as the left-most bit which is set (1) is 5th bit.

As we know, 0 \oplus 1 = 1 and 0 \oplus 0 = 1 \oplus 1 = 0, let jth bit be the MSB(Most Significant Bit) of a_i, so the jth bit of x should be 0, because 0 \oplus 1 = 1.

For eg. a_i = 10 = 1010, x = 5 = 101, so MSB of a_i = 4 i.e. the 4rd bit from the right and 4rd bit of x is 0 as 101 can be return as 0101 so a_i \oplus x = 1111 = 15.

If the MSB of a_i is jth bit and the jth bit of x is set then a_i \oplusx can never be greater than x because 1 \oplus 1 = 0, so the jth bit would become 0 instead of 1.
For eg. a_i = 10 = 1010, x = 8 = 1000, so MSB of a_i= 4 i.e. the 4rd bit from the right and 4rd bit of x is 1 so a_i \oplusx = 0010 = 2.

Now, for every query we can’t find the MSB of a_i, l<=i<=r and check with the jth bit (MSB of a_i) of x , as the time complexity would become O(N^2), which would lead to TLE.

So, the solution can be reduced to O(N*32), as we can pre-compute the MSB of a_i, i = 1,2,…n

How can we pre-compute MSB?
Make a 2D- array, binary[n][32] assuming 1-based indexing,
initially, all elements in binary array = 0.
for every a_i, the MSB can be stored at binary[i][MSB of a_i] = 1.
For e.g. a_i = 18 = 10010, so MSB would be stored as binary[i][5] = 1 as the the MSB of 18 is 5.
This way we can store MSB of all the elements in the given array.

Using prefix sum we can find the number of MSB’s at jth position from a_1 to a_i,
i.e. binary[i][j] = binary[i][j] + binary[i-1][j], i.e. from a_1 to a_i, number of MSB’s at jth bit.

In every query, we can iterate for every bit, first by checking if the ith bit of x is 0, then add the answer as number of MSB’s at ith position for the first a[i], i=1,2,…r - number of MSB’s at ith position for the first a_i, i=1,2,…l-1.
i.e. binary[r][j] - binary[l-1][j], where l,r is the range and j is the jth bit.
So the time complexity of the solution would become O(N * 32).

TIME COMPLEXITY:

O(N * 32) per testcase.

SOLUTIONS:

Setter's Solution
#include<bits/stdc++.h>
#define endl "\n"
#define int long long int
#define tt int tc;cin>>tc;while(tc--)
#define arrin(a,n) for(int INPUT=0;INPUT<n;INPUT++)cin>>a[INPUT]
#define jaldi_chal ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL)
using namespace std;

signed main()
{
    jaldi_chal;

    tt{
        int n;
        cin>>n;
        vector<int> a(n);
        arrin(a,n);

        vector<vector<int>> b(n+1 , vector<int> (32));

        for(int i=1;i<=n;i++){
            int msb = log2(a[i-1]);
            b[i][msb]++;
        }
        for(int i=1;i<=n;i++){
            for(int j=0;j<32;j++)
                b[i][j]+=b[i-1][j];
        }

        int q;
        cin>>q;
        while(q--){
            int l , r , x;
            cin>>l>>r>>x;
            int ans=0;
            for(int i=0;i<32;i++){
                if(((x>>i)&1)==0)
                    ans+=(b[r][i]-b[l-1][i]);
            }
            cout<<ans<<endl;
        }
    }

    return 0;
}
Tester's Solution
/*...................................................................*
 *............___..................___.....____...______......___....*
 *.../|....../...\........./|...../...\...|.............|..../...\...*
 *../.|...../.....\......./.|....|.....|..|.............|.../........*
 *....|....|.......|...../..|....|.....|..|............/...|.........*
 *....|....|.......|..../...|.....\___/...|___......../....|..___....*
 *....|....|.......|.../....|...../...\.......\....../.....|./...\...*
 *....|....|.......|../_____|__..|.....|.......|..../......|/.....\..*
 *....|.....\...../.........|....|.....|.......|.../........\...../..*
 *..__|__....\___/..........|.....\___/...\___/.../..........\___/...*
 *...................................................................*
 */
 
#include <bits/stdc++.h>
using namespace std;

int32_t main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    int tt=1;
    cin >> tt;
    while(tt--)
    {
        int n;
        cin >> n;
        int a[n];
        for(int i=0;i<n;i++)
            cin >> a[i];
        vector<int> pre[32];
        for(int i=0;i<n;i++)
        {
            for(int j=31;j>=0;j--)
            {
                if(a[i]&(1ll<<j))
                {
                    pre[j].push_back(i+1);
                    break;
                }
            }
        }
        int q;
        cin >> q;
        for(int i=0;i<q;i++)
        {
            int l,r,x;
            cin >> l >> r >> x;
            int ans=0;
            for(int j=31;j>=0;j--)
            {
                if((x&(1ll<<j))==0)
                {
                    ans+=upper_bound(pre[j].begin(),pre[j].end(),r)-lower_bound(pre[j].begin(),pre[j].end(),l);
                }
            }
            cout << ans << '\n';
        }
    }
    return 0;
}
Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;
using namespace chrono;
#define flash ios_base::sync_with_stdio(0);cin.tie(0);
#define ll long long
#define ndl '\n'
void solve()
{
    int tc;cin>>tc;while(tc--)
    {
        ll n; 
        cin>>n;
        vector<ll> a(n);
        vector<vector<int>> b(n,vector<int>(32));
        for(int i=0;i<n;i++){
          cin>>a[i];
          int msb = __lg(a[i]);
          b[i][msb]++;
        }
        for(int i=1;i<n;i++){
          for(int j=0;j<32;j++){
            b[i][j]+=b[i-1][j];
          }
        }
        int q; cin>>q;
        while(q--){
          int l,r,x; 
          cin>>l>>r>>x; 
          l--;
          r--;
          int ans = 0 ;
          for(int i=0;i<32;i++){
            if(((x>>i)&1) == 0){
              ans+=b[r][i] - (l>0?b[l-1][i]:0);
            }
          }
          cout<<ans<<ndl;
        }
       }
}
int main()
{    
    auto starttime = high_resolution_clock::now();
    flash
    solve();
    auto endtime = high_resolution_clock::now();
    double duration = duration_cast<microseconds>(endtime - starttime).count();
    duration/=1000000;
    cerr<<"Time Taken : "<<fixed<<setprecision(6)<<duration<<" secs"<<'\n';
  return 0;
}
2 Likes