KTHGRID - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3

Author: Alex Danilyuk
Tester: Radostin Chonev
Editorialist: Alex Danilyuk

DIFFICULTY:

Medium-Hard

PREREQUISITES:

Prefix sums, Sqrt-heuristics
Might be useful for some solutions: Fenwick tree (BIT), 2D Fenwick tree, Parallel Binary Search, Persistent Segment Tree, Wavelet matrix (I have no idea what that is)

PROBLEM:

Given is a N \times N grid A_{rc} which is a permutation of numbers from 1 to N^2. You have to process Q queries, in each query you are asked to find the value of k-th minimum in the subgrid with corners (r_1, c_1) and (r_2, c_2).

QUICK EXPLANATION:

Split numbers in blocks of size B, determine in which block the answer is (using one of many smart ideas), then add numbers from this block in increasing order.

EXPLANATION:

The problem has a deceivingly simple statement, looks like it should be textbook. And it is if asked on array instead of grid.

Well-known solution for one-dimensional problem

Complexity: \mathcal{O}(N \log N) precalc, \mathcal{O}(\log N) per query, online
We can do a binary search on answer for query, but now we need to have a way to ask the count of numbers \le X on query segment. For that, we will have a persistent segment tree on sum upon numbers, and make scanline on our array, thus we can ask the number of numbers \le X on some prefix. To do that on segment, we will say that it is a difference of 2 prefixes. That’s \mathcal{O}(\log^{2} N) per query. To achieve \mathcal{O}(\log N) per query we will do descent on segment tree instead of binary search, but since we should use 2 versions for 2 prefixes, we will do a parallel descent.

It is tempting to use our knowledge of the one-dimensional problem to do something similar in two dimensions. This temptation is bad though :slight_smile:

Some big guns that don’t work

What stops us from doing the same trick with persistent segment tree and parallel descent? Well, it is not clear what is ‘scanline’ in two dimensions. We can try to get the states of persistent segment tree for all 2D prefix sums we might be interested in, but I don’t see a good way to do that. Another idea is to use that Q is small, so we don’t need to be logarithmic per query. We can do the trick with persistent segment tree for each row, and then do parallel descent in 2N versions (2 in each row) per query. This actually works, and has a very nice complexity of \mathcal{O}(N^2 \log N) for precalc and \mathcal{O}(N \log N) per query, which is much better than most of the complexities we will see in this editorial later, and it works online! Seems like an overall great solution, what’s the catch? Well, the constant is not great, but an even bigger issue is memory. Such a persistent segment tree will need O(N^2 \log N) memory, and that won’t fly even with 1.5 GB of memory even for N \le 2500 subtask. This solution gets its rightful 35 points.

Another idea pros might have when they want to do binary search for each query, but have to ask some hard to calculate sum on each iteration of binary search — is to do parallel binary search and then somehow calculate all the sums at once. Alright, one query needed for binary search is ‘what is the number of numbers \le X in the given rectangle?’. The reason we did parallel binary search is to answer a bunch of such queries offline. Let’s sort the numbers, turn them on one-by-one, now the query is just ‘sum in rectangle’. 2D segment tree? No, 2D Fenwick tree! This solution works in \mathcal{O}(N^2 \log^{3} N + Q \log^{3} N) with a very small constant, it also gets 35 points.

One can see that this solution doesn’t use that Q is small, and it is rather wasteful. Instead of 2D Fenwick, we can make a Fenwick for each row. Yes, you will need to try all the rows to get the sum, but an update is now \mathcal{O}(\log N), and this makes the solution \mathcal{O}(N^2 \log^{2} N + QN \log^{2} N), which might be enough for 55 points depending on the implementation.

Work smart, not hard

Instead of using some hardcore data structures, let’s look at the problem from a different angle: let’s use sqrt-heuristic on numbers. Split the numbers into blocks of size B, determine in which block the answer will be by somehow calculating the number of numbers less than tB in the given rectangle, then just add numbers one-by-one in increasing order to find out when we will have k of them. Wait, we still need to answer the number of numbers less than something in the rectangle, how is that better than before? Well, now the ‘something’ in ‘less than something’ can only have N^2 / B values, which is pretty cool. Q is small, so we can make B big so that we have a really small number of possible numbers. Such small, in fact, that it is ok to just calculate 2D prefix sums for all of them, and even store them! \mathcal{O}(N^4 / B) time and memory for precalc, \mathcal{O}(N^2 / B + B) per query, which can be easily improved to \mathcal{O}(\log(N^2 / B) + B). Seems insane, but it is enough to get 55 points. And it works online!
If you want to be a complexity purist, we will choose B = N^2 / \sqrt{Q} to get \mathcal{O}((N^2 + Q) \sqrt{Q}).

To get at least 95 points, we have to think about memory again. We need to either store less for each block or… don’t store it at all! Why are we storing 2D prefix sums for each block, when we can answer the queries offline? Remember parallel binary search? Well, let’s do ‘parallel sqrt-heuristic’. Time complexity is still the same, but we only need \mathcal{O}(N^2 + Q) memory now. Implemented well, this can get 95 and even 98 points (we were not able to get 100 with it, but it might be possible).

Yo dawg, I heard you like sqrt-heuristics…

What did I say about ‘store less for each block’?
But let’s take a look at our complexity first. It consists of 3 parts: precalc + finding the block with the answer + finding the answer inside the block.
Right now it’s \mathcal{O}(N^4 / B) / \mathcal{O}(Q \log(N^2 / B)) / \mathcal{O}(QB). The third part is the staple of our solution, we won’t change it. The second part is very fast compared to the other two. Maybe we can make it slower, but decrease memory usage?
The idea is to make the grid into a grid with a bit larger cells: split each side into blocks of size K, thus covering the grid with a sparser grid with cells of size K \times K. And let’s store 2D prefix sums for this sparser grid instead: the memory (and time) usage will be \mathcal{O}((N/K)^2 (N^2 / B)) = \mathcal{O}(N^4 / (BK^2)).
It is harder to ask the sum in a rectangle now though: we can easily ask the sum in ‘main rectangle’: the one with sides on lines of the sparser grid, but not in any rectangle. But any rectangle is different from some main rectangle only in \mathcal{O}(NK) cells, 4N\frac{K-1}{2} to be precise. We can write them down, sort them, and then count them separately in binary search (with another binary search).
Now the complexities for three parts are \mathcal{O}(N^4 / (BK^2)) / \mathcal{O}(Q(NK \log(NK) + \log(N^2 / B) \log(NK))) / \mathcal{O}(QB). This should get 95 points without big troubles.
Suddenly the second part is the slowest. Did I say ‘sort them’? No-no-no, let’s already store them sorted for each row and column, write them down for each row and column separately, and count them in binary searches also separately. That will decrease the complexity for the second part to \mathcal{O}(Q(NK + K \log(N^2 / B) \log(NK))) and get us 100 points. And this solution is fully online!
Another remark for complexity purists: choosing B = N^{3/2}Q^{-1/4}, K = N^{1/2}Q^{-1/4} will gives us complexity \mathcal{O}(N^{3/2}Q^{3/4} + N^2 + N^{1/2}Q^{3/4} \log^2 N) which is… uh… cool?

There are several other solutions, but I would like to post the editorial now (it is long overdue already) and add them later.

CODE:

sqrt in sqrt (C++)
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#include <vector>
#include <set>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <queue>
#include <ctime>
#include <cassert>
#include <complex>
#include <string>
#include <cstring>
#include <chrono>
#include <random>
#include <bitset>
#include <array>
using namespace std;

#ifdef LOCAL
	#define eprintf(...) {fprintf(stderr, __VA_ARGS__);fflush(stderr);}
#else
	#define eprintf(...) 42
#endif

using ll = long long;
using ld = long double;
using uint = unsigned int;
using ull = unsigned long long;
template<typename T>
using pair2 = pair<T, T>;
using pii = pair<int, int>;
using pli = pair<ll, int>;
using pll = pair<ll, ll>;
mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
ll myRand(ll B) {
	return (ull)rng() % B;
}

#define mp make_pair
#define all(x) (x).begin(),(x).end()

const int N = 5001;
const int S = N * N;
const int K = 3;
const int KK = 2 * K + 1;
const int B = (int)1e5;
const int Z = N / KK + 2;
const int C = S / B + 1;
int p[S];
int ord[S];
int curPref[Z][Z];
int mem[C][Z][Z];
int ordForRow[2 * N][N];
int szForRow[2 * N];
int n, q;
int m;
int addId[4 * K];
int addLines[4 * K][N];


uint64_t state;
uint64_t nextRand() {
    state ^= state << 13;
    state ^= state >> 7;
    state ^= state << 17;
    return state >> 10;
}
int nextBounded(int B) {
	return nextRand() % B;
}
int minR, maxR, minC, maxC;
void genGrid() {
    for (int i = 0; i < n * n; i++)
        p[i] = i;
    for (int i = 0; i < n * n; i++)
        swap(p[i], p[nextBounded(i + 1)]);
    for (int i = 0; i < n * n; i++)
    	ord[p[i]] = i;
}
array<int, 5> genQuery() {
    int lenR = minR + nextBounded(maxR - minR + 1);
    int lenC = minC + nextBounded(maxC - minC + 1);
    int r1 = 1 + nextBounded(n - lenR + 1);
    int c1 = 1 + nextBounded(n - lenC + 1);
    int r2 = r1 + lenR - 1;
    int c2 = c1 + lenC - 1;
    int k = 1 + nextBounded(lenR * lenC);
    return {r1, c1, r2, c2, k};
}

void precalc() {
	for (int i = 0; i < 2 * n; i++)
		szForRow[i] = 0;
	int m = (n + KK - 1) / KK;
	for (int i = 0; i < n * n; i++) {
		if (i % B == 0) {
			int p = i / B;
			for (int x = 0; x < m; x++)
				for (int y = 0; y < m; y++)
					mem[p][x + 1][y + 1] = mem[p][x][y + 1] + mem[p][x + 1][y] - mem[p][x][y] + curPref[x][y];
		}
		int x = ord[i] / n, y = ord[i] % n;
		curPref[x / KK][y / KK]++;
		ordForRow[x][szForRow[x]++] = i;
		ordForRow[n + y][szForRow[n + y]++] = i;
	}
}

int calcVal(int p, int x1, int x2, int y1, int y2) {
	int res = mem[p][x1][y1] + mem[p][x2][y2] - mem[p][x1][y2] - mem[p][x2][y1];
	for (int i = 0; i < m; i++) {
		int pos = lower_bound(ordForRow[addId[i]], ordForRow[addId[i]] + n, p * B) - ordForRow[addId[i]];
		res += addLines[i][pos];
	}
	return res;
}
int solve() {
	auto qqq = genQuery();
	int x1 = qqq[0] - 1, y1 = qqq[1] - 1, x2 = qqq[2], y2 = qqq[3], k = qqq[4];
	m = 0;
	int xx1 = 0, yy1 = 0, xx2 = 0, yy2 = 0;
	if (x2 - x1 < KK) {
		for (int x = x1; x < x2; x++) {
			addId[m] = x;
			addLines[m][0] = 0;
			for (int i = 0; i < n; i++) {
				addLines[m][i + 1] = addLines[m][i];
				int p = ord[ordForRow[x][i]];
				int y = p % n;
				if (y1 <= y && y < y2)
					addLines[m][i + 1]++;
			}
			m++;
		}
	} else if (y2 - y1 < KK) {
		for (int y = y1; y < y2; y++) {
			addId[m] = n + y;
			addLines[m][0] = 0;
			for (int i = 0; i < n; i++) {
				addLines[m][i + 1] = addLines[m][i];
				int p = ord[ordForRow[n + y][i]];
				int x = p / n;
				if (x1 <= x && x < x2)
					addLines[m][i + 1]++;
			}
			m++;
		}
	} else {
		xx1 = (x1 + K) / KK;
		xx2 = (x2 + K) / KK;
		yy1 = (y1 + K) / KK;
		yy2 = (y2 + K) / KK;
		int xxx1 = min(xx1 * KK, n), xxx2 = min(xx2 * KK, n), yyy1 = min(yy1 * KK, n), yyy2 = min(yy2 * KK, n);
		bool up = xxx1 > x1;
		bool down = xxx2 < x2;
		bool lft = yyy1 > y1;
		bool rgt = yyy2 < y2;
		if (up) {
			for (int x = x1; x < xxx1; x++) {
				addId[m] = x;
				addLines[m][0] = 0;
				int l = y1, r = y2;
				for (int i = 0; i < n; i++) {
					addLines[m][i + 1] = addLines[m][i];
					int p = ord[ordForRow[x][i]];
					int y = p % n;
					if (l <= y && y < r)
						addLines[m][i + 1]++;
				}
				m++;
			}
		} else {
			for (int x = xxx1; x < x1; x++) {
				addId[m] = x;
				addLines[m][0] = 0;
				int l = yyy1, r = yyy2;
				for (int i = 0; i < n; i++) {
					addLines[m][i + 1] = addLines[m][i];
					int p = ord[ordForRow[x][i]];
					int y = p % n;
					if (l <= y && y < r)
						addLines[m][i + 1]--;
				}
				m++;
			}
		}
		if (down) {
			for (int x = xxx2; x < x2; x++) {
				addId[m] = x;
				addLines[m][0] = 0;
				int l = y1, r = y2;
				for (int i = 0; i < n; i++) {
					addLines[m][i + 1] = addLines[m][i];
					int p = ord[ordForRow[x][i]];
					int y = p % n;
					if (l <= y && y < r)
						addLines[m][i + 1]++;
				}
				m++;
			}
		} else {
			for (int x = x2; x < xxx2; x++) {
				addId[m] = x;
				addLines[m][0] = 0;
				int l = yyy1, r = yyy2;
				for (int i = 0; i < n; i++) {
					addLines[m][i + 1] = addLines[m][i];
					int p = ord[ordForRow[x][i]];
					int y = p % n;
					if (l <= y && y < r)
						addLines[m][i + 1]--;
				}
				m++;
			}
		}
		if (lft) {
			for (int y = y1; y < yyy1; y++) {
				addId[m] = n + y;
				addLines[m][0] = 0;
				int l = (up ? xxx1 : x1), r = (down ? xxx2 : x2);
				for (int i = 0; i < n; i++) {
					addLines[m][i + 1] = addLines[m][i];
					int p = ord[ordForRow[n + y][i]];
					int x = p / n;
					if (l <= x && x < r)
						addLines[m][i + 1]++;
				}
				m++;
			}
		} else {
			for (int y = yyy1; y < y1; y++) {
				addId[m] = n + y;
				addLines[m][0] = 0;
				int l = (up ? xxx1 : x1), r = (down ? xxx2 : x2);
				for (int i = 0; i < n; i++) {
					addLines[m][i + 1] = addLines[m][i];
					int p = ord[ordForRow[n + y][i]];
					int x = p / n;
					if (l <= x && x < r)
						addLines[m][i + 1]--;
				}
				m++;
			}
		}
		if (rgt) {
			for (int y = yyy2; y < y2; y++) {
				addId[m] = n + y;
				addLines[m][0] = 0;
				int l = (up ? xxx1 : x1), r = (down ? xxx2 : x2);
				for (int i = 0; i < n; i++) {
					addLines[m][i + 1] = addLines[m][i];
					int p = ord[ordForRow[n + y][i]];
					int x = p / n;
					if (l <= x && x < r)
						addLines[m][i + 1]++;
				}
				m++;
			}
		} else {
			for (int y = y2; y < yyy2; y++) {
				addId[m] = n + y;
				addLines[m][0] = 0;
				int l = (up ? xxx1 : x1), r = (down ? xxx2 : x2);
				for (int i = 0; i < n; i++) {
					addLines[m][i + 1] = addLines[m][i];
					int p = ord[ordForRow[n + y][i]];
					int x = p / n;
					if (l <= x && x < r)
						addLines[m][i + 1]--;
				}
				m++;
			}
		}
	}
	int l = 0, r = (n * n - 1) / B + 1;
	while(r - l > 1) {
		int mid = (l + r) / 2;
		if (calcVal(mid, xx1, xx2, yy1, yy2) < k)
			l = mid;
		else
			r = mid;
	}
	k -= calcVal(l, xx1, xx2, yy1, yy2);
	int pos;
	for (pos = l * B; k > 0; pos++) {
		int p = ord[pos];
		int x = p / n, y = p % n;
		if (x1 <= x && x < x2 && y1 <= y && y < y2) k--;
	}
	return pos;
}

int main()
{
	cin >> n >> q >> state >> minR >> maxR >> minC >> maxC;

	genGrid();

	precalc();
	while(q--) {
		printf("%d\n", solve());
	}

	return 0;
}
1 Like