XORPARTSORT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: sushil2006
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Binary lifting, familiarity with bitwise XOR

PROBLEM:

For an array B of length M, define f(B) to be the minimum possible integer k such that there exists a partition of B into k non-empty subarrays, and an array X = [X_1, X_2, \ldots, X_k], satisfying:

  • XOR every element of the i-th subarray by X_i.
  • The resulting array (of length M) should be sorted.

You’re given an array A.
Answer Q queries on it: given L and R, find f([A_L, A_{L+1}, \ldots, A_R]).

EXPLANATION:

Let A[L, R] \oplus X denote the array [A_L\oplus X, A_{L+1}\oplus X, \ldots, A_R\oplus X].

Let’s ignore the queries to begin with, and just try to find f(A) for a given array A.

Since the X_i are allowed to be arbitrary integers, partitioning into subarrays and XOR-ing each subarray with something so that the final array is sorted, is equivalent to simply ensuring that each individual subarray can be sorted with a single XOR operation.

Why?

One way is obvious: if the entire array is sorted after the operation, surely each chosen subarray is also going to be sorted.

As for the other direction: suppose you partition A into k subarrays A[L_1, R_1], A[L_2, R_2], \ldots, A[L_k,R_k] and choose values X_1, X_2, \ldots, X_k such that each A[L_i, R_i] \oplus X_i is sorted.
Then,

  • Keep X_1 the same.
  • Choose some b_2 such that 2^{b_2} is much larger than any existing element, and convert X_2 \to X_2\oplus 2^{b_2}
  • Choose some b_3 such that 2^{b_3} is much larger than any existing element (including $2^{b_2}), and convert X_3 \to X_3\oplus 2^{b_3}
    \vdots

Essentially, we can ‘separate’ the subarrays from each other by orders of magnitude; so that elements of an earlier subarray are always much smaller than elements of a later one.
WIthin a subarray, elements are sorted (that’s how the X_i were chosen), so the overall array B will also be sorted.

So, our goal is to simply partition A into the minimum possible number of subarrays, each of which can be sorted with a single XOR operation.
Let’s call a subarray A[L, R] “XOR-sortable”, if there exists an integer X such that A[L, R]\oplus X is sorted.

Finding the minimum number of XOR-sortable subarrays can be done greedily: repeatedly choose the longest XOR-sortable prefix of the remaining array.

Proof

Let k be the largest index such that A[1,k]\oplus X is XOR-sortable.

Consider an optimal partition of A into XOR-sortable subarrays.
Suppose the chosen prefix is A[1, x].

  • If x = k, inductively solve for the remaining array.
  • Otherwise, x \lt k.
    Note that bringing index x+1 into the prefix doesn’t make the answer worse - the first two subarrays will remain XOR-sortable (if the second subarray is a singleton, it’ll disappear entirely; giving us a better answer than optimal which is a contradiction).
    Repeatedly do this till the first k elements form a subarray; and the answer isn’t worse. Now inductively solve for the remaining array.

The only detail here is that we need to be able to recognize when an array is XOR-sortable.

How do I do that?

There are several different ways; here’s one of them.

Suppose we have just two numbers a and b.
For what values of X is a\oplus X \leq b\oplus X?

Answer

If a = b, then any X will do.

Otherwise, consider the highest bit where a and b differ.
No matter what X is chosen, a\oplus X and b\oplus X will still differ at this bit.
So, we need to ensure that a\oplus X has this bit unset, and b\oplus X has this bit set.
That uniquely fixes the value of this bit for X.
All the other bits of X can be anything at all!

Applying this to our array, for any index i (1 \leq i \lt N), we can find a certain bit b_i that either must be set, or must be unset, for A_i \oplus X \leq A_{i+1} \oplus X to hold.

Then, the subarray A[i, j] is XOR-sortable if and only if there are no ‘conflicts’ within it.
That is, if a certain bit must be set to make one pair of adjacent indices sorted, and unset for another pair; no solution can exist.
Otherwise, a valid X always exists: certain bits are forced into certain values, and everything else can be freely chosen.

This gives us a way to check in \mathcal{O}(\log\max A) time whether a subarray is XOR-sortable: for each bit, find out whether both the set and unset constraints exist within it (which can be done in constant time for a fixed bit using prefix sums, for instance); and ensure that there are no conflicts.


We are now ready to answer queries.
For a subarray A[L, R], the greedy process still works: simply keep cutting off XOR-sortable subarrays as long as you can. We just need to be able to simulate this fast enough.

For each index i, let’s compute the value \text{right}_i, where A[i, \text{right}_i] is the longest possible XOR-sortable subarray starting at i.
This is fairly easy to do given our criterion for recognizing XOR-sortable arrays; for instance with binary search (or even two pointers, if you’d like to save a log factor).

Now, to answer a query (L, R), we’d like to do the following:

  • Let x = L initially.
  • While x \leq R, replace x with \text{right}_x + 1.
    • The answer is the number of times this replacement is done.

Such repeated jumping can be done quickly with the help of binary lifting
Specifically, create a directed graph on N vertices, with edges i \to (\text{right}_i + 1) for each i.
Then, precompute jumps of length 2^k from each vertex.
Finally, for a query (L, R), do the following:

  • Start at L.
  • Go over jumps in descending order of length; and if your jump takes you to an index \leq R, make it.

This way, multiple jumps can be compressed into a single one, and each query is answered in \mathcal{O}(\log N) time.

TIME COMPLEXITY:

\mathcal{O}(N\log\max A + Q\log N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;

template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;

#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl

#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)

template<typename T>
void amin(T &a, T b) {
    a = min(a,b);
}

template<typename T>
void amax(T &a, T b) {
    a = max(a,b);
}

#ifdef LOCAL
#include "debug.h"
#else
#define debug(x) 42
#endif

/*



*/

const int MOD = 1e9 + 7;
const int N = 2e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;
const int LOG = 18;

struct node{
    int f[2];
    pii prev;
    node(){
        memset(f,-1,sizeof f);
        prev = {-1,-1};
    }
};

vector<node> tr;
int cnt[20][2];
vector<pii> leave[N];

bool insert(int i, int x){
    // check if ok, then insert
    int u = 0;
    bool ok = true;

    rev(bit,19,0){
        int b = 0;
        if(x&(1<<bit)) b = 1;
        if(tr[u].f[b] == -1){
            tr[u].f[b] = sz(tr);
            tr.pb(node());
        }

        auto [j,p] = tr[u].prev;

        if(p != -1 and p != b){
            if(cnt[bit][p^1]){
                ok = false;
            }
        }

        u = tr[u].f[b];
    }

    if(!ok) return false;

    u = 0;

    rev(bit,19,0){
        int b = 0;
        if(x&(1<<bit)) b = 1;

        auto [j,p] = tr[u].prev;

        if(p != -1 and p != b){
            leave[j].pb({bit,p});
            cnt[bit][p]++;
        }

        tr[u].prev = {i,b};
        u = tr[u].f[b];
    }

    return true;
}

void erase(int i, int x){
    for(auto [b,p] : leave[i]){
        cnt[b][p]--;
    }

    int u = 0;

    rev(bit,19,0){
        int b = 0;
        if(x&(1<<bit)) b = 1;

        auto [j,p] = tr[u].prev;
        if(j == i){
            tr[u].prev = {-1,-1};
        }

        u = tr[u].f[b];
    }
}

void solve(int test_case)
{
    int n,q; cin >> n >> q;
    rep1(i,n){
        leave[i].clear();
    }
    tr.clear();
    tr.pb(node());

    vector<int> a(n+5);
    rep1(i,n) cin >> a[i];

    vector<int> nxt(n+5);
    int ptr = 1;

    rep1(i,n){
        while(ptr <= n and insert(ptr,a[ptr])){
            ptr++;
        }

        nxt[i] = ptr;
        erase(i,a[i]);
    }

    int up[n+5][LOG];
    rep(j,LOG) up[n+1][j] = n+1;
    rep1(i,n) up[i][0] = nxt[i];
    rep1(j,LOG-1){
        rep1(i,n){
            up[i][j] = up[up[i][j-1]][j-1];
        }
    }

    while(q--){
        int l,r; cin >> l >> r;
        int ans = 0;

        rev(j,LOG-1,0){
            if(up[l][j] <= r){
                l = up[l][j];
                ans += (1<<j);
            }
        }

        ans++;
        cout << ans << endl;
    }
}

int main()
{
    fastio;

    int t = 1;
    cin >> t;

    rep1(i, t) {
        solve(i);
    }

    return 0;
}
Tester's code (C++)
#ifndef LOCAL
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,avx2,sse,sse2,sse3,sse4,popcnt,fma")
#endif

#include <bits/stdc++.h>
using namespace std;


#ifdef LOCAL
#define dbg(x...) cerr << "[" << #x << "] = ["; _print(x)
#else
#define dbg(...)
#endif

void __print(int32_t x) {cerr << x;}
void __print(int64_t x) {cerr << x;}
void __print(unsigned x) {cerr << x;}
void __print(unsigned long x) {cerr << x;}
void __print(unsigned long long x) {cerr << x;}
void __print(float x) {cerr << x;}
void __print(double x) {cerr << x;}
void __print(long double x) {cerr << x;}
void __print(char x) {cerr << '\'' << x << '\'';}
void __print(const char *x) {cerr << '\"' << x << '\"';}
void __print(string x) {cerr << '\"' << x << '\"';}
void __print(bool x) {cerr << (x ? "true" : "false");}
template<typename T>void __print(complex<T> x) {cerr << '{'; __print(x.real()); cerr << ','; __print(x.imag()); cerr << '}';}

template<typename T>
void __print(const T &x);
template<typename T, typename V>
void __print(const pair<T, V> &x) {cerr << '{'; __print(x.first); cerr << ','; __print(x.second); cerr << '}';}
template<typename T>
void __print(const T &x) {int f = 0; cerr << '{'; for (auto it = x.begin() ; it != x.end() ; it++) cerr << (f++ ? "," : ""), __print(*it); cerr << "}";}
void _print() {cerr << "]\n";}
template <typename T, typename... V>
void _print(T t, V... v) {__print(t); if (sizeof...(v)) cerr << ", "; _print(v...);}


struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && !isspace(buffer[now])) {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int32_t main() {
    ios_base::sync_with_stdio(0);   cin.tie(0);

    input_checker input;
    int T = input.readInt(1, (int)1e6); input.readEoln();
    int NN = 0, MM = 0;
    while(T-- > 0) {
        int N = input.readInt(1, (int)2e5); input.readSpace(); NN += N;
        int NQ = input.readInt(1, (int)2e5);    input.readEoln(); MM += NQ;
        vector<int> A = input.readInts(N, 0, (1 << 20) - 1);    input.readEoln();

        vector<int> B(N + 1); B[N - 1] = N;
        vector<array<int, 2>> P(20, {N, N});

        auto get = [&](int l, int r) -> int {
            return (l * 2 + r) - 1;
        };

        for(int i = N - 2; i >= 0 ; --i) {
            int Bit = 19;
            while(Bit >= 0 && ((A[i] >> Bit) == A[i + 1] >> Bit))
                --Bit;
            if (Bit >= 0) {
                P[Bit][get(A[i] >> Bit & 1, A[i + 1] >> Bit & 1)] = i + 1;
            }
            B[i] = N;
            for(int bit = 19 ; bit >= 0 ; --bit)
                B[i] = min(B[i], max(P[bit][0], P[bit][1]));
        }

        vector C(N + 1, vector<int>(20));
        for(int i = 0 ; i < N ; ++i)    C[i][0] = B[i];
        C[N][0] = N;
        for(int bit = 1 ; bit < 20 ; ++bit) {
            for(int i = N ; i >= 0 ; --i) {
                C[i][bit] = C[C[i][bit - 1]][bit - 1];
            }
        }

        for(int _ = 0 ; _ < NQ ; ++_) {
            int L = input.readInt(1, N);    input.readSpace();
            int R = input.readInt(1, N);    input.readEoln();
            --L, --R;
            assert(L <= R);
            int ans = 0;
            for(int bit = 19 ; bit >= 0 ; --bit) {
                if (C[L][bit] > R)
                    continue;
                ans |= 1 << bit;
                L = C[L][bit];
            }
            cout << (ans + 1) << '\n';
        }
    }
    assert(NN <= (int)2e5 && MM <=(int)2e5);

    input.readEof();

    return 0;
}
Editorialist's code (C++)
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n, q; cin >> n >> q;
        vector<int> a(n);
        for (int &x : a) cin >> x;
        vector<array<int, 18>> jump(n+1);
        vector<basic_string<int>> pos1(20), pos2(20);
        for (int i = 0; i+1 < n; ++i) {
            if (a[i] == a[i+1]) continue;
            int x = a[i] ^ a[i+1];
            for (int j = 19; j >= 0; --j) {
                if (x & (1 << j)) {
                    if (a[i] < a[i+1]) pos1[j].push_back(i);
                    else pos2[j].push_back(i);
                    break;
                }
            }
        }
        
        for (int i = 0; i < 18; ++i) jump[n][i] = n;
        for (int i = n-1; i >= 0; --i) {
            auto check = [&] (int R) {
                // Can [i...R] be sorted?
                if (i == R) return true;
                for (int j = 0; j < 20; ++j) {
                    auto it1 = lower_bound(begin(pos1[j]), end(pos1[j]), R);
                    if (it1 == begin(pos1[j])) continue;
                    auto it2 = lower_bound(begin(pos2[j]), end(pos2[j]), R);
                    if (it2 == begin(pos2[j])) continue;
                    --it1, --it2;
                    if (*it1 >= i and *it2 >= i) return false;
                }
                return true;
            };

            int lo = i, hi = n-1;
            while (lo < hi) {
                int mid = (lo + hi + 1)/2;
                if (check(mid)) lo = mid;
                else hi = mid-1;
            }
            jump[i][0] = lo+1;
            for (int j = 1; j < 18; ++j) jump[i][j] = jump[jump[i][j-1]][j-1];
        }

        while (q--) {
            int L, R; cin >> L >> R;
            --L, --R;
            int ans = 1;
            for (int i = 17; i >= 0; --i) {
                if (jump[L][i] <= R) {
                    ans += 1 << i;
                    L = jump[L][i];
                }
            }
            cout << ans << '\n';
        }
    }
}
1 Like

In the problem XOR Partition Sort, can someone explain me how to find the next right index in a simpler way? I am having a hard time to understand the editorial.

For example, if we take our array as [1, 2, 3], then for index a[0] = 1, if in binary search, we check for [i, j] = [0, 2], then for bit 2^0 = 1, it looks we have one pair each of set and unset case, as numbers (2, 3) and (1, 2) respectively. So, by editorial there is a conflict, but we know that it is already sorted, right?

To determine whether a\oplus X \leq b\oplus X, only the highest bit where a and b differ matters (because it’ll still differ after xor-ing with X).

So in your example, for (1, 2) the highest differing bit is 2^1, while for (2, 3) it’s 2^0. There’s no conflict here because they’re different bits entirely.

This gives the following algorithm:

  • For each i\lt N, find the highest bit where A_i and A_{i+1} differ (this is what the b_i in the editorial is).
    You’ll also need to store whether this bit needs to be set or not.
  • To check if [i, j] is XOR-sortable, the only thing that matters is whether there’s a conflict in its range or not.
    • For each bit from 0 to 20, check if there exists an index in [i, j-1] such that it needs to be set; and an index where b needs to be unset.
      This can be done with binary search (if you store for each bit, a list of positions where it should be set/unset) or prefix sums (again, separate prefix sum array for each bit as to whether it should be set or not).
  • The check is either \mathcal{O}(20\log N) or \mathcal{O}(20) for a single subarray depending on the method you choose; now throw binary search on this to find \text{right}_i for each i.
    (Technically this can be as bad as \mathcal{O}(20N\log^2 N) overall if you use binary search too much, but all instances of it can easily be replaced with two-pointers to improve complexity.
    With two pointers you can even achieve \mathcal{O}(N) precalculation; it’s also possible to answer queries in constant time but I don’t really recommend it).

You most likely missed the fact that we’re looking at only the highest differing bit; since imo the rest of it is fairly straightforward once you have that.

1 Like

Thanks for clarifying, I got it!