YATP - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Practice

Setter: Aryan Choudhary
Tester: Nishant Shah
Editorialist: Taranpreet Singh

DIFFICULTY

Medium

PREREQUISITES

Auxiliary tree, Disjoint set union.

PROBLEM

Nayra doesn’t like stories of people receiving random trees as birthday presents, but this time she received a tree as a present for her own birthday! After struggling for a day trying to figure out what to do with this tree, she asked Aryan for help. He gave her this problem.

You are given a tree with N vertices and weighted edges. It is rooted at vertex 1.

Define g(u,v) to be the maximum weight of an edge on the shortest path between vertex u and vertex v in this tree.

Just calculating g(u,v) is too easy, so you have to process Q queries.

In each query, you are given K distinct vertices v_1, v_2, \ldots, v_K. You have to compute the sum of the g-values of each pair of vertices among these K, i.e, the value

\sum_{i=1}^{K-1} \sum_{j=i+1}^K g(v_i, v_j)

Note: The input and output are large, so use fast input-output methods.

QUICK EXPLANATION

  • Auxiliary tree trick can be applied to reduce the problem into a single query problem to be solved in O(N) time.
  • In order to solve the simpler problem, sort the edges in non-decreasing order of weights and maintain sets of connected components. Whenever two sets are connected by an edge for the first time, the weight of that edge is largest on the path of any pair of special vertices, if the first vertex lies in the first component and the second vertex in the second component.
  • We can simply maintain the number of marked vertices in each component to compute the number of pairs having one specific edge as the edge with maximum weight on the path. The sum of such weights would be the required answer.

EXPLANATION

Let’s try to solve a simpler problem first

Simpler problem

Given a tree with N vertices and a list of distinct vertices S of size K. You need to compute \displaystyle\sum_{i=1}^{K-1} \sum_{j=i+1}^K g(v_i, v_j), where v_i denote i-th vertex in S.

Let’s assume all edges have distinct weights for now so that there is a unique edge with maximum weight. If for each e, we can compute the number of pairs of vertices having that edge as the edge with the largest weight, we can simply compute the sum of weights for each pair. For an edge (u_i, v_i, w_i), we don’t care about edges with weight up to w_i, but no edge with weight greater than w_i should be present.

So, let’s start with no edges. All nodes represent a distinct component. We also maintain the number of marked nodes in each component. We would process the edges in non-decreasing order of weight and add them to the tree. If, on adding edge e with weight w_e to tree, the nodes (x, y) become reachable from each other, then g(x, y) = w_e. Hence, on adding edge e, we need to compute the number of new pairs which became reachable after adding this edge.

Let’s say two components S_1 and S_2 are joined by edge e. The pairs of special vertices with this edge as an edge with the largest weight must have one endpoint in S_1 and one endpoint in S_2. if C_i denote the number of special vertices in S_i, then the number of pairs is C_1 * C_2, contributing total weight w_e * C_1 * C_2. Summing this over all edges gives the required answer.

In order to merge these two components, the disjoint set union can be used, either with path compression or union by rank to solve this in O(N*\alpha(N)).

One good practice problem for disjoint set union is here

Original problem

Now, we know how to solve each query in O(N*\alpha(N)). But this doesn’t scale well for Q queries. We need to solve the problem in time proportional to K. There is a neat trick, called auxiliary tree, which can be directly used to reduce the original problem to simpler problem.

The core idea of the trick is that, for a subset of vertices S, we only try to merge the edges into a single edge. Here, if we have a long chain of edges between two nodes, we can replace them with a single edge having a maximum of those weights.

This way, for a subset with K vertices, we can build an auxiliary tree with at most 2*K vertices. The idea is explained in detail here and here

The above trick allows us to build an auxiliary tree in O(K*log(K)), on which we can apply the solution to the simpler problem, solving the original problem in time proportional to K.

Practice problems for this trick are

TIME COMPLEXITY

The overall time complexity is O((N+Q)*log(N)) per test case, with a large constant.

SOLUTIONS

Setter's Solution
/*
  Compete against Yourself.
  Author - Aryan (@aryanc403)
*/
/*
  Credits -
  Atcoder library - https://atcoder.github.io/ac-library/production/document_en/ (namespace atcoder)
  Github source code of library - https://github.com/atcoder/ac-library/tree/master/atcoder
*/
 
#include <algorithm>
#include <cassert>
#include <vector>
 
namespace atcoder {
 
struct dsu {
  public:
    dsu() : _n(0) {}
    explicit dsu(int n) : _n(n), parent_or_size(n, -1) {}
 
    int merge(int a, int b) {
        assert(0 <= a && a < _n);
        assert(0 <= b && b < _n);
        int x = leader(a), y = leader(b);
        if (x == y) return x;
        // if (-parent_or_size[x] < -parent_or_size[y]) std::swap(x, y);
        parent_or_size[x] += parent_or_size[y];
        parent_or_size[y] = x;
        return x;
    }
 
    bool same(int a, int b) {
        assert(0 <= a && a < _n);
        assert(0 <= b && b < _n);
        return leader(a) == leader(b);
    }
 
    int leader(int a) {
        assert(0 <= a && a < _n);
        if (parent_or_size[a] < 0) return a;
        return parent_or_size[a] = leader(parent_or_size[a]);
    }
 
    int size(int a) {
        assert(0 <= a && a < _n);
        return -parent_or_size[leader(a)];
    }
 
    std::vector<std::vector<int>> groups() {
        std::vector<int> leader_buf(_n), group_size(_n);
        for (int i = 0; i < _n; i++) {
            leader_buf[i] = leader(i);
            group_size[leader_buf[i]]++;
        }
        std::vector<std::vector<int>> result(_n);
        for (int i = 0; i < _n; i++) {
            result[i].reserve(group_size[i]);
        }
        for (int i = 0; i < _n; i++) {
            result[leader_buf[i]].push_back(i);
        }
        result.erase(
            std::remove_if(result.begin(), result.end(),
                           [&](const std::vector<int>& v) { return v.empty(); }),
            result.end());
        return result;
    }
 
  private:
    int _n;
    std::vector<int> parent_or_size;
};
 
}  // namespace atcoder
 
#ifdef ARYANC403
    #include <header.h>
#else
    #pragma GCC optimize ("Ofast")
    #pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx")
    #pragma GCC optimize ("-ffloat-store")
    #include<bits/stdc++.h>
    #include <ext/pb_ds/assoc_container.hpp>
    #include <ext/pb_ds/tree_policy.hpp>
    #define dbg(args...) 42;
    #define endl "\n"
#endif
 
// y_combinator from @neal template https://codeforces.com/contest/1553/submission/123849801
// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2016/p0200r0.html
template<class Fun> class y_combinator_result {
    Fun fun_;
public:
    template<class T> explicit y_combinator_result(T &&fun): fun_(std::forward<T>(fun)) {}
    template<class ...Args> decltype(auto) operator()(Args &&...args) { return fun_(std::ref(*this), std::forward<Args>(args)...); }
};
template<class Fun> decltype(auto) y_combinator(Fun &&fun) { return y_combinator_result<std::decay_t<Fun>>(std::forward<Fun>(fun)); }
 
using namespace std;
#define fo(i,n)   for(i=0;i<(n);++i)
#define repA(i,j,n)   for(i=(j);i<=(n);++i)
#define repD(i,j,n)   for(i=(j);i>=(n);--i)
#define all(x) begin(x), end(x)
#define sz(x) ((lli)(x).size())
#define pb push_back
#define mp make_pair
#define X first
#define Y second
 
typedef long long int lli;
typedef long double mytype;
typedef pair<lli,lli> ii;
typedef vector<ii> vii;
typedef vector<lli> vi;
template <class T>
using ordered_set =  __gnu_pbds::tree<T,__gnu_pbds::null_type,less<T>,__gnu_pbds::rb_tree_tag,__gnu_pbds::tree_order_statistics_node_update>;
// X.find_by_order(k) return kth element. 0 indexed.
// X.order_of_key(k) returns count of elements strictly less than k.
 
const auto start_time = std::chrono::high_resolution_clock::now();
void aryanc403()
{
#ifdef ARYANC403
auto end_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end_time-start_time;
    cerr<<"Time Taken : "<<diff.count()<<"\n";
#endif
}
 
const lli INF = 0xFFFFFFFFFFFFFFFLL;
const lli SEED=chrono::steady_clock::now().time_since_epoch().count();
mt19937 rng(SEED);
inline lli rnd(lli l=0,lli r=INF)
{return uniform_int_distribution<lli>(l,r)(rng);}
 
class CMP
{public:
bool operator()(ii a , ii b) //For min priority_queue .
{    return ! ( a.X < b.X || ( a.X==b.X && a.Y <= b.Y ));   }};
 
void add( map<lli,lli> &m, lli x,lli cnt=1)
{
    auto jt=m.find(x);
    if(jt==m.end())         m.insert({x,cnt});
    else                    jt->Y+=cnt;
}
 
void del( map<lli,lli> &m, lli x,lli cnt=1)
{
    auto jt=m.find(x);
    if(jt->Y<=cnt)            m.erase(jt);
    else                      jt->Y-=cnt;
}
 
bool cmp(const ii &a,const ii &b)
{
    return a.X<b.X||(a.X==b.X&&a.Y<b.Y);
}
 
const lli mod = 1000000007L;
// const lli maxN = 1000000007L;
 
#define rep(i, a, b) for(int i = a; i < (b); ++i)
#define trav(a, x) for(auto& a : x)
typedef long long ll;
typedef pair<int, int> pii;
 
typedef vector<pii> vpi;
typedef vector<vpi> graph;
 
template<class T>
struct RMQ {
    vector<vector<T>> jmp;
    RMQ(const vector<T>& V) {
        int N = sz(V), on = 1, depth = 1;
        while (on < N) on *= 2, depth++;
        jmp.assign(depth, V);
        rep(i,0,depth-1) rep(j,0,N)
            jmp[i+1][j] = min(jmp[i][j],
            jmp[i][min(N - 1, j + (1 << i))]);
    }
    T query(int a, int b) {
        assert(a < b); // or return inf if a == b
        int dep = 31 - __builtin_clz(b - a);
        return min(jmp[dep][a], jmp[dep][b - (1 << dep)]);
    }
};
 
struct LCA {
    lli root;
    vi time;
    RMQ<pii> rmq;
    LCA(graph& C,const int rt) : root(rt),time(sz(C), -99), rmq(dfs(C)) {}
 
    vpi dfs(graph& C) {
        vector<tuple<int, int, int>> q;
        q.emplace_back(root,root,0);
        vpi ret;
        int T = 0, v, p, d;
        while (!q.empty()) {
            tie(v, p, d) = q.back();
            q.pop_back();
            if (d) ret.emplace_back(d, p);
            time[v] = T++;
            trav(e, C[v]) if (e.first != p)
                q.emplace_back(e.first, v, d+1);
        }
        return ret;
    }
 
    int query(int a, int b) {
        if (a == b) return a;
        a = time[a], b = time[b];
        return rmq.query(min(a, b), max(a, b)).second;
    }
};
 
void compressTree(vpi &ret,LCA& lca,vi& li) {
    // static vi rev; rev.resize(sz(lca.time));
    vi &T = lca.time;
    auto cmp = [&](int a, int b) { return T[a] < T[b]; };
    sort(all(li), cmp);
    int m = sz(li)-1;
    rep(i,0,m) {
        int a = li[i], b = li[i+1];
        li.push_back(lca.query(a, b));
    }
    sort(all(li), cmp);
    li.erase(unique(all(li)), li.end());
    rep(i,0,sz(li)-1) {
        int a = li[i], b = li[i+1];
        ret.emplace_back(lca.query(a, b), b);
    }
}
 
const lli maxN = 1e6;
lli rev[maxN];
 
//cities will be changed to new ids.
void init(LCA& lca,vi &cities,const vi &weights,vector<vi> &edges)
{
    auto subset=cities;
    vpi ctree;
    compressTree(ctree,lca,subset);
    int itr=0;
    for(auto x:subset)
        rev[x]=itr++;
    dbg(ctree,subset);
    edges.clear();
    for(auto x:ctree)
    {
        const lli u=rev[x.X];
        const lli v=rev[x.Y];
        const lli d=weights[lca.query(x.X,x.Y)];
        edges.pb({d,u,v});
    }
 
    for(auto &x:cities)
        x=rev[x];
}
 
 
int main(void) {
    ios_base::sync_with_stdio(false);cin.tie(NULL);
    // freopen("txt.in", "r", stdin);
    // freopen("txt.out", "w", stdout);
// cout<<std::fixed<<std::setprecision(35);
const auto solve=[](vector<vi> &edges,const vi &a)->lli{
    dbg(edges);
    const int n=sz(edges)+1;
    sort(all(edges));
    vi size(n);
    for(auto x:a)
        size[x]=1;
    atcoder::dsu d(n);
    lli ans=0;
    for(const auto &z:edges){
        const lli w=z[0],u=d.leader(z[1]),v=d.leader(z[2]);
        if(u==v)
            continue;
        ans+=w*size[u]*size[v];
        size[d.merge(u,v)] = size[u]+size[v];
    }
    return ans;
};
lli T;
cin>>T;while(T--)
{
    lli n,q;
    cin>>n>>q;
    dbg(T,n,q);
    vector<vi> edges(n-1,vi(3));
    for(int i=1;i<n;++i)
        edges[i-1][2]=i;
    for(auto &v:edges)
        cin>>v[1],v[1]--;
    for(auto &v:edges)
        cin>>v[0];
    sort(all(edges));
    const int N = 2*n-1;
    graph e(N);
    lli p=n;
    vi weights(N);
    atcoder::dsu d(N); // modified version
    for(const auto &z:edges){
        const lli w=z[0],u=d.leader(z[1]),v=d.leader(z[2]);
        assert(u!=v);
        weights[p]=w;
        e[p].pb({u,1});
        e[p].pb({v,1});
        e[v].pb({p,1});
        e[u].pb({p,1});
        d.merge(p,u);
        d.merge(p,v);
        assert(p==d.leader(p));
        p++;
    }
    assert(d.size(0)==N);
    for(auto &v:e)
        sort(all(v));
    LCA lca(e,N-1);
    vector<vi> queryEdges;
    while(q--){
        lli k;
        cin>>k;
        vi a(k);
        for(auto &x:a)
            cin>>x,x--;
        init(lca,a,weights,queryEdges);
        const lli ans = solve(queryEdges,a);
        cout<<ans<<" \n"[q==0];
    }
 
}   aryanc403();
    return 0;
}
Tester's Solution
#include <bits/stdc++.h>
using namespace std;
 
const int logN = 20;
const int maxN = 1000005;
 
vector<int> graph[maxN];
int timer, parent[maxN], weight[maxN], ancestor[maxN][logN], maximum[maxN][logN];
int depth[maxN], DsuParent[maxN], DsuSpecialCnt[maxN], outTime[maxN];
 
int find(int v){
    return (DsuParent[v] < 0)?v:(DsuParent[v] = find(DsuParent[v]));
}
 
long long combine(int u, int v, int w){
    u = find(u);
    v = find(v);
    assert(u != v);
    if(DsuParent[u] > DsuParent[v])swap(u, v);
 
    long long ret = 1LL*DsuSpecialCnt[u]*DsuSpecialCnt[v]*w;
 
    DsuSpecialCnt[u] += DsuSpecialCnt[v];
    DsuParent[u] += DsuParent[v]; DsuParent[v] = u;
    
    return ret;
}
 
inline int findLca(int u, int v){
    if(depth[u] < depth[v])swap(u, v);
    int depthDiff = depth[u] - depth[v];
    for(int i = logN - 1; i >= 0; i--){
	    if((depthDiff >> i) & 1){
		    u = ancestor[u][i];
	    }
    }
    if(u == v)return u;
    for(int i = logN - 1; i >= 0; i--){
	    if(ancestor[u][i] != ancestor[v][i]){
		    u = ancestor[u][i];
		    v = ancestor[v][i];
	    }
    }
    return ancestor[u][0];
}
 
inline int findAncestorPathWeight(int u, int v){
    int ret = 0;
    int depthDiff = depth[u] - depth[v];
    for(int i = logN - 1; i >= 0; i--){
	    if((depthDiff >> i) & 1){
		    ret = max(ret, maximum[u][i]);
		    u = ancestor[u][i];
	    }
    }
    return ret;
}
 
void dfs(int v){
    for(int u : graph[v]){
	    depth[u] = depth[v] + 1;
	    dfs(u);
    }
    outTime[v] = timer++;
}
 
void solve(){
    int N, Q;
    cin >> N >> Q;
    for(int i = 1; i <= N; i++){
	    graph[i].clear();
	    DsuParent[i] = -1;
	    DsuSpecialCnt[i] = 0;
    }
    for(int i = 2; i <= N; i++){
	    cin >> parent[i];
	    graph[parent[i]].emplace_back(i);
    }
    for(int i = 2; i <= N; i++){
	    cin >> weight[i];
    }
    depth[1] = 0;
    parent[1] = 1;
    weight[1] = 0;
    for(int i = 1; i <= N; i++){
	    maximum[i][0] = weight[i];
	    ancestor[i][0] = parent[i];
    }
    for(int k = 1; k < logN; k++){
	    for(int i = 1; i <= N; i++){
		    ancestor[i][k] = ancestor[ancestor[i][k - 1]][k - 1];
		    maximum[i][k] = max(maximum[i][k - 1], maximum[ancestor[i][k - 1]][k - 1]);
	    }
    }
    
    timer = 1; dfs(1);
 
    while(Q--){
	    int K;
	    cin >> K;
	    vector<int> vert(K);
	    for(int i = 0; i < K; i++){
		    cin >> vert[i];
	    }
	    set<pair<int, int> > vertSet;
	    for(int i = 0; i < K; i++){
		    vertSet.insert({outTime[vert[i]], vert[i]});
	    }
	    vector<array<int, 3> > compressedEdges;
	    while((int)vertSet.size() > 1){
		    int u = vertSet.begin()->second; vertSet.erase({outTime[u], u});
		    int v = vertSet.begin()->second;
		    int lca = findLca(u, v);
		    
		    compressedEdges.push_back({findAncestorPathWeight(u, lca), u, lca});
 
		    vertSet.insert({outTime[lca], lca});
	    }
	    for(int i = 0; i < K; i++){
		    DsuSpecialCnt[vert[i]] = 1;
	    }
	    long long answer = 0;
	    sort(compressedEdges.begin(), compressedEdges.end());
	    
	    for(auto &edge : compressedEdges){
		    answer += combine(edge[1], edge[2], edge[0]);
	    }
	    cout << answer << endl;
	    for(auto &edge : compressedEdges){
		    DsuParent[edge[1]] = -1;
		    DsuParent[edge[2]] = -1;
		    DsuSpecialCnt[edge[1]] = 0;
		    DsuSpecialCnt[edge[2]] = 0;
	    }
    }
}
 
int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
 
    int T;
    cin >> T;
    while(T--)solve();
 
    return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class YATP{
    //SOLUTION BEGIN
    static int B = 17;
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        int N = ni(), Q = ni();
        int[] from = new int[N-1], to = new int[N-1], w = new int[N-1];
        for(int i = 0; i< N-1; i++){
            from[i] = ni()-1;
            to[i] = i+1;
        }
        for(int i = 0; i< N-1; i++)w[i] = ni();
        int[][][] weightedTree = makeS(N, N-1, from, to, true);
        par = new int[B][N]; maxW = new int[B][N];
        dep = new int[N];
        for(int b = 0; b< B; b++){
            Arrays.fill(par[b], -1);
            Arrays.fill(maxW[b], Integer.MIN_VALUE);
        }
        dfs(weightedTree, w, 0, -1);
        
        VirtualTree vt = new VirtualTree(make(N, N-1, from, to, true));
        for(int q = 0; q< Q; q++){
            int K = ni();
            Integer[] V = new Integer[K];
            for(int i = 0; i< K; i++)V[i] = ni()-1;
            List<Object> o = vt.buildAuxTree(V);
            boolean[] original = (boolean[])o.get(0);
            int[][] edges = (int[][])o.get(1);
            p(solve(edges, original)+" ");
        }
        pn("");
    }
    //Given edge list of virtual tree, solve the problem
    long solve(int[][] e, boolean[] original){
        int N = original.length;
        Arrays.sort(e, (int[] i1, int[] i2) -> Integer.compare(i1[2], i2[2]));
        int[] set = java.util.stream.IntStream.range(0, N).toArray();
        int[] sz = new int[N];
        for(int i = 0; i< N; i++)if(original[i])sz[i]++;
        
        long ans = 0;
        for(int[] ee:e){
            int u = find(set, ee[0]), v = find(set, ee[1]);
            ans += ee[2] * (long)sz[u] * (long)sz[v];
            sz[u] += sz[v];
            set[v] = u;
        }
        return ans;
    }
    int find(int[] set, int u){return set[u] = (set[u] == u?u:find(set, set[u]));}
    
    static int[][] par, maxW;
    static int[] dep;
    void dfs(int[][][] tree, int[] w, int u, int p){
        for(int b = 1; b< B; b++)
            if(par[b-1][u] != -1){
                par[b][u] = par[b-1][par[b-1][u]];
                maxW[b][u] = Math.max(maxW[b-1][u], maxW[b-1][par[b-1][u]]);
            }
        
        for(int[] ee:tree[u]){
            int v = ee[0], weight = w[ee[1]];
            if(v == p)continue;
            dep[v] = dep[u]+1;
            par[0][v] = u;
            maxW[0][v] = weight;
            dfs(tree, w, v, u);
        }
    }
    //u must be ancestor of v
    int maxW(int u, int v){
        int d = dep[v]-dep[u];
        int ans = 0;
        for(int b = B-1; b>= 0; b--){
            if(((d>>b)&1)==1){
                ans = Math.max(ans, maxW[b][v]);
                v = par[b][v];
            }
        }
        return ans;
    }
    class VirtualTree {
        int N, time;
        int[] dep, st, en;
        int[][] tree;
        LCA lca;
        public VirtualTree(int[][] tree){
            time = -1;
            this.N = tree.length;
            this.tree = tree;
            dep = new int[N];
            st = new int[N];
            en = new int[N];
            pre(0, -1);
            lca = new LCA(tree);
        }
        void pre(int u, int p){
            st[u] = ++time;
            for(int v:tree[u])if(v != p){
                dep[v] = dep[u]+1;
                pre(v, u);
            }
            en[u] = time;
        }
        //Returns the edges of Virtual tree
        List<Object> buildAuxTree(Integer[] V){
            TreeMap<Integer, Integer> map = new TreeMap<>();
            int c = 0;
            for(Integer x:V)map.put(x, c++);
            Arrays.sort(V, (Integer i1, Integer i2) -> Integer.compare(st[i1], st[i2]));//Sorted by euler in time
            for(int i = 1; i< V.length; i++){
                int w = lca.lca(V[i-1], V[i]);
                if(!map.containsKey(w))map.put(w, c++);
            }
            //The set of vertices to be present in aux Tree is ready. Now let's add edges between them
            //We also relabel nodes from 0 to SZ-1, where labels from 0 to K-1 are initial labels
            int SZ = c;
            int[][] edges = new int[SZ-1][];
            int cnt = 0;
            Integer[] vertices = map.keySet().toArray(new Integer[SZ]);
            Arrays.sort(vertices, (Integer i1, Integer i2) -> Integer.compare(dep[i2], dep[i1]));//sorting by depth in descending order
            TreeMap<Integer, Integer> tin = new TreeMap<>();//Contains pair (tin[u], u) for vertices u which are processed, and whose parents are not yet assigned
            for(int u:vertices){
                //Processing vertex u, all deeper vertices already processed
                Map.Entry<Integer, Integer> e;
                //Following loop runs over all vertices v such that st[u] <= st[v] && en[v] <= en[u]
                while((e = tin.ceilingEntry(st[u])) != null && en[e.getValue()] <= en[u]){
                    int v = e.getValue();
                    //add edge u -> v with weight dist(u, v)
                    edges[cnt++] = new int[]{map.get(u), map.get(v), maxW(u, v)};
                    tin.remove(e.getKey());
                }
                tin.put(st[u], u);
            }
            
            boolean[] original = new boolean[SZ];
            for(int x:V)original[map.get(x)] = true;
            return Arrays.asList(original, edges);
        }
    }
    class LCA{
        int n = 0, ti= -1;
        int[] eu, fi, d;
        RMQ rmq;
        public LCA(int[][] g){
            n = g.length;
            eu = new int[2*n-1];fi = new int[n];d = new int[n];
            Arrays.fill(fi, -1);Arrays.fill(eu, -1);
            for(int i = 0; i< n; i++)if(fi[i] == -1)dfs(g, i, -1);
            rmq = new RMQ(1+ti, eu, d);
        }
        public LCA(int[][][] g){
            n = g.length;
            eu = new int[2*n-1];fi = new int[n];d = new int[n];
            Arrays.fill(fi, -1);Arrays.fill(eu, -1);
            for(int i = 0; i< n; i++)if(fi[i] == -1)dfs(g, i, -1);
            rmq = new RMQ(1+ti, eu, d);
        }
        void dfs(int[][] g, int u, int p){
            eu[++ti] = u;fi[u] = ti;
            for(int v:g[u])if(v!=p){
                d[v] = d[u]+1;
                dfs(g, v, u);eu[++ti] = u;
            }
        }
        void dfs(int[][][] g, int u, int p){
            eu[++ti] = u;fi[u] = ti;
            for(int[] v:g[u])if(v[0]!=p){
                d[v[0]] = d[u]+1;
                dfs(g, v[0], u);eu[++ti] = u;
            }
        }
        int lca(int u, int v){return rmq.query(Math.min(fi[u], fi[v]), Math.max(fi[u], fi[v]));}
        int dist(int u, int v){return d[u]+d[v]-2*d[lca(u,v)];}
        class RMQ{
            int[] len, d;
            int[][] rmq;
            public RMQ(int L, int[] ar, int[] weight){
                len = new int[L+1];
                this.d = weight;
                for(int i = 2; i<= L; i++)len[i] = len[i>>1]+1;
                rmq = new int[len[L]+1][L];
                for(int i = 0; i< rmq.length; i++)
                    for(int j = 0; j< rmq[i].length; j++)
                        rmq[i][j] = -1;
                for(int i = 0; i< L; i++)rmq[0][i] = ar[i];
                for(int b = 1; b<= len[L]; b++)
                    for(int i = 0; i + (1<<b)-1< L; i++)
                        if(weight[rmq[b-1][i]]<weight[rmq[b-1][i+(1<<(b-1))]])rmq[b][i] =rmq[b-1][i];
                        else rmq[b][i] = rmq[b-1][i+(1<<(b-1))];
            }
            int query(int l, int r){
                if(l==r)return rmq[0][l];
                int b = len[r-l];
                if(d[rmq[b][l]]<d[rmq[b][r-(1<<b)]])return rmq[b][l];
                return rmq[b][r-(1<<b)];
            }
        }
    }
    int[][] make(int n, int e, int[] from, int[] to, boolean f){
        int[][] g = new int[n][];int[]cnt = new int[n];
        for(int i = 0; i< e; i++){
            cnt[from[i]]++;
            if(f)cnt[to[i]]++;
        }
        for(int i = 0; i< n; i++)g[i] = new int[cnt[i]];
        for(int i = 0; i< e; i++){
            g[from[i]][--cnt[from[i]]] = to[i];
            if(f)g[to[i]][--cnt[to[i]]] = from[i];
        }
        return g;
    }
    int[][][] makeS(int n, int e, int[] from, int[] to, boolean f){
        int[][][] g = new int[n][][];int[]cnt = new int[n];
        for(int i = 0; i< e; i++){
            cnt[from[i]]++;
            if(f)cnt[to[i]]++;
        }
        for(int i = 0; i< n; i++)g[i] = new int[cnt[i]][];
        for(int i = 0; i< e; i++){
            g[from[i]][--cnt[from[i]]] = new int[]{to[i], i, 0};
            if(f)g[to[i]][--cnt[to[i]]] = new int[]{from[i], i, 1};
        }
        return g;
    }
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = true;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
        new YATP().run();
    }
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
                }
            }
            return st.nextToken();
        }

        String nextLine() throws Exception{
            String str = "";
            try{   
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            }  
            return str;
        }
    }
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

1 Like

Centroid Decomp kinda trivializes this problem. My (scuffed) in-contest submission: Solution: 54739158 | CodeChef

1 Like

The queries can be answered offline using only DSU.

1 Like

Would you please elaborate your approach