DDQUERY Editorial

PROBLEM LINK:

Practice
Contest

Author & Editoralist: Jafar Badour
Tester: Teja Reddy

PROBLEM EXPLANATION

You are given an unweighted tree with N nodes (numbered 1 through N). Let’s denote the distance between any two nodes p and q by d(p, q).

You should answer Q queries. In each query, you are given parameters a, d_a, b, d_b, and you should find a node x such that d(x, a) = d_a and d(x, b) = d_b, or determine that there is no such node.

1 \leq N,Q \leq 10^6

DIFFICULTY:

Medium

PREREQUISITES:

Trees, LCA, dp

EXPLANATION:

This editorial assumes you are familiar with following:
Good understanding of finding LCA of 2 nodes in a tree
Dp on trees.

Assume we have 2 nodes x,y. Let’s take some other node mid which is strictly on the path between x,y and assume that its distances to them are dx,dy respectively.

Assume that we move one edge closer to y (but remaining on the path between x and y). New distances would become (dx+1,dy-1). If we move one edge close to x they would become (dx-1,dy+1).

Assume we move from mid to any adjacent node but not on the path between x,y, new distances would become (dx+1,dy+1). (*)

You can notice that the parity would remain the same. If for some query:
(d_a+d_b)\%2 \, \neq dis(a,b) \% 2 \,\,\, the answer is -1.

Let’s denote by d the distance between x,y. Let’s assume that P=\{u_1,u_2,\ldots,u_d\} the set of nodes on the path between them. Let’s take any node from this path and let’s say it’s called mid.

It’s obvious that dis(mid,x)+dis(mid,y)=d and there must be no other node mid_2 \in S such that dis(mid,x)=dis(mid_2,x). (It was proved few lines above).

Also notice that for any node mid we will have -d \leq dis(mid,x)-dis(mid,y) \leq d

And for any K such that -d \leq K \leq d there’s only one node mid \in P that satisfies dis(mid,x)-dis(mid,y)=K

What does this mean?

For some query (x,dx,y,dy) we must search for some node on the path between x,y and let’s call it mid such that dis(mid,x)-dis(mid,y)=dx-dy

After finding mid we should move further in the tree until we satisfy the query, and we must move in a direction such that we get further from both x,y. (Look at *).

Finding mid is easy. Let’s assume we have a function Kth(x,y,K) that returns the K_{th} node on the path from x to y assuming 1_{st} node is x.

I will not explain how to write this function (you should know it already but I will include a snippet).

int Kth(int x , int y , int K){
    int lca = LCA(x , y);
    int a = depth[x] - depth[lca] + 1;
    if(a >= K)
        return jump(x , K - 1);
    K -= a;
    return jump(y , depth[y] - depth[lca] - K);
}

For some query x,dx,y,dy we will have mid = Kth(x , y , (dis(x,y) + dx - dy)/2 + 1)

Now we must keep moving from mid in such a way that we get further from both x,y until we have our desired distances.

Note that dx-dis(mid,x) must be equal to dy-dis(mid,y) otherwise you did something wrong.

Let’s keep for each node x the furthest 3 leaves from x in such a way that each of these leaves is coming from a different subtree (different incident node to x)

If we have the furthest 3 leaves from mid it’s guaranteed that our desired node is in the direction of one of them (in worst case 2 of them are in the direction of x or y). Let’s assume that in some direction (of these 3) we have a leaf v with distance l then we should move dx-dis(mid,x) nodes in that direction. It’s obvious that our desired node t=Kth(v , mid , l - dx + 1)

For each of three directions if we can move upto dx-dis(mid,x) steps we find a node in this direction, afterthat to confirm that we didn’t move in direction of x or y we can simply check if dis(t,x)==dx and dis(t,y)==dy. In case we found a node it’s our answer.

Bonus:

How to find furthest 3 leaves from a certain node x? We can do it with dp. Let’s keep in pairs a dp[N][3] table (dis,which) such that the dp[x][0].first is the distance to the furthest leaf, and dp[x][0].second is the leaf itself. dp[x][1] denotes the second furthest leaf, dp[x][2] denotes the third furthest leaf. Take a look at the code:

Snippet 1
pair < int , int > inc(pair < int , int > p){
    return {p.first + 1 , p.second};
}
void Pdfs(int x , int p){
    int nxt , C , sz=v[x].size();

    dp[x][0] = dp[x][1] = dp[x][2] = {-1 , -1};

    if(sz == 1 && p != -1){
        dp[x][0] = {0 , x};
        return;
    }

    for(int j=0;j<sz;j++){
        nxt=v[x][j];
        if(nxt == p) continue;
        depth[nxt]=depth[x]+1;
        Pdfs(nxt , x);
        dp[x][2] = max(dp[x][2] , inc(dp[nxt][0]));
        if(dp[x][2] > dp[x][1]) swap(dp[x][2] , dp[x][1]);
        if(dp[x][1] > dp[x][0]) swap(dp[x][1] , dp[x][0]);
    }
}

Now we found for each node the furthest 3 leaves in its subtree. We need also to find a leaf in the direction of the parent. We can also do this with 1 extra dfs.

Snippet2
pair < int , int > inc(pair < int , int > p){
    return {p.first + 1 , p.second};
}
void dfs(int x , int p , pair < int , int > toPar){
    for(auto nxt : v[x]){
        if(nxt == p) continue;
        if(dp[nxt][0].second == dp[x][0].second) dfs(nxt , x , inc(max(dp[x][1] , toPar)));
        else dfs(nxt , x , inc(max(dp[x][0] , toPar)));
    }
    dp[x][2] = max(dp[x][2] , toPar);
    if(dp[x][2] > dp[x][1]) swap(dp[x][2] , dp[x][1]);
    if(dp[x][1] > dp[x][0]) swap(dp[x][1] , dp[x][0]);

}

In second snippet toPar denotes a pair referring to the furthest leaf from the direction of the parent.

These codes are easy to understand, and ideas there are fundamental.

Complexity: O((N+Q)logN)

AUTHOR’S AND TESTER’S SOLUTIONS:

Setter's Solution
    #include<bits/stdc++.h>
using namespace std;
const int MX=(1<<20) , MXL=20;
vector < int > v[MX];
int n , QN , depth[MX] , par[MXL][MX];
pair < int , int > dp[MX][3];
pair < int , int > inc(pair < int , int > p){
    return {p.first + 1 , p.second};
}
void Pdfs(int x , int p){
    int nxt , C , sz=v[x].size();
 
    dp[x][0] = dp[x][1] = dp[x][2] = {-1 , -1};
 
    if(sz == 1 && p != -1){
        dp[x][0] = {0 , x};
        return;
    }
 
    for(int j=0;j<sz;j++){
        nxt=v[x][j];
        if(nxt == p) continue;
        depth[nxt]=depth[x]+1;
        par[0][nxt]=x;
        Pdfs(nxt , x);
        dp[x][2] = max(dp[x][2] , inc(dp[nxt][0]));
        if(dp[x][2] > dp[x][1]) swap(dp[x][2] , dp[x][1]);
        if(dp[x][1] > dp[x][0]) swap(dp[x][1] , dp[x][0]);
    }
}
void process(){
    for(int j=1;j<MXL;j++)
        for(int i=1;i<=n;i++)
            par[j][i]=par[j-1][par[j-1][i]];
}
int jump(int x , int K){
    int node=x;
    for(int j=0;j<MXL;j++)
        if((K&(1<<j)))
            node=par[j][node];
    return node;
}
int LCA(int x , int y){
    if(depth[x] < depth[y]) swap(x , y);
    x = jump(x , depth[x] - depth[y]);
    if(x == y) return x;
    for(int j = MXL - 1 ; j >= 0 ; j--)
        if(par[j][x] != par[j][y])
            x = par[j][x] , y = par[j][y];
    return par[0][x];
}
int Kth(int x , int y , int K){
    int lca = LCA(x , y);
    int a = depth[x] - depth[lca] + 1;
    if(a >= K)
        return jump(x , K - 1);
    K -= a;
    return jump(y , depth[y] - depth[lca] - K);
}
int DIS(int x , int y){
    return depth[x] + depth[y] - 2 * depth[LCA(x , y)];
}
void dfs(int x , int p , pair < int , int > toPar){
    for(auto nxt : v[x]){
        if(nxt == p) continue;
        if(dp[nxt][0].second == dp[x][0].second) dfs(nxt , x , inc(max(dp[x][1] , toPar)));
        else dfs(nxt , x , inc(max(dp[x][0] , toPar)));
    }
    dp[x][2] = max(dp[x][2] , toPar);
    if(dp[x][2] > dp[x][1]) swap(dp[x][2] , dp[x][1]);
    if(dp[x][1] > dp[x][0]) swap(dp[x][1] , dp[x][0]);
 
}
int main(){
 
    int T;
    cin>>T;
 
    int ln = 0;
 
    while(T--){
        scanf("%d %d",&n,&QN);
        for(int j = 1 ; j <= n ; j++){
            v[j].clear();
            for(int i = 0 ; i < MXL ; i++)
                par[i][j] = 0;
        }
        for(int j=1;j<n;j++){
            int a , b;
            scanf("%d %d",&a,&b);
            v[a].push_back(b);
            v[b].push_back(a);
        }
        depth[1] = 1;
        Pdfs(1 , -1);
        dfs(1 , -1 , {0 , 1});
        process();
 
        for(int qidx = 1 ; qidx <= QN ; qidx++){
            int x , dx , y , dy , qq , ww;
            scanf("%d %d %d %d",&x,&dx,&y,&dy); qq = dx , ww = dy;
            int lca = LCA(x , y);
            int dis = DIS(x , y);
            if(dis % 2 != (dx + dy)%2 ){
                puts("-1");
                continue;
            }
 
            int center = Kth(x , y , (dis + dx - dy)/2 + 1);
            int dcx = DIS(x , center) , dcy = DIS(y , center);
 
            dx -= dcx , dy -= dcy;
 
            //cout<<dx<<' '<<dy<<endl;
            if(dx != dy || dx < 0){
                puts("-1");
                continue;
            }
            int ans = -1;
 
            if(dx == 0){
                ans = center;
            }
 
            for(int k = 0 ; k < 3 && ans == -1; k++){
                int other = dp[center][k].second , len = dp[center][k].first;
                if(other == -1 || len < dx) continue;
                int tmp = Kth(other , center , len - dx + 1);
                if(DIS(tmp,x) == qq && DIS(tmp , y) == ww)
                    ans = tmp;
 
            }
            printf("%d\n",ans);
 
        }
 
    }
} 
4 Likes

Can we apply BFS to nodes a and b instead?

You can only get partial with that.

Why is this wrong?

class Graph:
    
    def __init__(self):
        self.d={}
    def edge(self,x,y):
        xo=self.d
        if(x not in xo):
            xo[x]=[y]
        else:
            xo[x].append(y)
        if(y not in xo):
            xo[y]=[x]
        else:
            xo[y].append(x)
    
    def dist(self,x):
        q=[[x,0]]
        visited=[0 for i in range(n)]
        visited[x-1]=1
        c=[]
        
        while(len(q)>0):
            
            a=q.pop(0)
            c.append(a)
            
            
            for v in self.d[a[0]]:
                if(visited[v-1]==0):
                    q.append([v,a[1]+1])
                    visited[v-1]=1
        return c
#This was my code

Please do write editorial a bit more in deatail. Editorial’s like this are very difficult to understand and just dont help at all.