PROBLEM LINK:
Contest Division 1
Contest Division 2
Contest Division 3
Setter: Ma Zihang
Tester: Taranpreet Singh
Editorialist: Kanhaiya Mohan
DIFFICULTY:
Easy-Medium
PREREQUISITES:
Trees, LCA
PROBLEM:
Given is a tree with N weighted vertices and N-1 weighted edges. The i^{th} vertex has a weight of A_i. The i^{th} edge connects vertices u_i and v_i has a weight W_i.
Let dist(x,y) be the sum of weights of the edges in the unique simple path connecting vertices x and y. Let V(x,y) be the set of vertices appearing in the unique simple path connecting vertices x and y (including x, y).
You are asked Q queries of the form x_i y_i. For each query, find the value of \sum_{k\in V(x_i,y_i)}(dist(x_i, k) - dist(k, y_i)) \cdot A_k
EXPLANATION:
Subtask 1: u_i = i, v_i = i+1
This represents a line graph. In this case, we can use dynamic programming to find the solution.
Let us formally define some dp arrays which we calculate moving in forward direction:
- sum_i = sum_{i-1} + A_i
- dis_i = dis_{i-1} + W_{i-1}
- weight_i = weight_{i-1} + dis_i\cdot A_i
The answer for each query \text{x y}, (x<y) is 2\cdot(weight_y-weight_x-(sum_y-sum_x)\cdot dis_x) - (sum_y-sum_{x-1})\cdot (dis_y-dis_x).
The time taken for precomputation is O(N) and each query can be answered in O(1). Thus, the complexity is O(N+Q).
Subtask 2: Original Constraints
Let acn=lca(x,y). We can precompute the LCA using binary lifting in O(Nlog(N)) time. Now, for each query, we can find the LCA in O(log(N)) time.
Let sum_x= \sum_{k\in V(x,root)}A_k and weight_x=\sum_ {k\in V(x,root)}dist(root,k)\cdot A_k. The values of sum_x and weight_x can be precomputed using a single DFS in O(N).
We can split the path from x to y into two parts: x to acn and y to acn. Each part looks like a linear graph. Let us consider the path from x to acn.
The answer for this path is (sum_x-sum_{acn})\cdot (dist(x,acn)-dist(acn,y))-2\cdot(weight_x-weight_{acn}-(sum_x-sum_{acn})\cdot dist(acn,root)).
We can similarly calculate the answer for the path from y to acn.
Note that the contribution of acn was skipped in the above two paths. Thus, we add (dist(x,acn)-dist(acn,y))\cdot A_{acn} to the answer.
TIME COMPLEXITY:
The time complexity is O((N+Q)log(N)) per test case.
SOLUTION:
Setter's Solution
#include <iostream>
#include <vector>
int const N = 2e5;
int const LGN = 18;
struct Node {
int to;
int value;
};
std::vector<Node> tree[N + 1];
int value[N + 1];
int depth[N + 1];
long long sum[N + 1];
long long dist[N + 1];
long long weight[N + 1];
int parent[N + 1];
int lg[2 * N];
int euler[LGN + 1][2 * N];
int pos[N + 1];
int euler_cnt;
int n, q;
void dfs(int, int);
int better(int, int);
void prepare();
int lca(int, int);
int lca(int a, int b) {
a = pos[a];
b = pos[b];
if (a > b) {
std::swap(a, b);
}
int k = lg[b - a + 1];
return better(euler[k][a], euler[k][b - (1 << k) + 1]);
}
void prepare() {
for (int i = 1; (1 << i) < 2 * n; i++) {
for (int j = 1; j < 2 * n; j++) {
euler[i][j] = better(euler[i - 1][j], euler[i - 1][j + (1 << (i - 1))]);
}
}
for (int i = 2; i < 2 * n; i++) {
lg[i] = lg[i / 2] + 1;
}
}
int better(int a, int b) {
if (depth[a] < depth[b]) {
return a;
} else {
return b;
}
}
void dfs(int root, int father) {
sum[root] = sum[father] + value[root];
depth[root] = depth[father] + 1;
parent[root] = father;
euler_cnt++;
euler[0][euler_cnt] = root;
pos[root] = euler_cnt;
for (auto edge : tree[root]) {
int to = edge.to;
int val = edge.value;
if (to != father) {
dist[to] = dist[root] + val;
weight[to] = weight[root] + value[to] * dist[to];
dfs(to, root);
euler_cnt++;
euler[0][euler_cnt] = root;
}
}
}
int dis(int x,int y) {
return dist[x] - dist[y];
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(0);
std::cout.tie(0);
int T;
std::cin >> T;
while(T--) {
std::cin >> n >> q;
euler_cnt = 0;
for(int i = 1; i <= n; i++) {
tree[i].clear();
pos[i] = sum[i] = dist[i] = weight[i] = parent[i] = 0;
}
for (int i = 1; i <= n; i++) {
std::cin >> value[i];
}
for (int i = 1; i < n; i++) {
int a, b, v;
std::cin >> a >> b >> v;
tree[a].push_back({ b, v });
tree[b].push_back({ a, v });
}
dfs(1, 0);
prepare();
for (int i = 1; i <= q; i++) {
int x, y;
std::cin >> x >> y;
int acn = lca(x, y);
long long ans = 0;
ans += (sum[x] - sum[acn]) * (dis(x, acn) - dis(y, acn)) - 2 * (weight[x] - weight[acn] - (sum[x] - sum[acn]) * dist[acn]);
ans += (sum[y] - sum[acn]) * (dis(x, acn) - dis(y, acn)) + 2 * (weight[y] - weight[acn] - (sum[y] - sum[acn]) * dist[acn]);
std::cout << ans + 1ll * (dis(x, acn) - dis(y, acn)) * value[acn]<< '\n';
}
}
return 0;
}
Tester's Solution
import java.util.*;
import java.io.*;
class TRMT{
//SOLUTION BEGIN
int B = 18;
void pre() throws Exception{}
void solve(int TC) throws Exception{
int N = ni(), Q = ni();
long[] A = new long[N];
for(int i = 0; i< N; i++)A[i] = nl();
int[] from = new int[N-1], to = new int[N-1];
long[] W = new long[N-1];
for(int i = 0; i< N-1; i++){
from[i] = ni()-1;
to[i] = ni()-1;
W[i] = nl();
}
int[][] par = new int[B][N];
long[][] sumA = new long[B][N], sumAD = new long[B][N];
long[] dist = new long[N];
int[] dep = new int[N];
for(int b = 0; b< B; b++)Arrays.fill(par[b], -1);
int[][][] g = makeS(N, N-1, from, to, true);
dfs(g, A, W, par, sumA, sumAD, dist, dep, 0, -1);
for(int q = 0; q< Q; q++){
int x = ni()-1, y = ni()-1;
int lca = lca(par, dep, x, y);
long sumA1 = 0, sumAD1 = 0, sumA2 = 0, sumAD2 = 0;
for(int b = B-1, u = x, v = y; b>= 0; b--){
if(par[b][u] != -1 && dep[par[b][u]] >= dep[lca]){
sumA1 += sumA[b][u];
sumAD1 += sumAD[b][u];
u = par[b][u];
}
if(par[b][v] != -1 && dep[par[b][v]] >= dep[lca]){
sumA2 += sumA[b][v];
sumAD2 += sumAD[b][v];
v = par[b][v];
}
}
long ans = (dist[x]-dist[y]+2*dist[lca])*sumA1 - 2*sumAD1+
(dist[x]-dist[y]-2*dist[lca])*sumA2 + 2*sumAD2 +
A[lca] * (dist[x]-dist[y]);
pn(ans);
}
}
void dfs(int[][][] g, long[] A, long[] W, int[][] par, long[][] sumA, long[][] sumAD, long[] dist, int[] dep, int u, int p){
for(int b = 1; b< B; b++){
if(par[b-1][u] != -1 && par[b-1][par[b-1][u]] != -1){
par[b][u] = par[b-1][par[b-1][u]];
sumA[b][u] = sumA[b-1][u] + sumA[b-1][par[b-1][u]];
sumAD[b][u] = sumAD[b-1][u] + sumAD[b-1][par[b-1][u]];
}
}
for(int[] edge:g[u]){
int v = edge[0], edge_id = edge[1];
long w = W[edge_id];
if(v == p)continue;
par[0][v] = u;
dep[v] = dep[u]+1;
dist[v] = dist[u] + w;
sumA[0][v] = A[v];
sumAD[0][v] = A[v]*dist[v];
dfs(g, A, W, par, sumA, sumAD, dist, dep, v, u);
}
}
int lca(int[][] par, int[] dep, int u, int v){
if(dep[v] > dep[u]){
int tmp = v;
v = u;
u = tmp;
}
for(int b = B-1; b >= 0; b--)
if((((dep[u]-dep[v])>>b)&1) == 1)
u = par[b][u];
if(u == v)return u;
for(int b = B-1; b>= 0; b--)
if(par[b][u] != par[b][v]){
u = par[b][u];
v = par[b][v];
}
return par[0][u];
}
int[][][] makeS(int n, int e, int[] from, int[] to, boolean f){
int[][][] g = new int[n][][];int[]cnt = new int[n];
for(int i = 0; i< e; i++){
cnt[from[i]]++;
if(f)cnt[to[i]]++;
}
for(int i = 0; i< n; i++)g[i] = new int[cnt[i]][];
for(int i = 0; i< e; i++){
g[from[i]][--cnt[from[i]]] = new int[]{to[i], i, 0};
if(f)g[to[i]][--cnt[to[i]]] = new int[]{from[i], i, 1};
}
return g;
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new TRMT().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define endl "\n"
int N, Q, LIMIT = 20;
vector<int> sum;
vector<int> arr;
vector<int> dist;
vector<int> depth;
vector<int> weight;
vector<vector<pair<int, int>>> tree;
vector<vector<int>> table;
void dfs(int src, int parent, int level = 1) {
sum[src] = sum[parent] + arr[src]; // Sum of arr[i] from root to i.
depth[src] = level;
table[src][0] = parent;
for(int i = 1; i <= LIMIT; i ++) {
if(table[src][i-1] == -1)
break;
table[src][i] = table[table[src][i-1]][i-1];
}
for(auto child : tree[src]) {
int idx = child.first, wt = child.second;
if(idx == parent) continue;
dist[idx] = dist[src] + wt;
weight[idx] = weight[src] + (dist[idx] * arr[idx]); // sum of arr[i] * dist(root, i) from root to i.
dfs(idx, src, level + 1);
}
}
int getLCA(int x, int y) {
if(depth[x] < depth[y]) {
swap(x, y);
}
for(int j = LIMIT; j >= 0; j --) {
if((depth[x] - (1 << j)) >= depth[y]) {
x = table[x][j];
}
}
if(x == y) return x;
for(int j = LIMIT; j >= 0; j --) {
if(table[x][j] != table[y][j]) {
x = table[x][j];
y = table[y][j];
}
}
return table[x][0];
}
void solve() {
dfs(1, 0);
while(Q -- ) {
int x, y;
cin >> x >> y;
int lca = getLCA(x, y);
int contribLCA = (dist[x] - dist[y]) * arr[lca];
int ans1 = (sum[x] - sum[lca]) * (dist[x] - dist[y] + 2 * dist[lca]) - 2 * (weight[x] - weight[lca]);
int ans2 = (sum[y] - sum[lca]) * (dist[x] - dist[y] - 2 * dist[lca]) + 2 * (weight[y] - weight[lca]);
int ans = ans1 + ans2 + contribLCA;
cout << ans << endl;
}
}
int32_t main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int T; cin >> T;
while(T --) {
cin >> N >> Q;
tree.resize(N + 1);
arr.assign(N + 1, 0LL);
depth.assign(N + 1, 0LL);
sum.assign(N + 1, 0LL);
dist.assign(N + 1, 0LL);
weight.assign(N + 1, 0LL);
table.assign(N + 1, vector<int>(21, 0));
for(int i = 1; i <= N; i ++) cin >> arr[i];
for(int i = 1; i < N; i ++) {
int u, v, w;
cin >> u >> v >> w;
tree[u].push_back({v, w});
tree[v].push_back({u, w});
}
solve();
tree.clear();
table.clear();
}
}