Problem Link
Author: Ivan Safonov
Tester: Hasan Jaddouh
Editorialist: Bhuvnesh Jain
Difficulty
MEDIUMHARD
Prerequisites
Inversions, Bit Manipulation, Binary Search, Meetinmiddle
Problem
You are given an array A of length N. All the numbers in the array are less than 2^K, where K is also given in the input. We define function f(x) as the number of inversions in the array B where B[i] = A[i] \oplus x. You need to find the P numbers in the sorted sequence \{f(x), x\}.
Explanation
An inversion pair (i, j) in an array means that i < j but A[i] > A[j]. Let us first understand how an inversion pair changes when both the elements of the array are xor’ed with a number, x.
Consider 2 numbers, a and b. Assume a < b. We need to find for what values of x, (a \oplus x) > (b \oplus x) i.e. the inequality changes. Consider the binary representation of a and b. The longest common prefix in a and b doesn’t matter and also the bits in x at that position don’t matter. The next bit will be 0 in a and 1 in b as a < b. So, if that bit is 0 in x, the inequality remains the same else it changes. Again, the remaining bits in x do not matter. For example: a = 57 and b = 62.
The first 3 bits in x don’t matter as the bits in (a \oplus x) and (b \oplus x) would remain the same. The fourth bit in x will decide the whether (a \oplus x) or (b \oplus x) is greater. If the bit is 0, (a \oplus x) < (b \oplus x) else if the bit is 1, (a \oplus x) > (b \oplus x). The remaining last 2 bits again don’t matter in x as the fourth bit is enough to decide the sign of inequality.
The above result shows that each bit in x independently decides whether it will contribute to the reversal of inequality between 2 numbers or not. Thus, for every pair (i, j) in the array we need to find out how many of them are there such that if the {y}^{th} bit is set or unset in x, it leads to an inversion pair. Let us maintain an array inv_cnt[y][z], where y is the required bit and z is either 0 or 1 denoting whether the bit is unset or set respectively. Below is a small pseudocode for the above computation:
def rec(array a, int bit):
if bit < 0 or len(a) == 0:
return
zero = 0, one = 0
unset_bit = []
set_bit = []
for x in a:
if x & (1 << bit):
one += 1
inv_cnt[bit][1] += zero
set_bit.append(x)
else:
zero += 1
inv_cnt[bit][0] += one
unset_bit.append(x)
rec(set_bit, bit  1)
rec(unset_bit, bit  1)
rec(a, k  1)
Let us understand why the above pseudocode works. Note that we are always scanning and putting elements into the array from left to right. So, if a number is found to be greater than an already scanned number, it leads to an inversion pair. For every bit, y, we need to find how many inversions exist in the array such that the longest common prefix between the considered pair is (y  1).
So, we start with the largest possible bit i.e. (k  1). We separate out all the numbers based on whether the current bit is unset or set in them. When we call the recursive function on the new arrays, we know that they have the first bit common between them. This way, in the recursive step w, we know that the numbers in the array have first (w  1) common between them.
Now, let see how the inversion logic is being taken care. Since, we are moving from left to right in the array and the elements in the array differ at only current bit (as they have been filtered based on matching previous bits as shown above), if the bit is 1, we know that all numbers having 0 will be smaller than the given number and hence not contribute to inversion pair if the bit is unset in x but will contribute to inversion pair if that bit is set in x. The similar logic applies to the other case as well. In case of any doubts till here, I request you to manually go through the pseudocode once for any array, say A = [7, 5, 3, 4, 1, 2] and K = 3.
The complexity of the above pseudocode is O(N * K) because the recursion can have a maximum depth of K (see the first condition in the function). At each step of recursion, we can iterate through maximum N elements.
Now, we know 2 things:
 Each bit in x contributes independently to inversion count in the array.
 The contribution of each bit to inversion count is already calculated.
So, a small bruteforce code which finds the number of inversions for all possible x (0 ≤ x < 2^K) is given below:
def get_all(A, K):
rec(A, K  1)
b = []
for i in [0, 2^K  1]:
inv = 0
for j in [0, K  1];
if i & (1 << j):
inv += inv_cnt[j][1]
else:
inv += inv_cnt[j][0]
b.append((inv, i))
b.sort()
return b
Thus, we can just print the P^{th} number in the above list. The above logic is enough to pass the first 2 subtasks for 30 points. But for the full solution, we can’t generate all the number of inversion for all possible x as it is quite large in number.
Let us restate the problem we have now. We have a function which can be evaluated independently for every bit and for which we need to find the number which gives the P^{th} smallest value of the function. Since we can find the contribution of every bit independently, we will try to use binary search along with meetin=themiddle here.
The idea is as follows:

We divide the bits into 2 groups of almost equal size i.e. (K/2). The most significant ones form one group and the least significant ones form another group. Let us denote the low_bit as the last floor(K/2) and the remaining ones as high_bit.

Ant number x, can be represented as \text{high\_bit} * 2^{(K/2)} + \text{low\_bit}.

We find the contribution for every possible x such that only bits in that group matter i.e. they are set or unset.

Now, we first do a binary search as follows :
 We first find how many inversions will the array have for P^{th} number. Let us say the final answer is denoted by W. (Note that range for this binary search will be [0, N * (N1)/2])
 Now, we know the number of inversions in the P^{th} number. So, we binary serach on how many numbers below a given number have number of inversions less or equal than W. (Note that the range for this binary search will be [0, 2^K  1]
For more details, you can refer to the commented editorialist’s solution for help.
The time complexity of the above approach will be O(2^{K/2} * (K + \log{N})). The first is due to the building of array low_bit and high_bit and sorting them in ascending order. The next part is due to the binary search where the first one requires O(2^{K/2} * \log{(N * (N1)/2)}) ~ O(2^{K/2} * \log{N}) complexity and the second binary search takes O(2^{K/2} * K).
Thus, the overall complexity of the problem is O(N * K + 2^{K/2} * (K + \log{N})). The space complexity of the above solution is O(N + K).
Feel free to share your approach, if it was somewhat different.
Time Complexity
O(N * K + 2^{K/2} * (K + \log{N})) per test case.
Space Complexity
O(N + K)