NCOMMONPATHS - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: progokcoe
Tester: jay_1048576
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

DFS

PROBLEM:

You’re given a tree on N vertices.
Find the number of ways of adding two edges to this tree such that the resulting graph has exactly two simple cycles; and these cycles don’t intersect.

EXPLANATION:

We’ll count all ordered ways of adding two edges first (i.e adding edge e_1 followed by e_2 is different from e_2 then e_1), and divide by 2 at the end.

Consider two different vertices u and v of the tree.
Suppose the first edge we add is between them. Let’s think about what choices we have for the second edge.

One nice way to visualize the graph we have right now, is a single cycle with rooted trees hanging off each of its vertices.
Let v_1, v_2, \ldots, v_k be the vertices forming the cycle, with the rooted tree hanging off of v_i being called T_i.
When adding the second edge,

  • Clearly, the endpoints of the second edge can’t be any of the cycle vertices; otherwise there’d be an intersection between the formed cycles.
  • The two vertices should both lie within the same T_i - if they don’t, the new cycle created by adding an edge between them will intersect part of the existing cycle.
  • Even within the same T_i, there’s one more constraint: the path between them shouldn’t contain v_i (the root of T_i) since it already lies on a cycle.
    • A simple way to state the previous two conditions, is that both vertices should lie in the subtree of some child of v_i (looking only at T_i).

With this in mind, let’s count the number of options we have for the second edge.
If the subtree of some child of v_i has size s, we can choose any two distinct vertices from it; which can be done in \frac{s\cdot (s-1)}{2} ways.

So, with the first edge (u, v) fixed, the number of choices of second edge is \sum_s \frac{s\cdot (s-1)}{2}, where the summation is taken across all subtree sizes of children of vertices on the cycle.


Now, we need to speed this algorithm up.
Let’s root the tree at vertex 1, and look at bit closer at what’s happening with the actual tree when the (u,v ) edge is added.

Specifically, it can be seen the ‘subtrees’ we’re looking for are exactly:

  • The entire subtree of some vertex x such that x doesn’t lie on the path from u to v, but its parent does; or
  • All the vertices “above” L = \text{lca}(u, v).
    That is, every vertex other than ones in the subtree of L.

With this information in hand, let’s look at things from the opposite direction.
Suppose you fix a vertex x.
For how many paths (u, v), will the subtree of x be one of the ones we’re looking for?
Well, as we noted above, such a path must pass through the parent of x, but not through x itself.

To compute this number, we can first find the total number of paths that pass through the parent of x, and from this subtract the number of paths that pass through both x and its parent.

  • To find the number of paths passing through a vertex, you can start from \binom{N}{2} and subtract all paths that don’t pass through it.
    This number is all the paths that lie fully within some subtree of a child of the vertex, plus the number of paths that lie fully outside the subtree of the vertex.
    Both of these are easy to compute if you know subtree sizes.
  • The number of paths that pass through both x and p_x is similarly not hard to find: one endpoint needs to lie within the subtree of x, and the other endpoint needs to lie outside.
    Once again, the number of choices is easily computed if the subtree sizes are known.

So, compute the number of such paths for x, and multiply it by the number of ways of choosing two vertices from the subtree of x.

Similarly, looking at the subtree “above” x, you’ll need to count the number of paths whose LCA is x.
This is again easy to do once you have subtree sizes: take all pairs of vertices within the subtree of x, then subtract out pairs whose LCA isn’t x (meaning they’d both have to be within the subtree of a child of x).

Summing this up across all x will give you (twice) the final answer.

Note that the answer for a tree of N vertices is bounded by \binom{N}{4} (it can be shown that for a fixed set of 4 vertices, there’s at most one way to pair them up such that the cycles don’t intersect), and for N = 10^5 this value is about 4\cdot 10^{18}.
So, twice the answer still fits in a signed 64-bit integer (whose limit is about 9\cdot 10^{18}), and there’s no issue with overflow.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
#define ll long long int
using namespace std;

ll dp[200000][2];
vector <bool> vis(200005,false);
vector <ll>     v[200005];
vector <ll>   sub(200005,0);
vector <ll>   par(200005);

void  dfs(ll node)
{
    vis[node] = true;
    vector <ll> temp;
    ll sum=0;
    for(auto z:v[node])
    {
        if(!vis[z])
        {
            dfs(z);
            temp.push_back(sub[z]);
            sum+=sub[z];
        }
    }
    sub[node]=1+sum;
    dp[node][0]=sum;
    for(auto k:temp) 
    {
        dp[node][0] += k*(sum-k);
        sum -= k;
    }
}
ll C(ll x)
{
    return (x*(x-1))/2;
}
void dfs_again(ll node,ll p,ll level)
{
    vis[node]=false;
    par[node]=p;
    if(node == 1) dp[node][1]=0;
    else          dp[node][1]=dp[p][1] + dp[p][0] - (sub[node]*(sub[p]-sub[node]))  + (sub[1]-sub[p])*(sub[p]-sub[node]);


    for(auto z:v[node])
    {
        if(vis[z])
        {
            dfs_again(z,node,1+level);
        }
    }
}

int main() {
    ll t,n;cin>>t;
    while(t--)
    {
        cin>>n;
        for(ll i=1;i<n;i++)
        {
            ll x,y;cin>>x>>y;
            v[x].push_back(y);
            v[y].push_back(x);
        }
        dfs(1);
        dfs_again(1,0,0);
        ll ans=0;
        dp[0][0]=dp[0][1]=0;
        ll temp=0;
        for(ll i=1;i<=n;i++)
        {
        ans+=dp[i][0]*dp[i][1];
        temp+=(dp[i][0]*(C(n-sub[i])-dp[i][1]));
        v[i].clear();
        }
        ans+=temp/2;
        cout<<ans<<"\n";
    }
}

Tester's code (C++)
// Input Checker
// Input verification
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

struct input_checker {
	string buffer;
	int pos;

	const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
	const string number = "0123456789";
	const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
	const string lower = "abcdefghijklmnopqrstuvwxyz";

	input_checker() {
		pos = 0;
		while (true) {
			int c = cin.get();
			if (c == -1) {
				break;
			}
			buffer.push_back((char) c);
		}
	}

	int nextDelimiter() {
		int now = pos;
		while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
			now++;
		}
		return now;
	}

	string readOne() {
		assert(pos < (int) buffer.size());
		int nxt = nextDelimiter();
		string res;
		while (pos < nxt) {
			res += buffer[pos];
			pos++;
		}
		return res;
	}

	string readString(int minl, int maxl, const string &pattern = "") {
		assert(minl <= maxl);
		string res = readOne();
		assert(minl <= (int) res.size());
		assert((int) res.size() <= maxl);
		for (int i = 0; i < (int) res.size(); i++) {
			assert(pattern.empty() || pattern.find(res[i]) != string::npos);
		}
		return res;
	}

	int readInt(int minv, int maxv) {
		assert(minv <= maxv);
		int res = stoi(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	long long readLong(long long minv, long long maxv) {
		assert(minv <= maxv);
		long long res = stoll(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	auto readIntVec(int n, int minv, int maxv) {
		assert(n >= 0);
		vector<int> v(n);
		for (int i = 0; i < n; ++i) {
			v[i] = readInt(minv, maxv);
			if (i+1 < n) readSpace();
		}
		return v;
	}

	auto readLongVec(int n, long long minv, long long maxv) {
		assert(n >= 0);
		vector<long long> v(n);
		for (int i = 0; i < n; ++i) {
			v[i] = readLong(minv, maxv);
			if (i+1 < n) readSpace();
		}
		return v;
	}

	void readSpace() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == ' ');
		pos++;
	}

	void readEoln() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == '\n');
		pos++;
	}

	void readEof() {
		assert((int) buffer.size() == pos);
	}
};

void dfsMark(vector<int> adj[],bool marked[],int u,int p)
{
    marked[u]=true;
    for(auto v:adj[u])
        if(v!=p)
            dfsMark(adj,marked,v,u);
}

void checkTree(vector<int> adj[],int n)
{
    bool marked[n];
    memset(marked,false,sizeof(marked));
    dfsMark(adj,marked,0,0);
    for(int i=0;i<n;i++)
        assert(marked[i]);
}

void dfs(vector<int> adj[],unsigned long long dis[],unsigned long long par[],unsigned long long subtree[],unsigned long long cycles[],int u,int p)
{
    par[u]=p;
    cycles[u]=0;
    unsigned long long sum=0,sum2=0;
    for(auto v:adj[u])
    {
        if(v!=p)
        {
            dis[v]=dis[u]+1;
            dfs(adj,dis,par,subtree,cycles,v,u);
            sum += subtree[v];
            sum2 += subtree[v]*subtree[v];
        }
    }
    subtree[u] = 1+sum;
    cycles[u] = sum+(sum*sum-sum2)/2;
}

void dfs2(vector<int> adj[],unsigned long long dis[],unsigned long long anticycles[],unsigned long long subtree[],unsigned long long cycles[],int u,int p,int n)
{
    if(u!=0)
        anticycles[u]=anticycles[p]+(cycles[p]-subtree[u]*(subtree[p]-subtree[u]))+(n-subtree[p])*(subtree[p]-subtree[u]);
    // cout << u+1 << " " << anticycles[p] << " " << anticycles[u] << '\n';
    for(auto v:adj[u])
    {
        if(v!=p)
        {
            dfs2(adj,dis,anticycles,subtree,cycles,v,u,n);
        }
    }
}

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	input_checker inp;
	
	int tc = inp.readInt(1, 10000);
	inp.readEoln();
	int sum_n=0;
	while(tc--)
	{
	    int n = inp.readInt(4, 100000);
	    inp.readEoln();
	    sum_n += n;
	    assert(sum_n <= 500000);
        vector<int> adj[n];
        for(int i=0;i<n-1;i++)
        {
            int u = inp.readInt(1,n);
            inp.readSpace();
            int v = inp.readInt(1,n);
            inp.readEoln();
            u--,v--;
            adj[u].push_back(v);
            adj[v].push_back(u);
        }
        checkTree(adj,n);
        unsigned long long dis[n],par[n],subtree[n],cycles[n];
        dis[0]=0;
        dfs(adj,dis,par,subtree,cycles,0,0);
        // for(int i=0;i<n;i++)
            // cout << cycles[i] << " ";
        // cout << '\n';
        unsigned long long anticycles[n];
        anticycles[0]=0;
        dfs2(adj,dis,anticycles,subtree,cycles,0,0,n);
        // for(int i=0;i<n;i++)
            // cout << anticycles[i] << " ";
        // cout << '\n';
        unsigned long long ans=0;
        for(int i=0;i<n;i++)
            ans += cycles[i]*(anticycles[i]+(n-subtree[i])*(n-subtree[i]-1)/2);
        ans /= 2;
        cout << ans << '\n';
    }
	inp.readEof();
}
Editorialist's code (Python)
def dfs(graph, start=0):
    n = len(graph)
    order = []
    visited, parent = [False] * n, [-1] * n

    stack = [start]
    while stack:
        start = stack[-1]
        stack.pop()

        visited[start] = True
        order.append(start)
        for child in graph[start]:
            if not visited[child]:
                parent[child] = start
                stack.append(child)

    return parent, order

for _ in range(int(input())):
    n = int(input())
    gr = [ [] for _ in range(n) ]
    for i in range(n-1):
        u, v = map(int, input().split())
        gr[u-1].append(v-1)
        gr[v-1].append(u-1)
    parent, order = dfs(gr)

    subsz = [0]*n
    for u in reversed(order):
        subsz[u] += 1
        if u > 0: subsz[parent[u]] += subsz[u]
    
    within = [0]*n
    through = [0]*n
    ans = 0
    for u in range(n):
        through[u] = n*(n-1)//2
        through[u] -= (n - subsz[u]) * (n - subsz[u] - 1) // 2

        within[u] = subsz[u]*(subsz[u] - 1) // 2
        for v in gr[u]:
            if parent[v] == u:
                through[u] -= subsz[v] * (subsz[v] - 1) // 2
                within[u] -= subsz[v] * (subsz[v] - 1) // 2
    
    for u in range(n):
        ans += within[u] * (n - subsz[u]) * (n - subsz[u] - 1) // 2
        if u > 0:
            ans += (through[parent[u]] - (subsz[u] * (n - subsz[u]))) * subsz[u] * (subsz[u] - 1) // 2
    print(ans // 2)