XORSRT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iceknight1093
Tester: raysh07
Editorialist: iceknight1093

DIFFICULTY:

Easy

PREREQUISITES:

None

PROBLEM:

For a permutation P, define f(P) to be the minimum non-negative integer K that satisfies the following property:

  • It’s possible to sort P in ascending order by repeatedly swapping adjacent elements with a bitwise XOR that’s \le K.

You’re given N and K.
Find any permutation of [1, N] for which f(P) = K.

EXPLANATION:

First, let’s figure out how to compute f(P) for a fixed permutation P.
More specifically, let’s fix a value of K and see whether it works or not.

Consider some pair of indices i and j with i \lt j.

  • If P_i \lt P_j, we never need to swap these values to reach ascending order.
    So, their bitwise XOR does not matter.
  • If P_i \gt P_j, we always need to eventually swap these two values, since we can only swap adjacent values.
    So, we surely must have (P_i \oplus P_j) \le K.

So, K must be at least as large as (P_i \oplus P_j) for each inversion (i, j) in P - and this is the only condition that matters.
This tells us that the minimum K that works is the maximum of (P_i \oplus P_j) across all inversions (i, j).


Now that we understand f(P), we need to think about constructing a permutation that achieves f(P) = K.
Such a permutation P must satisfy the following conditions:

  • For each (i, j) such that i \lt j and P_i \gt P_j, (P_i \oplus P_j) \le K must hold.
    That is, no inversion must have a XOR exceeding K.
  • There must exist at least one pair (i, j) such that i \lt j and P_i \gt P_j, with (P_i \oplus P_j) = K.
    That is, at least one inversion must have a XOR of K.

Note that the second condition is necessary; because if the first is satisfied but not the second, then we’ll have f(P) \lt K instead.

Now, we need some pair (X, Y) such that X\oplus Y = K and 1 \le X, Y \le N.

Suppose we find one such pair - this is easy to do by just iterating X from 1 to N, computing Y = X\oplus K, and checking if Y lies in [1, N].
Without loss of generality, assume X \lt Y.

Let’s look at the range of integers [X+1, Y-1].
It can be proved that for each of these integers Z, either (Z\oplus X) \le K or (Z\oplus Y) \le K must be true.
In fact, there’s an even stronger condition: there exists a “breakpoint” B in the range [X+1, Y-1] such that:

  • For all X+1 \le Z \lt B, (Z\oplus X) \le K will hold.
  • For all B \le Z \le Y-1, (Z\oplus Y) \le K will hold.
Proof

Let 2^b be the highest set bit in K.
Observe that since X\oplus Y = K and X \lt Y, X must not have 2^b set, while Y must have it set.
Further, all bits \gt b must be either unset in both X and Y, or set in both of them.
This gives an even stronger condition: all elements in [X, Y] have the same set bits among those \gt b, i.e. they have the same prefix till b+1.

Choose B to be the nearest element \gt X that has bit 2^b set.
Note that we surely have B \le Y.

Now,

  • For each Z \in [X+1, B), Z\oplus X \lt K.
    To see why, let’s analyze each bit.
    • Bits \gt b: Z and X have the same prefix till b+1, so all such bits are equal in Z and X, hence will cancel out in the XOR.
    • Bit b: X and Z both have it unset.
    • Bits \lt b: doesn’t really matter what happens here, since at worst the XOR value obtained can be \lt 2^b, whereas K \ge 2^b.
  • For each Z \in [B, Y-1], Z\oplus Y \lt K.
    The reasoning is similar: analyze bit by bit, just that now the bit b will be set in both Z and Y which is why it cancels out.

This proves our claim.

The above proof also tells exactly how to find one such breakpoint: simply choose B to be the nearest element \gt X that has bit 2^b set, where 2^b is the highest set bit in K.

Once B is known, we have the following construction:

  • Place all the elements from X+1 to B-1 in ascending order.
  • Then, place Y.
  • Then, place X.
  • Finally, place all the elements from B to Y-1 in ascending order.

That is, the array will look like

[X+1, X+2, \ldots, B-1, Y, X, B, \ldots, Y-1]

It’s easy to see that this can be sorted using a maximum bitwise XOR of exactly K: swap X with Y which needs K; then repeatedly swap X left till it reaches the beginning (each swap needs \lt K) and repeatedly swap Y right till it reaches the end (each swap needs \lt K again).

That leaves the ranges [1, X-1] and [Y+1, N] to be taken care of.
It’s easy to see that they can be just arranged in ascending order at the beginning/end, respectively.
That is, once X, Y, and B are known, our construction is simply

[1, 2, \ldots, X-1, X+1, X+2, \ldots, B-1, Y, X, B, \ldots, Y-1, Y+1, Y+2, \ldots, N]

As noted earlier, finding a pair (X, Y) can be done in linear time by just trying all X, and then B finding B is easy once X is known.
If we find any such pair (X, Y), we have a valid construction.
If no valid pair (X, Y) exists, then clearly no solution can exist since it’s impossible to even achieve a XOR of exactly K using elements in [1, N], so we output -1 for that case.

There is one edge case to be careful of here: if K = 0 you’ll have X = Y which might mess up your construction, depending on implementation.
However, for K = 0 there’s a trivial solution of just taking the array [1, 2, \ldots, N], so that can be handled separately.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

void Solve() 
{
    int n, k; cin >> n >> k;
    
    if (k == 0){
        for (int i = 1; i <= n; i++){
            cout << i << " ";
        }
        cout << "\n";
        return;
    }
    
    for (int x = 1; x <= n; x++){
        int y = x ^ k;
        if (1 <= y && y <= n){
            vector <int> l, r;
            for (int z = x + 1; z < y; z++){
                if ((z ^ x) <= k){
                    l.push_back(z);
                } else if ((z ^ y) <= k){
                    r.push_back(z);
                } else {
                    assert(false);
                }
            }
            
            vector <int> ans;
            for (int i = 1; i < x; i++) ans.push_back(i);
            for (int z : l) ans.push_back(z);
            ans.push_back(y);
            ans.push_back(x);
            for (int z : r) ans.push_back(z);
            for (int i = y + 1; i <= n; i++) ans.push_back(i);
            
            for (int i = 1; i <= n; i++){
                cout << ans[i - 1] << " ";
            }
            cout << "\n";
            return;
        }
    }
    
    cout << -1 << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n, k; cin >> n >> k;

        vector ans(n+1, 0);
        for (int i = 1; i <= n; ++i) {
            ans[i] = i;
        }
        
        if (k & (k - 1)) { // Not power of 2
            int p = 1;
            while (2*p < k) p *= 2;

            if (p > n or (p^k) > n) ans = {0, -1};
            else {
                for (int i = (p^k); i < p; ++i) {
                    ans[i] = i+1;
                }
                ans[p] = p^k;
            }
        }
        else if (k > 0) { // Power of 2
            if (k == 1) {
                if (n <= 2) ans = {0, -1};
                else {
                    swap(ans[2], ans[3]);
                }
            }
            else if (k+1 > n) ans = {0, -1};
            else {
                for (int i = 1; i < k; ++i) ans[i] = ans[i+1];
                ans[k-1] = k+1;
                ans[k] = 1;
                ans[k+1] = k;
            }
        }

        for (int i = 1; i < size(ans); ++i) {
            cout << ans[i] << ' ';
        }
        cout << '\n';
    }
}
4 Likes

wow!
Such a beautiful explanation!

3 Likes