WSQUER - Editorial

Problem Link:
Contest
Practice link

Setter: Aneesh D H
Tester: Suryansh Kumar (@hellb0y_suru)

Difficulty: Medium

Pre-requisite: Sparse Tables, LCA

Problem:
You are given a tree with N nodes where each node i has a value A_i and Q queries each of which has a pair of nodes u and v. For each query, find the sum of (1 + d_{ui}) \times A_i for each node i in P_{uv}. Here, P_{uv} is the path from u to v and d_{ui} is the distance from node u to node i.

Explanation:
Construct two sparse tables:

  • S1: For a node which represents the node range [p, q], the sparse table contains the value A_p + ... + A_q.
  • S2: For a node which represents the node range [p, q], the sparse table contains the value 1 \times A_p + ... + (1 + d_{pq}) \times A_q.

Consider a query for nodes u and v. Let the least common ancestor be m = lca(u, v). There are 3 possible cases:

  • m = v: The answer to the query = S2[u, v].
  • m = u: The answer to the query = (2+d_{uv}) \times S1[v, u] - S2[v, u].
  • m \neq u and m \neq v: Find m' | m' \in N(m) \land m' \in P_{vm}.
    The answer to the query = S1[u, m] + (2+d_{uv}) \times S1[v, m'] - S2[v, m'].
    The image below shows m'.
    Untitled Diagram (10)

Solutions:

Setter's Code
#pragma GCC optimize("O3")
#include <bits/stdc++.h>
#define gcj "Case #"
#define adj_list vector<vi>
#define endl "\n"
#define INF_INT 2e9
#define INF_LL 2e18
#define matmax 25
#define mod 1000000007
#define mp make_pair
#define pb push_back
#define pi pair<int, int>
#define pii pair<int, pair<int, int> >
#define pl pair<ll, ll>
#define pll pair<ll, pair<ll, ll> >
#define vi vector<int>
#define vl vector<ll>
#define fastio ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
using namespace std;
typedef long long int ll;

void __print(int x) {cerr << x;}
void __print(long x) {cerr << x;}
void __print(long long x) {cerr << x;}
void __print(unsigned x) {cerr << x;}
void __print(unsigned long x) {cerr << x;}
void __print(unsigned long long x) {cerr << x;}
void __print(float x) {cerr << x;}
void __print(double x) {cerr << x;}
void __print(long double x) {cerr << x;}
void __print(char x) {cerr << '\'' << x << '\'';}
void __print(const char *x) {cerr << '\"' << x << '\"';}
void __print(const string &x) {cerr << '\"' << x << '\"';}
void __print(bool x) {cerr << (x ? "true" : "false");}

template<typename T, typename V>
void __print(const pair<T, V> &x) {cerr << '{'; __print(x.first); cerr << ','; __print(x.second); cerr << '}';}
template<typename T>
void __print(const T &x) {int f = 0; cerr << '{'; for (auto &i: x) cerr << (f++ ? "," : ""), __print(i); cerr << "}";}
void _print() {cerr << "]\n";}
template <typename T, typename... V>
void _print(T t, V... v) {__print(t); if (sizeof...(v)) cerr << ", "; _print(v...);}
#define debug(x...) cerr << "[" << #x << "] = ["; _print(x)

ll pow1(ll a, ll b, ll m = mod)
{
	if (b == 0)
		return 1ll;
	else if (b == 1)
		return a;
	else
	{
		ll x = pow1(a, b/2, m);
		x *= x;
		x %= m;
		if (b%2)
		{
			x *= a;
			x %= m;
		}
		return x;
	}
}

template <class avgType>
avgType avg(avgType a, avgType b) {
	return (a + b)/2;
}

mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
mt19937_64 rng64(chrono::steady_clock::now().time_since_epoch().count());

int randInt() {
	return rng() % INT_MAX;
}

ll randLL() {
	return rng64() % LLONG_MAX;
}

vector< int > hashmods = {1000000007, 1000000009, 1000000021, 1000000033, 1000000087, 1000000093};

ll n, q, u, v, vals[110000];
ll stree[2][20][110000], par[20][110000], lev[110000];
vector< vi > tree(110000);

void dfs(int index, int p, int l) {
	par[0][index] = p;
	lev[index] = l;
	for (int neigh : tree[index]) {
		if (neigh == p)
			continue;
		dfs(neigh, index, l+1);
	}
}

void build() {
	for (int i = 0; i <= 19; i++) {
		for (int j = 1; j <= n; j++) {
			if (i == 0) {
				stree[0][i][j] = vals[j];
				stree[1][i][j] = vals[j];
			}
			else if (par[i-1][j] == 0) {
				continue;
			} else {
				par[i][j] = par[i-1][par[i-1][j]];
				stree[0][i][j] = stree[0][i-1][j] + stree[0][i-1][par[i-1][j]];
				stree[1][i][j] = stree[1][i-1][j] + stree[1][i-1][par[i-1][j]] + (1ll<<(i-1)) * stree[0][i-1][par[i-1][j]];
			}
		}
	}
}

ll query1(int index, int num) {
	if (num <= 0) {
		return 0;
	}

	ll ans = 0;
	int j;
	for (j = 0; (1<<j) <= num; j++);
	j--;

	ans = stree[0][j][index] + query1(par[j][index], num - (1<<j));

	return ans;
}

ll query2(int index, int num, ll beg) {
	if (num <= 0) {
		return 0;
	}

	ll ans = 0;
	int j;
	for (j = 0; (1<<j) <= num; j++);
	j--;

	ans = stree[1][j][index] + beg * (stree[0][j][index]) + query2(par[j][index], num - (1<<j), beg + (1<<j));
	return ans;
}

int goup(int index, int l) {
	if (l <= 0)
		return index;

	int curr = index, j;
	for (j = 0; (1<<j) <= l; j++);
	j--;
	
	curr = par[j][index];
	return goup(curr, l - (1<<j));
}

pi lca(int a, int b) {
	pi up = mp(0, 0);
	if (lev[a] < lev[b]) {
		up = mp(0, lev[b] - lev[a]);
		b = goup(b, lev[b] - lev[a]);
	} else {
		up = mp(lev[a] - lev[b], 0);
		a = goup(a, lev[a] - lev[b]);
	}

	if (a == b) {
		return up;
	}

	int curr = 0;
	for (int i = 19; i >= 0; i--) {
		if (par[i][a] == par[i][b]) {
			continue;
		}
		curr += (1<<i);
		a = par[i][a];
		b = par[i][b];
	}

	return mp(up.first + curr + 1, up.second + curr + 1);
}

int main() {
	fastio;
	cin>>n;
	assert(n <= 100000);

	for (int i = 1; i <= n; i++) {
		cin>>vals[i];
		assert(1 <= vals[i] and vals[i] <= 50000000);
	}

	for (int i = 1; i < n; i++) {
		cin>>u>>v;
		assert(u <= n);
		assert(v <= n);
		tree[u].pb(v);
		tree[v].pb(u);
	}

	dfs(1, -1, 1);
	build();

	cin>>q;
	assert(q <= 200000);
	ll prevans = 0;
	for (int i = 1; i <= q; i++) {
		cin>>u>>v;
		assert(0 <= u and u <= 1e18);
		assert(0 <= v and v <= 1e18);
		u = prevans ^ u;
		v = prevans ^ v;
		assert(1 <= u and u <= n);
		assert(1 <= v and v <= n);
		pi up = lca(u, v);
		int lcanode = goup(u, up.first);

		ll ans = 0, len = up.first + up.second + 1;
		
		if (up.first == 0) {
			ans = (len + 1) * query1(v, up.second + 1) - query2(v, up.second + 1, 0);
		} else if (up.second == 0) {
			ans = query2(u, up.first + 1, 0);
		} else {
			ans = query2(u, up.first + 1, 0);
			ll rem = (len + 1) * query1(v, up.second) - query2(v, up.second, 0);
			//debug(ans, rem, query1(v, up.second), query2(v, up.second, 1), (len+1));
			ans += rem;
		}
		cout<<ans<<endl;
		assert(1 <= ans and ans <= 1e18);
		prevans = ans;
	}

	return 0;
}
Tester's Code
#include <bits/stdc++.h>
#define ll long long int
using namespace std;
 
vector<vector<int>> g;
int a[100010];
ll sum[100010] , sSum[100010];
int height[100010];
struct LCA {
    vector<int> height, euler, first, segtree;
    vector<bool> visited;
    LCA(vector<vector<int>> &adj, int n ,int root = 1) {
        n++;
        height.resize(n);
        first.resize(n);
        euler.reserve(n * 2);
        visited.assign(n, false);
        dfs(adj, root);
        int m = euler.size();
        segtree.resize(m * 4 + 10);
        build(1, 1, m );
    }
    void dfs(vector<vector<int>> &adj, int node, int h = 0) {
        visited[node] = true;
        height[node] = h;
        first[node] = euler.size();
        euler.push_back(node);
        for (auto to : adj[node]) {
            if (!visited[to]) {
                dfs(adj, to, h + 1);
                euler.push_back(node);
            }
        }
    }
    int Height(int node)
    {
        return height[node];
    }
    void build(int node, int b, int e) {
        if(b>e) return;
        if (b == e) {
            segtree[node] = euler[b-1];
            return;
        } else {
            int mid = (b + e) / 2;
            build(node << 1, b, mid);
            build(node << 1 | 1, mid + 1, e);
            int l = segtree[node << 1], r = segtree[node << 1 | 1];
            segtree[node] = (height[l] < height[r]) ? l : r;
            return;
        }
    }
 
    int query(int node, int b, int e, int L, int R) {
        if (b > R || e < L)
            return -1;
        if (b >= L && e <= R)
            return segtree[node];
        int mid = (b + e) >> 1;
        int left = query(node << 1, b, mid, L, R);
        int right = query(node << 1 | 1, mid + 1, e, L, R);
        if (left == -1) return right;
        if (right == -1) return left;
        return height[left] < height[right] ? left : right;
    }
    int lca(int u, int v) {
        int left = first[u], right = first[v];
        if (left > right)
            swap(left, right);
        return query(1, 0, euler.size() - 1, left, right);
    }
};
 
void dfs(int node, int par, int h){
    height[node] = h;
    sSum[node] = a[node] + sSum[par];
    sum[node] = sSum[node] + sum[par];
    for(auto &i: g[node]){
        if(i!=par) {
            dfs(i,node,h+1);
        }
    }
}
 
ll get(int u, int v , int d){
    return sum[u] - sum[v] - d*sSum[v];
}
 
int main() {
    ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    int n; cin >> n; g.assign(n+10 , vector<int>());
    for(int i=1;i<=n;i++) cin >> a[i];
    for(int i=1;i<=n-1;i++){
        int u,v; cin >> u >> v;
        g[u].push_back(v); g[v].push_back(u);
    }
    LCA data(g , n);
    sSum[0] = 0 , sum[0] = 0 , a[0]=0;
    dfs(1,0,0);
    int q; cin >> q;
  //  cout << get(5,2,2) << "\n";
 //  for(int i=1;i<=n;i++) cout << sum[i] << " ";
 //  cout << "\n";
    ll prev = 0;
    while(q--){
        ll u,v; cin >> u >> v;
        u = prev^u , v = prev^v;
        int l = data.lca(u,v);
       // cout << l << "\n";
        if(u==v){
            prev = a[u];
            cout << prev << "\n";
            continue;
        }
        if(l==u || l==v){
            if(l==u){
                ll d = height[v] - height[l] + 1;
                ll left = (d+1)*(sSum[v] - sSum[l])  -   get(v,l,d-1);
                prev = left + a[l];
                cout << prev << "\n";
            }
            else{
                ll d = height[u] - height[l] + 1 ;
                ll left = get(u,l,d-1);
                prev = left + d*a[l];
                cout << prev << "\n";
            }
        }
        else{
            ll left = get(u,l , height[u] - height[l] ) ;
            ll d = height[u] + height[v] - 2*height[l]   + 1;
            ll temp = sSum[v] - sSum[l];
            ll right = (d+1)*temp - get(v,l , height[v] - height[l]);
            d = height[u] - height[l] + 1;
            prev  = left + d*a[l] + right;
            cout << prev << "\n";
        }
    }
}

Time Complexity: O(nlog_2n + qlog_2n).


I tried to explain it as properly as I could. Please let me know if there are any mistakes or if something is unclear. Thanks.

4 Likes