CP_04 Editorial

Problem Link

Click here

Difficulty

Medium

Solution

Let’s define dp[i][1] as the length of the longest alternating path that starts from some node i, visiting the nodes lying in its subtree and dp[i][0] as the length of the longest alternating path that starts from some child of node i, visiting the nodes lying in its child’s subtree.
To find the diameter of this tree we are going to iterate over each node i and find the maximum value of (dp[j][1]+dp[k][0]+w_{i-k}) where j and k are two different children of node i and w_{i-k} is the weight of the edge between i and k. To account for the paths existing between a node and its ancestors we just need to consider the value of dp[i][1] for each node.

Code

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
vector<vector<pair<int,ll> > >adj;
vector<vector<ll> >dp;
ll ans=0;
void dfs(int u,int p)
{
    vector<pair<ll,int> >v0,v1;
    //v1 including the last edge
    //v0 not including the last edge
    for(auto child:adj[u])
    {
        if(child.first == p)continue;
        dfs(child.first,u);
        v1.push_back({dp[child.first][0] + child.second,child.first});
        v0.push_back({dp[child.first][1],child.first});
    }
    sort(v1.begin(),v1.end());
    sort(v0.begin(),v0.end());
    if(v1.size() && v0.size())
    {
        int sz0 = v0.size();
        int sz1 = v1.size();
        int i = sz0-1;
        int j = sz1-1;
        if(v0[i].second != v1[j].second)
        {
            ans = max(ans,v0[i].first + v1[j].first);
        }
        else
        {
            if(i)
            {
                ans = max(ans,v0[i-1].first + v1[j].first);
            }
            if(j)
            {
                ans = max(ans,v0[i].first + v1[j-1].first);
            }
        }
    }
    if(v1.size())
    {
        int sz1 = v1.size();
        ans = max(ans,v1[sz1-1].first);
        dp[u][1] = v1[sz1-1].first;
    }
    if(v0.size())
    {
        int sz0 = v0.size();
        ans = max(ans,v0[sz0-1].first);
        dp[u][0] = v0[sz0-1].first;
    }
}
void solve()
{
    int n;
    cin>>n;
    adj = vector<vector<pair<int,ll> > >(n+1);
    dp = vector<vector<ll> >(n+1,vector<ll>(2));
    for(int i=0;i<(n-1);i++)
    {
        int u,v,w;
        cin>>u>>v>>w;
        adj[u].push_back({v,w});
        adj[v].push_back({u,w});
    }
    dfs(1,0);
    cout<<ans<<endl;
    ans=0;
}
signed main()
{
   ios_base::sync_with_stdio(false);cin.tie(0);cout.tie(0);
   int tests;
   cin>>tests;
   //cin>>tests;
   for(int i=0;i<tests;i++)
   {
       solve();
   }
}