INVRET - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Authors: Udit Sanghi, Jatin Yadav
Tester: Harris Leung
Editorialist: Nishank Suresh

DIFFICULTY:

3869

PREREQUISITES:

Mergesort, ad-hoc optimizations

PROBLEM:

You have a hidden permutation F of \{0, 1, 2, \ldots, N-1\}. You would like to compute the number of inversions of this permutation. To do so, you may append values to F with the help of several different functions:

  • An arbitrary number
  • F_{F_i} for some i
  • F_i + F_j
  • F_i - F_j
  • F_i \times F_j
  • F_i \oplus F_j
  • [F_i \lt F_j]

Devise a strategy that computes the number of inversions of F regardless of what F is, using at most 10^6 queries.

EXPLANATION:

The final strategy uses several optimizations built upon ideas from earlier subtasks, so let’s go through them one by one.

Subtask 1 (N = 7000, at most two elements are not in position)

This subtask doesn’t really play into the final solution too much, so let’s get it out of the way first.

There are only two possible cases here: either F is already sorted, or F is obtained from a sorted array by making one swap. Let’s look at them separately.

One swap

Suppose u and v are swapped, where u \lt v. Let k = v - u.

The observation here is that the number of inversions equals 2\cdot (k-1) + 1, so finding k is enough. To find k, scan through the array and compare all adjacent elements. Now consider what happens by taking the prefix xor of these N-1 comparisons:

  • If u+1 \lt v, exactly k-1 elements of the prefix xor will be 1, so adding them all up is enough
  • If u+1 = v, this is no longer true - some suffix of the prefix xors will be 1. However, this case can be differentiated from the previous one by looking at the last element of the prefix xor array — it will always be 1 here, and always be 0 in the earlier case.
No swaps

Once we’ve done our adjacent comparisons and prefix xors for the first case, this case can be identified by noting that the prefix xors must end with zero and the (arithmetic) sum of the prefix xor array will also be 0.

Subtasks 2 (N = 1000) and 3 (N = 1400)

Subtask 2 is rather trivial — simply compute the results of all \binom{N}{2} comparisons, and print their sum.

The solution from subtask 2 doesn’t quite fit within the limit for subtask 3. To optimize it, note that instead of comparing every pair (i, j) where 1 \leq i \lt j \leq N, it instead suffices to compare all those pairs with:

  • 1 \leq i \lt j \leq \frac{N}{2}
  • \frac{N}{2} \lt i \lt j \leq N

This information is sufficient to reconstruct the results of the comparisons of all pairs with i \lt j, and now fits within the limit.

Subtasks 4 (N = 4000) and 5 (N = 5000)

We can no longer hope to stay within the limit doing a quadratic number of comparisons. Instead, we look towards a method of computing inversions that utilizes far fewer comparisons — merge sort.

It is well-known that the problem of computing inversions in an array can be done using mergesort, by modifying the merge function slightly to compute the number of inversions created while merging the two halves. Let’s use this to our advantage.

Suppose we run merge sort on F. Its two halves will be sorted recursively, so we only need to worry about counting inversions during the merging step. The standard merging step looks as follows:

// L and R are the two halves being merged
i = 0, j = 0, invs = inversions(L) + inversions(R)
M = []
while i < |L| and j < |R|:
    if L[i] < R[j]:
        invs += j
        append L[i] to M
        i += 1
    else:
        append R[j] to M
        j += 1
append the rest of L to M if i < |L|
append the rest of R to M if i < |R|
return (M, invs)

To modify it for our purposes, we do the following:

  • First, sort L and R recursively
  • Then, arrange L and R in a line. To each of them, append an infinity value (for example, N). This is to ensure we don’t have to deal with any special cases once the main merging loop is done
  • Now, do the following:
i = start(L), j = start(R), invs = 0
sorted = []
repeat |L| + |R| times:
    invs += compare(L[i], R[j]) * (j - start[R])
    append (L[i]*compare(L[i], R[j]) + R[j]*compare(R[j], L[i])) to sorted
    i += compare(L[i], R[j])
    j += compare(R[j], L[i])
// Now, invs is the number of inversions created during the merge and sorted holds the sorted list

With proper implementation, this should pass subtask 5.

Subtasks 6 (N = 6000), 7 (N = 6500), 8 (N = 7000)

The final three subtasks are, for the most part, an exercise in micro-optimization. The base of the solution is still the merge sort from the previous subtasks, all that remains is to cut out some operations.

A variety of different optimizations can be applied here. For example,

  • When the size of the list is small (say, \leq 6), sort it using a different method, for example bubble sort. This uses a quadratic number of comparisons, but still saves a decent number because merge sort requires a fairly large overhead in maintaining pointers and such.
  • It’s possible to play around with the order of operations in the merging phase a little to ensure that the step where L and R are first placed in a line is unnecessary — this, of course, saves \mathcal{O}(N\log N) operations.
  • With some precomputation, it is in fact even possible to pass N = 7500 in just under 10^6 queries — see this comment for details and implementation.

On a final note, this problem is somewhat detail-heavy in implementation, and for the most part the editorial above only goes into the broad ideas behind each step - not the actual implementation details behind how to achieve them using the given operations. You are encouraged to try that yourself, or look at others’ code for reference.

TIME COMPLEXITY:

\mathcal{O}(N\log N), but with a rather high constant factor — roughly, it is the number of operations performed by the solution, and hence is \leq 10^6.

CODE:

Setter's Code
#include <bits/stdc++.h>
using namespace std;

vector<string> res;

int indx;

int add(int i, int j){
    res.push_back("add " + to_string(i) + " " + to_string(j));
    return indx++;
}

int subtract(int i, int j){
    res.push_back("subtract " + to_string(i) + " " + to_string(j));
    return indx++;
}

int multiply(int i, int j){
    res.push_back("multiply " + to_string(i) + " " + to_string(j));
    return indx++;
}

int compare(int i, int j){
    res.push_back("compare " + to_string(i) + " " + to_string(j));
    return indx++;
}

int put(int i){
    res.push_back("put " + to_string(i));
    return indx++;
}

int compose(int i){
    res.push_back("compose " + to_string(i));
    return indx++;
}

int n,ans;
int k0,k1,kn;

pair<int,int> order(int x,int y){
    int w = compare(x,y);
    int w_ = subtract(k1,w);
    ans = add(ans,w_);
    int val1 = add(multiply(x,w),multiply(y,w_));
    int val2 = subtract(add(x,y),val1);
    return {val1, val2};
}

vector<int> merge_sort(vector<int> indices){
    if(indices.size() <= 6){
        for(int i = 0; i < indices.size()-1; i ++){
            for(int j = 0; j < indices.size()-1-i; j ++){
                auto [i1,i2] = order(indices[j],indices[j+1]);
                indices[j] = i1;
                indices[j+1] = i2;
            }
        }
        vector<int> sorted_list;
        for(auto x:indices) sorted_list.push_back(add(x,k0));
        sorted_list.push_back(put(n+1));
        return sorted_list;
    }
    vector<int> lft, rgt;
    for(int i = 0; i < indices.size(); i ++){
        if(i < indices.size()/2) lft.push_back(indices[i]);
        else rgt.push_back(indices[i]);
    }
    lft = merge_sort(lft);
    rgt = merge_sort(rgt);
    int l = put(lft[0]),r = put(rgt[0]); // value(value(l)) = val(lft[l]) => value(l) = lft[l]
    int rptr = k0;
    vector<vector<int>> comp;
    vector<pair<int,int>> bruh;
    vector<int> sorted_list;
    for(int i = 0; i < indices.size(); i++){
        int lval = compose(l),rval = compose(r);
        int w = compare(lval,rval);
        int w_ = subtract(k1,w);
        ans = add(ans,multiply(w,rptr));
        l = add(l,w);
        r = add(r,w_);
        rptr = add(rptr,w_);
        comp.push_back({w,lval,w_,rval});
    }
    if(indices.size() != n){
        for(auto x:comp){
            bruh.push_back({multiply(x[0],x[1]),multiply(x[2],x[3])});
        }
        for(auto x:bruh) sorted_list.push_back(add(x.first,x.second));
        sorted_list.push_back(put(n+1));
    }
    return sorted_list;
}

signed main(){
    cin >> n;
    indx = n;
    vector<int> v;
    for(int i = 0; i < n; i ++) v.push_back(i);
    k0 = put(0);
    k1 = put(1);
    kn = put(n);
    ans = k0;
    merge_sort(v);
    add(ans,k0);
    cout << res.size() << "\n";
    for(auto x:res) cout << x << "\n";
}   
1 Like