YANGRY - Editorial

PROBLEM LINK:

Contest Link

Can Yamato test

Author: Nihal Mittal

DIFFICULTY:

Medium.

PREREQUISITES:

DFS, GRAPH

PROBLEM:

Given an undirected tree, You have to count the number of paths (u,v) such that in any of the path node y does not come after node x.

Quick Overview:

There will be n * (n-1) unique paths. We can easily subtract the paths in which node y comes after node x and that will be our answer.

Explanation:

We can consider node y as the root node. We will run a Depth-first search from the root node and count the number of nodes in the subtree of all nodes. We will now need to count the number of nodes in the subtree of y such that any of those nodes do not have node x as a descendant. We can easily do that by subtracting the number of nodes in the subtree of a direct child of node y that is also an ancestor of node x from the number of nodes in the subtree of node y.

Finally the answer would simple be: n * (n-1) - (sub[y] - sub[z]) * (sub[x]). Here the sub array is the subtree array that we calculated by running the DFS.

SOLUTIONS:

Setter's Solution
#include<bits/stdc++.h>
#define ll long long
#define ld long double
#define mk make_pair
#define pb push_back
#define in insert
#define se second
#define fi first
#define mod 1000000007
#define watch(x) cout << (#x) << " is " << (x) << "\n"
#define all(v) (v).begin(),(v).end()
#define fast ios_base::sync_with_stdio(false); cin.tie(NULL); cout.tie(0);
#define printclock cerr<<"Time : "<<1000*(ld)clock()/(ld)CLOCKS_PER_SEC<<"ms\n";
#define pii pair<ll,ll>
#define vi(x) vector<x>
#define maxn 100005
using namespace std;
vector<ll> v[300010];
ll sub[300010];
ll x,y,z;
void dfs(ll so, ll par,ll child) {
	sub[so] = 1;
	if(so == x) z = child;
	for(auto u: v[so]){
		if(u != par) {
			if(so == y) dfs(u, so,u);
			else dfs(u,so,child);
			sub[so] += sub[u];
		}
	}
}
signed main()
{
    fast;
    ll n,i,j;
    cin>>n>>x>>y;
    for(i=0;i<n-1;i++) {
    	ll x1,y1;
    	cin>>x1>>y1;
    	v[x1].pb(y1);
    	v[y1].pb(x1);
    }
    dfs(y,-1,-1);
    ll ans = n * (n-1) - (sub[x]) * (sub[y] - sub[z]);
    cout<<ans;
}