XORCIST - Editorial

PROBLEM LINK:

Contest

Author: Shahjalal Shohag
Tester: Ildar Gainullin
Editorialist: Rajarshi Basu

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Prefix Sums

PROBLEM:

You are given a sequence of integers A_1, A_2, \ldots, A_N. You have to answer Q queries.

In each query, you are given two integers L and R, and you have to find the number of ordered pairs (X, Y) such that L \le X, Y \le R and A_X \le A_X \oplus A_Y \le A_Y. Here, \oplus is the bitwise XOR operator.

Constraints

  • 1 \le T \le 50,000
  • 1 \le N, Q \le 5 \cdot 10^5
  • 0 \le A_i \lt 2^{30} for each valid i
  • 1 \le L \le R \le N
  • the sum of N over all test cases does not exceed 5 \cdot 10^5
  • the sum of Q over all test cases does not exceed 5 \cdot 10^5

EXPLANATION:

Brief Solution

To satisfy a_x \leq a_x \oplus a_y \leq a_y, the MSB of a_x must be set in a_y and MSB_{a_y} >MSB_{a_x} .
After this, we can answer the queries in O(log_2 MAX_{A_i}) where MAX_{A_i} = 2^{30} using some prefix sums.

=====================

How to approach

Ah, a problem about xor again. This should immediately make us think about the bitwise notation of a number, instead of some bruting techniques. Let us first assume that we are trying to check for a particular a_x and a_y, a_x < a_y. Try to think about the second condition, ie, a_x \oplus a_y \leq a_y.

Analysing how the bits work

From the upper bound, ie, a_x \oplus a_y \leq a_y, and WLOG assuming a_x < a_y, we can derive the following information:

Info 1
  • notice that MSB_{a_x} \leq MSB_{a_y}, because otherwise a_x > a_y.
Hint For Info 2
  • In bitwise notation, if the k^{th} bit is set, then the number is greater even if all bits till (k-1)^{th} bit is set. Try to apply this along with MSB values of a_x and a_y.
Info 2
  • The MSB_{a_x}^{th} bit must be set in a_y !! Otherwise, a_x \oplus a_y will be greater than a_y. For example consider the following bitwise patterns:
    • 101001
    • 010010
  • The \oplus of the above two numbers is greater than the larger number. However, the following is not:
    • 101001
    • 001010
But is that all?

There is however another bound we need to maintain, that is a_x \leq a_x \oplus a_y. How can we ensure that this is maintained? Hint: It’s just a tiny modification to the above informations.

The tiny change

The only change needed is observing that for the lowerbound to hold, MSB_{a_x} < MSB_{a_y}. Think why on your own. Try to take some examples.


Final solution idea

To satisfy a_x \leq a_x \oplus a_y \leq a_y, the MSB of a_x must be set in a_y and MSB_{a_y} >MSB_{a_x} .

Implementation
Hint

What can we iterate over in each query?

Details
  • Let us say that for a particular query [L,R], msb(i) means number of values a_x where MSB of a_x is i. Obviously, L \leq x \leq R.
  • Similarly, let us say present(i) gives us the count of number whose MSB is not i but the i^{th} bit is set in that number.
  • Our answer is S = \sum\limits_{i=0}^{29}{msb(i)*present(i)}. This is true because we are counting unordered pairs

  • In order to find msb(i) and present(i) in a particular range [L,R], use prefix sums on each bit. Think about the details on your own :stuck_out_tongue:

SOLUTIONS:

Setter’s Code
#include<bits/stdc++.h>
using namespace std;
 
const int N = 1e6 + 9;
int on[30][N], msb[30][N], z[N];
int32_t main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t; cin >> t;
    assert(1 <= t <= 100000);
    int n_sum = 0, q_sum = 0;
    while (t--) {
        int n, q; cin >> n >> q;
        assert(1 <= n && n <= 1000000);
        assert(1 <= q && q <= 1000000);
        n_sum += n; q_sum += q;
        for (int i = 1; i <= n; i++) {
            int x; cin >> x;
            assert(x >= 0 && x < (1 << 30));
            z[i] = z[i - 1] + (x == 0);
            int cur_msb = -1;
            for (int k = 0; k < 30; k++) {
                int b = x >> k & 1;
                on[k][i] = on[k][i - 1] + b;
                if (b) cur_msb = k;
            }
            for (int k = 0; k < 30; k++) {
                msb[k][i] = msb[k][i - 1] + (cur_msb == k);
            }
        }
        while (q--) {
            int l, r; cin >> l >> r;
            assert(1 <= l && l <= r && r <= n);
            long long ans = 1LL * (z[r] - z[l - 1]) * (r - l + 1);
            for (int k = 0; k < 29; k++) {
                int a = msb[k][r] - msb[k][l - 1];
                int b = on[k][r] - on[k][l - 1];
                ans += 1LL * a * (b - a);
            }
            cout << ans << '\n';
        }
    }
    assert(n_sum <= 1000000);
    assert(q_sum <= 1000000);
    return 0;
} 
Tester’s Code
#include <cmath>
#include <functional>
#include <fstream>
#include <iostream>
#include <vector>
#include <algorithm>
#include <string>
#include <set>
#include <map>
#include <list>
#include <time.h>
#include <math.h>
#include <random>
#include <deque>
#include <queue>
#include <cassert>
#include <unordered_map>
#include <unordered_set>
#include <iomanip>
#include <bitset>
#include <sstream>
#include <chrono>
#include <cstring>
 
using namespace std;
 
typedef long long ll;
 
#ifdef iq
  mt19937 rnd(228);
#else
  mt19937 rnd(chrono::high_resolution_clock::now().time_since_epoch().count());
#endif
 
const int N = 5e5 + 10;
 
int l[31][N], r[31][N];
 
int main() {
#ifdef iq
  freopen("a.in", "r", stdin);
#endif
  ios::sync_with_stdio(0);
  cin.tie(0);
  int t;
  cin >> t;
  while (t--) {
    int n, q;
    cin >> n >> q;
    vector <int> a(n);
    vector <int> b(n);
    vector <int> p(n + 1);
    for (int i = 0; i < n; i++) {
      cin >> a[i];
      if (a[i] == 0) b[i] = 30;
      else while (b[i] + 1 < 30 && a[i] >= (1 << (b[i] + 1))) b[i]++;
      p[i + 1] = p[i] + (a[i] == 0);
    }
    for (int z = 0; z <= 30; z++) {
      for (int i = 0; i < n; i++) {
        int p = (b[i] > z && ((a[i] >> z) & 1));
        int q = (b[i] == z);
        l[z][i + 1] = l[z][i] + p;
        r[z][i + 1] = r[z][i] + q;
      }
    }
    while (q--) {
      int vl, vr;
      cin >> vl >> vr;
      vl--, vr--;
      int ret = p[vr + 1] - p[vl];
      ll ans = ret * (ll) (vr - vl + 1);
      for (int z = 0; z <= 30; z++) {
        ans += (l[z][vr + 1] - l[z][vl]) * (ll) (r[z][vr + 1] - r[z][vl]);
      }
      cout << ans << '\n';
    }
  }
}
6 Likes

In Tab of “Final solution idea”, there is a typo, it is “MSB(a[y]) > MSB(a[y])”. Else, it should be “MSB(a[y]) > MSB(a[x])”. @rajarshi_basu

1 Like

fixed
Thanks for pointing it out.

1 Like

in brief solution there should be ax plz correct it

silly me, thanks again.

Could you explain what is MSB

MSB is most significant bit in the binary representation of a number.

MSB(i) will actually store the position of the most sig. bit in binary rep. of that number.

Hello everyone !

I’m unable to think about the cases where my code is going wrong. Please have a look at it. I have done everything similar as the editorial and added the cases for zero separately.

        #include<bits/stdc++.h>
        #define ll long long
        #define FIO ios_base::sync_with_stdio(false);cin.tie(NULL);

        using namespace std;

        int main()
        {
        	FIO
            ll t; 
            cin>>t;
        	while(t--){
        		ll n, q;
                cin >> n >> q;
        		ll ar[n];
        		ll present[n + 1][31] = {};
        		ll msb[n + 1][31] = {}; // stores base on 1 indexing
        		ll c0[n + 1] = {}; // stores the count of 0
        		ll flg;
        		for(ll i=0;i<n;i++) {
        			flg = 1;
        			cin >> ar[i];
        			if (ar[i] == 0) c0[i + 1]++;
        			//
        			for (ll j = 30; j >= 0; j--) {
        				if (ar[i] & (1LL << j)) {
        					present[i + 1][j] = 1;
        					if (flg) {
        						msb[i + 1][j]++; flg = 0;
        					}
        				}
        			}
        		}
        		for (ll i = 1; i <= n; i++) {
        			c0[i] += c0[i - 1];
        			for (ll j = 0; j < 31; j++) {
        				msb[i][j] += msb[i - 1][j];
        				present[i][j] += present[i - 1][j];
        			}
        		}
        		while (q--) {
        			ll l, r; cin >> l >> r; --l;
        			ll len = r - l;
        			ll cc0 = c0[r] - c0[l];
        			ll ans = cc0 + ((len - cc0) * cc0) + (cc0 * (cc0 - 1)) / 2;
        			for (ll i = 0; i <= 30; i++) {
        				ans += ((msb[r][i] - msb[l][i]) * (present[r][i] - present[l][i] - (msb[r][i] - msb[l][i])));
        			}
        			cout << ans << endl;
        		}
        	}
        	return 0;
        }

Thanks for reply. Does that mean the most left turned on bit in the number, for example:
0011001101

Is it MSB of this number?

yup , the leftmost set bit of a number .

Replaced it with:

  ll ans = ((len - cc0) * cc0) + cc0*cc0;

and it gives AC now. Try to think why.

1 Like

Yeah

Ohk I got it now. The pairs are ordered.

Thanks man. Really appreciate your effort.

1 Like

Wrong
Right

In above two solutions, the difference is in line 73 and 101.

In ‘Wrong’, I haven’t set MSB in present and hence in final answer I simply performed (chote x bade).

present[i][j] = present[i-1][j] + (j<mx_bit && ((1<<j)&arr[i]));
ans+=(1LL*chote*(bade));

In ‘Right’, I have set MSB in present and hence in final answer I simply performed (chote x (bade-chote)).

 present[i][j] = present[i-1][j] + (j<=mx_bit && ((1<<j)&arr[i])); 

 ans+=(1LL*chote*(bade-chote));

Why am I receiving WA for first solution? Please give test-case if possible. Thank you!

//code by usernameharsh

//man_it_makes_me_mad
#include<bits/stdc++.h>

using namespace std;
#define ll long long int
#define co cout<<
#define ld long double
#define scany(T) scanf("%lld\n",&T)
#define scanyy(X,Y) scanf("%lld %lld\n",&X,&Y)

int main()
{

    ll test,N,Q,Sop,l,r;
    scany(test);
    
    while(test--)
    {
        scanyy(N,Q);
        ll A[N];
          
        for(ll i=0;i<N-1;i++)
            scanf("%lld ",&A[i]);
        scanf("%lld\n",&A[N-1]);
        
        /*for(ll i=0;i<N;i++)
            co A[i]<<" ";*/
        
        
        ll rec[N][31];                  
        
        
        // storing the bits 
        for(ll i=0;i<N;i++)
        {
            ll a=A[i];

            rec[i][0]=1;
            // co A[i]<<" "<<rec[i][0]<<" ";
            
            for(ll j=1;j<=30;j++)
            {
                rec[i][j]=a%2;
                a/=2;
            }
        }
        //co "\n";
        
        //displaying no in binary form
        for(ll i=0;i<N;i++){
            for(ll j=30;j>=0;j--)
                co rec[i][j]<<" ";
            co "\n";
        }
        
        ll msb[N]={0};            //array storing the posn of msb in each no
        for(ll i=0;i<N;i++)
        {
            ll m=30;
            while(m>=0){
                if(rec[i][m]==1){
                    msb[i]=m;
                    break;
                }
                m--;
            }
        }
        
        for(ll i=0;i<N;i++)
            co msb[i]<<" ";
        co "\n";
        
        while(Q--)
        {
            Sop=0;
            scanyy(l,r);
            
            for(ll i=l-1;i<r;i++){
               for(ll j=l-1;j<r;j++){
                   if(msb[j]>msb[i]&&rec[j][msb[i]]==1){
                        Sop++;
                        co msb[i]<<" "<<A[i]<<" "<<A[j]<<"\n";
                   }
               }
            }
            co Sop<<"\n";
        }   
    }
return 0;

}

hi ! can anyone tell where’s the bug in my code
It passes 1st sample input
but in 2nd sample input it gives 22 i.e., one less than required output 23
i don’t know where i am lagging behind.

Kindly help !
rec[N][31] stores the info of the each bits in an no i.e. 0 or 1
msb[N] stores the index of msb each number

I have set 1st index to 1 in every number
so that if 0 no is taken then it will have msb 1 at the 1st index .

in the setters solution how are z[0] and on[k][0] being initialized???
as when i=1 z[i-1] is being used for calculating z[i]…