OR_XOR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: shubham_grg
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

2972

PREREQUISITES:

Familiarity with bitwise operations, binary search

PROBLEM:

Given an array A of length N, find the number of its subarrays whose bitwise OR is strictly greater than bitwise XOR.

EXPLANATION:

For conveinence, let \text{OR}(L, R) denote the bitwise OR of the subarray [A_L, A_{L+1}, \ldots, A_R], and \text{XOR}(L, R) denote its bitwise XOR.

For any subarray [L, R], it will always hold that \text{OR}(L, R) \geq \text{XOR}(L, R).
This is because any bit that appears in the subarray will be present in the OR but may or may not be in the XOR. Of course, any bit that doesn’t appear in the subarray at all will be present in neither values.

So, instead of counting the number of subarrays for which \text{OR}(L, R) \gt \text{XOR}(L, R), we can instead count the number of subarrays for which \text{OR}(L, R) = \text{XOR}(L, R); and subtract this from the total number of subarrays.

To facilitate this, we require one more observation:
Suppose we fix the right endpoint R of the subarray. Consider the set of all bitwise ORs ending at R, i.e, the set end(R) = \{\text{OR}(L, R) \mid 1 \leq L \leq R\}
Then, end(R) contains at most 21 elements.

Proof

It should be obvious that \text{OR}(L, R) \leq \text{OR}(L-1, R).
If \text{OR}(L, R) \lt \text{OR}(L-1, R), that means \text{OR}(L-1, R) contains at least one ‘new’ bit that wasn’t set in \text{OR}(L, R), while still containing all the bits already set in \text{OR}(L, R).

Since we’re dealing with 20-bit numbers, this addition of a new bit can happen at most 20 times before there are no more bits to add, and hence there are at most 21 distinct values in the set.

Notice that the above proof in fact told us something a bit more powerful: for each x \in end(R), there’s a range of indices [a_x, b_x] such that \text{OR}(L, R) = x if and only if a_x \leq L \leq b_x.

Actually finding these ranges isn’t too hard, although there are both painful and painless ways to implement it.

Implementation details

The simplest way to implement this is probably to do something similar to what’s done in point 3 of this blogpost, which describes the same idea but for GCD instead.

Let \text{mn}[i][x] denote the lowest position j such that \text{OR}(j, i) = x.
Suppose we’ve already computed the values of \text{mn}[i-1].
Then, for each x \in end(i-1), we have

\text{mn}[i][x \mid A_i] \gets \min(\text{mn}[i][x \mid A_i], \text{mn}[i-1][x])

because the bitwise OR value x\mid A_i till i arises from extending the bitwise OR value x from i-1, one step to the right.
Don’t forget to set \text{mn}[i][A_I] = i as the first step.

Now, \text{mn}[i][x] gives us the left endpoint a_x of the range corresponding to x.
To find the right endpoint b_x, instead find the left endpoint of the bitwise OR that’s just less than x, and move one step to its left.

If a map is used to maintain the \text{mn}[i][x] values, the complexity of this is \mathcal{O}(21N\log N), which is good enough.

There are other ways to implement this too, for example:

  • Use binary search along with a structure that allows for range OR queries, such as a segment tree or sparse table
  • Maintain some bitwise information to be able to quickly ‘jump’ to the next higher bitwise OR, for example by precomputing the closest element to the left with each bit set.

Now, suppose we’ve fixed R, and we know the elements of end(R) and their corresponding ranges.
Let’s fix an element x \in end(R) and look at its range [a_x, b_x].
We want to count the number of a_x \leq L \leq b_x such that \text{OR}(L, R) = \text{XOR}(L, R) = x.

However, \text{XOR}(L, R) = \text{pref}[R] \oplus \text{pref}[L-1], where \text{pref}[i] = \text{XOR}(1, i) denotes the prefix XOR array of A.
So, we have

\begin{align*} \text{XOR}(L, R) &= x \\ \text{pref}[R] \oplus \text{pref}[L-1] &= x \\ \text{pref}[L-1] &= x \oplus \text{pref}[R] \end{align*}

Since x and R are fixed, we only need to know the number of a_x \leq L \leq b_x that satisfy this condition.
But this is easy to do: keep a list of positions corresponding to each prefix XOR, then binary search on the list corresponding to x \oplus \text{pref}[R] to find the number of positions in the range [a_x-1, b_x-1].

So, we’ve solved for a single (R, x) pair in \mathcal{O}(\log N).
As noted earlier, there are at most 21\cdot N such pairs, so this is fast enough for the given constraints.

TIME COMPLEXITY

\mathcal{O}(B\cdot N\log N) per test case, where B = 21 here.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
 
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
 
typedef long long int ll;
 
using namespace __gnu_pbds;
template <typename T> using ordered_set = tree<T, null_type,less<T>, rb_tree_tag,tree_order_statistics_node_update>;
// ordered_set ->  find_by_order(x)<itr, x being 0-indexed>; order_of_key(x)<count, strictly less>
 
 
#define int                 ll
#define fast                ios::sync_with_stdio(0),cin.tie(0), cout.tie(0);
#define rep(i, m, n)        for (ll i = m; i < n; i++)
#define ppi                 pair<int, int>
#define pb                  push_back
#define endl                "\n"
#define all(v)              (v).begin(), (v).end()
#define f                   first
#define ss                  second
#define in                  insert
#define lb                  lower_bound
#define ub                  upper_bound
#define sz                  size()
#define bg                  begin()
#define pq                  priority_queue
#define vc                  vector<int>
#define vcp                 vector<ppi>
#define mp                  map<int, int> 
#define gp                  gp_hash_table<int, int, chash>
#define mem1(a)             memset(a, -1 ,sizeof(a));
#define memt(a)             memset(a, true ,sizeof(a));
#define re(a)               {cout<<a<<enl;}
// #define re(a)               return a;
#define sd                  greater<int>()
#define sdp                  greater<ppi>()
#define enl                 "\n"; return;
// #define SET(n)              cout << fixed << setprecision(n)
#define ppc                 __builtin_popcountll
#ifndef ONLINE_JUDGE
#define debug(x) cerr << #x <<" : "; _print(x); cerr << endl;
#else
#define debug(x)
#endif
 
template<typename T> istream& operator>>(istream& is,  vector<T>  &v){ for(auto& i : v) is >> i; return is;}
template<typename T> ostream& operator<<(ostream& os,  vector<T>  v){for (auto& i : v) os << i << ' '; return os;}
 
template<class T> void _print(T n){cerr<<n;}
template<class T, class V> void _print(T a[], V n){cerr<<"Array: [ "; rep(i, 0, n){_print(a[i]); cerr<<" ";} cerr<<" ] \n";}
template<class T, class V> void _print(pair<T, T> a[], V n){cerr<<"Pair Array: [ "; rep(i, 0, n){cerr<<"{";_print(a[i].f); cerr<<", "; _print(a[i].ss); cerr<<"},";cerr<<" ";} cerr<<"] \n";}
template <class T, class V> void _print(pair <T, V> p) {cerr << "{"; _print(p.f); cerr << ","; _print(p.ss); cerr << "}";}
template <class T> void _print(vector <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template <class T> void _print(set <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template <class T, class V> void _print(map <T, V> v) {cerr << "[ "; for (auto i : v) {_print(i); cerr << " ";} cerr << "]";}
const double eps=1e-6;
const int MOD=1e9+7, inf=INT_MAX, inff=INT_MIN;
//998244353
const int N=(1e5)+5;
const int RANDOM = chrono::high_resolution_clock::now().time_since_epoch().count();
struct chash { // To use most bits rather than just the lowest ones:
    int MUL=1e9+3;
    int operator()(int x) const { return std::hash<ll>{}((x ^ RANDOM) % MOD * MUL); }
};
ll expo1(ll a, ll b)  {ll res = 1; while (b > 0) {    if (b & 1)res = (res * a);     a = (a * a);     b = b >> 1;} return res;}
ll expo(ll a, ll b, ll MOD=1e9+7)   {ll res = 1; a%=MOD; while (b > 0) {if (b & 1)res = (res * a) % MOD; a = (a * a) % MOD; b = b >> 1;} return res;}
int LOG(ll n, ll x) {int ans=-1;while(n>0){    ans++, n/=x;}return ans;}
int Ceil(ll a, ll b) {if(a%b==0 || a<0) return a/b; else return a/b+1;}
int dx[]={1, 0, -1, 0}, dy[]={0, -1, 0, 1};

int Solve(vector<int>&a)
{
    int n=a.size();
    vector<int> prefix(n); 

    map<int, vector<int>>m;
    vector<int>last(31, -1);
    int xo=0, ans=0;
    m[0].pb(-1);

    rep(i, 0, n)
    {
        for(int j=0; j<31; j++) 
        {
            if((a[i]>>j)&1) last[j]=i;
        }
        xo^=a[i];
        prefix[i]=xo;

        vector<int>t=last;
        sort(all(t), greater<int>());

        int OR=a[i], past=i;

        for(int j=0; j<31; j++)
        {
            if((j && t[j]==t[j-1]) || t[j]==i) continue;
            int k=t[j];
            int x=(xo^OR);
            auto it=lb(all(m[x]), min(past, i-1))-lb(all(m[x]), k);
            ans+=it;
            OR|=a[k];
            past=k;
        }
        m[xo].pb(i);
    }
    return n*(n-1)/2-ans;
}

signed main()
{ 
    fast
    #ifndef ONLINE_JUDGE
    freopen("Error.txt", "w", stderr);  
    #endif

   
    int T;
    cin >> T;
    int i=1;
   
    while(T--)
    {
        int n; cin>>n;
        vc v(n); cin>>v;  
        cout<<Solve(v)<<endl;  
    } 

    #ifndef ONLINE_JUDGE
    cerr<<"\ntime taken : "<<(float)clock()/CLOCKS_PER_SEC<<" secs"<<"\n";
    #endif

    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e5);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 2e5);
        in.readEoln();
        sn += n;
        auto a = in.readInts(n, 0, (1 << 20) - 1);
        in.readEoln();
        vector<int> b(n + 1);
        for (int i = 0; i < n; i++) {
            b[i + 1] = b[i] ^ a[i];
        }
        vector<int> c(21, -1);
        vector<vector<pair<int, int>>> e(n + 1);
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < 20; j++) {
                if (a[i] & (1 << j)) {
                    c[j] = i;
                }
            }
            auto d = c;
            d.emplace_back(i);
            sort(d.rbegin(), d.rend());
            d.resize(unique(d.begin(), d.end()) - d.begin());
            int sz = (int) d.size();
            for (int j = 0; j < sz - 1; j++) {
                int t = 0;
                for (int k = 0; k < 20; k++) {
                    if (c[k] >= d[j]) {
                        t |= 1 << k;
                    }
                }
                t ^= b[i + 1];
                e[d[j + 1] + 1].emplace_back(t, 1);
                e[min(i, d[j] + 1)].emplace_back(t, -1);
            }
        }
        map<int, int> cnt;
        long long ans = 0;
        for (int i = 0; i < n + 1; i++) {
            for (auto [x, y] : e[i]) {
                cnt[x] += y;
            }
            ans += cnt[b[i]];
        }
        cout << n * 1LL * (n - 1) / 2 - ans << '\n';
    }
    assert(sn <= 2e5);
    in.readEof();
    return 0;
}
Editorialist's code (Python)
from collections import defaultdict
from bisect import bisect_left

for _ in range(int(input())):
	n = int(input())
	a = list(map(int, input().split()))
	xor_pos, ors, pref, ans = defaultdict(lambda: []), {}, 0, 0
	xor_pos[0].append(-1)
	for i in range(n):
		x, cur_ors = a[i], defaultdict(lambda: 10 ** 9)
		pref ^= x
		cur_ors[x] = i
		for y in ors: cur_ors[x | y] = min(cur_ors[x | y], ors[y])
		ors = cur_ors

		R = i
		for y in sorted(ors.keys()):
			ans += bisect_left(xor_pos[pref ^ y], R) - bisect_left(xor_pos[pref ^ y], ors[y] - 1)
			R = ors[y] - 1
		xor_pos[pref].append(i)
	print(n*(n+1)//2 - ans)