REC15E - Editorial

PROBLEM LINK:

Practice
Contest

Setter: Hritesh Mourya
Tester: Chilukuri Sri Harsha
Editorialist: Hritesh Mourya

DIFFICULTY:

MEDIUM-HARD

PREREQUISITES:

Tree, Dynamic Programming

PROBLEM:

We have a tree with N nodes where any one node can be blocked. In each query we need to find the maximum distance between any two nodes without traversing the blocked node.

QUICK EXPLANATION:

When a node with K adjacent nodes is blocked then K sub-trees are formed. Now we need to find the number of nodes in the longest path in all the subtrees. The naïve approach is to perform depth first search (DFS) in each subtree and calculate the longest path length and evaluate the maximum in it. The time complexity will be O(Q × N) or O (N^{2} + Q) for each testcase. Let us now look into a more efficient approach using Dynamic Programming.

EXPLANATION:

Refer to the solution for better understanding.

For any node the longest path can
(i) be present in its subtree.
(ii) include the node.
We will find the maximum of (i) and (ii) using DFS and DP.
We can use 2 arrays: mxLen[i] to stores the maximum depth from i^{th} node and ddp[i] to store maximum of (i) and (ii) for a given node

Calculating the longest path length in all child subtree for any node
mxInside = 0;
Iterating over all the children 
	mxInside = max(ddp[child], mxInside);

Now we have calculated (i) for a node.
In (ii) if the node is included then the path length is cross = 1 + mx1 + mx2, where mx1 and mx2 are the maximum depth and 2nd maximum depth from the given node.

Then we have the maximum (i) and (ii) as
ddp[i] = max(mxInside, cross)

The longest path in all child subtree is calculated. But now we need to check for the subtree in parent node too.

For any node X with Z adjacent nodes, if child node C_{k} is blocked then longest path may be present in any of the other K-1 sub-trees. Given that k_{th} child node is blocked, let inMx be maximum length in sub-tree from C_{0} to C_{k-1} and C_{k+1} to C_{Z-1}. We can efficiently find it by using a prefix and suffix vector of ddp[node] for all child nodes.

    inMx = max(prefix[k-1], suffix[k+1]); 

For any node X with Z adjacent nodes, if child node C_{k} is blocked then longest path may consist of the node along with two longest path in subtree from C_{0} to C_{k-1} and C_{k+1} to C_{Z-1}. This can be achieved by keeping a prefix and suffix (since we need the 2 longest paths we can use vector of multiset or pairs) of mxLen[node] (since it stores the max depth for any node) for all child nodes. In this DFS we also pass the longest path from parent subtree to the child node (upperLong) which is also considered along with sibling subtrees.
Given that k^{th} child is blocked then mx1 and mx2 are the 2 longest path among all its sibling subtrees and parent (i.e. upperLong) . So the longest path can be

	cross = 1 + mx1 + mx2;

For any {parent, child} pair store[par][C_{i}] denotes the longest path length when C_{i} child is blocked and par is its corresponding parent.
store[par][C_{i}] is maximum of all the path we calculated above and the longest path for the parent when the given child node is blocked. (Note that we already calculated the longest path for parent in this DFS itself)

	store[par][C[i]] = max(inMx, cross, longest_path(par)); 

Finally we can compute the longest path length in the Tree using a DFS and memoize it in ans[node] to answer each query in constant time where ans[node] denotes longest path length when node is blocked.

	ans[node] = max(ddp[child] for all children of the given node, store[par][node])

TIME COMPLEXITY:

O(N+Q) for each testcase.

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
#define endl '\n'
#define flash ios_base::sync_with_stdio(false); cin.tie(NULL)
using namespace std;
 
const int mxN = 1e5 + 5;
int n;
 
vector<unordered_map<int,int>> store(mxN);
int dia[mxN],ddp[mxN],dpUp[mxN],mxLen[mxN];
vector<int> adj[mxN];
 
void dfsDown(int curr,int par)
{
    int mxInside = 0;
    multiset<int> mx2;
    mx2.insert(0);mx2.insert(0);
    for (int& it : adj[curr]){
        if (it == par) continue;
        dfsDown(it,curr);
        mx2.insert(mxLen[it]);
        while(mx2.size() > 2){
            mx2.erase(mx2.begin());
        }
        mxInside = (ddp[it] > mxInside ? ddp[it]:mxInside);
    }
    int cross = 1 + *mx2.begin() + *mx2.rbegin();
    ddp[curr] = (mxInside > cross ? mxInside : cross);
    mxLen[curr] = 1 + *mx2.rbegin();
}
 
multiset<int> merge(multiset<int> st1,multiset<int> st2)
{
    multiset<int> res;
    for (int it:st1)
    res.insert(it);
    for(int it:st2)
    res.insert(it);
    while(res.size() > 2){
        res.erase(res.begin());
    }
    return res;
}
 
void dfsUp(int curr,int par,int upperLong){
    vector<int> childs{0},inMxpref{0},inMxsuff{0};
    vector<multiset<int>> joinPref(1),joinSuff(1);
    joinPref[0].insert(0);
    joinSuff[0].insert(0);
    
    multiset<int> st;
    for (int& it:adj[curr]){
        if (it == par) continue;
        childs.push_back(it);
        inMxpref.push_back(ddp[it]);inMxsuff.push_back(ddp[it]);
        st.clear();
        st.insert(mxLen[it]);
        joinPref.push_back(st);joinSuff.push_back(st);
    }
    int sz = childs.size();
    sz--;
    inMxpref.push_back(0);inMxsuff.push_back(0);
    st.clear();st.insert(0);
    joinPref.push_back(st);joinSuff.push_back(st);
    
    for (int i = 1;i <= sz;i++){
        inMxpref[i] = max(inMxpref[i-1],inMxpref[i]);
        joinPref[i] = merge(joinPref[i-1],joinPref[i]);
        
        inMxsuff[sz - i + 1] = max(inMxsuff[sz - i + 2],inMxsuff[sz - i + 1]);
        joinSuff[sz - i + 1] = merge(joinSuff[sz - i + 2],joinSuff[sz - i + 1]);
    }
    
    for (int inMx,lMX,i = 1;i <= sz;i++){
        inMx = max(inMxpref[i - 1],inMxsuff[i + 1]);
        
        st.clear();
        st.insert(upperLong);
        st = merge(st,merge(joinPref[i - 1],joinSuff[i + 1]));
        
        dpUp[curr] = max({dpUp[par] , inMx ,1 + *st.begin() + *st.rbegin()});
        store[curr][childs[i]] = dpUp[curr];
        dfsUp(childs[i],curr,1 + *st.rbegin());
    }
}
 
void dfs(int curr,int par)
{
    int mxm = store[par][curr];
    for (int& it:adj[curr]){
        if (it == par) continue;
        dfs(it,curr);
        mxm = (ddp[it] > mxm ? ddp[it]:mxm);
    }
    dia[curr] = mxm;
}
 
int main() {
    flash;
    int i,u,v,t,q;
    cin>>t;
    while(t--)
    {
        cin >> n;
        for(i=0;i<=n;i++)
        {
            adj[i].clear();
            store[i].clear();
        }
 
        for (i = 0;i < n - 1;i++){
            cin >> u >> v;
            adj[u].push_back(v);
            adj[v].push_back(u);
        }
        
        dfsDown(1,0);
        dfsUp(1,0,0);
        dfs(1,0);
        cin>>q;
        for(int i=1;i<=q;i++){
            int nod;
            cin>>nod;
            cout<<dia[nod]<<endl;
        }
    }
    return 0;
}  
Tester's Solution
#include<bits/stdc++.h>
using namespace std;
#define ASS(a,b,c) assert(a<=b and b<=c);
class Tree{
private:
    int n;
    vector<int> *edge;
    int *sol;
    vector<int> mainPath;
    vector<int> longestSubPath;
    vector<int> longestSubMax;
    vector<int> leftLongest;
    vector<int> rightLongest;
    //type=0 for lastNode from present
    //type=1 for length of last node from present
    int BFS(int node,int type=0,int blockNode1=0,int blockNode2=0){
        queue<int> nodes;
        int frnt;
        bool visited[n+1]={false};
        nodes.push(node);
        nodes.push(-1);
        int sol=node;
        int lengthSol=0;
        while(!nodes.empty()){
            frnt=nodes.front();
            nodes.pop();
            if(frnt==-1){
                lengthSol++;
                if(nodes.empty())
                    break;
                nodes.push(-1);
                continue;
            }
            sol=frnt;
            visited[sol]=true;
            for(auto i:edge[sol]){
                if(
                    visited[i] or
                    blockNode1==i or
                    blockNode2==i
                )   continue;
                nodes.push(i);
            }
        }
        if(type==0)
            return sol;
        else
            return lengthSol;
    }
    int DFS(int x,int y,bool vis[]){ 
        mainPath.push_back(x);
        if (x == y) {
            return 0; 
        } 
        vis[x] = true; 
        for (auto i:edge[x]) { 
            if (vis[i] == false) 
                if(!DFS(i,y,vis))return 0; 
        }
        mainPath.pop_back(); 
        return 1;
    } 
    int findLargestNode(int node,int blockNode1=0,int blockNode2=0){
        return BFS(node,0,blockNode1,blockNode2);
    }
    int findLargestLength(int node,int blockNode1=0,int blockNode2=0){
        return BFS(node,1,blockNode1,blockNode2);
    }
    void generatePath(int x,int y){
        bool visited[n+1]={false};
        DFS(x,y,visited);
    }
    int findLongestSubMax(int node,int blockNode1,int blockNode2){
        int ans=0;
        for(auto x:edge[node]){
            if(
                blockNode1==x or
                blockNode2==x
            )   continue;
            ans=max(
                ans,
                findLargestLength(
                    findLargestNode(
                        x,
                        node,
                        0
                    ),
                    node,
                    0
                )
            ); 
        }
        return ans;
    }
public:
    Tree(int n){
        this->n=n;
        edge=new vector<int>[n+1];
        sol=new int[n+1];
    }
    ~Tree(){
        delete []edge;
        delete []sol;
    }
    void process(){
        int end1=findLargestNode(1);
        int end2=findLargestNode(end1);
        generatePath(end1,end2);
        int longestPathLength=mainPath.size();
        for(int i=1;i<=n;i++){
            sol[i]=longestPathLength;
        }
        for(int i=0;i<longestPathLength;i++){
            longestSubPath.push_back(
                findLargestLength(
                    mainPath[i],
                    i==0?0:mainPath[i-1],
                    i==longestPathLength-1?0:mainPath[i+1]
                )
            );
            longestSubMax.push_back(
                findLongestSubMax(
                    mainPath[i],
                    i==0?0:mainPath[i-1],
                    i==longestPathLength-1?0:mainPath[i+1]
                )
            );
            leftLongest.push_back(max(
                longestSubMax[i],
                longestSubPath[i]+i
            ));
            if(i!=0)
                leftLongest[i]=max(leftLongest[i-1],leftLongest[i]);
            rightLongest.push_back(0);
        }
        for(int i=longestPathLength-1;i>=0;i--){
            rightLongest[i]=max(
                longestSubMax[i],
                longestSubPath[i]+longestPathLength-i-1
            );
            if(i!=longestPathLength-1)
                rightLongest[i]=max(rightLongest[i+1],rightLongest[i]);
        }
        for(int i=0;i<longestPathLength;i++){
            sol[mainPath[i]]=longestSubMax[i];
            if(i!=longestPathLength-1)
                sol[mainPath[i]]=max(rightLongest[i+1],sol[mainPath[i]]);
            if(i!=0)
                sol[mainPath[i]]=max(leftLongest[i-1],sol[mainPath[i]]);
        }
    }
    void in(){
        int a,b;
        for(int i=1;i<n;i++){
            cin>>a>>b;
            ASS(1,a,n);
            ASS(1,b,n);
            assert(a!=b);
            edge[a].push_back(b);
            edge[b].push_back(a);
        }
    }
    void out(){
        int q,x;
        cin>>q;
        ASS(1,1,n);
        while(q--){
            cin>>x;
            ASS(1,x,n);
            cout<<sol[x]<<"\n";
        }
    }
};
 
int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    int t,n;
    cin>>t;
    ASS(1,t,10);
    while(t--){
        cin>>n;
        ASS(2,n,5e4);
        Tree T(n);
        T.in(); 
        T.process();
        T.out(); 
    }
}