SORTXOR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: notsoloud
Testers: iceknight1093, rivalq
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Cycle decomposition of permutations

PROBLEM:

You have an array A of length N.
In one move, you can pick an index i and a non-empty subsequence S of A, and set A_i to the bitwise XOR of the elements in A.

Find a sequence of at most \left\lceil \frac{3N}{2}\right\rceil moves that sorts A.

EXPLANATION:

First, let’s try to solve a more restricted version of this task: assume that all N elements are distinct (i.e, A is a permutation).

A standard technique when dealing with permutations is to look at their cycle decompositions.
If you’re unfamiliar with cycle decompositions of permutations, I recommend reading through the blog linked in the prerequisites.

Let’s look at the cycle decomposition of A; in particular, let’s look at a specific cycle.

Consider a cycle C = [c_1, c_2, \ldots, c_k, c_1].
This means we need to place value A_{c_i} at position c_{i+1}.

Let’s attempt to sort this.
If you try working out a strategy for small cycles on paper (try sorting the arrays [2, 1], [3, 1, 2], [4, 1, 2, 3], for instance), you may find the following:

  • Set A_{c_1} to the XOR of the whole cycle
  • Set A_{c_2} to the XOR of the whole cycle
  • Set A_{c_3} to the XOR of the whole cycle
    \vdots
  • Set A_{c_k} to the XOR of the whole cycle
  • Set A_{c_1} to the XOR of the whole cycle
Why does this work?

Let’s look at how the values change after each move.
Note that our values start out at [A_{c_1}, A_{c_2}, \ldots, A_{c_k}] and we want to reach [A_{c_k}, A_{c_1}, \ldots, A_{c_{k-2}}, A_{c_{k-1}}].

Let X = A_{c_1} \oplus A_{c_2} \oplus \ldots \oplus A_{c_k}.
Then,

  • After the first move, the array is [X, A_{c_2}, A_{c_3}, \ldots, A_{c_k}]
  • After the second move, the array is [X, A_{c_1}, A_{c_3}, \ldots, A_{c_k}]
    • This is because X\oplus(A_{c_2}\oplus\ldots \oplus A_{c_k}) = (A_{c_1}\oplus A_{c_2}\oplus\ldots\oplus A_{c_k}) \oplus (A_{c_2}\oplus\ldots\oplus A_{c_k}) = A_{c_1}.
  • After the third move, the array is [X, A_{c_1}, A_{c_2}, A_{c_4}, \ldots, A_{c_k}]
    \vdots
  • The array becomes [X, A_{c_1}, A_{c_2}, \ldots, A_{c_{k-1}}]
  • After the final move, the array becomes [A_{c_k}, A_{c_1}, A_{c_2}, \ldots, A_{c_{k-1}}] which is exactly what we want.

This takes exactly k+1 moves, where k is the length of the cycle.
Note that if the length of the cycle is 1, it corresponds to an element that’s already in its place; and hence we don’t require any operations for it.

Now, for each cycle of length k we require k+1 operations.
Summing this across several cycles, it can be seen that the total sum is simply the sum of the lengths of the cycles, plus the number of cycles.

The cycles decompose the permutation, so the sum of their lengths is at most N.
Further, there are at most \left\lfloor \frac{N}{2} \right\rfloor cycles of length \geq 2, again because the cycles decompose the permutation.

This bounds the number of operations we make by N + \left\lfloor \frac{N}{2} \right\rfloor, which is within the limit.

In case A is not a permutation, simply convert it into one!
For each i, let B_i be the pair (A_i, i).
B then consists of distinct elements, and sorting B lexicographically is the same as sorting A so just do that using the above algorithm.

TIME COMPLEXITY

\mathcal{O}(N^2) per test case.

CODE:

Setter's code (C++)
#include <iostream> 
#include <string> 
#include <set> 
#include <map> 
#include <stack> 
#include <queue> 
#include <vector> 
#include <utility> 
#include <iomanip> 
#include <sstream> 
#include <bitset> 
#include <cstdlib> 
#include <iterator> 
#include <algorithm> 
#include <cstdio> 
#include <cctype> 
#include <cmath> 
#include <math.h> 
#include <ctime> 
#include <cstring> 
#include <unordered_set> 
#include <unordered_map> 
#include <cassert>
#define int long long int
#define pb push_back
#define mp make_pair
#define mod 1000000007
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;

const int N=500023;
bool vis[N];
vector <int> adj[N];
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }

            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}

int sumN = 0;

void solve()
{
    int n = readInt(1, 1000, '\n');
    sumN += n;
    vector<int> a(n);
    set<int> s;
    for(int i = 0; i < n-1; i++) {
        a[i] = readInt(1, n, ' ');
        s.insert(a[i]);
    }
    a[n-1] = readInt(1, n, '\n');
    s.insert(a[n-1]);
    //assert(s.size() == n);
    cerr << "Input read successfully" <<endl;
    vector<int> b = a;
    sort(all(b));

    map<int, vector<int>> indices;
    unordered_map<int, bool> vis;
    for(int i = 0; i < n; i++) {
        indices[b[i]].push_back(i);
    }

    vector<vector<int> > cycles;
    for(int i = 0; i<n; i++){
        if(indices[a[i]].empty()) continue;
        vector<int> cycle;
        int j = i;
        while(!vis[j]){
            //cerr << j << " ";
            vis[j] = true;
            cycle.pb(j);
            int temp = j;
            j = indices[a[j]].back();
            indices[a[temp]].pop_back();
        }
        //cerr << endl;
        if(cycle.size() > 1)
            cycles.pb(cycle);
    }

    int ans = 0;
    vector<pair<int, int>> index;
    vector<vector<int>> operations;
    for(auto cycle : cycles) {
        ans += cycle.size()+1;
        for(int i = 0; i < cycle.size(); i++) {
            index.pb({cycle[i], cycle.size()}); 
            operations.push_back(cycle);
            //cerr << cycle[i] << " " << cycle.size() << endl;
        }
        index.pb({cycle[0], cycle.size()}); 
        operations.push_back(cycle);
    }

    //cerr << ans << endl;

    for(int i = 0; i<ans; i++){
        int toUpdate = index[i].first;
        int updatedVal = 0;
        for(int j = 0; j < operations[i].size(); j++) {
            updatedVal ^= a[operations[i][j]];
        }
        a[toUpdate] = updatedVal;
    }

    for(int i = 0; i < n; i++) {
        //cerr << a[i] << " ";
        assert(b[i] == a[i]);
    }
    //cerr << endl;

    cout<<ans<<'\n';
    for(int i = 0; i < index.size(); i++) {
        cout<<index[i].first+1<<" "<<index[i].second<<'\n';
        for(int j = 0; j < operations[i].size(); j++) {
            cout<<operations[i][j]+1<<" ";
        }
        cout<<'\n';
    }
    cerr << "Operations printed" <<endl;
}

int32_t main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL),cout.tie(NULL);
    int T=readInt(1,2000,'\n');
    cerr << "#Testcases read successfully" <<endl;
    while(T--){
        solve();
        //cout<<'\n';
    }
    cerr << sumN << '\n';
    assert(sumN <= 6000);
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}

Tester's code (C++)
// Jai Shree Ram  
  
#include<bits/stdc++.h>
using namespace std;

#define rep(i,a,n)     for(int i=a;i<n;i++)
#define ll             long long
#define int            long long
#define pb             push_back
#define all(v)         v.begin(),v.end()
#define endl           "\n"
#define x              first
#define y              second
#define gcd(a,b)       __gcd(a,b)
#define mem1(a)        memset(a,-1,sizeof(a))
#define mem0(a)        memset(a,0,sizeof(a))
#define sz(a)          (int)a.size()
#define pii            pair<int,int>
#define hell           1000000007
#define elasped_time   1.0 * clock() / CLOCKS_PER_SEC



template<typename T1,typename T2>istream& operator>>(istream& in,pair<T1,T2> &a){in>>a.x>>a.y;return in;}
template<typename T1,typename T2>ostream& operator<<(ostream& out,pair<T1,T2> a){out<<a.x<<" "<<a.y;return out;}
template<typename T,typename T1>T maxs(T &a,T1 b){if(b>a)a=b;return a;}
template<typename T,typename T1>T mins(T &a,T1 b){if(b<a)a=b;return a;}

// -------------------- Input Checker Start --------------------
 
long long readInt(long long l, long long r, char endd)
{
    long long x = 0;
    int cnt = 0, fi = -1;
    bool is_neg = false;
    while(true)
    {
        char g = getchar();
        if(g == '-')
        {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if('0' <= g && g <= '9')
        {
            x *= 10;
            x += g - '0';
            if(cnt == 0)
                fi = g - '0';
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);
            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if(g == endd)
        {
            if(is_neg)
                x = -x;
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(false);
            }
            return x;
        }
        else
        {
            assert(false);
        }
    }
}
 
string readString(int l, int r, char endd)
{
    string ret = "";
    int cnt = 0;
    while(true)
    {
        char g = getchar();
        assert(g != -1);
        if(g == endd)
            break;
        cnt++;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}
 
long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
void readEOF() { assert(getchar() == EOF); }
 
vector<int> readVectorInt(int n, long long l, long long r)
{
    vector<int> a(n);
    for(int i = 0; i < n - 1; i++)
        a[i] = readIntSp(l, r);
    a[n - 1] = readIntLn(l, r);
    return a;
}
 
// -------------------- Input Checker End --------------------

int solve(){
 		int n = readIntLn(2, 3000);
 		static int sum_n = 0;
 		sum_n += n;
 		assert(sum_n <= 6000);
 		vector<int> a = readVectorInt(n, 1, n);


 		auto b = a;

 		sort(all(b));

 		map<int, vector<int>> mp;

 		for(int i = 0; i < n; i++){
 			mp[b[i]].push_back(i);
 		} 

 		vector<pair<int, vector<int>>> op;

 		vector<int> vis(n + 1);

 		for(int i = 0; i < n; i++){
 			if(mp[a[i]].empty()) continue;
 			vector<int> cyc;
 			int j = i;
 			while(!vis[j]){
 				vis[j] = 1;
 				cyc.push_back(j);
 				int tmp = mp[a[j]].back();
 				mp[a[j]].pop_back();
 				j = tmp;
 			}
 			if(cyc.size() <= 1) continue; 
 			for(auto j: cyc){
 				op.push_back({j, cyc});
 			}
 			op.push_back({cyc[0], cyc});

 		}
 		cout << op.size() << endl;
 		assert(op.size() * 2 <= 3*n);
 		for(auto [i,j] :op){
 			cout << i + 1 << " " << j.size() << endl;
 			for(auto k: j) cout << k + 1 << " ";
 			cout << endl; 
 		}


 return 0;
}
signed main(){
    ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    //freopen("input.txt", "r", stdin);
    //freopen("output.txt", "w", stdout);
    #ifdef SIEVE
    sieve();
    #endif
    #ifdef NCR
    init();
    #endif
    int t = readIntLn(1, 3000);
    while(t--){
        solve();
    }
    return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
	n = int(input())
	a = []
	for i, x in enumerate(list(map(int, input().split()))):
		a.append((x, i))
	b = sorted(a)
	ops = []
	mark = [0]*n
	for i in range(n):
		if mark[i] == 1: continue
		cycle = []
		x = i
		while not mark[x]:
			cycle.append(x+1)
			mark[x] = 1
			x = b[x][1]
		if len(cycle) == 1: continue
		cycle = cycle[::-1]
		for x in cycle: ops.append((x, cycle))
		ops.append((cycle[0], cycle))
	print(len(ops))
	for x, ind in ops:
		print(x, len(ind))
		print(*ind)
1 Like

Can this question be solved using DSU ? Anyone with dsu solution is welcome !

For input
6
6 5 1 2 3 4
8
2 1 4 3 2 2 2 2
What will be the cycle ? How will you do xor ?

Minimum swaps to sort array : Minimum swaps to sort array - gfg, Minimum swaps to sort array - leetcode
Nice question : Lucky Permutation
XOR trick to swap in cycle : Swap three numbers using xor

@iceknight1093
Can you please explain these lines in more detail in context to setter’s solution ?
I tried a lot but still have lots of doubts.

What exactly is your confusion?

The fact that all the (A_i, i) pairs are distinct should be obvious, since their second parts are all distinct.
This means you can sort them lexicographically and then treat the smallest element as 1, second smallest as 2, and so on. Obviously, this gives you a permutation and now you apply the solution for permutations.