RATISTERA - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Sorting

PROBLEM:

CodeChef has N participants. Participant i's rating is always in the range [A_i, B_i].
If participant X is ratist, and participant Y had a lower rating than participant X before a contest but a higher rating after it, Y becomes ratist.
Multiple people can become ratist after a single contest.

For each starting ratist from 1 to N, compute the maximum number of distinct people who can ever become ratist after several contests.

EXPLANATION:

Suppose participant i is ratist.
Participant j can then become ratist directly through participant i, if and only if there exist integers x and y such that:

  • A_i \leq x, y \leq B_i, so that participant i can have a rating of x and y (at different points of time).
  • A_j \lt x and y \lt B_j, so that participant j can have a rating that’s less than that of participant i at some point, and also a greater rating later in the future.

Clearly it’s ideal to choose x as large as possible and y as small as possible, i.e, x = B_i and y = A_i.
So, participant j can become ratist through participant i if and only if A_j \lt B_i and A_i \lt B_j.
Note that this relation is symmetric: if i and j are swapped, the conditions remain exactly the same.

Visually, the inequalities A_j \lt B_i and A_i \lt B_j simply mean that the two intervals [A_i, B_i] and [A_j, B_j] should intersect non-trivially - as in, their intersection should contain a non-endpoint of at least one of the intervals.
For example, [1, 3] and [2, 4] intersect non-trivially, as do [1, 3] and [2, 2]. However, [1, 2] and [2, 3] do not (the only intersection point is 2, which is an endpoint of both intervals).


Consider a simple undirected graph on N vertices, with an edge between i and j if and only if one of them being ratist allows the other one to become ratist directly.

If participant x is the only initial ratist, it’s easy to see that the set of people who can possibly become ratist eventually is exactly the set of people in the connected component containing x.
Further, they all can indeed be made ratist (since we have functionally infinite time, first make every neighbor of x ratist one by one, then make all their neighbors ratist, and so on); so this is exactly the answer for participant x.

So, all we really need to do is find the connected components of this graph - we immediately know the answer for all participants.

Of course, the graph can have \Theta(N^2) edges, so it’s not possible to directly generate it and then compute its components.
Instead, we can use the fact that the graph is generated by interval intersection to do a sort of sweepline algorithm.

We process the intervals in ascending order of their left endpoints (break ties in increasing order of right endpoints - though the only real constraint is that any singletons should be processed before any longer intervals starting at the same position).
Maintain the current connected component (say C), and also the maximum right endpoint of some interval in the current component (say M).
When considering the interval [A_i, B_i],

  • If A_i \geq M, the interval [A_i, B_i] cannot belong to the component C (because we’ve already processed all intervals starting before A_i, and so any further intervals can intersect some interval of C at its right endpoint at best; which isn’t good enough for us).
    So, start a new component, containing only [A_i, B_i].
  • Otherwise, [A_i, B_i] can be added to the current component.
    Insert it into C, and update M as M = \max(M, B_i).

Processing each interval takes constant time, so the overall time complexity is \mathcal{O}(N\log N) due to sorting.
It’s also possible to implement this in \mathcal{O}(N) since A_i, B_i \leq 2N.

TIME COMPLEXITY:

\mathcal{O}(N\log N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

struct ufds{
    vector <int> root, sz;
    int n;
 
    void init(int nn){
        n = nn;
        root.resize(n + 1);
        sz.resize(n + 1, 1);
        for (int i = 0; i <= n; i++) root[i] = i;
    }
 
    int find(int x){
        if (root[x] == x) return x;
        return root[x] = find(root[x]);
    }
 
    bool unite(int x, int y){
        x = find(x); y = find(y);
        if (x == y) return false;
 
        if (sz[y] > sz[x]) swap(x, y);
        sz[x] += sz[y];
        root[y] = x;
        return true;
    }
};

void Solve() 
{
    int n; cin >> n;

    ufds uf;
    uf.init(n);

    set <array<int, 3>> S;
    vector <pair<int, int>> vec;
    for (int i = 0; i < n; i++){
        int a, b; cin >> a >> b;
        if (a == b){
            vec.push_back({a, i});
            continue;
        }

        b--;

        int l = a, r = b;

        auto it = S.lower_bound({a, -1, -1});
        // this is first interval that starts after a 
        set <array<int, 3>> R;
        auto it2 = it;
        if (it2 != S.begin()){
            it2--;
            // last interval before a 
            if ((*it2)[1] >= a){
                uf.unite((*it2)[2], i);
                l = min(l, (*it2)[0]);
                r = max(r, (*it2)[1]);
                R.insert(*it2);
            }
        }

        while (it != S.end()){
            if ((*it)[0] > b) break;
            uf.unite((*it)[2], i);
            l = min(l, (*it)[0]);
            r = max(r, (*it)[1]);
            R.insert(*it);
            ++it;
        }

        for (auto x : R) S.erase(x);
        S.insert({l, r, i});
    }

    // for (auto [l, r, i] : S){
    //     cout << l << " " << r << " " << i << "\n";
    // }

    vector <int> who(2 * n + 1, -1);
    for (auto [l, r, i] : S){
        for (int j = l + 1; j <= r; j++){
            who[j] = i;
        }
    }

    for (auto [x, i] : vec){
        if (who[x] != -1){
            uf.unite(i, who[x]);
        }
    }

    for (int i = 0; i < n; i++){
        cout << uf.sz[uf.find(i)] << " \n"[i + 1 == n];
    }
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++)
#include<bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...)
#endif

#ifndef LOCAL
struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};
#else

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
      string X; cin >> X;
      return X;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res;  cin >> res;
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res;  cin >> res;
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
    }

    void readEoln() {
    }

    void readEof() {
    }
};
#endif


struct DSU{
  int N, cmp;
  vector<int> par, sz;
  DSU(int N_) : N(N_ + 1), cmp(N_), par(N_ + 1), sz(N_ + 1, 1) {
    iota(par.begin(), par.end(), 0);
  }

  int find(int node) {
    if(par[node] == node) {
      return node;
    }
    return par[node] = find(par[node]);
  }

  bool join(int u, int v) {
    u = find(u);
    v = find(v);
    if(u == v) {
      return false;
    }
    if(sz[u] > sz[v]) {
      swap(u, v);
    }

    sz[v] += sz[u];
    par[u] = v;
    --cmp;
    return true;
  }
};

int32_t main() {
  ios_base::sync_with_stdio(0);
  cin.tie(0);

  input_checker inp;
  int T = inp.readInt(1, (int)1e4), NN = 0; inp.readEoln();
  while(T-- > 0) {
    int N = inp.readInt(1, (int)5e5); inp.readEoln();
    NN += N;
    vector<int> L(N), R(N);
    vector<vector<int>> upd(2 * N + 1);
    for(int i = 0 ; i < N ; ++i) {
      L[i] = inp.readInt(1, 2 * N); inp.readSpace();
      R[i] = inp.readInt(1, 2 * N); inp.readEoln();

      upd[L[i]].push_back(i);
      upd[R[i]].push_back(i + N);
    }

    set<pair<int, int>> St;
    DSU ds(N);
    for(int i = 1 ; i <= 2 * N ; ++i) {
      for(auto &ri : upd[i]) if(ri >= N) {
        St.erase({L[ri - N], ri - N});
        if(!St.empty() && L[ri - N] != R[ri - N]) {
          ds.join(St.begin() -> second, ri - N);
          ds.join(St.rbegin() -> second, ri - N);
        }
      }

      auto get = [&](int x) {
        if(x >= N)  x -= N;
        return x;
      };

      sort(upd[i].begin(), upd[i].end(), [&](int x, int y) {
        return R[get(x)] - L[get(x)] > R[get(y)] - L[get(y)];
      });

      for(auto &li: upd[i]) if(li < N) {
        if(!St.empty()) {
          ds.join(St.begin() -> second, li);
          ds.join(St.rbegin() -> second, li);
        }
      }
      for(auto &li: upd[i]) if(li < N && R[li] > i) {
        if(!St.empty())
          ds.join(St.rbegin() -> second, li);
        St.insert({L[li], li});
      }
    }

    for(int i = 0 ; i < N ; ++i) {
      cout << max(1, ds.sz[ds.find(i)]) << " \n"[i == N - 1];
    }
  }
  assert(NN <= (int)5e5);
  inp.readEof();
  
  return 0;
}

Editorialist's code (Python)
for _ in range(int(input())):
    n = int(input())
    a = [(10**9 + 1, 0, 0)]
    for i in range(n):
        l, r = map(int, input().split())
        a.append((l, r, i))
    a.sort()
    ans, comp, mx = [0]*n, [], -1
    for l, r, i in a:
        if l >= mx:
            for u in comp: ans[u] = len(comp)
            comp = []
        comp.append(i)
        mx = max(mx, r)
    print(*ans)
Editorialist's code (Python, linear)
for _ in range(int(input())):
    n = int(input())
    ends = [ [] for _ in range(2*n + 2)]
    alone = [ [] for _ in range(2*n + 2)]
    for i in range(n):
        l, r = map(int, input().split())
        if l == r: alone[l].append(i)
        else: ends[l].append((r, i))
    ans, comp, mx = [0]*n, [], 1
    for i in range(1, 2*n + 2):
        if mx <= i:
            for u in comp: ans[u] = len(comp)
            for u in alone[i]: ans[u] = 1
            comp = []
        else:
            for u in alone[i]: comp.append(u)
        for r, u in ends[i]:
            mx = max(mx, r)
            comp.append(u)
    print(*ans)
1 Like

Can someone help me find the mistake? Is my implementation of the segment tree wrong? I’m just finding count of elements whose A < R and B > L for each index i, where L is arr[i][0] and R is arr[i][1].

Code

That’s not enough. For example

1
3
1 3
4 6
2 5

the answer is 3 3 3 - even though [1, 3] and [4, 6] don’t intersect, if one of them is ratist the other one can be made ratist via [2, 5].

1 Like