RANKQ - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: pvtr
Tester: jay_1048576
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Ordered set in C++, or a segment tree/similar data structure.

PROBLEM:

You’re given an array A.
The rank of A_i is defined to be L+1, where L is the number of elements in the array strictly larger than A_i.

Answer Q queries on this array:

  • Given K and R, find the smallest integer x such that you can delete \leq x elements from A[1: K-1], another \leq x elements from A[K+1 : N], and have the rank of A_K in the resulting array be \leq R.

EXPLANATION:

Let’s see how we’d answer a single query (K, R).

Since we want to reduce the rank of A_K, it’s clearly optimal to only remove elements that are \gt A_K.
Smaller values don’t affect its rank at all.

So, suppose there are x elements in A[1:K-1] and y elements in A[K+1:N] that are \gt A_K.
That means the initial rank of A_K is x+y+1, and each element we delete will reduce it by 1.

Our target is to make the rank of A_K be \leq R.
If x+y+1 \leq R the answer is 0, otherwise we need to delete at least (x+y+1) - R elements.

In particular, we’d like to find the smallest integer z such that deleting at most z elements from both sides allows us to delete at least x+y+1-R elements in total.
Finding this z is not too hard.

  • One way is to note that as z increases, we can delete more elements — meaning binary search can be applied.
    If z is fixed, the maximum number of deleted elements is \min(x, z) + \min(y, z), so you can binary search on z to find the first time this number is at least x+y+1-R.
    This takes \mathcal{O}(\log N) time.
  • Alternately, you can do a bit of casework and find z in \mathcal{O}(1) time.
    it can be seen that:
    • If y-x \leq R-1 (assuming x \leq y), we’ll have z = \left \lceil \frac{x+y+1-R}{2} \right\rceil
    • Otherwise, we’ll have z = y - (R-1).
      These can be derived by considering the cases when you delete \leq \min(x, y) elements from both sides (in which case an equal number can be deleted from both); and when this isn’t enough.

At any rate, we can see that the ‘slow’ part of the query is finding x and y: once they’re known, the answer can be found in \mathcal{O}(\log N) or \mathcal{O}(1) time.


Essentially, we want to know the following information:

  • For each 1 \leq i \leq N, how many elements before i are greater than A_i? How many elements after i are greater than A_i?

Let’s call these values \text{left}[i] and \text{right}[i], respectively.
All the values of \text{left}[i] and \text{right}[i] can be found in \mathcal{O}(N\log N) time by using an appropriate data structure.

In C++, the simplest way to do this is to use the builtin __gnu_pbds::tree, also commonly called “ordered set” or “indexed set”.
A blog on the basic functionality and how to include it can be found here.
Essentially, this allows us to keep a data structure that supports everything std::set does, along with:

  • Given k, find the k-th largest element present in the set (find_by_order)
  • Given an element x, find the number of elements in the set that are \lt x (order_of_key)

both in \mathcal{O}(\log N) time.

For our use-case, the second query type, order_of_key, is what we want.

For example, we can calculate the \text{left} values by iterating across the array from left to right, each time querying for the number of elements that are \leq A_i (which in turn also tells us how many elements are \gt A_i, namely \text{left}[i]).
Then, insert A_i into the set and continue on.

In languages other than C++, various other data structures can fulfill this purpose.
For example, you can use a segment tree/fenwick tree built on values, or write (or copy) a custom self-balancing BST (for example, a treap) augmented with appropriate data (namely, subtree sizes).

TIME COMPLEXITY

\mathcal{O}(N\log N + Q) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>

#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
  
#define ordered_set tree<pair<int,int>, null_type,less<pair<int,int>>, rb_tree_tag,tree_order_statistics_node_update>
using namespace std;
 
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

#define int long long
#define pb push_back
#define rep(i,a,b) for(int i = a; i < b; i++)
#define all(x) x.begin(),x.end()
#define in(a) for(int i = 0; i<a.size(); i++) cin>>a[i];
#define out(a) for(int i = 0; i<a.size(); i++) cout<<a[i]<<" ";
typedef vector<int> vi;
#define sqrt(x) sqrtl(x)
#define ret(a) cout<<a<<"\n"; return


const int T = 10000;
const int N = 2e5;
const int Q = 2e5;
const int MAX_A = 1e9;
const int MIN_A = -1e9;

int SUM_N = 0;
int SUM_Q = 0;


bool check(int x, int l, int r, int rank){
    int left = max(0LL, l - x);
    int right = max(0LL, r - x);
    
    if(left + right + 1 <= rank){
        return true;
    }

    return false;
}

void solve(){
    int n, q; cin>>n>>q;

    SUM_N += n;
    SUM_Q += q;

    vi a(n);
    in(a);

    for(int i = 0; i < n; i++){
        assert(MIN_A <= a[i] && a[i] <= MAX_A);
    }

    vector<vector<pair<int,int>>> m(n);

    rep(j,0,q){
        int i, x; cin>>i>>x;
        assert(1 <= i && i <= n);
        assert(1 <= x && x <= n);
        m[i-1].push_back({x, j});
    }

    ordered_set pref, suf;
    rep(i,0,n){
        suf.insert({a[i], n-i});
    }

    vector<int> ans(q);

    rep(i,0,n){
        suf.erase({a[i], n-i});

        int l = i - pref.order_of_key({a[i], i});
        int r = n - i - 1 - suf.order_of_key({a[i], n-i});

        for(auto &x: m[i]){
            int rank = x.first;
            int query = x.second;
           
            int f = 0, s = n;
            while(s - f > 1){
                int mid = (s + f)/2;
                if(check(mid, l, r, rank)){
                    s = mid;
                }
                else{
                    f = mid;
                }
            }
            if(check(f, l, r, rank)) s = f;

            ans[query] = s;
        }

        pref.insert({a[i], i});
    }

    for(int i = 0; i < q; i++) cout<<ans[i]<<"\n";
}   
 
int32_t main() {    
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0); 
    int t = 1; 
    cin>>t;
    assert(1 <= t && t <= T);
    for(int i = 1; i<=t; i++){
        // cout<<"Case #"<<i<<": ";
        solve();
    }
    assert(1 <= SUM_N && SUM_N <= N);
    assert(1 <= SUM_Q && SUM_Q <= Q);
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-6 << "ms\n"; 
    return 0;
}
Tester's code (C++)
/*...................................................................*
 *............___..................___.....____...______......___....*
 *.../|....../...\........./|...../...\...|.............|..../...\...*
 *../.|...../.....\......./.|....|.....|..|.............|.../........*
 *....|....|.......|...../..|....|.....|..|............/...|.........*
 *....|....|.......|..../...|.....\___/...|___......../....|..___....*
 *....|....|.......|.../....|...../...\.......\....../.....|./...\...*
 *....|....|.......|../_____|__..|.....|.......|..../......|/.....\..*
 *....|.....\...../.........|....|.....|.......|.../........\...../..*
 *..__|__....\___/..........|.....\___/...\___/.../..........\___/...*
 *...................................................................*
 */
 
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <bits/stdc++.h>
using namespace std;
using namespace __gnu_pbds;
#define int long long
#define INF 1000000000000000000
#define MOD 1000000007
#define oset tree<pair<int,int>,null_type,less<pair<int,int> >,rb_tree_tag,tree_order_statistics_node_update>

void solve(int tc)
{
    int n,q;
    cin >> n >> q;
    int a[n];
    for(int i=0;i<n;i++)
        cin >> a[i];
    int left[n],right[n];
    oset s;
    for(int i=0;i<n;i++)
    {
        left[i] = i-s.order_of_key({a[i],INF});
        s.insert({a[i],i});
    }
    s.clear();
    for(int i=n-1;i>=0;i--)
    {
        right[i] = n-i-1-s.order_of_key({a[i],INF});
        s.insert({a[i],i});
    }
    while(q--)
    {
        int k,r;
        cin >> k >> r;
        k--;
        int mn = min(left[k],right[k]);
        int d = left[k]+right[k]+1-r;
        if(d<=0)
            cout << 0 << '\n';
        else if(d<=2*mn)
            cout << (d+1)/2 << '\n';
        else
            cout << d-mn << '\n';
    }
}

int32_t main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    int tc=1;
    cin >> tc;
    for(int ttc=1;ttc<=tc;ttc++)
        solve(ttc);
    return 0;
}
Editorialist's code (Python)
class FenwickTree:
    def __init__(self, x):
        """transform list into BIT"""
        self.bit = x
        for i in range(len(x)):
            j = i | (i + 1)
            if j < len(x):
                x[j] += x[i]

    def update(self, idx, x):
        """updates bit[idx] += x"""
        while idx < len(self.bit):
            self.bit[idx] += x
            idx |= idx + 1

    def query(self, end):
        """calc sum(bit[:end])"""
        x = 0
        while end:
            x += self.bit[end - 1]
            end &= end - 1
        return x

import sys
input = sys.stdin.readline

for _ in range(int(input())):
    n, q = map(int, input().split())
    a = list(map(int, input().split()))
    
    values = list(sorted(set(a)))
    compress = {}
    for i in range(len(values)):
        compress[values[i]] = i
    for i in range(n):
        a[i] = compress[a[i]]
    
    fen = FenwickTree([0]*n)
    left = [0]*n
    for i in range(n):
        left[i] = i - fen.query(a[i]+1)
        fen.update(a[i], 1)
    
    fen = FenwickTree([0]*n)
    right = [0]*n
    for i in reversed(range(n)):
        right[i] = n-1-i - fen.query(a[i]+1)
        fen.update(a[i], 1)
    
    for query in range(q):
        k, r = map(int, input().split())
        x, y = left[k-1], right[k-1]

        if x+y+1 <= r: print(0)
        elif abs(x-y)+1 <= r: print((x+y+2-r) // 2)
        else: print(max(x, y) - (r-1))