SWAP_SHIFT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: blobo2_blobo2
Testers: iceknight1093, tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

None

PROBLEM:

You’re given a permutation A.
In one move, you can either swap any two elements of A, or left-rotate the array by one step.
However, you can make at most 2 swaps.

What’s the lexicographically smallest array you can attain?

EXPLANATION:

Since we want the lexicographically smallest array, ideally we’d want our array to look like [1, 2, 3, \ldots].
We have unlimited rotations, so we can always make A_1 = 1.
Let’s first rotate the array to make this the case.

Now there are two cases: we either swap the 1 somewhere else, or we don’t.
Let’s consider both cases separately and take the best answer among them.

Don't swap 1

As noted at the start, we want our array to look like [1, 2, 3, \ldots]; that is, A_i = i.

This case can be solved greedily.
For each i from 1 to N:

  • If A_i = i, nothing needs to be done.
  • Otherwise, we know A_i \gt i for sure.
    • Find the position of i in A, and perform one swap to move it to this position.
  • If you can’t perform any more swaps, stop immediately.

This is clearly the best we can do.
Each time we don’t swap, it takes \mathcal{O}(1) time.
Each time we do swap, it takes \mathcal{O}(N) time. However, this step is done at most twice, so the complexity is still \mathcal{O}(N).

In particular, notice that we can always make A_i = i for at least the first three elements of A.

Swap 1

Suppose we do swap 1 somewhere else.
We then need to decide two things: where to swap it to, and what the second swap should be.

In fact, if we fix the first swap, the second one can be made greedily just as in the previous case: rotate the array so that A_1 = 1, then find the first position such that A_i \neq i and swap i to this position.

However, this takes \mathcal{O}(N) time, so we cannot really try every first swap; so let’s be a bit more clever about them.

First, recall that the “don’t swap 1” case always gives us an array starting with [1, 2, 3] at the very least.

So, in this case it’s pointless to consider doing anything that doesn’t at least match that.

In particular, if we swap 1 to position i, then if A_{i+1} \geq 4 and A_{i+2} \geq 4 making this swap is pointless since we’ll be forced to use our second swap to achieve [1, 2]; but the third element will never be 3 so it’s strictly worse than the previous case.

This means there are at most 6 valid positions for the first swap: the two positions preceding 2, the two positions preceding 3, and potentially swapping 1 with 2 or 1 with 3.

Each of these 4 cases can then be solved in \mathcal{O}(N), giving us 6 candidate arrays for the answer.

This gives us 7 candidate arrays; the answer is then the minimum of them all.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Setter's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define all(v) v.begin(),v.end()
#define endl "\n"
#define gen(arr,n,nxt)  generate(arr,arr+n,nxt)
const int mod = 1e9+7;
int nxt(){int x;cin>>x;return x;}
int n;
void shift(vector<int>&arr,int x){
	int a[n];
	for(int i=0;i<n;i++){
		int idx = i - x;
		idx+=n;
		idx%=n;
		a[idx] = arr[i];
	}
	for(int i=0;i<n;i++)arr[i] = a[i];
}
void swapSmall(vector<int>&a){
    int x=-1;
	for(int i=0;i<n;i++){
		if(a[i]!=i+1){
			x = i+1;
			break;
		}
	}
    if(x == -1)return;
	for(int i=0;i<n;i++){
		if(a[i] == x){
			swap(a[i],a[x-1]);
			break;
		}
	}
}
signed main() {
    int qu=nxt();
    while(qu--){
        n=nxt();
        vector<int>arr(n),a(n),b(n),c(n),d(n);
        for(int i=0;i<n;i++)arr[i]=nxt(),a[i] = b[i] = c[i] = arr[i];
        if(n == 1){
            cout<<"1\n";
            continue;
        }
        if(n == 2){
            cout<<"1 2\n";
            continue;
        }
        int one=-1,two=-1,three=-1;
        for(int i=0;i<n;i++){
            if(arr[i] == 1)
                one = i;
            else if(arr[i] == 2)
                two = i;
            else if(arr[i] == 3)
                three = i;
        }
        shift(a,one);
        shift(b,two - 1);
        shift(c,three -  2);
        swapSmall(a);
        swapSmall(a);
        swapSmall(b);
        swapSmall(b);
        swapSmall(c);
        swapSmall(c);
        vector<vector<int>>v;
        v.push_back(a);
        v.push_back(b);
        v.push_back(c);
        int pre[] = {one,two,three};
        sort(pre,pre+3);
        do{
            for(int i=0;i<n;i++)d[i] = arr[i];
            int pre2[] = {one,two,three};
            int swaps = 2;
            for(int i=0;i<3;i++){
                if(pre[i] != pre2[i]){
                    swaps--;
                    for(int j=0;j<3;j++){
                        if(pre2[j] == pre[i]){
                            swap(pre2[j],pre2[i]);
                            swap(d[pre2[j]],d[pre2[i]]);
                            break;
                        }
                    }
                }
            }
            for(int i=0;i<n;i++)a[i] = b[i] = c[i] = d[i];
            int One=-1,Two=-1,Three=-1;
            for(int i=0;i<n;i++){
                if(d[i] == 1)
                    One = i;
                else if(d[i] == 2)
                    Two = i;
                else if(d[i] == 3)
                    Three = i;
            }
            shift(a,One);
            shift(b,Two - 1);
            shift(c,Three -  2);
            if(swaps == 1){
                swapSmall(a);
                swapSmall(b);
                swapSmall(c);
            }
            else if(swaps == 2){
                swapSmall(a);
                swapSmall(a);
                swapSmall(b);
                swapSmall(b);
                swapSmall(c);
                swapSmall(c);
            }
            v.push_back(a);
            v.push_back(b);
            v.push_back(c);
        }while(next_permutation(pre,pre+3));
        sort(all(v));
        for(auto x:v[0])cout<<x<<' ';
        cout<<endl;
    }
	return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string& pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    input_checker in;
    int tt = in.readInt(1, 5e5);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 1e5);
        in.readEoln();
        sn += n;
        vector<int> a(n);
        for (int i = 0; i < n; i++) {
            a[i] = in.readInt(1, n);
            (i == n - 1 ? in.readEoln() : in.readSpace());
            a[i]--;
        }
        {
            auto b = a;
            sort(b.begin(), b.end());
            for (int i = 0; i < n; i++) {
                assert(b[i] == i);
            }
        }
        rotate(a.begin(), min_element(a.begin(), a.end()), a.end());
        vector<int> ans = a;
        auto Eval = [&](vector<int> b, int c) {
            rotate(b.begin(), min_element(b.begin(), b.end()), b.end());
            for (int i = 1; i < n; i++) {
                if (b[i] != i) {
                    c--;
                    for (int j = i + 1; j < n; j++) {
                        if (b[j] == i) {
                            swap(b[i], b[j]);
                            break;
                        }
                    }
                    if (c == 0) {
                        break;
                    }
                }
            }
            ans = min(ans, b);
        };
        Eval(a, 2);
        for (int i = 1; i < n; i++) {
            swap(a[0], a[i]);
            if (a[(i + 1) % n] <= 2 || a[(i + 2) % n] <= 2) {
                Eval(a, 1);
            }
            swap(a[0], a[i]);
        }
        for (int i = 0; i < n; i++) {
            cout << ans[i] + 1 << " \n"[i == n - 1];
        }
    }
    assert(sn <= 5e5);
    in.readEof();
    return 0;
}
Editorialist's code (Python)
def solve(a, swaps):
	onepos = a.index(1)
	a = a[onepos:] + a[:onepos]
	for i in range(1, n):
		if swaps == 0 or a[i] == i+1: continue
		pos = a.index(i+1)
		swaps -= 1
		a[i], a[pos] = a[pos], a[i]
	return a

for _ in range(int(input())):
	n = int(input())
	a = list(map(int, input().split()))
	ans = solve(a[:], 2)
	onepos = a.index(1)
	starts = []
	for i in range(n):
		if a[i] == 2:
			starts.append(i)
			starts.append((i-1)%n)
			starts.append((i-2)%n)
		if a[i] == 3:
			starts.append((i-2)%n)
			starts.append(i)
	for x in starts:
		a[onepos], a[x] = a[x], a[onepos]
		ans = min(ans, solve(a[:], 1))
		a[onepos], a[x] = a[x], a[onepos]
	print(*ans)