PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: aryan12
Tester: udhav2003
Editorialist: iceknight1093
DIFFICULTY:
2714
PREREQUISITES:
Computing distances in a tree
PROBLEM:
You’re given a tree with N vertices. Answer Q queries on it:
- Given a subset S of vertices of the tree, define the ugliness of a vertex x of the tree to be \max(\text{dist}(x, y)) across all y \in S.
Find the minimum ugliness across all vertices.
EXPLANATION:
Let’s try to solve a simpler version of the problem first: what if the subset S consisted of all N vertices?
Then, we want to find a vertex that isn’t too far away from all the vertices of the tree.
In particular, if there’s a path of length d in the tree, then any vertex is at a distance of at least \left\lceil \frac{d}{2} \right\rceil away from one of the endpoints of this path.
So, let D be the length of the longest path in the tree (i.e, its diameter).
Then, the answer is at least \left\lceil \frac{D}{2} \right\rceil.
It’s not hard to see that the answer is indeed exactly this: choose the midpoint of the diameter to attain it (if there are two midpoints, choose either one).
This idea can be extended to an arbitrary subset S of vertices.
If there’s a path of length d between two vertices of S, once again we see that the answer is at least \left\lceil \frac{d}{2} \right\rceil.
Yet again, if D is the longest path between two vertices of S, the answer will be exactly \left\lceil \frac{D}{2} \right\rceil by choosing the midpoint of this path.
So, our problem turns into the following:
Given a subset S of vertices of the tree, find the longest path whose endpoints lie in S.
To compute this quickly, recall one of the methods to compute the diameter of a tree:
- Fix an arbitrary vertex x.
- Find the furthest vertex from x; let it be d_1.
- Find the further vertex from d_1; let it be d_2.
- d_1 and d_2 are then endpoints of a diameter of the tree.
This algorithm adapts itself nicely to our version:
- Fix an arbitrary vertex x \in S.
- Find the farthest vertex from x that lies in S; say d_1.
- Find the farthest vertex from d_1 that lies in S; say d_2.
- d_1 and d_2 are the endpoints of one longest path, so D equals the distance between them.
The problem now is performing the second and third step quickly.
Of course, we can run a full BFS/DFS each time to compute distances to all vertices; however that would take \mathcal{O}(N) time and definitely won’t be fast enough to answer multiple queries.
Instead, note that with a bit of precomputation, we can calculate the distances between any two nodes of a tree in \mathcal{O}(\log N) or \mathcal{O}(1) as follows:
- Root the tree at some node, say 1. Then, using a BFS/DFS calculate d_u — the distance of node u from 1 — for every node u.
- Then, we have \text{dist}(x, y) = d_x + d_y - d_l, where l = LCA(x, y) is the lowest common ancestor of x and y.
l can be found in \mathcal{O}(\log N) time using binary lifting, for example.
This allows us to transform our algorithm to \mathcal{O}(|S|\log N) per query from \mathcal{O}(N), since in each step we only compute distances from a vertex of S to all the other vertices of S.
Since the sum of |S| across all queries is bounded by 4\cdot 10^5, this is good enough.
TIME COMPLEXITY
\mathcal{O}((N+Q)\log N) per test case.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
const int INF = 1e18, N = 3e5 + 5;
vector<int> g[N];
int dp[20][N];
int depth[N];
void dfs(int node, int par, int dep)
{
depth[node] = dep;
dp[0][node] = par;
for(int to: g[node])
{
if(to == par)
{
continue;
}
dfs(to, node, dep + 1);
}
}
int dist(int x, int y)
{
if(depth[x] > depth[y])
{
swap(x, y);
}
int diff = depth[y] - depth[x];
int ans = 0;
for(int i = 19; i >= 0; i--)
{
if((1 << i) & diff)
{
y = dp[i][y];
ans += (1 << i);
}
}
if(x == y)
{
return ans;
}
for(int i = 19; i >= 0; i--)
{
if(dp[i][x] != dp[i][y])
{
x = dp[i][x];
y = dp[i][y];
ans += 2 * (1 << i);
}
}
return ans + 2;
}
void Solve()
{
int n, q;
cin >> n >> q;
for(int i = 1; i <= n; i++)
{
g[i].clear();
}
for(int i = 1; i < n; i++)
{
int u, v;
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1, -1, 0);
for(int i = 1; i < 20; i++)
{
for(int j = 1; j <= n; j++)
{
dp[i][j] = (dp[i - 1][j] == -1) ? dp[i - 1][j] : dp[i - 1][dp[i - 1][j]];
}
}
for(int i = 1; i <= q; i++)
{
// cout << "query number: " << i << "\n";
int k;
cin >> k;
if(k == 1)
{
int u;
cin >> u;
cout << "0\n";
continue;
}
int u, v;
cin >> u >> v;
int max_dist = dist(u, v);
for(int j = 3; j <= k; j++)
{
int x;
cin >> x;
int diam1 = dist(x, u);
int diam2 = dist(x, v);
if(diam1 > max_dist && diam1 >= diam2)
{
v = x;
max_dist = diam1;
}
else if(diam2 > max_dist && diam2 >= diam1)
{
u = x;
max_dist = diam2;
}
}
cout << (max_dist + 1) / 2 << "\n";
}
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}
Tester's code (C++)
#pragma GCC optimisation("O3")
#pragma GCC target("avx,avx2,fma")
#pragma GCC optimize("Ofast,unroll-loops")
#include<bits/stdc++.h>
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
#define NUM1 1000000007LL
#define all(a) a.begin(), a.end()
#define beg(a) a.begin(), a.begin()
#define sq(a) ((a)*(a))
#define NUM2 998244353LL
#define MOD NUM2
#define LMOD 1000000006LL
#define fi first
#define se second
typedef long double ld;
const ll MAX = 200010;
const ll MAX2 = MAX;
const ll large = 1e18;
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
void dfs(vector<ll> adj[], ll st, ll p, vector<ll>& prt, vector<ll>& dist, ll dval)
{
dist[st] = dval;
prt[st] = p;
for(auto x: adj[st]){
if(x != p){
dfs(adj, x, st, prt, dist, dval + 1);
}
}
}
signed main(){
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
ll t;
cin >> t;
while(t--){
ll n, q;
cin >> n >> q;
vector<ll> adj[n];
for(ll i = 0; i < n - 1; i++){
ll u, v;
cin >> u >> v;
u -= 1; v -= 1;
adj[u].push_back(v);
adj[v].push_back(u);
}
vector<ll> prt(n), dist(n);
dfs(adj, 0, 0, prt, dist, 0);
vector<vector<ll>> up(n, vector<ll>(20));
for(ll i = 0; i < n; i++) up[i][0] = prt[i];
for(ll j = 1; j < 20; j++){
for(ll i = 0; i < n; i++){
up[i][j] = up[up[i][j - 1]][j - 1];
}
}
auto jump = [&](ll v, ll val){
for(ll j = 19; j >= 0; j--){
if(val >= (1 << j)){
val -= (1 << j);
v = up[v][j];
}
}
assert(val == 0);
return v;
};
auto lca = [&](ll a, ll b){
if(dist[a] < dist[b]) swap(a, b);
a = jump(a, dist[a] - dist[b]);
if(a == b) return a;
for(ll j = 19; j >= 0; j--){
if(up[a][j] != up[b][j]){
a = up[a][j];
b = up[b][j];
}
}
return prt[a];
};
auto dis = [&](ll a, ll b){
return dist[a] + dist[b] - 2*dist[lca(a, b)];
};
while(q--){
// cerr << "ehre\n";
ll k;
cin >> k;
vector<ll> v(k);
for(auto& x: v){
cin >> x;
x--;
}
if(k == 1){
cout << 0 << '\n';
continue;
}
vector<ll> e(2);
e[0] = v[0]; e[1] = v[1];
ll dval = dis(e[0], e[1]);
for(ll i = 2; i < k; i++){
// cerr << i << '\n';
ll p0 = dis(v[i], e[0]);
ll p1 = dis(v[i], e[1]);
// cerr << p0 << ' ' << p1 << '\n';
if(p1 > p0){
if(p1 > dval){
dval = p1;
e[0] = v[i];
}
}
else{
if(p0 > dval){
dval = p0;
e[1] = v[i];
}
}
}
cout << (1 + dval)/2 << '\n';
}
}
return 0;
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());
template<class T>
struct RMQ {
vector<vector<T>> jmp;
RMQ(const vector<T>& V) : jmp(1, V) {
for (int pw = 1, k = 1; pw * 2 <= (int)size(V); pw *= 2, ++k) {
jmp.emplace_back(size(V) - pw * 2 + 1);
for (int j = 0; j < (int)size(jmp[k]); ++j)
jmp[k][j] = min(jmp[k - 1][j], jmp[k - 1][j + pw]);
}
}
T query(int a, int b) {
assert(a < b); // or return inf if a == b
int dep = 31 - __builtin_clz(b - a);
return min(jmp[dep][a], jmp[dep][b - (1 << dep)]);
}
};
struct LCA {
int T = 0;
vector<int> time, out, dep, path, ret;
RMQ<int> rmq;
LCA(auto& C) : time(size(C)), out(size(C)), dep(size(C)), rmq((dfs(C,0,-1), ret)) {}
void dfs(auto& C, int v, int par) {
time[v] = T++;
for (int y : C[v]) if (y != par) {
path.push_back(v), ret.push_back(time[v]);
dep[y] = 1 + dep[v];
dfs(C, y, v);
}
out[v] = T;
}
int lca(int a, int b) {
if (a == b) return a;
tie(a, b) = minmax(time[a], time[b]);
return path[rmq.query(a, b)];
}
};
int main()
{
ios::sync_with_stdio(false); cin.tie(0);
int t; cin >> t;
while (t--) {
int n, q; cin >> n >> q;
vector adj(n, basic_string<int>());
for (int i = 0; i < n-1; ++i) {
int u, v; cin >> u >> v;
adj[--u].push_back(--v);
adj[v].push_back(u);
}
LCA L(adj);
vector active(n, basic_string<array<int, 2>> ());
const int inf = 1e9;
vector<int> dist(n, inf);
auto farthest = [&] (int src) {
basic_string<int> used;
dist[src] = 0;
queue<int> q; q.push(src);
int far = src;
while (!q.empty()) {
int u = q.front(); q.pop();
used.push_back(u);
if (dist[u] > dist[far]) far = u;
for (auto [v, w] : active[u]) {
if (dist[v] <= w + dist[u]) continue;
dist[v] = w + dist[u];
q.push(v);
}
}
int mxd = dist[far];
for (auto x : used) dist[x] = inf;
return array{far, mxd};
};
auto solve = [&] (auto &vertices) {
// Build virtual tree of vertices
sort(begin(vertices), end(vertices), [&] (int u, int v) {return L.time[u] < L.time[v];});
int k = size(vertices);
for (int i = 0; i+1 < k; ++i) vertices.push_back(L.lca(vertices[i], vertices[i+1]));
sort(begin(vertices), end(vertices), [&] (int u, int v) {return L.time[u] < L.time[v];});
vertices.erase(unique(begin(vertices), end(vertices)), end(vertices));
stack<int> st;
for (int x : vertices) {
while (!st.empty()) {
int u = st.top();
if (L.out[u] >= L.out[x] and u != x) break;
st.pop();
}
if (!st.empty()) {
int u = st.top(); // u is the parent of x in this virtual tree
active[u].push_back({x, L.dep[x] - L.dep[u]});
active[x].push_back({u, L.dep[x] - L.dep[u]});
}
st.push(x);
}
int u = farthest(vertices[0])[0];
auto [v, diam] = farthest(u);
cout << (diam + 1)/2 << '\n';
for (int x : vertices) active[x].clear();
};
while (q--) {
int k; cin >> k;
basic_string<int> vertices;
for (int i = 0; i < k; ++i) {
int x; cin >> x;
vertices.push_back(--x);
}
solve(vertices);
}
}
}