PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: shubham_grg
Tester: tabr
Editorialist: iceknight1093
DIFFICULTY:
2972
PREREQUISITES:
Familiarity with bitwise operations, binary search
PROBLEM:
Given an array A of length N, find the number of its subarrays whose bitwise OR is strictly greater than bitwise XOR.
EXPLANATION:
For conveinence, let \text{OR}(L, R) denote the bitwise OR of the subarray [A_L, A_{L+1}, \ldots, A_R], and \text{XOR}(L, R) denote its bitwise XOR.
For any subarray [L, R], it will always hold that \text{OR}(L, R) \geq \text{XOR}(L, R).
This is because any bit that appears in the subarray will be present in the OR but may or may not be in the XOR. Of course, any bit that doesn’t appear in the subarray at all will be present in neither values.
So, instead of counting the number of subarrays for which \text{OR}(L, R) \gt \text{XOR}(L, R), we can instead count the number of subarrays for which \text{OR}(L, R) = \text{XOR}(L, R); and subtract this from the total number of subarrays.
To facilitate this, we require one more observation:
Suppose we fix the right endpoint R of the subarray. Consider the set of all bitwise ORs ending at R, i.e, the set end(R) = \{\text{OR}(L, R) \mid 1 \leq L \leq R\}
Then, end(R) contains at most 21 elements.
Proof
It should be obvious that \text{OR}(L, R) \leq \text{OR}(L-1, R).
If \text{OR}(L, R) \lt \text{OR}(L-1, R), that means \text{OR}(L-1, R) contains at least one ‘new’ bit that wasn’t set in \text{OR}(L, R), while still containing all the bits already set in \text{OR}(L, R).
Since we’re dealing with 20-bit numbers, this addition of a new bit can happen at most 20 times before there are no more bits to add, and hence there are at most 21 distinct values in the set.
Notice that the above proof in fact told us something a bit more powerful: for each x \in end(R), there’s a range of indices [a_x, b_x] such that \text{OR}(L, R) = x if and only if a_x \leq L \leq b_x.
Actually finding these ranges isn’t too hard, although there are both painful and painless ways to implement it.
Implementation details
The simplest way to implement this is probably to do something similar to what’s done in point 3 of this blogpost, which describes the same idea but for GCD instead.
Let \text{mn}[i][x] denote the lowest position j such that \text{OR}(j, i) = x.
Suppose we’ve already computed the values of \text{mn}[i-1].
Then, for each x \in end(i-1), we have
because the bitwise OR value x\mid A_i till i arises from extending the bitwise OR value x from i-1, one step to the right.
Don’t forget to set \text{mn}[i][A_I] = i as the first step.
Now, \text{mn}[i][x] gives us the left endpoint a_x of the range corresponding to x.
To find the right endpoint b_x, instead find the left endpoint of the bitwise OR that’s just less than x, and move one step to its left.
If a map is used to maintain the \text{mn}[i][x] values, the complexity of this is \mathcal{O}(21N\log N), which is good enough.
There are other ways to implement this too, for example:
- Use binary search along with a structure that allows for range OR queries, such as a segment tree or sparse table
- Maintain some bitwise information to be able to quickly ‘jump’ to the next higher bitwise OR, for example by precomputing the closest element to the left with each bit set.
Now, suppose we’ve fixed R, and we know the elements of end(R) and their corresponding ranges.
Let’s fix an element x \in end(R) and look at its range [a_x, b_x].
We want to count the number of a_x \leq L \leq b_x such that \text{OR}(L, R) = \text{XOR}(L, R) = x.
However, \text{XOR}(L, R) = \text{pref}[R] \oplus \text{pref}[L-1], where \text{pref}[i] = \text{XOR}(1, i) denotes the prefix XOR array of A.
So, we have
Since x and R are fixed, we only need to know the number of a_x \leq L \leq b_x that satisfy this condition.
But this is easy to do: keep a list of positions corresponding to each prefix XOR, then binary search on the list corresponding to x \oplus \text{pref}[R] to find the number of positions in the range [a_x-1, b_x-1].
So, we’ve solved for a single (R, x) pair in \mathcal{O}(\log N).
As noted earlier, there are at most 21\cdot N such pairs, so this is fast enough for the given constraints.
TIME COMPLEXITY
\mathcal{O}(B\cdot N\log N) per test case, where B = 21 here.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
typedef long long int ll;
using namespace __gnu_pbds;
template <typename T> using ordered_set = tree<T, null_type,less<T>, rb_tree_tag,tree_order_statistics_node_update>;
// ordered_set -> find_by_order(x)<itr, x being 0-indexed>; order_of_key(x)<count, strictly less>
#define int ll
#define fast ios::sync_with_stdio(0),cin.tie(0), cout.tie(0);
#define rep(i, m, n) for (ll i = m; i < n; i++)
#define ppi pair<int, int>
#define pb push_back
#define endl "\n"
#define all(v) (v).begin(), (v).end()
#define f first
#define ss second
#define in insert
#define lb lower_bound
#define ub upper_bound
#define sz size()
#define bg begin()
#define pq priority_queue
#define vc vector<int>
#define vcp vector<ppi>
#define mp map<int, int>
#define gp gp_hash_table<int, int, chash>
#define mem1(a) memset(a, -1 ,sizeof(a));
#define memt(a) memset(a, true ,sizeof(a));
#define re(a) {cout<<a<<enl;}
// #define re(a) return a;
#define sd greater<int>()
#define sdp greater<ppi>()
#define enl "\n"; return;
// #define SET(n) cout << fixed << setprecision(n)
#define ppc __builtin_popcountll
#ifndef ONLINE_JUDGE
#define debug(x) cerr << #x <<" : "; _print(x); cerr << endl;
#else
#define debug(x)
#endif
template<typename T> istream& operator>>(istream& is, vector<T> &v){ for(auto& i : v) is >> i; return is;}
template<typename T> ostream& operator<<(ostream& os, vector<T> v){for (auto& i : v) os << i << ' '; return os;}
template<class T> void _print(T n){cerr<<n;}
template<class T, class V> void _print(T a[], V n){cerr<<"Array: [ "; rep(i, 0, n){_print(a[i]); cerr<<" ";} cerr<<" ] \n";}
template<class T, class V> void _print(pair<T, T> a[], V n){cerr<<"Pair Array: [ "; rep(i, 0, n){cerr<<"{";_print(a[i].f); cerr<<", "; _print(a[i].ss); cerr<<"},";cerr<<" ";} cerr<<"] \n";}
template <class T, class V> void _print(pair <T, V> p) {cerr << "{"; _print(p.f); cerr << ","; _print(p.ss); cerr << "}";}
template <class T> void _print(vector <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template <class T> void _print(set <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template <class T, class V> void _print(map <T, V> v) {cerr << "[ "; for (auto i : v) {_print(i); cerr << " ";} cerr << "]";}
const double eps=1e-6;
const int MOD=1e9+7, inf=INT_MAX, inff=INT_MIN;
//998244353
const int N=(1e5)+5;
const int RANDOM = chrono::high_resolution_clock::now().time_since_epoch().count();
struct chash { // To use most bits rather than just the lowest ones:
int MUL=1e9+3;
int operator()(int x) const { return std::hash<ll>{}((x ^ RANDOM) % MOD * MUL); }
};
ll expo1(ll a, ll b) {ll res = 1; while (b > 0) { if (b & 1)res = (res * a); a = (a * a); b = b >> 1;} return res;}
ll expo(ll a, ll b, ll MOD=1e9+7) {ll res = 1; a%=MOD; while (b > 0) {if (b & 1)res = (res * a) % MOD; a = (a * a) % MOD; b = b >> 1;} return res;}
int LOG(ll n, ll x) {int ans=-1;while(n>0){ ans++, n/=x;}return ans;}
int Ceil(ll a, ll b) {if(a%b==0 || a<0) return a/b; else return a/b+1;}
int dx[]={1, 0, -1, 0}, dy[]={0, -1, 0, 1};
int Solve(vector<int>&a)
{
int n=a.size();
vector<int> prefix(n);
map<int, vector<int>>m;
vector<int>last(31, -1);
int xo=0, ans=0;
m[0].pb(-1);
rep(i, 0, n)
{
for(int j=0; j<31; j++)
{
if((a[i]>>j)&1) last[j]=i;
}
xo^=a[i];
prefix[i]=xo;
vector<int>t=last;
sort(all(t), greater<int>());
int OR=a[i], past=i;
for(int j=0; j<31; j++)
{
if((j && t[j]==t[j-1]) || t[j]==i) continue;
int k=t[j];
int x=(xo^OR);
auto it=lb(all(m[x]), min(past, i-1))-lb(all(m[x]), k);
ans+=it;
OR|=a[k];
past=k;
}
m[xo].pb(i);
}
return n*(n-1)/2-ans;
}
signed main()
{
fast
#ifndef ONLINE_JUDGE
freopen("Error.txt", "w", stderr);
#endif
int T;
cin >> T;
int i=1;
while(T--)
{
int n; cin>>n;
vc v(n); cin>>v;
cout<<Solve(v)<<endl;
}
#ifndef ONLINE_JUDGE
cerr<<"\ntime taken : "<<(float)clock()/CLOCKS_PER_SEC<<" secs"<<"\n";
#endif
return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}
string readOne() {
assert(pos < (int) buffer.size());
string res;
while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
res += buffer[pos];
pos++;
}
return res;
}
string readString(int min_len, int max_len, const string& pattern = "") {
assert(min_len <= max_len);
string res = readOne();
assert(min_len <= (int) res.size());
assert((int) res.size() <= max_len);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int min_val, int max_val) {
assert(min_val <= max_val);
int res = stoi(readOne());
assert(min_val <= res);
assert(res <= max_val);
return res;
}
long long readLong(long long min_val, long long max_val) {
assert(min_val <= max_val);
long long res = stoll(readOne());
assert(min_val <= res);
assert(res <= max_val);
return res;
}
vector<int> readInts(int size, int min_val, int max_val) {
assert(min_val <= max_val);
vector<int> res(size);
for (int i = 0; i < size; i++) {
res[i] = readInt(min_val, max_val);
if (i != size - 1) {
readSpace();
}
}
return res;
}
vector<long long> readLongs(int size, long long min_val, long long max_val) {
assert(min_val <= max_val);
vector<long long> res(size);
for (int i = 0; i < size; i++) {
res[i] = readLong(min_val, max_val);
if (i != size - 1) {
readSpace();
}
}
return res;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
};
int main() {
input_checker in;
int tt = in.readInt(1, 1e5);
in.readEoln();
int sn = 0;
while (tt--) {
int n = in.readInt(1, 2e5);
in.readEoln();
sn += n;
auto a = in.readInts(n, 0, (1 << 20) - 1);
in.readEoln();
vector<int> b(n + 1);
for (int i = 0; i < n; i++) {
b[i + 1] = b[i] ^ a[i];
}
vector<int> c(21, -1);
vector<vector<pair<int, int>>> e(n + 1);
for (int i = 0; i < n; i++) {
for (int j = 0; j < 20; j++) {
if (a[i] & (1 << j)) {
c[j] = i;
}
}
auto d = c;
d.emplace_back(i);
sort(d.rbegin(), d.rend());
d.resize(unique(d.begin(), d.end()) - d.begin());
int sz = (int) d.size();
for (int j = 0; j < sz - 1; j++) {
int t = 0;
for (int k = 0; k < 20; k++) {
if (c[k] >= d[j]) {
t |= 1 << k;
}
}
t ^= b[i + 1];
e[d[j + 1] + 1].emplace_back(t, 1);
e[min(i, d[j] + 1)].emplace_back(t, -1);
}
}
map<int, int> cnt;
long long ans = 0;
for (int i = 0; i < n + 1; i++) {
for (auto [x, y] : e[i]) {
cnt[x] += y;
}
ans += cnt[b[i]];
}
cout << n * 1LL * (n - 1) / 2 - ans << '\n';
}
assert(sn <= 2e5);
in.readEof();
return 0;
}
Editorialist's code (Python)
from collections import defaultdict
from bisect import bisect_left
for _ in range(int(input())):
n = int(input())
a = list(map(int, input().split()))
xor_pos, ors, pref, ans = defaultdict(lambda: []), {}, 0, 0
xor_pos[0].append(-1)
for i in range(n):
x, cur_ors = a[i], defaultdict(lambda: 10 ** 9)
pref ^= x
cur_ors[x] = i
for y in ors: cur_ors[x | y] = min(cur_ors[x | y], ors[y])
ors = cur_ors
R = i
for y in sorted(ors.keys()):
ans += bisect_left(xor_pos[pref ^ y], R) - bisect_left(xor_pos[pref ^ y], ors[y] - 1)
R = ors[y] - 1
xor_pos[pref].append(i)
print(n*(n+1)//2 - ans)