ADVITIYA6 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: ladchat
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Familiarity with bitwise operations, prefix sums

PROBLEM:

An array is called magical if its bitwise AND equals its bitwise XOR.
Given an array A, find the number of its subarrays that are magical.

EXPLANATION:

The main observation here is that while there are around N^2 subarrays, the number of distinct subarray ANDs really isn’t much at all.
In particular, we have the following (rather well-known) result.
Claim: Let A be an array of length N consisting of integers that are \lt 2^B. Then, there are \mathcal{O}(N\cdot B) distinct subarray ANDs in A.
In particular, there are at most B+1 distinct subarray ANDs when considering all subarrays that end at a fixed position, and each such one occurs for a contiguous range of starting positions.

Proof

Suppose the right endpoint of a subarray is fixed, say to R.
Let f(L, R) = A_L \& A_{L+1}\&\ldots\& A_R.

Observe that f(L, R) = A_L \& f(L+1, R), so f(L, R) either equals f(L+1, R), or is a strict submask of it.
That is, when starting with L = R and moving L leftwards, either the bitwise AND remains the same, or goes to a submask of the current mask.
However, moving to a submask removes at least one bit; and this can happen at most B times before we reach 0 since we have only B-bit numbers.
This bounds the number of times the bitwise AND changes by B, leading to our bound of B+1 values.
It’s also easy to see that, since the subarray ANDs are monotonic, each one will occur only for some range of L.


With this observation in hand, let’s solve the problem.
Fix the right endpoint R of a subarray, and find all that subarray ANDs that end there (along with the range of L for each one).

How?

There’s a few different ways to do it. Here’s one of them.

Let E be a sorted list of pairs of (value, index), denoting pairs of subarray ANDs ending at R and the lowest left index for this AND.
That means the subarray AND E_{i, 0} occurs for all left endpoints from E_{i, 1} to E_{i+1, 1}-1 (since E_{i+1, 1} is the leftmost endpoint of the next higher bitwise AND).
As noted above, this list has at most B+1 elements.

If we know this list for R, it can be easily obtained for endpoint R+1, as follows:

  • Create a new list of pairs E'.
    Initially, it contains only the pair (A_{R+1}, R+1), for the left endpoint R+1.
  • Now, observe that any other subarray ending at R+1 can be obtained by taking a subarray ending at R and extending it by one step.
  • So, for each pair (x, y) \in E, add the pair (x\ \&\ A_{R+1}, y) to E'.
  • Now, sort E', and for first element, keep only the lowest occurrence of the second element.
  • Finally, set E := E', and we have the required list for endpoint R+1.

We sort a list of size at most B+2, so this is \mathcal{O}(B\log B) for a single index.

Suppose the bitwise AND X occurs for all subarrays [L, R] such that L_1 \leq L \leq L_2.
Then, we want to count the number of L in this range such that their bitwise XOR is also equal to X.
However, recall that bitwise XOR is invertible, which lets us compute subarray XORs with the help of prefixes.
In particular, if we let P_i = A_1 \oplus A_2 \oplus\ldots\oplus A_i, the bitwise XOR of subarray [L, R] equals simply P_R \oplus P_{L-1}.

So, what we’re looking for is the number of L (L_1 \leq L \leq L_2) such that P_R \oplus P_{L-1} = X, or in other words, P_{L-1} = P_R\oplus X.
This is quite simple in \mathcal{O}(\log N) time: just store a list of positions corresponding to each prefix XOR, then binary search on the list corresponding to P_I\oplus X.

This binary search is done \mathcal{O}(N\cdot B) times, for an overall complexity of \mathcal{O}(B\cdot N\log N).

TIME COMPLEXITY:

\mathcal{O}(B\cdot N\log N) per testcase, where B = 30 here.

CODE:

Author's code (C++)
#include<bits/stdc++.h>
using namespace std;
#pragma GCC optimize ("O3","unroll-loops")
#pragma GCC optimize("inline","-ffast-math")
#pragma GCC target("fma,sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,avx2,tune=native")
const int N1 = 500001;
int spr[N1][21];
void bldsp(int a[],int n){
    for(int i=0;i<n;i++)spr[i][0]=a[i];
    for(int j = 1;(1<<j)<=n;j++)
        for(int i = 0;i+(1<<j)-1<n;i++)
            spr[i][j]=(spr[i][j-1])&(spr[i+(1<<(j-1))][j-1]);
}
int quespr(int l,int r){
    int lng = log2l(r-l+1);
    return ((spr[l][lng])&(spr[r-(1<<lng)+1][lng]));
}
int main(){
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int testcase=1;
cin>>testcase;
while(testcase--){
   int n;cin>>n;
   map<int,vector<int>> mp;
   int arr[n+1],prfx[n+1];arr[0]=0;prfx[0]=0;
   for(int i=1;i<=n;i++)cin>>arr[i];
   for(int i=1;i<=n;i++){
      prfx[i]=prfx[i-1]^arr[i];
      mp[prfx[i]].push_back(i);
   }
   int fnlans=0;
   bldsp(arr,n+1);
   for(int i = 1;i<=n;i++){
      int pt=i;
      while(pt<=n){
         int crn = quespr(i,pt);
         int p1=pt,p2=n,ans=pt;
         while(p2>=p1){
            int md = (p1+p2)>>1;
            int er = quespr(i,md);
            if(er>=crn){ans=md;p1=md+1;}
            else p2=md-1;
         }
         int checkinxor = crn^prfx[i-1];
         int lastid = upper_bound(mp[checkinxor].begin(),mp[checkinxor].end(),ans)-mp[checkinxor].begin();
         int firstid = lower_bound(mp[checkinxor].begin(),mp[checkinxor].end(),pt)-mp[checkinxor].begin();
         fnlans+=lastid-firstid;
         pt=ans+1;
      }
   }
   cout<<fnlans<<endl;
   
}
  return 0;
}
Tester's code (C++)
#ifndef LOCAL
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,avx2,sse,sse2,sse3,sse4,popcnt,fma")
#endif

#include <bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...) "11-111"
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && !isspace(buffer[now])) {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int32_t main() {
    ios_base::sync_with_stdio(0);   cin.tie(0);

    input_checker input;

    int T = input.readInt(1, (int)1e5); input.readEoln();
    int NN = 0;
    while(T-- > 0) {
        int N = input.readInt(1, (int)1e5); input.readEoln();
        vector<int> A = input.readInts(N, 0, (1 << 30) - 1);    input.readEoln();
        NN += N;

        vector<int> B(N), P(N + 1);
        vector<pair<int, int>> C(1, {0, 0});
        for(int i = 0 ; i < N ; ++i) {
            P[i + 1] = P[i] ^ A[i];
            C.push_back({C.back().first ^ A[i], i + 1});
        }
        sort(C.begin(), C.end());

        auto get_count = [&](int l, int r, int v) {
            return int(upper_bound(C.begin(), C.end(), make_pair(v, r))
                - lower_bound(C.begin(), C.end(), make_pair(v, l)));
        };

        int64_t res = 0;
        for(int i = 0 ; i < N ; ++i) {
            B[i] = A[i];
            for(int j = i - 1 ; j >= 0 ; --j) {
                if((B[j] & A[i]) == B[j])   break;
                B[j] &= A[i];
            }

            int in = i;
            while(in >= 0) {
                int l = lower_bound(B.begin(), B.begin() + in + 1, B[in]) - B.begin();
                res += get_count(l, in, P[i + 1] ^ B[l]);
                in = l - 1;
            }
        }
        cout << res << '\n';
    }

    assert(NN <= (int)1e5);
    input.readEof();

    return 0;
}

Editorialist's code (Python)
import bisect
for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    prefxor = [0]*n
    for i in range(n):
        prefxor[i] = a[i]
        if i > 0: prefxor[i] ^= prefxor[i-1]
    positions = dict()
    positions[0] = [-1]
    for i in range(n):
        if prefxor[i] not in positions: positions[prefxor[i]] = []
        positions[prefxor[i]].append(i)
    sufs = []
    ans = 0
    for i, x in enumerate(a):
        nsufs = [(x, i)]
        for suf, j in sufs:
            nsufs.append((suf & x, j))
        nsufs.sort()
        sufs.clear()
        for suf, j in nsufs:
            if not sufs or sufs[-1][0] != suf: sufs.append((suf, j))
        rt = i
        for suf, j in reversed(sufs):
            want = prefxor[i] ^ suf
            if want in positions:
                how_many = bisect.bisect_left(positions[want], rt) - bisect.bisect_left(positions[want], j-1)
                ans += how_many
            rt = j-1
    print(ans)
1 Like

I did what was said in the editorial but I am getting Wrong answer on 5th test case.
Here is the submission link : CodeChef: Practical coding for everyone . Can anyone check what is the problem? Thank you.