DUALDIST - Editorial


Contest Division 1
Contest Division 2
Contest Division 3

Setter: Daanish Mahajan
Tester: Istvan Nagy
Editorialist: Taranpreet Singh




Tree DP.


Given a tree with N nodes, answer Q queries of the following type.

  • Given two nodes a and b, compute \displaystyle \sum_{i = 1}^N \text{min}(\text{dist}(i, a), \text{dist}(i, b))



Some Observation

Let’s assume we have to solve only 1 query. For given nodes a and b, compute \displaystyle \sum_{i = 1}^N \text{min}(\text{dist}(i, a), \text{dist}(i, b)).

Let us write all nodes on path from u to v in the order they appear.

Considering the tree in above image, and a = 1 and b = 5, the nodes on path from a to b are 1 3 2 4 5. Let’s call these nodes as special nodes, highlighted in red.

Also, Let sp_u denote the special node directly reachable via only non-special nodes.

Now, for every non-special node, We can see that from this non-special node, we can reach exactly 1 special node by travelling on non-special node.

In above example, nodes 1, 6 and 11 can reach 1, so sp_1 = sp_6 = sp_{11} = 1, nodes 3, 8, 9, 10 can reach node 3, so sp_3 = sp_8 = sp_9 = sp_{10} = 3, node 2 and 12 can reach node 2, so sp_2 = sp_{12} = 2, only node 4 can reach node 4, so sp_4 = 4 and node 5 and 7 can reach node 5, so we have sp_5 = sp_7 = 5.

Claim: If we root the tree at node u, the lca of a and b is sp_u.

The implication of above statement is that if we consider shortest path from u to a and from u to v, the path from u to sp_u remains same, then it separates.

This way, we can see that for some node u, dist(u, a) = dist(u, sp_u)+dist(sp_u, a) and dist(u, b) = dist(u, sp_u) + dist(sp_u, b)

Hence, \text{min}(\text{dist}(u, a), \text{dist}(u, b)) = \text{dist}(u, sp_u) + \text{min}(\text{dist}(sp_u, a), \text{dist}(sp_u, b)).

Hence, if we group all nodes by sp_u, we can see that for all nodes in same group, it is optimal to considering shortest path from all nodes in group to a, or all nodes in group to b.

So, if for each special node, we know if it is optimal to connect it to a or b, we can decide for whole group, since each group contains exactly one special node.

Another Observation

Let’s consider the nodes on path from a to b in the order they appear on path from a to b, then as we move, dist(u, a) increases and dist(u, b) decreases. Let’s call this list P.

It is easy to see that first half of the nodes have dist(u, a) \leq dist(u, b), and rest half has dist(u, a) \geq dist(u, b). If there are odd number of nodes, we may consider it to be in either group, though take care not to include in both, as that would lead to double counting.

For above example, nodes 1 3 2 are in left half and 4 5 are in right half (We can include 2 in either half).

Hence, for all nodes u of tree having sp_u in left half of list P, we must add dist(u, a) to answer, and for all nodes u in tree having sp_u in right half, we must add dist(u, b) to answer.


Think of it as removing edge 2 4 in above tree. What do we get?

We get two trees, one of them containing a and one of them containing b. For all nodes reachable from a, we want to compute sum of \text{dist}(u, a), and for all nodes reachable from b, we want to compute sum of \text{dist}(u, b). The above observation proves that the sum of distances computed is the required sum of minimum distances. Let’s say set A denotes set of nodes connected to A after removing edge, and B denotes set of nodes connected to B after removing edge.

In above picture, for set of nodes S_a = \{1,2,3,6,8,9,10,11,12\}, we must compute \displaystyle X = \sum_{u \in S_a} dist(u, a) and for set of nodes S_b = \{4,5,7\}, we must compute Y = \displaystyle \sum_{u \in S_b} dist(u, b). Then X+Y si the required answer.

Enough about Idea, How to compute distances?

This point on, we have to answer Q queries, each query specifying a and b. So we cannot iterate over all nodes computing distances. What information do we need?

Using precomputation, we can compute distance of a and b, and therefore the edge, which we removed (We only visualize removing that edge, we do not remove it in real).

So for pair (1, 5), we find edge (2, 4) which we remove.

Let’s choose any node as root, and precompute down_u as \displaystyle \sum_{v \in sub_u} dist(u, v), where sub_u denotes set of nodes in subtree of node u, which can be done using simple DFS.

Then compute U_u as \displaystyle \sum_{v \notin sub_u} dist(u, v), the sum of distances of node u to all nodes outside subtree of node v. To compute this, Tree DP, specifically In-Out DP is required. This video is a good place to learn.

What we can do is compute \sum_{u} dist(u, a) + dist(u, b) and then subtract out \displaystyle\sum_{u \in A} dist(u, b) and \displaystyle\sum_{u \in B} dist(u, a)

Let’s say edge (x, y) was removed, where a is reachable from x and b is reachable from y.

Then \displaystyle\sum_{u \in B} dist(u, a) = \sum_{u \in B} dist(y, a)+dist(y, u). Note that dist(a, y) is constant, and \displaystyle \sum_{u \in B} dist(y, u) is the sum of values in subtree of node y, if node x is parent of node y.

Similarly, \displaystyle\sum_{u \in A} dist(u, b) = \sum_{u \in A} dist(x, b)+dist(x, u). Note that dist(b, x) is constant, and \displaystyle \sum_{u \in A} dist(x, u) is the sum of values in subtree of node x, if node y is parent of node x.

These sums can also be calculated using D and U arrays and sub array, where sub denotes the subtree size of nodes.

It is worth a shot to figure out the exact formula now, if facing trouble, I have added comments in my solution, which you may refer below.


For each query, we need to compute distance, and then the middle edge, which can be done in O(log(N)) using Binary lifting, the rest is O(1) per query after O(N*log(N)) computation, leading to time complexity O((N+Q)*log(N)) per test case.


Setter's Solution
# define pb push_back 
#define pii pair<int, int>
#define mp make_pair
# define ll long long int

using namespace std;

const int maxtq = 5e5, maxtn = 5e5, maxn = 1e5, maxq = 1e5;
const string newln = "\n", space = " ";
vector<int> g[maxn + 10];
bool visit[maxn + 10];
int depth[maxn + 10], parent[maxn + 10][20], subsize[maxn + 10];
ll subdist[maxn + 10], totdist[maxn + 10];
int n, q;
bool isGraph(int u, int pa){
    parent[u][0] = pa; subsize[u] = 1; subdist[u] = 0;
    if(visit[u])return false;
    visit[u] = true;
    for(int v : g[u]){
        if(v == pa)continue;
        depth[v] = depth[u] + 1;
        if(!isGraph(v, u))return false;
        subsize[u] += subsize[v];
        subdist[u] += subdist[v] + subsize[v];
    return true;

void dfs(int u, int pa){
    for(int v : g[u]){
        if(v == pa)continue;
        totdist[v] = totdist[u] - 2 * subsize[v] + n; //subdist[v] + totdist[u] - subdist[v] - subsize[v] + n - subsize[v]
        dfs(v, u);

int lca(int u, int v){
    if(u == v)return u;
    if(depth[u] < depth[v])swap(u, v);
    for(int i = 19; i >= 0; i--){
        if(depth[u] - (1 << i) >= depth[v]){
            u = parent[u][i];
    if(u == v)return u;
    for(int i = 19; i >= 0; i--){
        if(parent[u][i] != parent[v][i]){
            u = parent[u][i]; v = parent[v][i];
    return parent[u][0];
int main()
    ios_base::sync_with_stdio(false); cin.tie(NULL); cout.tie(NULL);
    int t; cin >> t;
    int tn = 0, tq = 0;
        cin >> n >> q;
        for(int i = 0; i <= n; i++){
            visit[i] = false;
        int u, v;
        for(int i = 1; i < n; i++){
            cin >> u >> v;
            assert(u != v); assert(u != 0); assert(v != 0);
            g[u].pb(v); g[v].pb(u);
        depth[1] = 0;
        assert(isGraph(1, 0));

        for(int i = 1; i < 20; i++){
            for(int j = 1; j <= n; j++){
                parent[j][i] = parent[parent[j][i - 1]][i - 1];

        totdist[1] = subdist[1];
        dfs(1, 0);

            cin >> u >> v;
            assert(u != v);
            if(depth[u] < depth[v])swap(u, v);
            int dist = depth[u] + depth[v] - 2 * depth[lca(u, v)];
            int jump = (dist - 1) / 2;

            int pa1 = u, pa2; // breaking edge
            for(int i = 19; i >= 0; i--){
                if(jump - (1 << i) >= 0){
                    jump -= (1 << i);
                    pa1 = parent[pa1][i];
            pa2 = parent[pa1][0];
            // u
            ll ans = totdist[u] - (totdist[pa2] - (subdist[pa1] + subsize[pa1]) + (ll)(n - subsize[pa1]) * ((dist + 1) / 2));
            // v
            ans += totdist[v] - (subdist[pa1] + (ll)subsize[pa1] * (dist - (dist - 1) / 2));

            cout << ans << endl;
Tester's Solution
#include <iostream>
#include <cassert>
#include <vector>
#include <set>
#include <map>
#include <algorithm>
#include <random>

#ifdef HOME
    #include <windows.h>

#define all(x) (x).begin(), (x).end()
#define rall(x) (x).rbegin(), (x).rend()
#define forn(i, n) for (int i = 0; i < (int)(n); ++i)
#define for1(i, n) for (int i = 1; i <= (int)(n); ++i)
#define ford(i, n) for (int i = (int)(n) - 1; i >= 0; --i)
#define fore(i, a, b) for (int i = (int)(a); i <= (int)(b); ++i)

template<class T> bool umin(T &a, T b) { return a > b ? (a = b, true) : false; }
template<class T> bool umax(T &a, T b) { return a < b ? (a = b, true) : false; }

using namespace std;

long long readInt(long long l, long long r, char endd) {
    long long x = 0;
    int cnt = 0;
    int fi = -1;
    bool is_neg = false;
    while (true) {
	    char g = getchar();
	    if (g == '-') {
		    assert(fi == -1);
		    is_neg = true;
	    if ('0' <= g && g <= '9') {
		    x *= 10;
		    x += g - '0';
		    if (cnt == 0) {
			    fi = g - '0';
		    assert(fi != 0 || cnt == 1);
		    assert(fi != 0 || is_neg == false);

		    assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
	    else if (g == endd) {
		    assert(cnt > 0);
		    if (is_neg) {
			    x = -x;
		    assert(l <= x && x <= r);
		    return x;
	    else {

string readString(int l, int r, char endd) {
    string ret = "";
    int cnt = 0;
    while (true) {
	    char g = getchar();
	    assert(g != -1);
	    if (g == endd) {
	    ret += g;
    assert(l <= cnt && cnt <= r);
    return ret;
long long readIntSp(long long l, long long r) {
    return readInt(l, r, ' ');
long long readIntLn(long long l, long long r) {
    return readInt(l, r, '\n');
string readStringLn(int l, int r) {
    return readString(l, r, '\n');
string readStringSp(int l, int r) {
    return readString(l, r, ' ');

int main(int argc, char** argv) 
#ifdef HOME
	    freopen("../in.txt", "rb", stdin);
	    freopen("../out.txt", "wb", stdout);
    int T = readIntLn(1, 8);
    int sumN = 0;
    int sumQ = 0;
    forn (tc, T)
	    int N = readIntSp(2, 100'000);
	    sumN += N;
	    int Q = readIntLn(1, 100'000);
	    sumQ += Q;
	    vector<pair<int, int>> edges(N - 1);
	    vector<vector<int>> neighb(N);
	    vector<int64_t> childCtr(N);
	    vector<int64_t> childDist(N);
	    vector<int64_t> topDist(N);
	    for(auto& ei: edges)
		    ei.first = readIntSp(1, N);
		    ei.second = readIntLn(1, N);
		    assert(ei.first != ei.second);
	    vector<int> parent(N, -1);
	    vector<int> depth(N, 0);
	    vector<bool> visited(N);
	    vector<int> q(1);
	    parent[0] = 0;
	    forn(i, q.size())
		    int actNode = q[i];
		    visited[actNode] = true;
		    for (const auto& ne : neighb[actNode])
			    if (!visited[ne])
				    visited[ne] = true;
				    parent[ne] = actNode;
				    depth[ne] = depth[actNode] + 1;
	    reverse(q.begin(), q.end());
	    for(const auto& actNode: q)
		    for (const auto& ne : neighb[actNode])
			    if (parent[actNode] != ne)
				    childCtr[actNode] += childCtr[ne] + 1;
				    childDist[actNode] += childDist[ne] + childCtr[ne] + 1;

	    reverse(q.begin(), q.end());
	    for (const auto& actNode : q)
		    int p = parent[actNode];
		    topDist[actNode] = topDist[p] + childDist[p] + N- 2*(childCtr[actNode] + 1) - childDist[actNode];
	    assert(q.size() == N);
	    vector<vector<int>> dp(20, vector<int>(N));
	    dp[0] = parent;
	    forn(i, 18)
		    forn(j, N)
			    dp[i + 1][j] = dp[i][dp[i][j]];

	    auto findAncestor = [&](int u, int h)
		    forn(i, dp.size())
			    if (h & 1)
				    u = dp[i][u];
			    h >>= 1;
			    if(h == 0)
		    return u;

	    auto findLCA = [&](int u, int v) {
		    if (depth[u] < depth[v])
			    swap(u, v);
		    int dd = depth[u] - depth[v];
		    //int res = dd;
		    u = findAncestor(u, dd);
		    //now u, v on the same level
		    //find LCA 0, depth[0] where they have the same parent
		    int td = depth[u];
		    if (u == v)
			    return u;
		    int lo = 1, hi = td;
		    while (lo < hi)
			    int mi = (lo + hi) / 2;
			    int vp = findAncestor(v, mi);
			    int up = findAncestor(u, mi);
			    if (up == vp)
				    hi = mi;
				    lo = mi + 1;
		    return findAncestor(u, lo);
	    forn(q, Q)
		    int a = readIntSp(1, N);
		    int b = readIntLn(1, N);
		    assert(a != b);
		    if (depth[a] > depth[b])
			    swap(a, b);
		    int dd = findLCA(a, b);
		    int dab = depth[a] + depth[b] - 2 * depth[dd];
		    int mid = findAncestor(b, dab / 2);
		    int dbm = dab/2;
		    int dam = dab - dbm;

		    //find the node in the middle
 			int64_t ans = topDist[a] + childDist[a] - childDist[mid] - dam*(childCtr[mid]+1);
		    ans += childDist[b] + topDist[b] - topDist[mid] - dbm * (N - childCtr[mid]);
		    ans += dbm;
		    printf("%lld\n", ans);
    assert(sumN <= 500'000);
    assert(sumQ <= 500'000);
    return 0;
Editorialist's Solution
import java.util.*;
import java.io.*;
    final int B = 18;
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        int N = ni(), Q = ni();
        int[] from = new int[N-1], to = new int[N-1];
        for(int i = 0; i< N-1; i++){
            from[i] = ni()-1;
            to[i] = ni()-1;
        int[][] g = make(N, from, to);
        long[] sub = new long[N], dDown = new long[N], dUp = new long[N];
        int[] depth = new int[N];
        int[][] par = new int[B][N];
        for(int b = 0; b< B; b++)Arrays.fill(par[b], -1);
        dfs(g, par, depth, sub, dDown, 0, -1);
        dfs2(g, sub, dDown, dUp, 0, -1);
        for(int q = 0; q< Q; q++){
            int u = ni()-1, v = ni()-1;
            if(depth[u] > depth[v]){
                int tmp = u;
                u = v;
                v = tmp;
            int lca = lca(par, depth, u, v);
            int dist = depth[u]+depth[v]-2*depth[lca];
            int lift = (dist-1)/2;
            int ev = lift(par, v, lift);
            hold(ev != lca);
            int eu = par[0][ev];
            //Edge (eu, ev) is removed, where u is reachable from eu and v is reachable from ev
            //node eu is parent of ev
            //sub[ev]*dist(u, ev) => distance from u to ev for each node in subtree of ev
            //dDown[ev] = sum of distances of nodes in subtree of ev from ev
            //((N-sub[ev])*dist(par, depth, eu, v) => distance of node v to node eu for each node not in subtree of ev
            //dUp[eu] + dDown[eu] => distance of eu to all nodes
            //-(sub[ev]+dDown[ev])) => distaince of eu to all nodes in subtree of ev
            long distU = dDown[u]+dUp[u] - (sub[ev]*dist(par, depth, u, ev)+dDown[ev]);
            long distV = dDown[v]+dUp[v] - ((N-sub[ev])*dist(par, depth, eu, v) + dUp[eu] + dDown[eu]-(sub[ev]+dDown[ev]));
    void dfs2(int[][] g, long[] sub, long[] dDown, long[] dUp, int u, int p){
        int N = g.length;
        for(int v:g[u]){
            if(v == p)continue;
            dUp[v] = dUp[u] + dDown[u]-(sub[v]+dDown[v]) + (N-sub[v]);
            dfs2(g, sub, dDown, dUp, v, u);
    void dfs(int[][] g, int[][] par, int[] depth, long[] sub, long[] dDown, int u, int p){
        par[0][u] = p;
        for(int b = 1; b< B; b++)
            if(par[b-1][u] != -1)
                par[b][u] = par[b-1][par[b-1][u]];
        for(int v:g[u]){
            if(v == p)continue;
            depth[v] = depth[u]+1;
            dfs(g, par, depth, sub, dDown, v, u);
            sub[u] += sub[v];
            dDown[u] += sub[v]+dDown[v];
    int dist(int[][] par, int[] d, int u, int v){
        return d[u]+d[v]-2*d[lca(par, d, u, v)];
    int lca(int[][] par, int[] d, int u, int v){
        if(d[u] > d[v])u = lift(par, u, d[u]-d[v]);
        if(d[v] > d[u])v = lift(par, v, d[v]-d[u]);
        if(u == v)return u;
        for(int b = B-1; b >= 0; b--)
            if(par[b][u] != par[b][v]){
                u = par[b][u];
                v = par[b][v];
        return par[0][u];
    int lift(int[][] par, int u, int delta){
        for(int b = B-1; b >= 0; b--)
                u = par[b][u];
        return u;
    int[][] make(int N, int[] from, int[] to){
        int[] cnt = new int[N];
        for(int x:from)cnt[x]++;
        for(int x:to)cnt[x]++;
        int[][] g = new int[N][];
        for(int i = 0; i< N; i++)g[i] = new int[cnt[i]];
        for(int i = 0; i< N-1; i++){
            g[from[i]][--cnt[from[i]]] = to[i];
            g[to[i]][--cnt[to[i]]] = from[i];
        return g;
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = true;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
    public static void main(String[] args) throws Exception{
        new DUALDIST().run();
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
            return st.nextToken();

        String nextLine() throws Exception{
            String str = "";
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            return str;

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:


Nice little extension to this problem:
→ In each query find \displaystyle\sum_{i=1}^N \left\lvert D(i,a)-D(i,b)\right\rvert


Hint: After cutting the edge, one component has D(i, a) \geq D(i, b) and other component has D(i, a) \leq D(i, b).

1 Like

Cool ! Another mathematical way to go about it, for any node i.

\min(D(i,a),D(i,b))=\frac{1}{2}(D(i,a)+D(i,b)-\left\lvert D(i,a)-D(i,b) \rvert\right) \newline \lvert D(i,a)-D(i,b) \rvert=D(i,a)+D(i,b)-2\times \min(D(i,a),D(i,b)) \\ \sum_{i=1}^N\lvert D(i,a)-D(i,b) \rvert=\sum_{i=1}^N D(i,a)+\sum_{i=1}^ND(i,b)-2\times\sum_{i=1}^N \min(D(i,a),D(i,b)) \\

Now while solving the above problem we evaluated all terms present in the RHS.


@taran_1407 I came up with a similar (N+Q)logN solution in python but got TLE ?

Can @cubefreak777 @taran_1407 or anyone else explain me the formula for ans in main function

I haven’t gone through the approach that the editorial suggests, I can explain it once I go through it, till then you can check my \mathcal{O}(N \log N) Solution if it happens to make more sense to you by any chance.

1 Like

Instead of directly making a single dp storing sum of distances directly, I broke this into two parts .
The sum of distances of present node with nodes other than nodes in its subtree (dp_up in below implementation) and other sum of distances of present node with nodes in its subtree (dp_down).
Now, with this it becomes quite evident why a formula like that comes -
The edge where we break removes some contributions from both the a and b total ans.
Let a be the node at a larger distance from LCA of a and b.
Let the broken edge be (c, d).

Then for a, we have to remove the contribution from above of broken edge node i.e dp_up of c
. But from a, these distances are more, we need to add distances between c and a to all of these removed nodes, so we will remove dp_up[c] + (n - number of nodes in subtree of c)*(distance between c and a which is (dist+1) / 2).
Similary, you can see with b, we need to remove dp_down of d after modifying it as we done above.
Implementation: Solution: 47663817 | CodeChef

Thanks this one really helped :slight_smile:

1 Like

@taran_1407 when will you reveal the remaining editorials.

1 Like

Great Editorial. I had this in mind on 4th day of contest

1 Like

All editorials except OPTSET are added. SUBTRCOV is ready, would be posted today.


When will the editorial of OPTSET be posted? It looks like a really interesting problem to me.