XOP - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: rudra_1232
Tester: kingmessi
Editorialist: iceknight1093

DIFFICULTY:

Easy

PREREQUISITES:

PROBLEM:

For an array B, define f(B) to be 1 if all its elements can be made equal by the following process, and 0 otherwise:

  • Choose a subarray [l, r] of B.
  • For each l \leq i \leq r, replace B_i by B_i \oplus (i - l + 1).

Given an array A, find \sum_{i=1}^n\sum_{j=i}^n f([A_i, A_{i+1}, \ldots, A_j]).

EXPLANATION:

Let’s first understand when all the elements of a single array can be made equal.

So, suppose we have an array B of length N.
Our move is essentially to choose some subarray and XOR its elements by 1, 2, 3, \ldots in order.
In particular, note that B_i can only be XOR-ed with some integer that’s \leq i, though this can be done multiple times.

In fact, note that it’s possible to XOR B_i with any x \leq i, without changing the rest of the array.
This arises due to XOR being its own inverse:

  • Choose l = i-x+1 and r = i, which will result in each of B_{i-x+1}, B_{i-x+2}, \ldots, B_i being XOR-ed with 1, 2, 3, \ldots, x.
  • Then, choose l = i-x+1 and r = i-1, which will result in the same thing except without affecting B_i itself.
    This will reset each of B_{i-x+1}, \ldots, B_{i-1} to their original values.

This now tells us exactly which values B_i can take: if h is the largest integer such that 2^h \leq i, we can freely change any of its bits \leq h (for example by choosing the appropriate power of 2), and cannot change its bits \gt h at all.

In particular, note that all the bits of B_1, other than its lowest bit, are fixed. This essentially uniquely determines the final value in the array (because the lowest bit of every number can be freely changed anyway).
Extending this to further indices,

  • The lowest two bits of B_2 and B_3 can be changed freely, but all bits \geq 2 are fixed.
    This means that all their bits that are \geq 2 must match the corresponding bit of B_1; otherwise no solution exists.
  • Similarly, B_4, B_5, B_6, B_7 all must match B_1 at bits \geq 3, and so on.
  • In general, for each i \geq 1, B_i must match B_1 at all bits \geq \left\lfloor \log_2 i \right\rfloor.

We now have a relatively easy check for a single array B. Let’s extend this to counting for all subarrays.

Let’s fix the left end L of the subarray, and try to count all valid R.
A direct check, using the criterion devised above, is as follows:

  • For each R = L, L+1, L+2, \ldots in order, check if B_L and B_R match at all bits other than the lowest \left\lfloor \log_2(R-L+1) \right\rfloor ones.
  • If they do match, [L, R] is valid so add 1 to the answer.
    Otherwise, [L, R] is invalid, and also all [L, R'] for R'\gt R will be invalid (since this index will prevent equality happening no matter what), so we can break out immediately.

This algorithm, will correct, can take quadratic time - so we must optimize it.


One way of optimization is to iterate over bits rather than indices.
That is, once L is fixed, we’ll iterate over values of b = 0, 1, 2, \ldots and try to perform the check for all the elements at indices R such that \left\lfloor \log_2(R-L+1) \right\rfloor = b, simultaneously.
That is, for all R in the range [L + 2^b - 1, \min(N, L + 2^{b+1} - 2)].

Note that all of them have the same check: we want to know if their bits \gt b match the corresponding bits of B_L.
To check this, let’s look at some bit b' \gt b.

  • If b' is not set in B_L, it must then not be set in any of the B_R values in the range we’re looking at.
    This can be checked in constant time by counting the number of integers in this segment that have b' set using prefix sums built on this bit alone.
    Alternately, compute the bitwise OR of the range (using, say, a sparse table) and check the bit b' of this OR.
  • Similarly, if b' is set in B_L, it must be set in every B_R in this range, which can again be checked by looking at the count of values in this range that have it set; or verifying that the bitwise AND of the range has it set.

This check needs to be performed for each b \lt b' \lt 60, and all of them must pass.

Once this check is done, we have two possibilities:

  1. Every R corresponding to this b is valid.
    Here, add the length of the range to the answer and move to the next b (or stop, if the end of the array has been reached).
  2. Not every R is valid.
    We now need to find the first R in this range that’s invalid. That can be done using binary search on the range - the check function is the exact same as done above.

This way, we do \mathcal{O}(60\cdot N\log N) work, which is fast enough.

TIME COMPLEXITY:

\mathcal{O}(60\cdot N \log N) per testcase.

CODE:

Author's code (C++)
#include<bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
template<typename T>
using ordered_set = tree<T,null_type,less<T>,rb_tree_tag,tree_order_statistics_node_update>;
typedef long long int  ll;
typedef long double ld;
#define len(x) (ll)(x).size()
#define F first
#define S second
#define all(x) (x).begin(),(x).end()
#define pb push_back
#define mp make_pair
#define nl '\n'
ll N = 1e9+7;
ll N1 =998244353;
const int NN=2e5+5;

        vector<ll> a(NN),pc(NN);



int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(NULL);
    int t=1; 
    bool take_t=true;
    if(take_t)cin>>t;
    while(t--){
        int n;
        cin>>n;
        for(int i=0;i<n;i++){cin>>a[i];}
        if(n==1)cout<<1;
        else{
            ll ln=__lg(n);
            vector<ll> cb0(ln+1,n),cb1(ln+1,n);
            auto get_sublen=[&](int x,int ind){
                ll ret=n-1;
                for(int j=ln;j>-1;j--){
                    if((x&(1<<j))){
                        if(cb0[j]!=n&&ind+(1<<j)-1>cb0[j])
                            ret=min(cb0[j]-1,ret);
                    }
                    else {
                        if(cb1[j]!=n&&ind+(1<<j)-1>cb1[j])
                            ret=min(cb1[j]-1,ret);
                    }
                }
                return ret;
            }; 
            // ll pc[n];
            pc[n-1]=n-1;
            for(int i=n-2;i>-1;i--){
                if((a[i]>>(ln+1))!=(a[i+1]>>(ln+1)))pc[i]=i;
                else pc[i]=pc[i+1];
            }
            ll ans=0;
            for(int i=n-1;i>-1;i--){
                for(int j=0;j<ln+1;j++){
                    if((a[i]&(1LL<<j)))cb1[j]=i;
                    else cb0[j]=i;
                }
                // for a[i] 
                ans+=(min(pc[i],get_sublen(a[i],i))+1-i);
                
                // for a[i]^1 
                // ans+=(min(pc[i],get_sublen((a[i]^1),i))+1-i);
            }
            cout<<ans;
        }
        if(t)cout<<nl;
    }

    // cerr<<"worked in "<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
    return 0;
}
Tester's code (C++)
//Har Har Mahadev
#include<bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp> // Common file
#include <ext/pb_ds/tree_policy.hpp>
#define ll long long
#define int long long
#define rep(i,a,b) for(int i=a;i<b;i++)
#define rrep(i,a,b) for(int i=a;i>=b;i--)
#define repin rep(i,0,n)
#define precise(i) cout<<fixed<<setprecision(i)
#define vi vector<int>
#define si set<int>
#define mii map<int,int>
#define take(a,n) for(int j=0;j<n;j++) cin>>a[j];
#define give(a,n) for(int j=0;j<n;j++) cout<<a[j]<<' ';
#define vpii vector<pair<int,int>>
#define db double
#define be(x) x.begin(),x.end()
#define pii pair<int,int>
#define pb push_back
#define pob pop_back
#define ff first
#define ss second
#define lb lower_bound
#define ub upper_bound
#define bpc(x) __builtin_popcountll(x) 
#define btz(x) __builtin_ctz(x)
using namespace std;

using namespace __gnu_pbds;

typedef tree<int, null_type, less<int>, rb_tree_tag,tree_order_statistics_node_update> ordered_set;
typedef tree<pair<int, int>, null_type,less<pair<int, int> >, rb_tree_tag,tree_order_statistics_node_update> ordered_multiset;

const long long INF=1e18;
const long long M=1e9+7;
const long long MM=998244353;
  
int power( int N, int M){
    int power = N, sum = 1;
    if(N == 0) sum = 0;
    while(M > 0){if((M & 1) == 1){sum *= power;}
    power = power * power;M = M >> 1;}
    return sum;
}

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && !isspace(buffer[now])) {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
}inp;

int smn = 0;

void solve()
{
    int n;
    // cin >> n;
    n = inp.readInt(1,200'000);
    smn += n;
    inp.readEoln();
    vi a(n);
    // take(a,n);
    int uB = 1ll<<60;
    uB--;
    repin{
        a[i] = inp.readLong(0,uB);
        if(i == n-1)inp.readEoln();
        else inp.readSpace();
    }
    set<int> s;
    s.insert(n);
    vi st;
    int x = 1;
    st.pb(0);
    vi b(n);
    while(true){
        rep(i,st.back(),min(st.back()+x,n)){
            b[i] = (a[i]|(x*2-1))^(x*2-1);
            b[i] |= (a[0]&(x*2-1));
        }
        if(st.back()+x >= n)break;
        st.pb(st.back()+x);
        x *= 2;
    }
    rep(i,0,n-1){
        if(b[i] != b[i+1])s.insert(i+1);
    }

    reverse(be(st));
    st.pob();
    reverse(be(st));
    for(auto &x : st)x--;


    int ans = (*s.begin());
    rep(i,1,n){
        if(s.count(i))s.erase(i);
        rep(j,0,st.size()){
            if(s.count(i+st[j]+1))s.erase(i+st[j]+1);
            if(s.count(i+st[j]))s.erase(i+st[j]);
        }
        s.insert(n);
        int x = 2;
        rep(j,0,st.size()){
            if(i+st[j] >= n)break;
            b[i+st[j]] = ((a[i+st[j]]|(x-1))^(x-1));
            b[i+st[j]] |= (a[i]&(x-1));
            if(i+st[j]-1 >= i){
                b[i+st[j]-1] = ((a[i+st[j]-1]|(x-1))^(x-1));
                b[i+st[j]-1] |= (a[i]&(x-1));
            }
            if(i+st[j]+1 < n){
                b[i+st[j]+1] = ((a[i+st[j]+1]|(2*x-1))^(2*x-1));
                b[i+st[j]+1] |= (a[i]&(2*x-1));
            }
            x *= 2;
        }
        rep(j,0,st.size()){
            if(i+st[j]+1 < n && b[i+st[j]+1] != b[i+st[j]])s.insert(i+st[j]+1);
            if(i+st[j] < n && st[j] &&  b[i+st[j]] != b[i+st[j]-1])s.insert(i+st[j]);
        }
        ans += (*s.begin())-i;
    }
    cout << ans << '\n';


}

signed main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    #ifdef NCR
        init();
    #endif
    #ifdef SIEVE
        sieve();
    #endif
    int t;
    // cin >> t;
    t = inp.readInt(1,200'000);
    inp.readEoln();
    while(t--)
        solve();
    assert(smn <= 200'000);
    inp.readEof();
    return 0;
}
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);
    
    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector<ll> a(n);
        for (ll &x : a) cin >> x;
        vector<array<int, 60>> pref(n+1);
        for (int i = 0; i < n; ++i) {
            pref[i+1] = pref[i];
            for (int b = 0; b < 60; ++b)
                pref[i+1][b] += (a[i] >> b) & 1;
        }

        ll ans = 0;
        for (int i = n-1; i >= 0; --i) {
            for (int b = 0; i + (1 << b) - 1 < n; ++b) {
                auto check = [&] (int L, int R) {
                    bool good = true;
                    for (int b2 = b+1; b2 < 60; ++b2) {
                        int x = (a[i] >> b2) & 1;
                        int y = pref[R][b2] - pref[L][b2];
                        if (x == 0) good &= y == 0;
                        else good &= y == (R - L);
                    }
                    return good;
                };
                
                int L = i - 1 + (1 << b);
                int R = min(n, L + (1 << b));
                // [L, R)

                if (check(L, R)) {
                    ans += R - L;
                    continue;
                }

                int lo = L - 1, hi = R - 1;
                while (lo < hi) {
                    int mid = (lo + hi) / 2;
                    if (check(lo, mid + 1)) {
                        lo = mid + 1;
                    }
                    else {
                        hi = mid;
                    }
                }
                ans += hi - L;
                break;
            }
        }
        cout << ans << '\n';
    }
}
1 Like

Note that asking if all numbers in a range share a certain number of high bits is equivalent to asking what is the longest common prefix of these numbers, so all related techniques also work. For example, one may record the most significant bit of B_{i-1} \oplus B_i for all i = 2, \ldots, n and make range maximum queries to find the LCP.

2 Likes

I had the same solution (except no binary search), what is wrong in my implementation?
https://www.codechef.com/viewsolution/1121748080

can only be XOR-ed with some integer that’s <=i,

as you mentioned this, so i just checked if the number i would need to xor the current number with to reach the starting number is less than i or not, but go wa in this , is there some error in my logic? :confused:

Notice that we only need to deal with last logN bits, because the maximum subarray length is N. We maintain a variable next(i, j) meaning the last index k \ge i that only the last j bits are not required to be same between a_i and a_k. This variable could be easily pre-computed by a O(NlogN) reverse dp.
After this precomputation, we start from each i, check next(i,j) of j from 0 to logN and add contribution to ans. This part is same as the editorial. The total time complexity is O(NlogN).

2 Likes

I found the bug (a missing -1), here is the accepted submission.
https://www.codechef.com/viewsolution/1122339778