XSORT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Authors: dannyboy1204
Testers: tabr, yash_daga
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

None

PROBLEM:

You have an array A of length N.
In one move, you can pick i and j such that A_i \neq A_j and set them both to A_i \oplus A_j.

Find a sequence of at most 20\cdot N moves that results in a sorted array.

EXPLANATION:

The XOR operation is a bit of a red herring: there exists a solution that is completely independent of the replacement function.

A simpler case

First, let’s deal with a simpler version of the problem, where N is always a power of two; say N = 2^K.

In this case, we can not only sort the array, we can make all the array elements equal using K\cdot 2^{K-1} operations:

  • First, recursively make the first and second halves of the array equal, since they’re also power-of-two sized.
  • Then, perform the operations (1, N/2+1), (2, N/2+2), (3, N/2+3), \ldots, (N/2, N).
    Since the first half is equal and the second half is equal, the entire array is now equal.

It’s easy to see that this uses K\cdot 2^{K-1} moves: there are K levels of the recursion, and each one uses N/2 = 2^{K-1} operations in total.

Solving for any N

When N isn’t a power of 2, we can in fact use the above method as a subroutine.

First, perform N/2 operations to make A a palindrome: (1, N), (2, N-1), (3, N-2), \ldots, (N/2, N/2+1)

Now, let x be the largest power of 2 that’s \leq N.

Let’s first make the first x elements equal using the earlier method, and then make the last x elements equal.
Note that this gives us either a sorted array or a reverse-sorted array, with at most two distinct elements.

if it’s sorted, we’re done.
If it’s reverse-sorted, note that our method for powers of 2 is symmetric with respect to reversing the array.
So, we can instead make the last x elements equal, then make the first x elements equal: this is guaranteed to give us a sorted array.

This way, we use the K\cdot 2^{K-1} method twice, with a further N/2 operations.
N \leq 10^5 gives us K \leq 16, so the number of operations we use is bounded by 2\cdot 16\cdot N/2 + N/2 \leq 17N, which is good enough.

TIME COMPLEXITY

\mathcal{O}(N\log N) per test case.

CODE:

Setter's code (C++)
#include <bits/stdc++.h>
#define ll long long
#define int long long
#define fi first
#define se second
#define mat vector<vector<ll>> 
using namespace std;
void db() {cout << '\n';}
template <typename T, typename ...U> void db(T a, U ...b) {cout << a << ' ', db(b...);}
#ifdef Cloud
#define file freopen("input.txt", "r", stdin), freopen("output.txt", "w", stdout)
#else
#define file ios::sync_with_stdio(false); cin.tie(0)
#endif
auto SEED = chrono::steady_clock::now().time_since_epoch().count();
mt19937 rng(SEED);
const int N = 1e5 + 1, mod = 998244353, inf = 1ll << 60;
int a[N];
vector<pair<int, int>> v;
void work(int l, int r){
    if (l == r) return;
    int mid = l + r >> 1;
    work(l, mid), work(mid + 1, r);
    for (int i = 0; i < r - mid; i++) {
        if (a[l + i] == a[mid + 1 + i]) continue;
        a[l + i] = a[mid + 1 + i] = a[l + i] ^ a[mid + 1 + i];
        v.push_back({l + i, mid + 1 + i});
    } 
}
int f(int n){
    return __lg(n) * n / 2;
}
void solve(){
    int n;
    cin >> n;
    vector<pair<int, int>> ans;
    for (int i = 0; i < n; i++) cin >> a[i];
    int k = 1;
    while (k * 2 < n) k *= 2;
    int mid = k * 2 - n;
    for (int i = 0; i < n - 1 - i; i++){
        if (a[i] == a[n - 1 - i]) continue;
        ans.push_back({i + 1, n - i});
        a[i] = a[n - 1 - i] = a[i] ^ a[n - 1 - i];
    }
    work(0, k - 1);
    work(n - k, n - 1);
    for (auto i : v){
        if (a[0] > a[n - 1]){
            ans.push_back({n - i.fi, n - i.se});
        }
        else{
        ans.push_back({i.fi + 1, i.se + 1});
        }
    }
    v.clear();
    cout << ans.size() << '\n';
    for (auto i : ans) cout << i.fi << ' ' << i.se << '\n';
}
signed main(){
    file;
    int t;
    cin >> t;
    while (t--) solve();
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        // cerr << res << endl;
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    input_checker in;
    int tt = in.readInt(1, 1e4);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 1e5);
        in.readEoln();
        vector<int> a(n);
        for (int i = 0; i < n; i++) {
            a[i] = in.readInt(0, 1e9);
            (i == n - 1 ? in.readEoln() : in.readSpace());
        }
        int k = 1;
        while (2 * k <= n) {
            k *= 2;
        }
        vector<pair<int, int>> ans;
        auto F = [&](int x, int y) {
            if (a[x] == a[y]) {
                return;
            }
            ans.emplace_back(x, y);
            a[x] = a[y] = a[x] ^ a[y];
        };
        function<void(int, int)> D = [&](int l, int r) {
            if (l + 1 >= r) {
                return;
            }
            int m = (l + r) >> 1;
            for (int i = l; i < m; i++) {
                F(i, i + m - l);
            }
            D(l, m);
            D(m, r);
        };
        for (int i = 0; i < n - 1 - i; i++) {
            F(i, n - 1 - i);
        }
        auto ansb = ans;
        D(0, k);
        D(n - k, n);
        if (!is_sorted(a.begin(), a.end())) {
            assert(is_sorted(a.rbegin(), a.rend()));
            for (int i = (int) ansb.size(); i < (int) ans.size(); i++) {
                ans[i].first = n - 1 - ans[i].first;
                ans[i].second = n - 1 - ans[i].second;
            }
        }
        cout << ans.size() << '\n';
        for (auto [x, y] : ans) {
            cout << x + 1 << " " << y + 1 << '\n';
        }
    }
    assert(sn <= 1e5);
    in.readEof();
    return 0;
}