Problem Link:
Contest
Practice link
Setter: Aneesh D H
Tester: Suryansh Kumar (@hellb0y_suru)
Difficulty: Medium
Pre-requisite: Sparse Tables, LCA
Problem:
You are given a tree with N nodes where each node i has a value A_i and Q queries each of which has a pair of nodes u and v. For each query, find the sum of (1 + d_{ui}) \times A_i for each node i in P_{uv}. Here, P_{uv} is the path from u to v and d_{ui} is the distance from node u to node i.
Explanation:
Construct two sparse tables:
- S1: For a node which represents the node range [p, q], the sparse table contains the value A_p + ... + A_q.
- S2: For a node which represents the node range [p, q], the sparse table contains the value 1 \times A_p + ... + (1 + d_{pq}) \times A_q.
Consider a query for nodes u and v. Let the least common ancestor be m = lca(u, v). There are 3 possible cases:
- m = v: The answer to the query = S2[u, v].
- m = u: The answer to the query = (2+d_{uv}) \times S1[v, u] - S2[v, u].
-
m \neq u and m \neq v: Find m' | m' \in N(m) \land m' \in P_{vm}.
The answer to the query = S1[u, m] + (2+d_{uv}) \times S1[v, m'] - S2[v, m'].
The image below shows m'.
Solutions:
Setter's Code
#pragma GCC optimize("O3")
#include <bits/stdc++.h>
#define gcj "Case #"
#define adj_list vector<vi>
#define endl "\n"
#define INF_INT 2e9
#define INF_LL 2e18
#define matmax 25
#define mod 1000000007
#define mp make_pair
#define pb push_back
#define pi pair<int, int>
#define pii pair<int, pair<int, int> >
#define pl pair<ll, ll>
#define pll pair<ll, pair<ll, ll> >
#define vi vector<int>
#define vl vector<ll>
#define fastio ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
using namespace std;
typedef long long int ll;
void __print(int x) {cerr << x;}
void __print(long x) {cerr << x;}
void __print(long long x) {cerr << x;}
void __print(unsigned x) {cerr << x;}
void __print(unsigned long x) {cerr << x;}
void __print(unsigned long long x) {cerr << x;}
void __print(float x) {cerr << x;}
void __print(double x) {cerr << x;}
void __print(long double x) {cerr << x;}
void __print(char x) {cerr << '\'' << x << '\'';}
void __print(const char *x) {cerr << '\"' << x << '\"';}
void __print(const string &x) {cerr << '\"' << x << '\"';}
void __print(bool x) {cerr << (x ? "true" : "false");}
template<typename T, typename V>
void __print(const pair<T, V> &x) {cerr << '{'; __print(x.first); cerr << ','; __print(x.second); cerr << '}';}
template<typename T>
void __print(const T &x) {int f = 0; cerr << '{'; for (auto &i: x) cerr << (f++ ? "," : ""), __print(i); cerr << "}";}
void _print() {cerr << "]\n";}
template <typename T, typename... V>
void _print(T t, V... v) {__print(t); if (sizeof...(v)) cerr << ", "; _print(v...);}
#define debug(x...) cerr << "[" << #x << "] = ["; _print(x)
ll pow1(ll a, ll b, ll m = mod)
{
if (b == 0)
return 1ll;
else if (b == 1)
return a;
else
{
ll x = pow1(a, b/2, m);
x *= x;
x %= m;
if (b%2)
{
x *= a;
x %= m;
}
return x;
}
}
template <class avgType>
avgType avg(avgType a, avgType b) {
return (a + b)/2;
}
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
mt19937_64 rng64(chrono::steady_clock::now().time_since_epoch().count());
int randInt() {
return rng() % INT_MAX;
}
ll randLL() {
return rng64() % LLONG_MAX;
}
vector< int > hashmods = {1000000007, 1000000009, 1000000021, 1000000033, 1000000087, 1000000093};
ll n, q, u, v, vals[110000];
ll stree[2][20][110000], par[20][110000], lev[110000];
vector< vi > tree(110000);
void dfs(int index, int p, int l) {
par[0][index] = p;
lev[index] = l;
for (int neigh : tree[index]) {
if (neigh == p)
continue;
dfs(neigh, index, l+1);
}
}
void build() {
for (int i = 0; i <= 19; i++) {
for (int j = 1; j <= n; j++) {
if (i == 0) {
stree[0][i][j] = vals[j];
stree[1][i][j] = vals[j];
}
else if (par[i-1][j] == 0) {
continue;
} else {
par[i][j] = par[i-1][par[i-1][j]];
stree[0][i][j] = stree[0][i-1][j] + stree[0][i-1][par[i-1][j]];
stree[1][i][j] = stree[1][i-1][j] + stree[1][i-1][par[i-1][j]] + (1ll<<(i-1)) * stree[0][i-1][par[i-1][j]];
}
}
}
}
ll query1(int index, int num) {
if (num <= 0) {
return 0;
}
ll ans = 0;
int j;
for (j = 0; (1<<j) <= num; j++);
j--;
ans = stree[0][j][index] + query1(par[j][index], num - (1<<j));
return ans;
}
ll query2(int index, int num, ll beg) {
if (num <= 0) {
return 0;
}
ll ans = 0;
int j;
for (j = 0; (1<<j) <= num; j++);
j--;
ans = stree[1][j][index] + beg * (stree[0][j][index]) + query2(par[j][index], num - (1<<j), beg + (1<<j));
return ans;
}
int goup(int index, int l) {
if (l <= 0)
return index;
int curr = index, j;
for (j = 0; (1<<j) <= l; j++);
j--;
curr = par[j][index];
return goup(curr, l - (1<<j));
}
pi lca(int a, int b) {
pi up = mp(0, 0);
if (lev[a] < lev[b]) {
up = mp(0, lev[b] - lev[a]);
b = goup(b, lev[b] - lev[a]);
} else {
up = mp(lev[a] - lev[b], 0);
a = goup(a, lev[a] - lev[b]);
}
if (a == b) {
return up;
}
int curr = 0;
for (int i = 19; i >= 0; i--) {
if (par[i][a] == par[i][b]) {
continue;
}
curr += (1<<i);
a = par[i][a];
b = par[i][b];
}
return mp(up.first + curr + 1, up.second + curr + 1);
}
int main() {
fastio;
cin>>n;
assert(n <= 100000);
for (int i = 1; i <= n; i++) {
cin>>vals[i];
assert(1 <= vals[i] and vals[i] <= 50000000);
}
for (int i = 1; i < n; i++) {
cin>>u>>v;
assert(u <= n);
assert(v <= n);
tree[u].pb(v);
tree[v].pb(u);
}
dfs(1, -1, 1);
build();
cin>>q;
assert(q <= 200000);
ll prevans = 0;
for (int i = 1; i <= q; i++) {
cin>>u>>v;
assert(0 <= u and u <= 1e18);
assert(0 <= v and v <= 1e18);
u = prevans ^ u;
v = prevans ^ v;
assert(1 <= u and u <= n);
assert(1 <= v and v <= n);
pi up = lca(u, v);
int lcanode = goup(u, up.first);
ll ans = 0, len = up.first + up.second + 1;
if (up.first == 0) {
ans = (len + 1) * query1(v, up.second + 1) - query2(v, up.second + 1, 0);
} else if (up.second == 0) {
ans = query2(u, up.first + 1, 0);
} else {
ans = query2(u, up.first + 1, 0);
ll rem = (len + 1) * query1(v, up.second) - query2(v, up.second, 0);
//debug(ans, rem, query1(v, up.second), query2(v, up.second, 1), (len+1));
ans += rem;
}
cout<<ans<<endl;
assert(1 <= ans and ans <= 1e18);
prevans = ans;
}
return 0;
}
Tester's Code
#include <bits/stdc++.h>
#define ll long long int
using namespace std;
vector<vector<int>> g;
int a[100010];
ll sum[100010] , sSum[100010];
int height[100010];
struct LCA {
vector<int> height, euler, first, segtree;
vector<bool> visited;
LCA(vector<vector<int>> &adj, int n ,int root = 1) {
n++;
height.resize(n);
first.resize(n);
euler.reserve(n * 2);
visited.assign(n, false);
dfs(adj, root);
int m = euler.size();
segtree.resize(m * 4 + 10);
build(1, 1, m );
}
void dfs(vector<vector<int>> &adj, int node, int h = 0) {
visited[node] = true;
height[node] = h;
first[node] = euler.size();
euler.push_back(node);
for (auto to : adj[node]) {
if (!visited[to]) {
dfs(adj, to, h + 1);
euler.push_back(node);
}
}
}
int Height(int node)
{
return height[node];
}
void build(int node, int b, int e) {
if(b>e) return;
if (b == e) {
segtree[node] = euler[b-1];
return;
} else {
int mid = (b + e) / 2;
build(node << 1, b, mid);
build(node << 1 | 1, mid + 1, e);
int l = segtree[node << 1], r = segtree[node << 1 | 1];
segtree[node] = (height[l] < height[r]) ? l : r;
return;
}
}
int query(int node, int b, int e, int L, int R) {
if (b > R || e < L)
return -1;
if (b >= L && e <= R)
return segtree[node];
int mid = (b + e) >> 1;
int left = query(node << 1, b, mid, L, R);
int right = query(node << 1 | 1, mid + 1, e, L, R);
if (left == -1) return right;
if (right == -1) return left;
return height[left] < height[right] ? left : right;
}
int lca(int u, int v) {
int left = first[u], right = first[v];
if (left > right)
swap(left, right);
return query(1, 0, euler.size() - 1, left, right);
}
};
void dfs(int node, int par, int h){
height[node] = h;
sSum[node] = a[node] + sSum[par];
sum[node] = sSum[node] + sum[par];
for(auto &i: g[node]){
if(i!=par) {
dfs(i,node,h+1);
}
}
}
ll get(int u, int v , int d){
return sum[u] - sum[v] - d*sSum[v];
}
int main() {
ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
int n; cin >> n; g.assign(n+10 , vector<int>());
for(int i=1;i<=n;i++) cin >> a[i];
for(int i=1;i<=n-1;i++){
int u,v; cin >> u >> v;
g[u].push_back(v); g[v].push_back(u);
}
LCA data(g , n);
sSum[0] = 0 , sum[0] = 0 , a[0]=0;
dfs(1,0,0);
int q; cin >> q;
// cout << get(5,2,2) << "\n";
// for(int i=1;i<=n;i++) cout << sum[i] << " ";
// cout << "\n";
ll prev = 0;
while(q--){
ll u,v; cin >> u >> v;
u = prev^u , v = prev^v;
int l = data.lca(u,v);
// cout << l << "\n";
if(u==v){
prev = a[u];
cout << prev << "\n";
continue;
}
if(l==u || l==v){
if(l==u){
ll d = height[v] - height[l] + 1;
ll left = (d+1)*(sSum[v] - sSum[l]) - get(v,l,d-1);
prev = left + a[l];
cout << prev << "\n";
}
else{
ll d = height[u] - height[l] + 1 ;
ll left = get(u,l,d-1);
prev = left + d*a[l];
cout << prev << "\n";
}
}
else{
ll left = get(u,l , height[u] - height[l] ) ;
ll d = height[u] + height[v] - 2*height[l] + 1;
ll temp = sSum[v] - sSum[l];
ll right = (d+1)*temp - get(v,l , height[v] - height[l]);
d = height[u] - height[l] + 1;
prev = left + d*a[l] + right;
cout << prev << "\n";
}
}
}
Time Complexity: O(nlog_2n + qlog_2n).
I tried to explain it as properly as I could. Please let me know if there are any mistakes or if something is unclear. Thanks.