BLWHTREE - Editorial

PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Setter: Andrew

Tester: Teja Vardhan Reddy

Editorialist: Taranpreet Singh

DIFFICULTY:

Medium

PREREQUISITES:

Heavy-light Decomposition, Inclusion-Exclusion, Combinatorics and Segment Tree.

PROBLEM:

Given a tree with N nodes where each node is assigned colors Black and White, denoted by 1 or 0 respectively, we have to answer Q queries where each query specifies a path, say from node L to R and we have to count the number of ways to choose a set of three distinct nodes such that

  • The chosen nodes lie on the shortest path between node L and R
  • For any two chosen nodes, say u and v, there is at least one black colored node on the shortest path between nodes u and v (both inclusive).

DEFINITIONS

  • For a query (L, R), a triplet (u, v, w) represent the set of nodes u, v and w such that all these nodes lie on shortest path from L to R and while moving from L to R, first u is found, then v and then w.
  • Let P(u, v) denote the path from u to v
  • A triplet (u, v, w) is valid for a query (L, R) if there’s at least one black node on P(u, v) and at least one black node on P(v, w)

QUICK EXPLANATION

  • The total number of triplets is \displaystyle\binom{K}{3} where K is the number of nodes on the path from L to R. The mutually disjoint exhaustive set of invalid triplets are as follows.
    • The triplet has no black node on P(u, w). The number of such triplets is given by \displaystyle\sum_{x \in S} \binom{x}{3} where S contain blocks of consecutive 0s on P(L, R)
    • The triplet has no black node on P(u, v) and at least one black node on P(v, w)
      OR
      The triplet has at least one black node on P(u, v) and no black node on P(v, w)
      The number of such triplets is given by \displaystyle \sum_{x \in S} (K-x)*\binom{x}{2}
      We shall exclude these invalid triplets from total triplets to get the number of valid triplets.
  • We need to maintain this information over ranges, so we need to use the segment tree, each node storing the above information, and handle node merging while maintaining the same set of information.
  • In order to extend the above solution over a tree, we need to apply the Heavy-light Decomposition.

EXPLANATION

Let’s solve a simpler problem.
Given an array A of length K consisting of 0 and 1 only, answer queries to count the number of valid triplets for specified subarray.

Now, let’s count the number of invalid triplets. For a triplet (u, v, w) to be invalid, we need either A[u, v] to contain all zeroes, or A[v, w] to contain all zeroes or both.

If we group the consecutive 0s and 1s, we can see that an invalid triplet contains at least two of the three chosen nodes in the same group. Let’s see an example.

Consider K= 15 and array 1 1 1 0 0 0 0 0 1 1 1 1 0 0 0
Grouping 0s in the above array, we get One group of 5 zeroes and one group of 3 zeroes. So multiset S = \{5, 3\}

Let’s count invalid triplets (u, v, w) where there’s no 1s in A[u, w]. We can easily see that we need u, v and w to be in the same group which is the same as the number of ways to select three items from x elements, which is \displaystyle\binom{x}{3} where x is group size. Hence, the number of such triplets is \displaystyle \binom{5}{3}+\binom{3}{3} which is 10+1 = 11

Now, let us count the number of invalid triplets where

  • A[u, v] contain all 0s and A[v, w] contain at least one 1
  • A[u, v] contain all 1s and A[v, w] contain all 0s

Now, we need either (u, v) to lie in the same group of 0s and w to be not in the same group (which implies at least one 1, OR (v, w) to lie in one group and u to be not in the same group.

The number of ways to choose a pair of nodes in the same group is \displaystyle\binom{x}{2} and we have K-x choices for third node where x is group size. Hence, the total number of invalid triplets of this type is \displaystyle (K-x)*\binom{x}{2}. In current example, it becomes (15-5)*\binom{5}{2}+(15-3)*\binom{3}{2} = 100+36 = 136

Hence, the total number of invalid triplets is 11+136 = 147. Total number of triplets is \displaystyle\binom{15}{3} = 455

Hence, the number of valid triplets is given by 308.

In general, the number of valid triplets becomes \displaystyle\binom{K}{3}-\sum_{x \in S} \bigg[(K-x)*\binom{x}{2} + \binom{x}{3}\bigg]

Now, we know how to calculate the valid triplets, and It is intuitive to use range Data structure like segment tree, but what information to store to answer subarray queries?

A node in segment tree corresponds to a range, and it is also possible for a group to lie in different nodes, so in each node, we need to store whether whole range contains 0s only and if not, the number of 0s at start of range, the number of 0s at end of range.

Now, Ignoring above, we also need to be able to calculate \displaystyle\binom{x}{2}, \displaystyle x*\binom{x}{2} and \displaystyle\binom{x}{3} for all groups of 0s lying completely inside this range. Turns out this information is enough.

We just need to merge the two ranges carefully taking care of prefix and suffix 0s and when they are merged.

This allows us to answer subarray queries in O((Q+K)*log(K)) time.

Now, returning to our original problem, we have done all the hard work, we just need to extend this over paths in the tree and as you guys familiar with this type of problem may have guessed, we use Heavy Light Decomposition.

Stating briefly from the above link, The essence of heavy light decomposition is to split the tree into several paths so that we can reach the root vertex from any v by traversing at most logN paths. In addition, none of these paths should intersect with another. As we can see here, we have exactly done that here, solved the same problem for array and using HLD to extend it over the tree.

Implementation things to take care, if you don’t want to go crazy debugging:

  • While merging nodes, be careful not to include the LCA node twice.
  • Directions matter here. If a node represents information for P(u, v) then we need to flip start and end block information to get information for P(v, u). This is needed while merging paths.
  • Same way, the merging is non-commutative. That matters in the segment tree.

While the above solution is the intended solution, there also exists a solution without using Heavy Light Decomposition, relying on precomputation, as used by the tester, as well as by most of the users.

TIME COMPLEXITY

The time complexity is O(N*log(N)+Q*log^2(N)) per test case.

SOLUTIONS:

Setter's Solution
#include<bits/stdc++.h>
using namespace std;


const int N = 1e5 + 5;


int n, m, a[N];
vector < int > g[N];
int nxt[N], sz[N], d[N], p[N];
int chain[N], num_chain, top[N], num[N], all, ob[N];

/// ans, kol_0, kol_1, sum_l_1, sum_r_1, sum_l_0, sum_r_0, suf_0, pref_0, sz

struct my{
	long long ans;
	int kol_0, kol_1;
	long long sum_l_1, sum_r_1, sum_l_0, sum_r_0;
	int suf_0, pref_0, sz;
	my(long long _ans = 0, int _kol_0 = 0, int _kol_1 = 0, long long _sum_l_1 = 0,
	   long long _sum_r_1 = 0, long long _sum_l_0 = 0, long long _sum_r_0 = 0,
	   int _suf_0 = 0, int _pref_0 = 0, int _sz = 0){
	    ans = _ans;
	    kol_0 = _kol_0;
	    kol_1 = _kol_1;
	    sum_l_1 = _sum_l_1;
	    sum_r_1 = _sum_r_1;
	    sum_l_0 = _sum_l_0;
	    sum_r_0 = _sum_r_0;
	    suf_0 = _suf_0;
	    pref_0 = _pref_0;
	    sz = _sz;
	}
};
my t[4 * N];

my cmb(my l, my r){
	my res;
	res.sz = l.sz + r.sz;
	res.kol_1 = l.kol_1 + r.kol_1;
	res.kol_0 = l.kol_0 + r.kol_0;
	res.pref_0 = l.pref_0;
	if(l.sz == l.kol_0){
	    res.pref_0 += r.pref_0;
	}
	res.suf_0 = r.suf_0;
	if(r.sz == r.kol_0){
	    res.suf_0 += l.suf_0;
	}
	res.ans = l.ans + r.ans;
	res.ans += l.sum_l_1 * r.sz;
	res.ans += r.sum_r_1 * l.sz;
	res.sum_r_1 = l.sum_r_1 + r.sum_r_1;
	res.sum_r_1 += 1LL * l.kol_1 * r.sz;
	res.sum_l_1 = r.sum_l_1 + l.sum_l_1;
	res.sum_l_1 += 1LL * r.kol_1 * l.sz;
	res.ans += 1LL * r.pref_0 * (r.sz - r.pref_0) * (l.sz - l.suf_0) + (r.sum_r_0 - 1LL * r.pref_0 * (r.sz - r.pref_0)) * l.sz;
	res.ans += 1LL * l.suf_0 * (l.sz - l.suf_0) * (r.sz - r.pref_0) + (l.sum_l_0 - 1LL * l.suf_0 * (l.sz - l.suf_0)) * r.sz;
	res.sum_r_0 = l.sum_r_0 + r.sum_r_0;
	res.sum_r_0 += 1LL * (l.kol_0 - l.suf_0) * r.sz + 1LL * l.suf_0 * (r.sz - r.pref_0);
	res.sum_l_0 = l.sum_l_0 + r.sum_l_0;
	res.sum_l_0 += 1LL * (r.kol_0 - r.pref_0) * l.sz + 1LL * r.pref_0 * (l.sz - l.suf_0);
	return res;
}

void build(int v, int l, int r){
	if(l == r){
	    int ans = 0,
	        kol_0 = (a[ob[l]] == 0),
	        kol_1 = (a[ob[l]] == 1),
	        sum_l_1 = 0,
	        sum_r_1 = 0,
	        sum_l_0 = 0,
	        sum_r_0 = 0,
	        suf_0 = (a[ob[l]] == 0),
	        pref_0 = (a[ob[l]] == 0),
	        sz = 1;
	    t[v] = my(ans, kol_0, kol_1, sum_l_1, sum_r_1, sum_l_0, sum_r_0, suf_0, pref_0, sz);
	    return;
	}
	int mid = (r + l) >> 1;
	build(v + v, l, mid);
	build(v + v + 1, mid + 1, r);
	t[v] = cmb(t[v + v], t[v + v + 1]);
}

my res;
void get(int v, int l, int r, int tl, int tr){
	if(l > r || l > tr || tl > r){
	    return;
	}
	if(tl <= l && r <= tr){
	    if(res.ans == -1){
	        res = t[v];
	    }
	    else{
	        res = cmb(res, t[v]);
	    }
	    return;
	}
	int mid = (r + l) >> 1;
	get(v + v, l, mid, tl, tr);
	get(v + v + 1, mid + 1, r, tl, tr);
}

void dfs(int v, int pr = 0){
	p[v] = pr;
	sz[v] = 1;
	for(int to : g[v]){
	    if(to == pr){
	        continue;
	    }
	    d[to] = d[v] + 1;
	    dfs(to, v);
	    sz[v] += sz[to];
	    if(nxt[v] == -1 || sz[to] > sz[nxt[v]]){
	        nxt[v] = to;
	    }
	}
}
void hld(int v, int pr = -1){
	chain[v] = num_chain;
	num[v] = ++all;
	ob[all] = v;
	if(nxt[v] != -1){
	    top[nxt[v]] = top[v];
	    hld(nxt[v], v);
	}
	for(int to : g[v]){
	    if(to == pr || to == nxt[v]){
	        continue;
	    }
	    num_chain += 1;
	    top[to] = to;
	    hld(to, v);
	}
}

void solve(){
	cin >> n;
	for(int i = 1; i <= n; i++){
	    cin >> a[i];
	    g[i].clear();
	}
	for(int i = 1; i < n; i++){
	    int x, y;
	    cin >> x >> y;
	    g[x].push_back(y);
	    g[y].push_back(x);
	}
	memset(nxt, -1, sizeof(nxt));
	dfs(1);
	all = 0;
	top[1] = num_chain = 1;
	hld(1);
	build(1, 1, n);
	cin >> m;
	while(m--){
	    int x, y;
	    cin >> x >> y;
	    vector < my > l, r;
	    while(chain[x] != chain[y]){
	        if(d[top[x]] > d[top[y]]){
	            res.ans = -1;
	            get(1, 1, n, num[top[x]], num[x]);
	            l.push_back(res);
	            x = p[top[x]];
	        }
	        else{
	            res.ans = -1;
	            get(1, 1, n, num[top[y]], num[y]);
	            r.push_back(res);
	            y = p[top[y]];
	        }
	    }
	    int xx = min(num[x], num[y]),
	        yy = max(num[x], num[y]);
	    res.ans = -1;
	    get(1, 1, n, xx, yy);
	    if(num[x] < num[y]){
	        r.push_back(res);
	    }
	    else{
	        l.push_back(res);
	    }
	    reverse(r.begin(), r.end());
	    for(auto &it : r){
	        swap(it.pref_0, it.suf_0);
	        swap(it.sum_l_0, it.sum_r_0);
	        swap(it.sum_l_1, it.sum_r_1);
	    }
	    for(auto it : r){
	        l.push_back(it);
	    }
	    my ans = l[0];
	    for(int i = 1; i < (int)l.size(); i++){
	        ans = cmb(l[i], ans);
	    }
	    cout << ans.ans << "\n";
	}
}
int main(){
	ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
	int tt = 1;
	cin >> tt;
	while(tt--){
	    solve();
	}
}
Tester's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp> 
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val
 
using namespace std;
using namespace __gnu_pbds;
 
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout) 
#define primeDEN 727999983
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
 
// find_by_order()  // order_of_key
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;
 
#define int ll
vector<vi> adj(123456);
int paren[123456][18];
int d0[123456],d1[123456],d2[123456];
int gg[123456][3][3],nearfw[123456];
int dep[123456];
int a[123456];
 
int dfs(int cur,int par,int lastw,int lastb){

	paren[cur][0]=par;
	dep[cur]=dep[par]+1;
	int h;
	if(a[cur]==0){
		d0[cur]=d0[par]+1;
		d1[cur]=d1[par]+dep[cur];
		d2[cur]=d2[par]+dep[cur]*dep[cur];
	}
	else{
		d0[cur]=d0[par];
		d1[cur]=d1[par];
		d2[cur]=d2[par];
	}
	if(a[cur]==0 && a[par]==1){
	    nearfw[cur]=cur;
	    //cout<<lastb<<endl;
		h=dep[cur]-dep[lastb]-1;
		gg[cur][0][0]=gg[par][0][0]+1;
		gg[cur][1][0]=gg[par][1][0]+h;
		gg[cur][1][1]=gg[par][1][1]+h*dep[cur];
		gg[cur][1][2]=gg[par][1][2]+h*dep[cur]*dep[cur];
		gg[cur][2][1]=gg[par][2][1]+h*h*dep[cur];
		gg[cur][2][0]=gg[par][2][0]+h*h;
 
	}
	else{
		nearfw[cur]=nearfw[par];
		gg[cur][0][0]=gg[par][0][0];
		gg[cur][1][0]=gg[par][1][0];
		gg[cur][1][1]=gg[par][1][1];
		gg[cur][1][2]=gg[par][1][2];
		gg[cur][2][1]=gg[par][2][1];
		gg[cur][2][0]=gg[par][2][0];
	}
	if(a[cur]==0)
		lastb=cur;
	else
		lastw=cur;
	int i;
	rep(i,adj[cur].size()){
	    if(adj[cur][i]==par)
	        continue;
		dfs(adj[cur][i],cur,lastw,lastb);
	}
	return 0;
}
int getlca(int u,int v){
	int i;
	if(dep[u]>dep[v])
		swap(u,v);
	fd(i,17,0){
		if(dep[v]-(1<<i)>=dep[u]){
			v=paren[v][i];
	    }
	}
	if(u==v)
		return u;
	//cout<<paren[u][0]<<endl;
	fd(i,17,0){
		if(paren[u][i]!=paren[v][i]){
	        u=paren[u][i];
			v=paren[v][i];
		}
	}
	return paren[u][0];
}
main(){
	std::ios::sync_with_stdio(false); cin.tie(NULL);
	int t;
	cin>>t;
	while(t--){
	    int n;
	    cin>>n;
	    int i;
	    rep(i,n+10){
	        adj[i].clear();
	        nearfw[i]=0;
	    }
	    f(i,1,n+1){
	        cin>>a[i];
	        a[i]^=1;
	    }
	    int u,v;
	    rep(i,n-1){
	        cin>>u>>v;
	        adj[u].pb(v);
	        adj[v].pb(u);
	    }
	    dfs(1,0,0,0);
	    int j;
	    f(j,1,18){
	        f(i,1,n+1){
	            if(paren[i][j-1]==0){
	                paren[i][j]=0;
	            }
	            else
	                paren[i][j]=paren[paren[i][j-1]][j-1];
	        }
	    }
	    int lca,len,dist,z,haha0,haha1,haha2;
	    int haha10,haha11,haha20,haha21,haha12;
	    int ans,nearestu,nearestv,gao,cur,h,x;
	    int m;
	    cin>>m;
	    ans=0;
	    rep(i,m){
	        cin>>u>>v;
	        //cout<<u<<" "<<v<<endl;
	        lca=getlca(u,v);
	        if(u==v){
	            cout<<0<<endl;
	            continue;
	        }
	        len=dep[u]+dep[v]-2*dep[lca]+1;
	        // black vertices
	        z=dep[u];
	        haha0=d0[u]-d0[paren[lca][0]];
	        haha1=d1[u]-d1[paren[lca][0]];
	        haha2=d2[u]-d2[paren[lca][0]];
	        ans=haha0*((len-1)*z);
	        ans+=(1-len)*haha1;
	        ans-=haha0*z*z;
	        ans-=haha2;
	        ans+=2*z*haha1;
	        //cout<<ans<<endl;
	        
	        swap(u,v);
	        z=dep[u];
	        haha0=d0[u]-d0[lca];
	        haha1=d1[u]-d1[lca];
	        haha2=d2[u]-d2[lca];
	        ans+=haha0*((len-1)*z);
	        ans+=(1-len)*haha1;
	        ans-=haha0*z*z;
	        ans-=haha2;
	        ans+=2*z*haha1;
	        //cout<<ans<<endl;
	        swap(u,v);
	        // white vertices.
	        z=dep[u];
	        if(gg[u][0][0]-gg[lca][0][0]==0){
	            if(a[lca]==0){
	                nearestu=lca;
	            }
	            else{
	                nearestu=-1;
	            }
	        }
	        else{
	            gao=nearfw[u];
	            cur=u;
	            fd(j,17,0){
	                if(dep[cur]-(1<<j)>=dep[lca] && d0[paren[cur][j]]>d0[paren[lca][0]]){
	                    //cout<<j<<endl;
	                    cur=paren[cur][j];
	                }
	            }
	            //cout<<gao<<" "<<cur<<endl;
	            haha10 = gg[gao][1][0]-gg[cur][1][0];
	            haha11 = gg[gao][1][1]-gg[cur][1][1];
	            haha21 = gg[gao][2][1]-gg[cur][2][1];
	            haha12 = gg[gao][1][2]-gg[cur][1][2]; 
	            haha20 = gg[gao][2][0]-gg[cur][2][0]; 
	            ans+=len*haha10*z - haha11*len -  haha20*z + haha21 + len*haha10
	            - haha20 -(z+1)*(z+1)*haha10 -haha12 + 2*(z+1)*haha11;
	            nearestu = cur;
	        }
	        nearestv = nearestu;
	        //cout<<ans<<endl;
	        swap(u,v);
	        z=dep[u];
	        if(gg[u][0][0]-gg[lca][0][0]==0){
	            if(a[lca]==0){
	                nearestu=lca;
	            }
	            else{
	                nearestu=-1;
	            }
	        }
	        else{
	            gao=nearfw[u];
	            cur=u;
	            fd(j,17,0){
	                if(dep[cur]-(1<<j)>=dep[lca] && d0[paren[cur][j]]>d0[paren[lca][0]]){
	                   // cout<<j<<endl;
	                    cur=paren[cur][j];
	                }
	            }
	            //cout<<gao<<" "<<cur<<endl;
	            haha10 = gg[gao][1][0]-gg[cur][1][0];
	            haha11 = gg[gao][1][1]-gg[cur][1][1];
	            haha21 = gg[gao][2][1]-gg[cur][2][1];
	            haha12 = gg[gao][1][2]-gg[cur][1][2]; 
	            haha20 = gg[gao][2][0]-gg[cur][2][0]; 
	            ans+=len*haha10*z - haha11*len -  haha20*z + haha21 + len*haha10
	            - haha20 -(z+1)*(z+1)*haha10 -haha12 + 2*(z+1)*haha11;
	            nearestu = cur;
	        }
	        //cout<<haha10<<endl;
	        if(nearestu!=-1 && nearestv !=-1){
	            dist=dep[nearestu]+dep[nearestv]-2*dep[lca]+1;
	            dist-=2;
	            x=dep[nearestu];
	            if(dist>0){
	                ans+=dist*(z-x+1)*(len-dist-(z-x+1));
	            }
 
	        }
	        cout<<ans<<endl;
	    }
	}
 
	return 0;   
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class BLWHTREE{
	//SOLUTION BEGIN
	int B = 20;
	void pre() throws Exception{}
	void check(int[] a){
	    long brute = 0;
	    for(int i = 0; i< a.length; i++)
	        for(int j = i+1; j< a.length; j++)
	            for(int k = j+1; k< a.length; k++){
	                boolean f = false, g=  false;
	                for(int l = i; l <= j; l++)f |= a[l] == 1;
	                for(int l = j; l <= k; l++)g |= a[l] == 1;
	                if(f && g)brute++;
	            }
	    long a1 = nC3(a.length), a2 = 0, a3 = 0, a4 = 0;
	    for(int i = 0; i< a.length; i++)
	        for(int j = i+1; j< a.length; j++)
	            for(int k = j+1; k< a.length; k++){
	                boolean f = true, g = true;
	                for(int l = i; l <= j; l++)f &= a[l] == 0;
	                for(int l = j; l <= k; l++)g &= a[l] == 0;
	                if(f)a2++;
	                if(g)a3++;
	                if(f && g)a4++;
	            }
//        pn(brute);
//        pn(a1+" "+a2+" "+a3+" "+a4);
//        pn(brute == (a1+a4-a2-a3));
	}
	void solve(int TC) throws Exception{
	    check(new int[]{1, 1, 1, 0, 0, 1, 1, 0});
	    int n = ni();
	    int[] a = new int[n];
	    for(int i = 0; i< n; i++)a[i] = ni();
	    int[] from = new int[n-1], to = new int[n-1];
	    for(int i = 0; i< n-1; i++){from[i] = ni()-1;to[i] = ni()-1;}
	    int[][] g = makeU(n, from, to);
	    int[] d = new int[n], sub = new int[n], ch = new int[n], ti = new int[n];
	    int[][] par = new int[B][n];
	    for(int i = 0; i< B; i++)Arrays.fill(par[i], -1);
	    pre(g, sub, d, par, 0, -1);
	    time = -1;
	    dfs(g, ti, sub, ch, 0, -1);
	    int[] b = new int[n];
	    for(int i = 0; i< n; i++)b[ti[i]] = a[i];
	    SegTree t = new SegTree(b);
	    for(int qq = ni(); qq>0; qq--){
	        int u = ni()-1, v = ni()-1;
	        int l = lca(par, d, u, v);
	        int length = d[u]+d[v]-2*d[l]+1;
	        if(l == u || l == v){
	            if(l == u){
	                Node ans = queryUp(t, par, ti, ch, d, v, d[l]);
	                pn(ans.eval(length));
	            }else if(l == v){
	                Node ans = queryUp(t, par, ti, ch, d, u, d[l]);
	                ans.flip();
	                pn(ans.eval(length));
	            }
	        }else{
	            Node le = queryUp(t, par, ti, ch, d, u, d[l]), ri = queryUp(t, par, ti, ch, d, v, d[l]+1);
	            le.flip();
	            Node ans = merge(le, ri);
	            pn(ans.eval(length));
	        }
	    }
	}
	Node queryUp(SegTree t, int[][] par, int[] ti, int[] ch,int[] d, int u, int dep){
	    Node ans = null;
	    while(dep < d[ch[u]]){
	        ans = merge(t.query(ti[ch[u]], ti[u]), ans);
	        u = par[0][ch[u]];
	    }
	    ans = merge(t.query(ti[lift(par, u, d[u]-dep)], ti[u]), ans);
	    return ans;
	}
	int lift(int[][] par, int u, int di){
	    for(int b = B-1; b>= 0; b--)
	        if(((di>>b)&1)==1)
	            u = par[b][u];
	    return u;
	}
	int lca(int[][] par, int[] d, int u, int v){
	    if(d[v] > d[u])v = lift(par, v, d[v]-d[u]);
	    if(d[u] > d[v])u = lift(par, u, d[u]-d[v]);
	    if(u == v)return u;
	    for(int b = B-1; b>= 0; b--)
	        if(par[b][u] != par[b][v]){
	            u = par[b][u];
	            v = par[b][v];
	        }
	    return par[0][u];
	}
	class Node{
	    boolean all;
	    int le, ri;
	    long xC2, xxC2, xC3;
	    public Node(int ai){
	        if(ai == 0){
	            all = true;
	            le = 1;ri = 0;
	        }else{
	            all = false;
	            le = 0;ri = 0;
	        }
	    }
	    public Node(boolean all, int le, int ri, long xC2, long xxC2, long xC3){
	        this.all = all;
	        this.le = le;
	        this.ri = ri;
	        this.xC2 = xC2;
	        this.xxC2 = xxC2;
	        this.xC3 = xC3;
	    }
	    public Node copy(){return new Node(all, le, ri, xC2, xxC2, xC3);}
	    void flip(){if(all)return;int tmp = le;le = ri;ri = tmp;}
	    long eval(long length){
	        if(all)//All zero
	            return 0;
	        long ans = nC3(length)+xxC2-length*xC2-xC3;
	        ans += le*nC2(le);
	        ans -= length*nC2(le);
	        ans -= nC3(le);
	        ans += ri*nC2(ri);
	        ans -= length*nC2(ri);
	        ans -= nC3(ri);
	        return ans;
	    }
	    void print(){
	        pn(all+" "+le+" "+ri+" "+xC2+" "+xxC2+" "+xC3);
	    }
	}
	Node merge(Node le, Node ri){
	    if(le == null && ri == null)return null;
	    if(le == null)return ri.copy();
	    if(ri == null)return le.copy();
	    Node ans = null;
	    if(le.all){
	        if(ri.all)ans = new Node(true, le.le+ri.le, 0, 0, 0, 0);
	        else ans = new Node(false, le.le+ri.le, ri.ri, ri.xC2, ri.xxC2, ri.xC3);
	    }else{
	        if(ri.all)ans = new Node(false, le.le, le.ri+ri.le, le.xC2, le.xxC2, le.xC3);
	        else ans = new Node(false, le.le, ri.ri, le.xC2+ri.xC2+nC2(le.ri+ri.le), le.xxC2+ri.xxC2+(le.ri+ri.le)*nC2(le.ri+ri.le), le.xC3+ri.xC3+nC3(le.ri+ri.le));
	    }
	    return ans;
	}
	class SegTree{
	    int m = 1;
	    Node[] t;
	    public SegTree(int[] a){
	        while(m<a.length)m<<=1;
	        t = new Node[m<<1];
	        for(int i = 0; i< a.length; i++)t[i+m] = new Node(a[i]);
	        for(int i = m-1; i> 0; i--)t[i] = merge(t[i<<1], t[i<<1|1]);
	    }
	    Node query(int l, int r){
	        Node le = null, ri = null;
	        for(l+=m, r+=m+1; l< r; l>>=1, r>>=1){
	            if((l&1)==1)le = merge(le, t[l++]);
	            if((r&1)==1)ri = merge(t[--r], ri);
	        }
	        return merge(le, ri);
	    }
	}
	long nC2(long x){
	    return (x*x-x)/2;
	}
	long nC3(long x){
	    return (x*(x-1)*(x-2))/6;
	}
	int time;
	void dfs(int[][] g, int[] ti, int[] sub, int[] ch, int u, int p){
	    ti[u] = ++time;
	    int hc = -1;
	    for(int i = 0; i< g[u].length; i++){
	        if(g[u][i] == p)continue;
	        if(hc == -1 || sub[g[u][i]] > sub[g[u][hc]])hc = i;
	    }
	    if(hc != -1){
	        int tmp = g[u][hc];
	        g[u][hc] = g[u][0];
	        g[u][0] = tmp;
	    }
	    for(int v:g[u]){
	        if(v == p)continue;
	        if(v == g[u][0])ch[v] = ch[u];
	        else ch[v] = v;
	        dfs(g, ti, sub, ch, v, u);
	    }
	}
	void pre(int[][] g, int[] sub, int[] d, int[][] par, int u, int p){
	    for(int b = 1; b< B; b++)
	        if(par[b-1][u] != -1)
	            par[b][u] = par[b-1][par[b-1][u]];
	    for(int v:g[u]){
	        if(v == p)continue;
	        d[v] = d[u]+1;
	        par[0][v] = u;
	        pre(g, sub, d, par, v, u);
	        sub[u]+=sub[v];
	    }
	    sub[u]++;
	}
	int[][] makeU(int n, int[] from, int[] to){
	    int[][] g = new int[n][];int[] cnt = new int[n];
	    for(int i:from)cnt[i]++;for(int i:to)cnt[i]++;
	    for(int i = 0; i< n; i++)g[i] = new int[cnt[i]];
	    for(int i = 0; i< n-1; i++){
	        g[from[i]][--cnt[from[i]]] = to[i];
	        g[to[i]][--cnt[to[i]]] = from[i];
	    }
	    return g;
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	DecimalFormat df = new DecimalFormat("0.00000000000");
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new BLWHTREE().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile: