TREEQUERIES - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Vishesh Saraswat
Testers: Nishank Suresh, Satyam
Editorialist: Nishank Suresh

DIFFICULTY:

3242

PREREQUISITES:

Binary lifting

PROBLEM:

You are given a tree and Q queries on it. Each query is of the form (u, K), where you need to find the K-th vertex encountered if a DFS is started from u.
Queries must be answered online.

EXPLANATION:

There’s an obvious \mathcal{O}(N) solution for each query, where you perform the DFS directly. This is of course too slow, and there’s no obvious way to optimize it.

Instead, let’s do something else.
Let’s root the tree at vertex 1 and start a DFS, thus computing the tin values for every vertex from 1. From now on, we assume the tree to be rooted at 1 whenever necessary, for example when talking about subtrees/children/ancestors/etc.

Let s_v denote the size of the subtree rooted at vertex v.

Now, let’s see what happens when we have a query (u, K).
For convenience, let’s say c_1 \lt c_2 \lt \ldots \lt c_x \lt p \lt c_{x+1} \lt \ldots \lt c_r, where the c_j are the children of u and p is its parent.
Then,

  • If K \leq 1 + s_{c_1} + s_{c_2} + \ldots + s_{c_x}, then the answer to this query is simply the answer to the query (1, K + tin[u]) instead (which has been precomputed already).
  • If K \gt 1 + s_{c_1} + \ldots + s_{c_x} + (N - s_u), the answer to this query is the answer to query (1, K+tin[u] - (N - s_u)) (once again, precomputed).
  • Otherwise, we have 1 + s_{c_1} + s_{c_2} + \ldots + s_{c_x} \lt K \leq 1 + s_{c_1} + \ldots + s_{c_x} + (N - s_u). In other words, the answer to this query doesn’t lie inside the subtree of u, and we have to look outside.

To deal with the third case, one obvious way would be to simply move to the parent p of u, update K appropriately, and once again run this process.
Updating K can be done with a bit of casework, once again depending on the relative order of u as a child of p.

However, this can degenerate to \mathcal{O}(N) per query, since you might have to move up to a parent \mathcal{O}(N) times.

Notice that the only operation that really needs to be sped up is the ‘move to parent’ operation, where it’d be nice if we were able to move up multiple steps at the same time.

This is exactly what binary lifting accomplishes!

In fact, that’s pretty much the remainder of the solution: use binary lifting to maintain appropriate data so that each query can be answered in \mathcal{O}(\log N) time.

Unfortunately, the devil is in the details: the hard part here is maintaining ‘appropriate data’ across the lift.
A bunch of things need to be maintained so that a query can be answered properly, for example:

  • Ancestors of elements
  • Subtree sizes
  • The order in which children of a vertex are visited, and the total number of other vertices visited before entering this one (from the perspective of the child).
  • Left and right borders of vertices (note that the hard part above was dealing with K when it was in the ‘middle’ of the values for u. We maintain the left and right borders of this middle).

It all reduces to maintaining a few formulae in terms of these values, and then answering a query is a simple binary lift in \mathcal{O}(\log N) time where u and K are changed appropriately.

I recommend looking at the code linked below for how to implement this, if you are stuck.

TIME COMPLEXITY

\mathcal{O}(N \log N) per test case.

CODE:

Setter's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int int64_t
#define sp << ' ' <<
#define nl << '\n'

const int Z = 3e5, B = 20;

// INPUT, 0-indexed
int N;
vector<int> g[Z];
// ................

int dfsTimer, tin[Z], sz[Z], t[Z], e[Z], p[Z][B], q[Z][B], l[Z][B], r[Z][B];
bool h[Z][B];

void dfs(int u) {
	e[tin[u] = dfsTimer++] = u;

	for(int i = 0; i + 1 < B; ++i) {
		p[u][i+1] = p[p[u][i]][i];
		q[u][i+1] = q[p[u][i]][i];
		h[u][i+1] = h[p[u][i]][i];
	}
	sort(begin(g[u]), end(g[u]));

	for(int v : g[u]) {
		if(v != p[u][0]) {
			p[v][0] = u;
			q[v][0] = v;
			h[v][0] = 1;
			dfs(v);
		} else
			l[u][0] = dfsTimer - tin[u];
	}
	sz[u] = dfsTimer - tin[u];

	int add {};
	for(int v : g[u]) {
		if(v != p[u][0]) t[v] = tin[v] - tin[u] + add;
		else add = N - sz[u];
	}

	r[u][0] = l[u][0] + (N - sz[u]) - 1;
}

int query(int u, int x) {
	for(int i = B; i--; ) if(h[u][i]) {
		if(l[u][i] <= x && x <= r[u][i]) {
			x -= l[u][i];
			if(t[q[u][i]] <= x) x += sz[q[u][i]];
			u = p[u][i];
		}
	}
	if(!u) return e[x];

	if(x < l[u][0])	return e[tin[u] + x];
	return e[tin[u] + x - (N - sz[u])];
}

signed main() {
	cin.tie(0)->sync_with_stdio(0);

	int T; cin >> T;
	while(T--) {
		cin >> N;
		for(int i = 0; i < N; ++i) {
			g[i].clear();
			fill(h[i], h[i] + B, 0);
		}
		dfsTimer = 0;

		for(int i = 1; i < N; ++i) {
			int u, v; cin >> u >> v;
			--u, --v;
			g[u].push_back(v);
			g[v].push_back(u);
		}

		dfs(0);

		for(int i = 0; i + 1 < B; ++i) {
			for(int u = 1; u < N; ++u) {
				int &lv = l[u][i+1] = l[p[u][i]][i];
				int &rv = r[u][i+1] = r[p[u][i]][i];

				if(p[u][i] && q[u][i] < p[p[u][i]][0]) {
					lv -= sz[q[u][i]];
					rv -= sz[q[u][i]];
				}

				lv += l[u][i];
				rv += l[u][i];
			}
		}

		int Q, last {}; cin >> Q;
		while(Q--) {
			int u, x; cin >> u >> x;
			cout << (last = query((u ^ last) - 1, x ^ last) + 1) nl;
		}
	}
}
Editorialist's code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());


int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	const int LOG = 19;

	int t; cin >> t;
	while (t--) {
		int n; cin >> n;
		vector<vector<int>> adj(n+1);
		for (int i = 1; i < n; ++i) {
			int u, v; cin >> u >> v;
			adj[u].push_back(v);
			adj[v].push_back(u);
		}
		for (int i = 1; i <= n; ++i) sort(begin(adj[i]), end(adj[i]));

		vector<int> tin(n+1), ord, subsz(n+1, 1), enter(n+1);
		vector<array<int, LOG>> anc(n+1), left(n+1), right(n+1);
		int timer = 0;
		auto dfs = [&] (const auto &self, int u, int par) -> void {
			tin[u] = timer++;
			ord.push_back(u);
			for (int i = 1; i < LOG; ++i) anc[u][i] = anc[anc[u][i-1]][i-1];
			for (int v : adj[u]) {
				if (v == par) {
					left[u][0] = timer - tin[u];
					continue;
				}
				anc[v][0] = u;
				self(self, v, u);
				subsz[u] += subsz[v];
			}
			right[u][0] = left[u][0] + n - subsz[u] - 1;
			for (int v : adj[u]) {
				if (v == par) continue;
				enter[v] = tin[v] - tin[u];
				if (v > par) enter[v] += n - subsz[u];
			}
		};
		auto kth = [&] (int u, int k) {
			for (int i = LOG - 1; i >= 0; --i) {
				if (k >= (1 << i)) {
					k -= 1 << i;
					u = anc[u][i];
				}
			}
			return u;
		};
		auto upd = [&] (auto &table, int u, int i) {
			table[u][i] = left[u][i-1] + table[anc[u][i-1]][i-1];
			if (anc[u][i-1] > 1) {
				int x = kth(u, (1 << (i-1)) - 1);
				if (x < anc[x][1]) table[u][i] -= subsz[x];
			}
		};
		auto query = [&] (int u, int x) {
			for (int i = LOG - 1; i >= 0; --i) {
				if (anc[u][i] and x >= left[u][i] and x <= right[u][i]) {
					int y = kth(u, (1 << i) - 1);
					x -= left[u][i];
					if (enter[y] <= x) x += subsz[y];
					u = anc[u][i];
				}
			}
			return ord[x + tin[u] - (x >= left[u][0])*(n - subsz[u])];
		};
		dfs(dfs, 1, 0);
		for (int i = 1; i < LOG; ++i) {
			for (int u = 1; u <= n; ++u) {
				upd(left, u, i);
				upd(right, u, i);
			}
		}

		int q, ans = 0; cin >> q;
		while (q--) {
			int u, x; cin >> u >> x;
			u ^= ans;
			x ^= ans;
			ans = query(u, x);
			cout << ans << '\n';
		}
	}
}