PROBLEM LINK:
Contest Division 1
Contest Division 2
Contest Division 3
Practice
Setter: Daanish Mahajan
Tester: Istvan Nagy
Editorialist: Taranpreet Singh
DIFFICULTY
Easy-Medium
PREREQUISITES
Tree DP.
PROBLEM
Given a tree with N nodes, answer Q queries of the following type.
- Given two nodes a and b, compute \displaystyle \sum_{i = 1}^N \text{min}(\text{dist}(i, a), \text{dist}(i, b))
QUICK EXPLANATION
EXPLANATION
Some Observation
Let’s assume we have to solve only 1 query. For given nodes a and b, compute \displaystyle \sum_{i = 1}^N \text{min}(\text{dist}(i, a), \text{dist}(i, b)).
Let us write all nodes on path from u to v in the order they appear.
Considering the tree in above image, and a = 1 and b = 5, the nodes on path from a to b are 1 3 2 4 5
. Let’s call these nodes as special nodes, highlighted in red.
Also, Let sp_u denote the special node directly reachable via only non-special nodes.
Now, for every non-special node, We can see that from this non-special node, we can reach exactly 1 special node by travelling on non-special node.
In above example, nodes 1, 6 and 11 can reach 1, so sp_1 = sp_6 = sp_{11} = 1, nodes 3, 8, 9, 10 can reach node 3, so sp_3 = sp_8 = sp_9 = sp_{10} = 3, node 2 and 12 can reach node 2, so sp_2 = sp_{12} = 2, only node 4 can reach node 4, so sp_4 = 4 and node 5 and 7 can reach node 5, so we have sp_5 = sp_7 = 5.
Claim: If we root the tree at node u, the lca of a and b is sp_u.
The implication of above statement is that if we consider shortest path from u to a and from u to v, the path from u to sp_u remains same, then it separates.
This way, we can see that for some node u, dist(u, a) = dist(u, sp_u)+dist(sp_u, a) and dist(u, b) = dist(u, sp_u) + dist(sp_u, b)
Hence, \text{min}(\text{dist}(u, a), \text{dist}(u, b)) = \text{dist}(u, sp_u) + \text{min}(\text{dist}(sp_u, a), \text{dist}(sp_u, b)).
Hence, if we group all nodes by sp_u, we can see that for all nodes in same group, it is optimal to considering shortest path from all nodes in group to a, or all nodes in group to b.
So, if for each special node, we know if it is optimal to connect it to a or b, we can decide for whole group, since each group contains exactly one special node.
Another Observation
Let’s consider the nodes on path from a to b in the order they appear on path from a to b, then as we move, dist(u, a) increases and dist(u, b) decreases. Let’s call this list P.
It is easy to see that first half of the nodes have dist(u, a) \leq dist(u, b), and rest half has dist(u, a) \geq dist(u, b). If there are odd number of nodes, we may consider it to be in either group, though take care not to include in both, as that would lead to double counting.
For above example, nodes 1 3 2
are in left half and 4 5
are in right half (We can include 2
in either half).
Hence, for all nodes u of tree having sp_u in left half of list P, we must add dist(u, a) to answer, and for all nodes u in tree having sp_u in right half, we must add dist(u, b) to answer.
Intuition
Think of it as removing edge 2 4
in above tree. What do we get?
We get two trees, one of them containing a and one of them containing b. For all nodes reachable from a, we want to compute sum of \text{dist}(u, a), and for all nodes reachable from b, we want to compute sum of \text{dist}(u, b). The above observation proves that the sum of distances computed is the required sum of minimum distances. Let’s say set A denotes set of nodes connected to A after removing edge, and B denotes set of nodes connected to B after removing edge.
In above picture, for set of nodes S_a = \{1,2,3,6,8,9,10,11,12\}, we must compute \displaystyle X = \sum_{u \in S_a} dist(u, a) and for set of nodes S_b = \{4,5,7\}, we must compute Y = \displaystyle \sum_{u \in S_b} dist(u, b). Then X+Y si the required answer.
Enough about Idea, How to compute distances?
This point on, we have to answer Q queries, each query specifying a and b. So we cannot iterate over all nodes computing distances. What information do we need?
Using precomputation, we can compute distance of a and b, and therefore the edge, which we removed (We only visualize removing that edge, we do not remove it in real).
So for pair (1, 5), we find edge (2, 4) which we remove.
Let’s choose any node as root, and precompute down_u as \displaystyle \sum_{v \in sub_u} dist(u, v), where sub_u denotes set of nodes in subtree of node u, which can be done using simple DFS.
Then compute U_u as \displaystyle \sum_{v \notin sub_u} dist(u, v), the sum of distances of node u to all nodes outside subtree of node v. To compute this, Tree DP, specifically In-Out DP is required. This video is a good place to learn.
What we can do is compute \sum_{u} dist(u, a) + dist(u, b) and then subtract out \displaystyle\sum_{u \in A} dist(u, b) and \displaystyle\sum_{u \in B} dist(u, a)
Let’s say edge (x, y) was removed, where a is reachable from x and b is reachable from y.
Then \displaystyle\sum_{u \in B} dist(u, a) = \sum_{u \in B} dist(y, a)+dist(y, u). Note that dist(a, y) is constant, and \displaystyle \sum_{u \in B} dist(y, u) is the sum of values in subtree of node y, if node x is parent of node y.
Similarly, \displaystyle\sum_{u \in A} dist(u, b) = \sum_{u \in A} dist(x, b)+dist(x, u). Note that dist(b, x) is constant, and \displaystyle \sum_{u \in A} dist(x, u) is the sum of values in subtree of node x, if node y is parent of node x.
These sums can also be calculated using D and U arrays and sub array, where sub denotes the subtree size of nodes.
It is worth a shot to figure out the exact formula now, if facing trouble, I have added comments in my solution, which you may refer below.
TIME COMPLEXITY
For each query, we need to compute distance, and then the middle edge, which can be done in O(log(N)) using Binary lifting, the rest is O(1) per query after O(N*log(N)) computation, leading to time complexity O((N+Q)*log(N)) per test case.
SOLUTIONS
Setter's Solution
#include<bits/stdc++.h>
# define pb push_back
#define pii pair<int, int>
#define mp make_pair
# define ll long long int
using namespace std;
const int maxtq = 5e5, maxtn = 5e5, maxn = 1e5, maxq = 1e5;
const string newln = "\n", space = " ";
vector<int> g[maxn + 10];
bool visit[maxn + 10];
int depth[maxn + 10], parent[maxn + 10][20], subsize[maxn + 10];
ll subdist[maxn + 10], totdist[maxn + 10];
int n, q;
bool isGraph(int u, int pa){
parent[u][0] = pa; subsize[u] = 1; subdist[u] = 0;
if(visit[u])return false;
visit[u] = true;
for(int v : g[u]){
if(v == pa)continue;
depth[v] = depth[u] + 1;
if(!isGraph(v, u))return false;
subsize[u] += subsize[v];
subdist[u] += subdist[v] + subsize[v];
}
return true;
}
void dfs(int u, int pa){
for(int v : g[u]){
if(v == pa)continue;
totdist[v] = totdist[u] - 2 * subsize[v] + n; //subdist[v] + totdist[u] - subdist[v] - subsize[v] + n - subsize[v]
dfs(v, u);
}
}
int lca(int u, int v){
if(u == v)return u;
if(depth[u] < depth[v])swap(u, v);
for(int i = 19; i >= 0; i--){
if(depth[u] - (1 << i) >= depth[v]){
u = parent[u][i];
}
}
if(u == v)return u;
for(int i = 19; i >= 0; i--){
if(parent[u][i] != parent[v][i]){
u = parent[u][i]; v = parent[v][i];
}
}
return parent[u][0];
}
int main()
{
ios_base::sync_with_stdio(false); cin.tie(NULL); cout.tie(NULL);
int t; cin >> t;
int tn = 0, tq = 0;
while(t--){
cin >> n >> q;
for(int i = 0; i <= n; i++){
g[i].clear();
visit[i] = false;
}
int u, v;
for(int i = 1; i < n; i++){
cin >> u >> v;
assert(u != v); assert(u != 0); assert(v != 0);
g[u].pb(v); g[v].pb(u);
}
depth[1] = 0;
assert(isGraph(1, 0));
for(int i = 1; i < 20; i++){
for(int j = 1; j <= n; j++){
parent[j][i] = parent[parent[j][i - 1]][i - 1];
}
}
totdist[1] = subdist[1];
dfs(1, 0);
while(q--){
cin >> u >> v;
assert(u != v);
if(depth[u] < depth[v])swap(u, v);
int dist = depth[u] + depth[v] - 2 * depth[lca(u, v)];
int jump = (dist - 1) / 2;
int pa1 = u, pa2; // breaking edge
for(int i = 19; i >= 0; i--){
if(jump - (1 << i) >= 0){
jump -= (1 << i);
pa1 = parent[pa1][i];
}
}
pa2 = parent[pa1][0];
// u
ll ans = totdist[u] - (totdist[pa2] - (subdist[pa1] + subsize[pa1]) + (ll)(n - subsize[pa1]) * ((dist + 1) / 2));
// v
ans += totdist[v] - (subdist[pa1] + (ll)subsize[pa1] * (dist - (dist - 1) / 2));
cout << ans << endl;
}
}
}
Tester's Solution
#include <iostream>
#include <cassert>
#include <vector>
#include <set>
#include <map>
#include <algorithm>
#include <random>
#ifdef HOME
#include <windows.h>
#endif
#define all(x) (x).begin(), (x).end()
#define rall(x) (x).rbegin(), (x).rend()
#define forn(i, n) for (int i = 0; i < (int)(n); ++i)
#define for1(i, n) for (int i = 1; i <= (int)(n); ++i)
#define ford(i, n) for (int i = (int)(n) - 1; i >= 0; --i)
#define fore(i, a, b) for (int i = (int)(a); i <= (int)(b); ++i)
template<class T> bool umin(T &a, T b) { return a > b ? (a = b, true) : false; }
template<class T> bool umax(T &a, T b) { return a < b ? (a = b, true) : false; }
using namespace std;
long long readInt(long long l, long long r, char endd) {
long long x = 0;
int cnt = 0;
int fi = -1;
bool is_neg = false;
while (true) {
char g = getchar();
if (g == '-') {
assert(fi == -1);
is_neg = true;
continue;
}
if ('0' <= g && g <= '9') {
x *= 10;
x += g - '0';
if (cnt == 0) {
fi = g - '0';
}
cnt++;
assert(fi != 0 || cnt == 1);
assert(fi != 0 || is_neg == false);
assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
}
else if (g == endd) {
assert(cnt > 0);
if (is_neg) {
x = -x;
}
assert(l <= x && x <= r);
return x;
}
else {
//assert(false);
}
}
}
string readString(int l, int r, char endd) {
string ret = "";
int cnt = 0;
while (true) {
char g = getchar();
assert(g != -1);
if (g == endd) {
break;
}
cnt++;
ret += g;
}
assert(l <= cnt && cnt <= r);
return ret;
}
long long readIntSp(long long l, long long r) {
return readInt(l, r, ' ');
}
long long readIntLn(long long l, long long r) {
return readInt(l, r, '\n');
}
string readStringLn(int l, int r) {
return readString(l, r, '\n');
}
string readStringSp(int l, int r) {
return readString(l, r, ' ');
}
int main(int argc, char** argv)
{
#ifdef HOME
if(IsDebuggerPresent())
{
freopen("../in.txt", "rb", stdin);
freopen("../out.txt", "wb", stdout);
}
#endif
int T = readIntLn(1, 8);
int sumN = 0;
int sumQ = 0;
forn (tc, T)
{
int N = readIntSp(2, 100'000);
sumN += N;
int Q = readIntLn(1, 100'000);
sumQ += Q;
vector<pair<int, int>> edges(N - 1);
vector<vector<int>> neighb(N);
vector<int64_t> childCtr(N);
vector<int64_t> childDist(N);
vector<int64_t> topDist(N);
for(auto& ei: edges)
{
ei.first = readIntSp(1, N);
ei.second = readIntLn(1, N);
--ei.first;
--ei.second;
assert(ei.first != ei.second);
neighb[ei.first].push_back(ei.second);
neighb[ei.second].push_back(ei.first);
}
vector<int> parent(N, -1);
vector<int> depth(N, 0);
vector<bool> visited(N);
vector<int> q(1);
parent[0] = 0;
forn(i, q.size())
{
int actNode = q[i];
visited[actNode] = true;
for (const auto& ne : neighb[actNode])
{
if (!visited[ne])
{
visited[ne] = true;
parent[ne] = actNode;
q.push_back(ne);
depth[ne] = depth[actNode] + 1;
}
}
}
reverse(q.begin(), q.end());
for(const auto& actNode: q)
{
for (const auto& ne : neighb[actNode])
{
if (parent[actNode] != ne)
{
childCtr[actNode] += childCtr[ne] + 1;
childDist[actNode] += childDist[ne] + childCtr[ne] + 1;
}
}
}
reverse(q.begin(), q.end());
for (const auto& actNode : q)
{
if(actNode==0)
continue;
int p = parent[actNode];
topDist[actNode] = topDist[p] + childDist[p] + N- 2*(childCtr[actNode] + 1) - childDist[actNode];
}
assert(q.size() == N);
vector<vector<int>> dp(20, vector<int>(N));
dp[0] = parent;
forn(i, 18)
{
forn(j, N)
{
dp[i + 1][j] = dp[i][dp[i][j]];
}
}
auto findAncestor = [&](int u, int h)
{
forn(i, dp.size())
{
if (h & 1)
{
u = dp[i][u];
}
h >>= 1;
if(h == 0)
break;
}
return u;
};
auto findLCA = [&](int u, int v) {
if (depth[u] < depth[v])
swap(u, v);
int dd = depth[u] - depth[v];
//int res = dd;
u = findAncestor(u, dd);
//now u, v on the same level
//find LCA 0, depth[0] where they have the same parent
int td = depth[u];
if (u == v)
return u;
int lo = 1, hi = td;
while (lo < hi)
{
int mi = (lo + hi) / 2;
int vp = findAncestor(v, mi);
int up = findAncestor(u, mi);
if (up == vp)
{
hi = mi;
}
else
{
lo = mi + 1;
}
}
return findAncestor(u, lo);
};
forn(q, Q)
{
int a = readIntSp(1, N);
int b = readIntLn(1, N);
assert(a != b);
--a;
--b;
if (depth[a] > depth[b])
swap(a, b);
int dd = findLCA(a, b);
int dab = depth[a] + depth[b] - 2 * depth[dd];
int mid = findAncestor(b, dab / 2);
int dbm = dab/2;
int dam = dab - dbm;
//find the node in the middle
int64_t ans = topDist[a] + childDist[a] - childDist[mid] - dam*(childCtr[mid]+1);
ans += childDist[b] + topDist[b] - topDist[mid] - dbm * (N - childCtr[mid]);
ans += dbm;
printf("%lld\n", ans);
}
}
assert(sumN <= 500'000);
assert(sumQ <= 500'000);
return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class DUALDIST{
//SOLUTION BEGIN
final int B = 18;
void pre() throws Exception{}
void solve(int TC) throws Exception{
int N = ni(), Q = ni();
int[] from = new int[N-1], to = new int[N-1];
for(int i = 0; i< N-1; i++){
from[i] = ni()-1;
to[i] = ni()-1;
}
int[][] g = make(N, from, to);
long[] sub = new long[N], dDown = new long[N], dUp = new long[N];
int[] depth = new int[N];
int[][] par = new int[B][N];
for(int b = 0; b< B; b++)Arrays.fill(par[b], -1);
dfs(g, par, depth, sub, dDown, 0, -1);
dfs2(g, sub, dDown, dUp, 0, -1);
for(int q = 0; q< Q; q++){
int u = ni()-1, v = ni()-1;
if(depth[u] > depth[v]){
int tmp = u;
u = v;
v = tmp;
}
int lca = lca(par, depth, u, v);
int dist = depth[u]+depth[v]-2*depth[lca];
int lift = (dist-1)/2;
int ev = lift(par, v, lift);
hold(ev != lca);
int eu = par[0][ev];
//Edge (eu, ev) is removed, where u is reachable from eu and v is reachable from ev
//node eu is parent of ev
//sub[ev]*dist(u, ev) => distance from u to ev for each node in subtree of ev
//dDown[ev] = sum of distances of nodes in subtree of ev from ev
//((N-sub[ev])*dist(par, depth, eu, v) => distance of node v to node eu for each node not in subtree of ev
//dUp[eu] + dDown[eu] => distance of eu to all nodes
//-(sub[ev]+dDown[ev])) => distaince of eu to all nodes in subtree of ev
long distU = dDown[u]+dUp[u] - (sub[ev]*dist(par, depth, u, ev)+dDown[ev]);
long distV = dDown[v]+dUp[v] - ((N-sub[ev])*dist(par, depth, eu, v) + dUp[eu] + dDown[eu]-(sub[ev]+dDown[ev]));
pn(distU+distV);
}
}
void dfs2(int[][] g, long[] sub, long[] dDown, long[] dUp, int u, int p){
int N = g.length;
for(int v:g[u]){
if(v == p)continue;
dUp[v] = dUp[u] + dDown[u]-(sub[v]+dDown[v]) + (N-sub[v]);
dfs2(g, sub, dDown, dUp, v, u);
}
}
void dfs(int[][] g, int[][] par, int[] depth, long[] sub, long[] dDown, int u, int p){
par[0][u] = p;
for(int b = 1; b< B; b++)
if(par[b-1][u] != -1)
par[b][u] = par[b-1][par[b-1][u]];
sub[u]++;
for(int v:g[u]){
if(v == p)continue;
depth[v] = depth[u]+1;
dfs(g, par, depth, sub, dDown, v, u);
sub[u] += sub[v];
dDown[u] += sub[v]+dDown[v];
}
}
int dist(int[][] par, int[] d, int u, int v){
return d[u]+d[v]-2*d[lca(par, d, u, v)];
}
int lca(int[][] par, int[] d, int u, int v){
if(d[u] > d[v])u = lift(par, u, d[u]-d[v]);
if(d[v] > d[u])v = lift(par, v, d[v]-d[u]);
if(u == v)return u;
for(int b = B-1; b >= 0; b--)
if(par[b][u] != par[b][v]){
u = par[b][u];
v = par[b][v];
}
return par[0][u];
}
int lift(int[][] par, int u, int delta){
for(int b = B-1; b >= 0; b--)
if(((delta>>b)&1)==1)
u = par[b][u];
return u;
}
int[][] make(int N, int[] from, int[] to){
int[] cnt = new int[N];
for(int x:from)cnt[x]++;
for(int x:to)cnt[x]++;
int[][] g = new int[N][];
for(int i = 0; i< N; i++)g[i] = new int[cnt[i]];
for(int i = 0; i< N-1; i++){
g[from[i]][--cnt[from[i]]] = to[i];
g[to[i]][--cnt[to[i]]] = from[i];
}
return g;
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new DUALDIST().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Feel free to share your approach. Suggestions are welcomed as always.