TRMT - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3

Setter: Ma Zihang
Tester: Taranpreet Singh
Editorialist: Kanhaiya Mohan

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Trees, LCA

PROBLEM:

Given is a tree with N weighted vertices and N-1 weighted edges. The i^{th} vertex has a weight of A_i. The i^{th} edge connects vertices u_i and v_i has a weight W_i.

Let dist(x,y) be the sum of weights of the edges in the unique simple path connecting vertices x and y. Let V(x,y) be the set of vertices appearing in the unique simple path connecting vertices x and y (including x, y).

You are asked Q queries of the form x_i y_i. For each query, find the value of \sum_{k\in V(x_i,y_i)}(dist(x_i, k) - dist(k, y_i)) \cdot A_k

EXPLANATION:

Subtask 1: u_i = i, v_i = i+1

This represents a line graph. In this case, we can use dynamic programming to find the solution.
Let us formally define some dp arrays which we calculate moving in forward direction:

  • sum_i = sum_{i-1} + A_i
  • dis_i = dis_{i-1} + W_{i-1}
  • weight_i = weight_{i-1} + dis_i\cdot A_i

The answer for each query \text{x y}, (x<y) is 2\cdot(weight_y-weight_x-(sum_y-sum_x)\cdot dis_x) - (sum_y-sum_{x-1})\cdot (dis_y-dis_x).

The time taken for precomputation is O(N) and each query can be answered in O(1). Thus, the complexity is O(N+Q).

Subtask 2: Original Constraints

Let acn=lca(x,y). We can precompute the LCA using binary lifting in O(Nlog(N)) time. Now, for each query, we can find the LCA in O(log(N)) time.
Let sum_x= \sum_{k\in V(x,root)}A_k and weight_x=\sum_ {k\in V(x,root)}dist(root,k)\cdot A_k. The values of sum_x and weight_x can be precomputed using a single DFS in O(N).

We can split the path from x to y into two parts: x to acn and y to acn. Each part looks like a linear graph. Let us consider the path from x to acn.
The answer for this path is (sum_x-sum_{acn})\cdot (dist(x,acn)-dist(acn,y))-2\cdot(weight_x-weight_{acn}-(sum_x-sum_{acn})\cdot dist(acn,root)).
We can similarly calculate the answer for the path from y to acn.
Note that the contribution of acn was skipped in the above two paths. Thus, we add (dist(x,acn)-dist(acn,y))\cdot A_{acn} to the answer.

TIME COMPLEXITY:

The time complexity is O((N+Q)log(N)) per test case.

SOLUTION:

Setter's Solution
#include <iostream>
#include <vector>

int const N = 2e5;
int const LGN = 18;

struct Node {
	int to;
	int value;
};

std::vector<Node> tree[N + 1];
int value[N + 1];
int depth[N + 1];
long long sum[N + 1];
long long dist[N + 1];
long long weight[N + 1];
int parent[N + 1];
int lg[2 * N];
int euler[LGN + 1][2 * N];
int pos[N + 1];
int euler_cnt;
int n, q;

void dfs(int, int);
int better(int, int);
void prepare();
int lca(int, int);

int lca(int a, int b) {
	a = pos[a];
	b = pos[b];

	if (a > b) {
		std::swap(a, b);
	}

	int k = lg[b - a + 1];
	return better(euler[k][a], euler[k][b - (1 << k) + 1]);
}

void prepare() {
	for (int i = 1; (1 << i) < 2 * n; i++) {
		for (int j = 1; j < 2 * n; j++) {
			euler[i][j] = better(euler[i - 1][j], euler[i - 1][j + (1 << (i - 1))]);
		}
	}

	for (int i = 2; i < 2 * n; i++) {
		lg[i] = lg[i / 2] + 1;
	}
}

int better(int a, int b) {
	if (depth[a] < depth[b]) {
		return a;
	} else {
		return b;
	}
}

void dfs(int root, int father) {
	sum[root] = sum[father] + value[root];
	depth[root] = depth[father] + 1;
	parent[root] = father;

	euler_cnt++;
	euler[0][euler_cnt] = root;
	pos[root] = euler_cnt;

	for (auto edge : tree[root]) {
		int to = edge.to;
		int val = edge.value;

		if (to != father) {
			dist[to] = dist[root] + val;
			weight[to] = weight[root] + value[to] * dist[to];
			dfs(to, root);

			euler_cnt++;
			euler[0][euler_cnt] = root;
		}
	}
}

int dis(int x,int y) {
	return dist[x] - dist[y];
}

int main() {
	std::ios::sync_with_stdio(false);
	std::cin.tie(0);
	std::cout.tie(0);
	int T;
	std::cin >> T;
	
	while(T--) {
		std::cin >> n >> q;
		
		euler_cnt = 0;
		for(int i = 1; i <= n; i++) {
			tree[i].clear();
			pos[i] = sum[i] = dist[i] = weight[i] = parent[i] = 0;
		}
		
		for (int i = 1; i <= n; i++) {
			std::cin >> value[i];
		}
	
		for (int i = 1; i < n; i++) {
			int a, b, v;
			std::cin >> a >> b >> v;
	
			tree[a].push_back({ b, v });
			tree[b].push_back({ a, v });
		}
	
		dfs(1, 0);
		prepare();
	
		for (int i = 1; i <= q; i++) {
			int x, y;
			std::cin >> x >> y;
	
			int acn = lca(x, y);
			
			long long ans = 0;
			ans += (sum[x] - sum[acn]) * (dis(x, acn) - dis(y, acn)) - 2 * (weight[x] - weight[acn] - (sum[x] - sum[acn]) * dist[acn]); 
			ans += (sum[y] - sum[acn]) * (dis(x, acn) - dis(y, acn)) + 2 * (weight[y] - weight[acn] - (sum[y] - sum[acn]) * dist[acn]);
	
			std::cout << ans + 1ll * (dis(x, acn) - dis(y, acn)) * value[acn]<< '\n';
		}

	}
	return 0;
}

Tester's Solution
import java.util.*;
import java.io.*;
class TRMT{
    //SOLUTION BEGIN
    int B = 18;
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        int N = ni(), Q = ni();
        long[] A = new long[N];
        for(int i = 0; i< N; i++)A[i] = nl();
        int[] from = new int[N-1], to = new int[N-1];
        long[] W = new long[N-1];
        for(int i = 0; i< N-1; i++){
            from[i] = ni()-1;
            to[i] = ni()-1;
            W[i] = nl();
        }
        
        int[][] par = new int[B][N];
        long[][] sumA = new long[B][N], sumAD = new long[B][N];
        long[] dist = new long[N];
        int[] dep = new int[N];
        for(int b = 0; b< B; b++)Arrays.fill(par[b], -1);
        int[][][] g = makeS(N, N-1, from, to, true);
        dfs(g, A, W, par, sumA, sumAD, dist, dep, 0, -1);

        for(int q = 0; q< Q; q++){
            int x = ni()-1, y = ni()-1;
            int lca = lca(par, dep, x, y);
            
            long sumA1 = 0, sumAD1 = 0, sumA2 = 0, sumAD2 = 0;
            for(int b = B-1, u = x, v = y; b>= 0; b--){
                if(par[b][u] != -1 && dep[par[b][u]] >= dep[lca]){
                    sumA1 += sumA[b][u];
                    sumAD1 += sumAD[b][u];
                    u = par[b][u];
                }
                if(par[b][v] != -1 && dep[par[b][v]] >= dep[lca]){
                    sumA2 += sumA[b][v];
                    sumAD2 += sumAD[b][v];
                    v = par[b][v];
                }
            }
            
            long ans = (dist[x]-dist[y]+2*dist[lca])*sumA1 - 2*sumAD1+
                    (dist[x]-dist[y]-2*dist[lca])*sumA2 + 2*sumAD2 +
                    A[lca] * (dist[x]-dist[y]);
            
            pn(ans);
        }
    }
    void dfs(int[][][] g, long[] A, long[] W, int[][] par, long[][] sumA, long[][] sumAD, long[] dist, int[] dep, int u, int p){
        for(int b = 1; b< B; b++){
            if(par[b-1][u] != -1 && par[b-1][par[b-1][u]] != -1){
                par[b][u] = par[b-1][par[b-1][u]];
                sumA[b][u] = sumA[b-1][u] + sumA[b-1][par[b-1][u]];
                sumAD[b][u] = sumAD[b-1][u] + sumAD[b-1][par[b-1][u]];
            }
        }
        
        for(int[] edge:g[u]){
            int v = edge[0], edge_id = edge[1];
            long w = W[edge_id];
            if(v == p)continue;
            par[0][v] = u;
            dep[v] = dep[u]+1;
            dist[v] = dist[u] + w;
            sumA[0][v] = A[v];
            sumAD[0][v] = A[v]*dist[v];
            dfs(g, A, W, par, sumA, sumAD, dist, dep, v, u);
        }
    }
    int lca(int[][] par, int[] dep, int u, int v){
        if(dep[v] > dep[u]){
            int tmp = v;
            v = u;
            u = tmp;
        }
        for(int b = B-1; b >= 0; b--)
            if((((dep[u]-dep[v])>>b)&1) == 1)
                u = par[b][u];
        if(u == v)return u;
        for(int b = B-1; b>= 0; b--)
            if(par[b][u] != par[b][v]){
                u = par[b][u];
                v = par[b][v];
            }
        return par[0][u];
    }
    int[][][] makeS(int n, int e, int[] from, int[] to, boolean f){
        int[][][] g = new int[n][][];int[]cnt = new int[n];
        for(int i = 0; i< e; i++){
            cnt[from[i]]++;
            if(f)cnt[to[i]]++;
        }
        for(int i = 0; i< n; i++)g[i] = new int[cnt[i]][];
        for(int i = 0; i< e; i++){
            g[from[i]][--cnt[from[i]]] = new int[]{to[i], i, 0};
            if(f)g[to[i]][--cnt[to[i]]] = new int[]{from[i], i, 1};
        }
        return g;
    }
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = true;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
        new TRMT().run();
    }
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
                }
            }
            return st.nextToken();
        }

        String nextLine() throws Exception{
            String str = "";
            try{   
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            }  
            return str;
        }
    }
}
Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;
#define int long long 

#define endl "\n"

int N, Q, LIMIT = 20;
vector<int> sum;
vector<int> arr;
vector<int> dist;
vector<int> depth;
vector<int> weight;
vector<vector<pair<int, int>>> tree;
vector<vector<int>> table;

void dfs(int src, int parent, int level = 1) {
    sum[src] = sum[parent] + arr[src]; // Sum of arr[i] from root to i.
    depth[src] = level;
    table[src][0] = parent;
    for(int i = 1; i <= LIMIT; i ++) {
        if(table[src][i-1] == -1)
            break;
        table[src][i] = table[table[src][i-1]][i-1];
    }
    for(auto child : tree[src]) {
        int idx = child.first, wt = child.second;
        if(idx == parent) continue;
        dist[idx] = dist[src] + wt;
        weight[idx] = weight[src] + (dist[idx] * arr[idx]); // sum of arr[i] * dist(root, i) from root to i.
        dfs(idx, src, level + 1);
    }
}

int getLCA(int x, int y) {
    if(depth[x] < depth[y]) {
        swap(x, y);
    }
    for(int j = LIMIT; j >= 0; j --) {
        if((depth[x] - (1 << j)) >= depth[y]) {
            x = table[x][j];
        }
    }
    if(x == y) return x;
    for(int j = LIMIT; j >= 0; j --) {
        if(table[x][j] != table[y][j]) {
            x = table[x][j];
            y = table[y][j];
        }
    }
    return table[x][0];
}

void solve() {
    dfs(1, 0);
    while(Q -- ) {
        int x, y;
        cin >> x >> y;
        int lca = getLCA(x, y);
        int contribLCA = (dist[x] - dist[y]) * arr[lca];
        int ans1 = (sum[x] - sum[lca]) * (dist[x] - dist[y] + 2 * dist[lca]) - 2 * (weight[x] - weight[lca]); 
        int ans2 = (sum[y] - sum[lca]) * (dist[x] - dist[y] - 2 * dist[lca]) + 2 * (weight[y] - weight[lca]);
        int ans = ans1 + ans2 + contribLCA;
        cout << ans << endl;
    }
}

int32_t main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    int T; cin >> T;
    while(T --) {
        cin >> N >> Q;
        tree.resize(N + 1);
        arr.assign(N + 1, 0LL);
        depth.assign(N + 1, 0LL);
        sum.assign(N + 1, 0LL);
        dist.assign(N + 1, 0LL);
        weight.assign(N + 1, 0LL);
        table.assign(N + 1, vector<int>(21, 0));
        
        for(int i = 1; i <= N; i ++) cin >> arr[i];
        for(int i = 1; i < N; i ++) {
            int u, v, w;
            cin >> u >> v >> w;
            tree[u].push_back({v, w});
            tree[v].push_back({u, w});
        }
        solve();
        tree.clear();
        table.clear();
    }
}
5 Likes