PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: rudra_1232
Tester: kingmessi
Editorialist: iceknight1093
DIFFICULTY:
Easy
PREREQUISITES:
PROBLEM:
For an array B, define f(B) to be 1 if all its elements can be made equal by the following process, and 0 otherwise:
- Choose a subarray [l, r] of B.
- For each l \leq i \leq r, replace B_i by B_i \oplus (i - l + 1).
Given an array A, find \sum_{i=1}^n\sum_{j=i}^n f([A_i, A_{i+1}, \ldots, A_j]).
EXPLANATION:
Let’s first understand when all the elements of a single array can be made equal.
So, suppose we have an array B of length N.
Our move is essentially to choose some subarray and XOR its elements by 1, 2, 3, \ldots in order.
In particular, note that B_i can only be XOR-ed with some integer that’s \leq i, though this can be done multiple times.
In fact, note that it’s possible to XOR B_i with any x \leq i, without changing the rest of the array.
This arises due to XOR being its own inverse:
- Choose l = i-x+1 and r = i, which will result in each of B_{i-x+1}, B_{i-x+2}, \ldots, B_i being XOR-ed with 1, 2, 3, \ldots, x.
- Then, choose l = i-x+1 and r = i-1, which will result in the same thing except without affecting B_i itself.
This will reset each of B_{i-x+1}, \ldots, B_{i-1} to their original values.
This now tells us exactly which values B_i can take: if h is the largest integer such that 2^h \leq i, we can freely change any of its bits \leq h (for example by choosing the appropriate power of 2), and cannot change its bits \gt h at all.
In particular, note that all the bits of B_1, other than its lowest bit, are fixed. This essentially uniquely determines the final value in the array (because the lowest bit of every number can be freely changed anyway).
Extending this to further indices,
- The lowest two bits of B_2 and B_3 can be changed freely, but all bits \geq 2 are fixed.
This means that all their bits that are \geq 2 must match the corresponding bit of B_1; otherwise no solution exists. - Similarly, B_4, B_5, B_6, B_7 all must match B_1 at bits \geq 3, and so on.
- In general, for each i \geq 1, B_i must match B_1 at all bits \geq \left\lfloor \log_2 i \right\rfloor.
We now have a relatively easy check for a single array B. Let’s extend this to counting for all subarrays.
Let’s fix the left end L of the subarray, and try to count all valid R.
A direct check, using the criterion devised above, is as follows:
- For each R = L, L+1, L+2, \ldots in order, check if B_L and B_R match at all bits other than the lowest \left\lfloor \log_2(R-L+1) \right\rfloor ones.
- If they do match, [L, R] is valid so add 1 to the answer.
Otherwise, [L, R] is invalid, and also all [L, R'] for R'\gt R will be invalid (since this index will prevent equality happening no matter what), so we can break out immediately.
This algorithm, will correct, can take quadratic time - so we must optimize it.
One way of optimization is to iterate over bits rather than indices.
That is, once L is fixed, we’ll iterate over values of b = 0, 1, 2, \ldots and try to perform the check for all the elements at indices R such that \left\lfloor \log_2(R-L+1) \right\rfloor = b, simultaneously.
That is, for all R in the range [L + 2^b - 1, \min(N, L + 2^{b+1} - 2)].
Note that all of them have the same check: we want to know if their bits \gt b match the corresponding bits of B_L.
To check this, let’s look at some bit b' \gt b.
- If b' is not set in B_L, it must then not be set in any of the B_R values in the range we’re looking at.
This can be checked in constant time by counting the number of integers in this segment that have b' set using prefix sums built on this bit alone.
Alternately, compute the bitwise OR of the range (using, say, a sparse table) and check the bit b' of this OR. - Similarly, if b' is set in B_L, it must be set in every B_R in this range, which can again be checked by looking at the count of values in this range that have it set; or verifying that the bitwise AND of the range has it set.
This check needs to be performed for each b \lt b' \lt 60, and all of them must pass.
Once this check is done, we have two possibilities:
- Every R corresponding to this b is valid.
Here, add the length of the range to the answer and move to the next b (or stop, if the end of the array has been reached). - Not every R is valid.
We now need to find the first R in this range that’s invalid. That can be done using binary search on the range - the check function is the exact same as done above.
This way, we do \mathcal{O}(60\cdot N\log N) work, which is fast enough.
TIME COMPLEXITY:
\mathcal{O}(60\cdot N \log N) per testcase.
CODE:
Author's code (C++)
#include<bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
template<typename T>
using ordered_set = tree<T,null_type,less<T>,rb_tree_tag,tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
#define len(x) (ll)(x).size()
#define F first
#define S second
#define all(x) (x).begin(),(x).end()
#define pb push_back
#define mp make_pair
#define nl '\n'
ll N = 1e9+7;
ll N1 =998244353;
const int NN=2e5+5;
vector<ll> a(NN),pc(NN);
int main(){
ios_base::sync_with_stdio(0);
cin.tie(NULL);
int t=1;
bool take_t=true;
if(take_t)cin>>t;
while(t--){
int n;
cin>>n;
for(int i=0;i<n;i++){cin>>a[i];}
if(n==1)cout<<1;
else{
ll ln=__lg(n);
vector<ll> cb0(ln+1,n),cb1(ln+1,n);
auto get_sublen=[&](int x,int ind){
ll ret=n-1;
for(int j=ln;j>-1;j--){
if((x&(1<<j))){
if(cb0[j]!=n&&ind+(1<<j)-1>cb0[j])
ret=min(cb0[j]-1,ret);
}
else {
if(cb1[j]!=n&&ind+(1<<j)-1>cb1[j])
ret=min(cb1[j]-1,ret);
}
}
return ret;
};
// ll pc[n];
pc[n-1]=n-1;
for(int i=n-2;i>-1;i--){
if((a[i]>>(ln+1))!=(a[i+1]>>(ln+1)))pc[i]=i;
else pc[i]=pc[i+1];
}
ll ans=0;
for(int i=n-1;i>-1;i--){
for(int j=0;j<ln+1;j++){
if((a[i]&(1LL<<j)))cb1[j]=i;
else cb0[j]=i;
}
// for a[i]
ans+=(min(pc[i],get_sublen(a[i],i))+1-i);
// for a[i]^1
// ans+=(min(pc[i],get_sublen((a[i]^1),i))+1-i);
}
cout<<ans;
}
if(t)cout<<nl;
}
// cerr<<"worked in "<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n";
return 0;
}
Tester's code (C++)
//Har Har Mahadev
#include<bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp> // Common file
#include <ext/pb_ds/tree_policy.hpp>
#define ll long long
#define int long long
#define rep(i,a,b) for(int i=a;i<b;i++)
#define rrep(i,a,b) for(int i=a;i>=b;i--)
#define repin rep(i,0,n)
#define precise(i) cout<<fixed<<setprecision(i)
#define vi vector<int>
#define si set<int>
#define mii map<int,int>
#define take(a,n) for(int j=0;j<n;j++) cin>>a[j];
#define give(a,n) for(int j=0;j<n;j++) cout<<a[j]<<' ';
#define vpii vector<pair<int,int>>
#define db double
#define be(x) x.begin(),x.end()
#define pii pair<int,int>
#define pb push_back
#define pob pop_back
#define ff first
#define ss second
#define lb lower_bound
#define ub upper_bound
#define bpc(x) __builtin_popcountll(x)
#define btz(x) __builtin_ctz(x)
using namespace std;
using namespace __gnu_pbds;
typedef tree<int, null_type, less<int>, rb_tree_tag,tree_order_statistics_node_update> ordered_set;
typedef tree<pair<int, int>, null_type,less<pair<int, int> >, rb_tree_tag,tree_order_statistics_node_update> ordered_multiset;
const long long INF=1e18;
const long long M=1e9+7;
const long long MM=998244353;
int power( int N, int M){
int power = N, sum = 1;
if(N == 0) sum = 0;
while(M > 0){if((M & 1) == 1){sum *= power;}
power = power * power;M = M >> 1;}
return sum;
}
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}
int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && !isspace(buffer[now])) {
now++;
}
return now;
}
string readOne() {
assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
return res;
}
string readString(int minl, int maxl, const string &pattern = "") {
assert(minl <= maxl);
string res = readOne();
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int minv, int maxv) {
assert(minv <= maxv);
int res = stoi(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
long long res = stoll(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
auto readInts(int n, int minv, int maxv) {
assert(n >= 0);
vector<int> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readInt(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
auto readLongs(int n, long long minv, long long maxv) {
assert(n >= 0);
vector<long long> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readLong(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
}inp;
int smn = 0;
void solve()
{
int n;
// cin >> n;
n = inp.readInt(1,200'000);
smn += n;
inp.readEoln();
vi a(n);
// take(a,n);
int uB = 1ll<<60;
uB--;
repin{
a[i] = inp.readLong(0,uB);
if(i == n-1)inp.readEoln();
else inp.readSpace();
}
set<int> s;
s.insert(n);
vi st;
int x = 1;
st.pb(0);
vi b(n);
while(true){
rep(i,st.back(),min(st.back()+x,n)){
b[i] = (a[i]|(x*2-1))^(x*2-1);
b[i] |= (a[0]&(x*2-1));
}
if(st.back()+x >= n)break;
st.pb(st.back()+x);
x *= 2;
}
rep(i,0,n-1){
if(b[i] != b[i+1])s.insert(i+1);
}
reverse(be(st));
st.pob();
reverse(be(st));
for(auto &x : st)x--;
int ans = (*s.begin());
rep(i,1,n){
if(s.count(i))s.erase(i);
rep(j,0,st.size()){
if(s.count(i+st[j]+1))s.erase(i+st[j]+1);
if(s.count(i+st[j]))s.erase(i+st[j]);
}
s.insert(n);
int x = 2;
rep(j,0,st.size()){
if(i+st[j] >= n)break;
b[i+st[j]] = ((a[i+st[j]]|(x-1))^(x-1));
b[i+st[j]] |= (a[i]&(x-1));
if(i+st[j]-1 >= i){
b[i+st[j]-1] = ((a[i+st[j]-1]|(x-1))^(x-1));
b[i+st[j]-1] |= (a[i]&(x-1));
}
if(i+st[j]+1 < n){
b[i+st[j]+1] = ((a[i+st[j]+1]|(2*x-1))^(2*x-1));
b[i+st[j]+1] |= (a[i]&(2*x-1));
}
x *= 2;
}
rep(j,0,st.size()){
if(i+st[j]+1 < n && b[i+st[j]+1] != b[i+st[j]])s.insert(i+st[j]+1);
if(i+st[j] < n && st[j] && b[i+st[j]] != b[i+st[j]-1])s.insert(i+st[j]);
}
ans += (*s.begin())-i;
}
cout << ans << '\n';
}
signed main(){
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
#ifdef NCR
init();
#endif
#ifdef SIEVE
sieve();
#endif
int t;
// cin >> t;
t = inp.readInt(1,200'000);
inp.readEoln();
while(t--)
solve();
assert(smn <= 200'000);
inp.readEof();
return 0;
}
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());
int main()
{
ios::sync_with_stdio(false); cin.tie(0);
int t; cin >> t;
while (t--) {
int n; cin >> n;
vector<ll> a(n);
for (ll &x : a) cin >> x;
vector<array<int, 60>> pref(n+1);
for (int i = 0; i < n; ++i) {
pref[i+1] = pref[i];
for (int b = 0; b < 60; ++b)
pref[i+1][b] += (a[i] >> b) & 1;
}
ll ans = 0;
for (int i = n-1; i >= 0; --i) {
for (int b = 0; i + (1 << b) - 1 < n; ++b) {
auto check = [&] (int L, int R) {
bool good = true;
for (int b2 = b+1; b2 < 60; ++b2) {
int x = (a[i] >> b2) & 1;
int y = pref[R][b2] - pref[L][b2];
if (x == 0) good &= y == 0;
else good &= y == (R - L);
}
return good;
};
int L = i - 1 + (1 << b);
int R = min(n, L + (1 << b));
// [L, R)
if (check(L, R)) {
ans += R - L;
continue;
}
int lo = L - 1, hi = R - 1;
while (lo < hi) {
int mid = (lo + hi) / 2;
if (check(lo, mid + 1)) {
lo = mid + 1;
}
else {
hi = mid;
}
}
ans += hi - L;
break;
}
}
cout << ans << '\n';
}
}