MEXXOR2 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

Binary search, XOR basis

PROBLEM:

You’re given an integer N.
Count the number of subsets S of \{0, 1, 2, \ldots, N\} for which \text{MEX}(S) = \text{XOR}(S).

EXPLANATION:

To begin, we need to understand what it means for a set to have equal mex and XOR.
Let M = \text{MEX}(S) = \text{XOR}(S). This then means:

  • M \notin S, and
  • \text{XOR}(S\cup \{M\}) = \text{XOR}(S) \oplus M = 0

That is, inserting the MEX of S into it results in a set with zero XOR; so S itself is one element away from having zero XOR.

Let’s look at this from the opposite perspective: suppose Z is a set with zero XOR, and we create a valid set S from it by removing one element.
The one element we remove should be the MEX of the resulting set; which is only possible if everything smaller than it also exists in S (and hence in Z).

In particular, suppose all the integers 0, 1, 2, \ldots, K exist in Z, while K+1 does not.
Then, removing any integer \leq K from Z will result in a valid set S: and these are our only options.
So, such a set Z gives rise to K+1 different sets S.
Note that K+1 = \text{MEX}(Z), so one way to state this is just that there are \text{MEX}(Z) choices.


The above gives us a (slow) way to compute the answer, as a starting point.
Define f(K) to be the number of subsets of \{0, 1, 2, \ldots, N\} whose XOR is 0 and MEX is K.
The answer is then

\sum_{K=0}^{N+1} f(K) \cdot K

There are now two things to figure out:

  1. How do we compute f(K)?
  2. Even if we’re able to do that, how do we compute the above summation quickly enough? \mathcal{O}(N) or worse is clearly too slow.

To get an idea of what to do, let’s compute just f(K) for a fixed K first.

We know that 0, 1, 2, \ldots, K-1 should be in the subset, and K should not.
Everything \gt K can be chosen or discarded freely: the only constraint is that the XOR of the entire subset must be 0.

The XOR of the entire subset being 0 means that the XOR of elements \gt K must equal the XOR of elements \leq K-1.
The XOR of elements \leq K-1 is easy to compute, and well-known to be one of 0, 1, K-1, K depending on the value of K modulo 4.
Let this XOR be p_K.

Once p_K known, we’re interested in the number of subsets of \{K+1, \ldots, N\} that have XOR equal to p_K.
This is, in fact, a standard result if you know some linear algebra: if r_K = \text{rank}(\text{span}\{K+1, \ldots, N\}) then:

  • If p_K \in \text{span}\{K+1, \ldots, N\}, there are 2^{N-K-r_K} subsets.
  • If p_K \notin \text{span}\{K+1, \ldots, N\}, there are 0 subsets.

If the rank result is new or unfamiliar to you, one or more of these might be good reads: USACO Guide, blog 1, blog 2.
The rest of this editorial will assume knowledge of XOR basis and its properties.


You may notice that we in fact have a non-trivial subproblem in the above discussion: how exactly do we compute \text{rank}(\text{span}\{K+1, \ldots, N\}) quickly?

The span of a range

The standard method of doing this is to build a basis on all the values and then look at its size, but that would be \mathcal{O}(N\log N) which we can’t afford for even a single K.

To get around this, we will instead build a basis only on “important” values.

Observe that for any integer X such that 2^b divides X, the segment [X, X+1, X+2, \ldots, X + 2^b) goes through all possible values of the b lower bits, while keeping higher bits static.
This means that one basis for this segment is \{X, 1, 2, 4, \ldots, 2^{b-1}\} (unless X = 0 in which case it won’t be in the basis).

A slight generalization of this is to note that the segment [X, X+Y] where Y \lt 2^b has one basis being \{X, 1, 2, 4, \ldots, 2^m\}, where m is the largest integer such that X + 2^m \leq Y.
We’ll call segments of this form, simple.

Now, consider an arbitrary segment [L, R] of integers.
Observe that we can break it into \mathcal{O}(\log R) simple segments using the following algorithm:

  • Start with X = L.
    Compute the maximum b such that 2^b divides X (this is the number of trailing zeros of X).
  • Now, consider the segment [X, X + 2^b).
    If this segment includes R, consider [X, R] instead.
    Either way, the segment with us is simple, so its basis can be computed easily.
  • Next, if the segment processed was [X, R], stop and break.
    Otherwise, update X to X + 2^b and repeat the process from the first step.

This works because each time we restart the process, we increase the largest power of 2 that divides X.
So, after at most \log_2 R steps it’ll definitely exceed R and we’ll finish.

With this algorithm, we’ve hence broken [L, R] into \mathcal{O}(\log R) segments, and found a basis of size \mathcal{O}(\log R) for each one.
Simply merging all these bases will give us a basis for the entire segment [L, R].

Finally, note that it’s possible to shave off one \mathcal{O}(\log R) factor here.
Looking closely at the structure of each small basis, we see that at most one element within each of them is not a power of 2: and further, the powers of 2 themselves will be some prefix.
So, we only really need to merge the non-trivial elements, as well as keep track of the largest power of two encountered this way and merge them in at the end – this results in merging only \mathcal{O}(\log R) elements rather than \mathcal{O}(\log^2 R).


Now, we know that the answer is the sum of K\cdot 2^{N - K - r_K} across all “valid” K. All that remains is to compute this fast.

Observe that r_K is in general not going to be very large: it’s certainly not more than 30 since we’re dealing with 30-bit numbers at most.
Further, since r_K is the rank of some suffix ending at N, as K increases r_K can only decrease.

This means there are at most 30 distinct values of r_K, and each of them occurs on some contiguous interval.
Specifically, define b_i to the largest integer i such that r_{b_i} = i.
Then, we’ll have b_1 \gt b_2 \gt b_3 \gt \ldots, and the interval [b_{i+1} + 1, b_i] is exactly all integers that result in a rank of i.

Finding these b_i values is not too hard: you can for example binary search on the largest value with rank \geq i, utilizing the “compute rank of range” function described earlier as the check function.


Now that all the b_i are known, let’s process all K with r_K = i.
As noted earlier, this is exactly the interval [b_{i+1}+1, b_i].
Which of the points in this interval are “good”?

It turns out, nearly all of them!
Recall that K is good iff p_K = \text{XOR}(\{0 ,1, \ldots, K-1\}) \in \text{span}\{K+1, \ldots, N\}.
However, p_K can only take the values 0, 1, K-1, or K.
Now,

  • 0 is always in the span.
  • 1 will be in the span as soon as 2x and 2x+1 are both in the span for some x.
    In particular, as soon as there are at least three numbers in the set \{K+1, \ldots, N\}, 1 will be in the span.
  • K will be in the span as long as b_{i+1}+1 \lt K.
    This is because, for such K, we know that inserting K into \text{span}\{K+1, \ldots, N\} doesn’t change its rank; meaning K already exists in the span.
  • Similarly, K-1 will be in the span as long as b_{i+1}+2 \lt K.

So, other than potentially a couple of elements at the start of the interval (and a few elements near N, for the p_K = 1 case), every point is good.

In particular, once these \mathcal{O}(1) edgecases are excluded, the set of good points forms a contiguous range.
Let this range be [L, R]. Then, we want to compute

\sum_{K=L}^R K\cdot 2^{N-K-r_K}

r_K is now a constant though, so this is simply an arithmetico-geometric sequence whose sum can be computed with a direct formula.


That leaves the potential edge cases: a few values near N, along with \mathcal{O}(1) values in the prefix of each interval, making for \mathcal{O}(\log N) edge cases in total.
Each of these can be processed quickly since all we need to do is compute the appropriate span (by building the basis) and check if p_K exists in it or not, so this is not a problem.

Adding up the answers to the edgecases, as well as the sum of answers of the intervals, gives us the final answer.


What’s the time complexity of our solution?
Let’s go over each step.

  • The “compute span of range” method has either \mathcal{O}(\log N) or \mathcal{O}(\log^2 N) basis insertions depending on what you do, for \mathcal{O}(\log^2 N) or \mathcal{O}(\log^3 N) overall.
  • Computing the rank breakpoints, i.e. b_i values, gives an extra \mathcal{O}(\log^2 N) to the multiplier.
    This is because there are \mathcal{O}(\log N) possible ranks, and each of them has an additional binary search.
    This can be optimized to \mathcal{O}(\log N) as well, utilizing monotonicity.
  • For computing the answer itself: each interval [L, R] is processed in \mathcal{O}(\log{MOD}) time since it’s a single formula involving division, and each edge case requires one call to “span of range”, with there being \mathcal{O}(\log N) edgecases.

So, the overall complexity is between \mathcal{O}(\log^3 N) and \mathcal{O}(\log^5 N) depending on which optimizations are chosen.
For our constraints, even the quintic-log complexity is perfectly fine and should pass without issue.

TIME COMPLEXITY:

\mathcal{O}(\log^3 N) \sim \mathcal{O}(\log^5 N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

const int mod = 998244353;

struct mint{
    int x;

    mint (){ x = 0;}
    mint (int32_t xx){ x = xx % mod; if (x < 0) x += mod;}
    mint (long long xx){ x = xx % mod; if (x < 0) x += mod;}

    int val(){
        return x;
    }
    mint &operator++(){
        x++;
        if (x == mod) x = 0;
        return *this;
    }
    mint &operator--(){
        if (x == 0) x = mod;
        x--;
        return *this;
    }
    mint operator++(int32_t){
        mint result = *this;
        ++*this;
        return result;
    }
    
    mint operator--(int32_t){
        mint result = *this;
        --*this;
        return result;
    }
    mint& operator+=(const mint &b){
        x += b.x;
        if (x >= mod) x -= mod;
        return *this;
    }
    mint& operator-=(const mint &b){
        x -= b.x;
        if (x < 0) x += mod;
        return *this;
    }
    mint& operator*=(const mint &b){
        long long z = x;
        z *= b.x;
        z %= mod;
        x = (int)z;
        return *this;
    }
    mint operator+() const {
        return *this;
    }
    mint operator-() const {
        return mint() - *this;
    }
    mint operator/=(const mint &b){
        return *this = *this * b.inv();
    }
    mint power(long long n) const {
        mint ok = *this, r = 1;
        while (n){
            if (n & 1){
                r *= ok;
            }
            ok *= ok;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        return power(mod - 2);
    }
    friend mint operator+(const mint& a, const mint& b){ return mint(a) += b;}
    friend mint operator-(const mint& a, const mint& b){ return mint(a) -= b;}
    friend mint operator*(const mint& a, const mint& b){ return mint(a) *= b;}
    friend mint operator/(const mint& a, const mint& b){ return mint(a) /= b;}
    friend bool operator==(const mint& a, const mint& b){ return a.x == b.x;}
    friend bool operator!=(const mint& a, const mint& b){ return a.x != b.x;}
    mint power(mint a, long long n){
        return a.power(n);
    }
    friend ostream &operator<<(ostream &os, const mint &m) {
        os << m.x;
        return os;
    }
    explicit operator bool() const {
        return x != 0;
    }
};

// Remember to check MOD

void Solve() 
{
    auto basis = [&](int l, int r){
        int ans = 0;
        vector <int> who(31, -1);
        
        auto add = [&](int x){
            for (int i = 30; i >= 0; i--) if (x >> i & 1){
                if (who[i] == -1){
                    who[i] = x;
                    ans++;
                    break;
                }
                x ^= who[i];
            }  
        };
        
        int prev_suf = 0;
        
        while (l <= r){
            add(l);
            
            int suf = 0;
            while (!(l >> suf & 1)) ++suf;
            
            for (int i = prev_suf; i < suf; i++){
                if (l + (1 << i) <= r){
                    add(l + (1 << i));
                }
            }
            prev_suf = suf;
            l += 1 << suf;
        }
        
        return ans;
    };
    
    int n; cin >> n;
    mint ans = 0;
    
    vector <int> a(31, -1);
    for (int i = 1; i < 31; i++){
        if (basis(1, n) < i){
            continue;
        }
        
        int lo = 1, hi = n;
        while (lo != hi){
            int mid = (lo + hi + 1) / 2;
            
            if (basis(mid, n) < i){
                hi = mid - 1;
            } else {
                lo = mid;
            }
        }
        
        a[i] = lo;
        // cout << i << " " << a[i] << "\n";
    }
    
    for (int i = 1; i < 31; i++) if (a[i] != -1){
        // size of basis = i for which places? 
        int r = a[i];
        int l;
        if (i == 30 || a[i + 1] == -1){
            l = 1;
        } else {
            l = a[i + 1] + 1;
        }
        
        auto sum = [&](int l, int r){
            mint ans = mint(2).power(r + 1) - mint(2).power(l);
            return ans;
        };
        
        // 2^(size - i) 
        // 2^(n + 1 - j - i) for j in range(L, R) 
        // n + 1 - R - i to n + 1 - L - i 
        mint got = sum(n + 1 - r - i, n + 1 - l - i); 
        ans += got;
    }
    
    auto get = [&](int n){
        if (n <= 0){
            return 0LL;
        }  
        if (n % 4 == 0){
            return n;
        }
        if (n % 4 == 3){
            return 0LL;
        }
        if (n % 4 == 1){
            return n ^ (n - 1);
        } else {
            return n ^ (n - 1) ^ (n - 2);
        }
    };
    
    if (get(n) == n + 1){
        ans++;
    }
    if (get(n - 1) == n){
        ans++;
    }
    
    // edges might not be correct   
    for (int i = max(n - 5, 1LL); i <= n; i++){
        // [i, n] is our xor basis 
        vector <int> basis;
        for (int j = i; j <= n; j++){
            int x = j;
            for (int y : basis){
                x = min(x, x ^ y);
            }
            if (x){
                basis.push_back(x);
            }
        }
        
        // mex of 0...i - 2 
        int me = get(i - 2) ^ (i - 1);
        
        for (int x : basis){
            me = min(me, me ^ x);
        }
        
        if (me){
            ans -= mint(2).power(n + 1 - i - (int)basis.size());
        }
    }
    
    ans -= 1;
    
    cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

template<typename T>
struct Basis {
    static const int SZ = 30;
    array<T, SZ> basis{};
    Basis() {}
    int rank() {
        int res = 0;
        for (auto x : basis)
            res += (x > 0);
        return res;
    }
    void insert(T x) {
        for (auto &y : basis) {
            if (y) x = min(x, x ^ y);
            else { y = x; x = 0; }
            if (!x) break;
        }
    }
    bool inSpan(T x) {
        for (const auto &y : basis) {
            x = min(x, x ^ y);
        }
        return x == 0;
    }
    void Merge(Basis &other) {
        for (const auto &y : other.basis) {
            insert(y);
        }
    }
};

/**
 * Integers modulo p, where p is a prime
 * Source: Aeren (modified from tourist?)
 *         Modmul for 64-bit mod from kactl:ModMulLL
 * Works with p < 7.2e18 with x87 80-bit long double, and p < 2^52 ~ 4.5e12 with 64-bit
 */
template<typename T>
struct Z_p{
	using Type = typename decay<decltype(T::value)>::type;
	static vector<Type> MOD_INV;
	constexpr Z_p(): value(){ }
	template<typename U> Z_p(const U &x){ value = normalize(x); }
	template<typename U> static Type normalize(const U &x){
		Type v;
		if(-mod() <= x && x < mod()) v = static_cast<Type>(x);
		else v = static_cast<Type>(x % mod());
		if(v < 0) v += mod();
		return v;
	}
	const Type& operator()() const{ return value; }
	template<typename U> explicit operator U() const{ return static_cast<U>(value); }
	constexpr static Type mod(){ return T::value; }
	Z_p &operator+=(const Z_p &otr){ if((value += otr.value) >= mod()) value -= mod(); return *this; }
	Z_p &operator-=(const Z_p &otr){ if((value -= otr.value) < 0) value += mod(); return *this; }
	template<typename U> Z_p &operator+=(const U &otr){ return *this += Z_p(otr); }
	template<typename U> Z_p &operator-=(const U &otr){ return *this -= Z_p(otr); }
	Z_p &operator++(){ return *this += 1; }
	Z_p &operator--(){ return *this -= 1; }
	Z_p operator++(int){ Z_p result(*this); *this += 1; return result; }
	Z_p operator--(int){ Z_p result(*this); *this -= 1; return result; }
	Z_p operator-() const{ return Z_p(-value); }
	template<typename U = T>
	typename enable_if<is_same<typename Z_p<U>::Type, int>::value, Z_p>::type &operator*=(const Z_p& rhs){
		#ifdef _WIN32
		uint64_t x = static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value);
		uint32_t xh = static_cast<uint32_t>(x >> 32), xl = static_cast<uint32_t>(x), d, m;
		asm(
			"divl %4; \n\t"
			: "=a" (d), "=d" (m)
			: "d" (xh), "a" (xl), "r" (mod())
		);
		value = m;
		#else
		value = normalize(static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value));
		#endif
		return *this;
	}
	template<typename U = T>
	typename enable_if<is_same<typename Z_p<U>::Type, int64_t>::value, Z_p>::type &operator*=(const Z_p &rhs){
		uint64_t ret = static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value) - static_cast<uint64_t>(mod()) * static_cast<uint64_t>(1.L / static_cast<uint64_t>(mod()) * static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value));
		value = normalize(static_cast<int64_t>(ret + static_cast<uint64_t>(mod()) * (ret < 0) - static_cast<uint64_t>(mod()) * (ret >= static_cast<uint64_t>(mod()))));
		return *this;
	}
	template<typename U = T>
	typename enable_if<!is_integral<typename Z_p<U>::Type>::value, Z_p>::type &operator*=(const Z_p &rhs){
		value = normalize(value * rhs.value);
		return *this;
	}
	template<typename U>
	Z_p &operator^=(U e){
		if(e < 0) *this = 1 / *this, e = -e;
		Z_p res = 1;
		for(; e; *this *= *this, e >>= 1) if(e & 1) res *= *this;
		return *this = res;
	}
	template<typename U>
	Z_p operator^(U e) const{
		return Z_p(*this) ^= e;
	}
	Z_p &operator/=(const Z_p &otr){
		Type a = otr.value, m = mod(), u = 0, v = 1;
		if(a < (int)MOD_INV.size()) return *this *= MOD_INV[a];
		while(a){
			Type t = m / a;
			m -= t * a; swap(a, m);
			u -= t * v; swap(u, v);
		}
		assert(m == 1);
		return *this *= u;
	}
	template<typename U> friend const Z_p<U> &abs(const Z_p<U> &v){ return v; }
	Type value;
};
template<typename T> bool operator==(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value == rhs.value; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(const Z_p<T>& lhs, U rhs){ return lhs == Z_p<T>(rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) == rhs; }
template<typename T> bool operator!=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(const Z_p<T> &lhs, U rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(U lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T> bool operator<(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value < rhs.value; }
template<typename T> bool operator>(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value > rhs.value; }
template<typename T> bool operator<=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value <= rhs.value; }
template<typename T> bool operator>=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value >= rhs.value; }
template<typename T> Z_p<T> operator+(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(const Z_p<T> &lhs, U rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T> Z_p<T> operator-(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T> Z_p<T> operator*(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T> Z_p<T> operator/(const Z_p<T> &lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(const Z_p<T>& lhs, U rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(U lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T> istream &operator>>(istream &in, Z_p<T> &number){
	typename common_type<typename Z_p<T>::Type, int64_t>::type x;
	in >> x;
	number.value = Z_p<T>::normalize(x);
	return in;
}
template<typename T> ostream &operator<<(ostream &out, const Z_p<T> &number){ return out << number(); }

/*
using ModType = int;
struct VarMod{ static ModType value; };
ModType VarMod::value;
ModType &mod = VarMod::value;
using Zp = Z_p<VarMod>;
*/

// constexpr int mod = 1e9 + 7; // 1000000007
constexpr int mod = (119 << 23) + 1; // 998244353
// constexpr int mod = 1e9 + 9; // 1000000009
using Zp = Z_p<integral_constant<decay<decltype(mod)>::type, mod>>;

template<typename T> vector<typename Z_p<T>::Type> Z_p<T>::MOD_INV;
template<typename T = integral_constant<decay<decltype(mod)>::type, mod>>
void precalc_inverse(int SZ){
	auto &inv = Z_p<T>::MOD_INV;
	if(inv.empty()) inv.assign(2, 1);
	for(; inv.size() <= SZ; ) inv.push_back((mod - 1LL * mod / (int)inv.size() * inv[mod % (int)inv.size()]) % mod);
}

template<typename T>
vector<T> precalc_power(T base, int SZ){
	vector<T> res(SZ + 1, 1);
	for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * base;
	return res;
}

template<typename T>
vector<T> precalc_factorial(int SZ){
	vector<T> res(SZ + 1, 1); res[0] = 1;
	for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * i;
	return res;
}

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    auto pre = [&] (int n) {
        if (n%4 == 0) return n;
        if (n%4 == 1) return 1;
        if (n%4 == 2) return n+1;
        return 0;
    };

    auto calc_span = [] (int L, int R) {
        // span of (L...R)
        Basis<int> b;
        L = max(L, 1);
        int mxp = -1;
        while (L <= R) {
            int k = __builtin_ctz(L);
            b.insert(L);
            mxp = max(mxp, __lg(min(R-L, (1 << k)-1)));
            L += 1 << k;
        }
        for (int i = 0; i <= mxp; ++i)
            b.insert(1 << i);
        return b;
    };

    auto agp_sum = [] (Zp a, int n, Zp d, Zp r) {
        Zp res = a - (a + d*n)*(r ^ (n+1));
        res += d*r*(1 - (r^n)) / (1 - r);
        res /= 1 - r;
        return res;
    };
    
    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        Zp ans = -1; // subtract empty set
        
        const int tail = 5;
        {
            Basis<int> cur;
            int ct = 0;
            for (int i = 0; i < tail; ++i) {
                if (n - i < 0) break;
                if (i >= 2) cur.insert(n-i+2), ++ct;

                if (cur.inSpan(pre(n-i))) {
                    ans += (Zp(2) ^ (ct - cur.rank())) * (n-i+1);
                }
            }
        }
        if (pre(n) == n+1) ++ans;
        if (n < tail) {
            cout << ans << '\n';
            continue;
        }

        // Compute rank breakpoints
        vector<int> right_rank(33, n+1);
        right_rank[1] = n;
        for (int r = 2; r <= 30; ++r) {
            // rightmost i such that rank(i...n) >= r

            if (calc_span(0, n).rank() < r) break;
            int lo = 1, hi = right_rank[r-1] - 1;
            while (lo < hi) {
                int mid = (lo + hi + 1)/2;
                if (calc_span(mid, n).rank() >= r) lo = mid;
                else hi = mid - 1;
            }
            right_rank[r] = lo;
        }

        for (int rank = 1; rank <= 30; ++rank) {
            if (right_rank[rank] == n+1) break;
            
            int L = right_rank[rank+1], R = min(n - tail + 1, right_rank[rank] - 1);
            if (L > n) L = 1;
            if (L > R) continue;

            auto go = [&] (int x) {
                auto b = calc_span(x+1, n);
                if (b.inSpan(pre(x-1))) {
                    ans += (Zp(2)^(n - x - rank)) * x;
                }
            };
            
            // [L...R]
            for (int x : {L, L+1}) {
                if (x <= R) go(x);
            }
            L += 2;
            if (L <= R) ans += agp_sum(L, R-L, 1, Zp(1) / 2) * (Zp(2) ^ (n - rank - L));
        }

        
        cout << ans << '\n';
    }
}

The number of A’s such that MEX(A) = XOR(A) = x can be expressed as 0 or 2^{N - x - b}, where b is the number of bases for (x + 1, x + 2, … , N).

Using this, the problem can be solved in O(\log(n)^{2}), and by performing more careful precomputation, it can be solved in O(\log(n)).

1 Like