PROBLEM LINK:
Setter: Aman Kumar Singh
Tester: Radoslav Dimitrov
Editorialist: Teja Vardhan Reddy
DIFFICULTY:
Medium
PREREQUISITES:
Math, fast lca queries.
PROBLEM:
Given a tree containing n vertices. You have to answer following type of queries on the tree
Query: given u,v. Find number of unordered pairs (a,b) which have exactly one vertex in common with path (u,v) (Also lets call a path which has only one common vertex with (u,v) is called perfect path).
EXPLANATION
Let us assume tree is rooted at node 1.
Let us maintain for all the vertices their parents at powers of two. It will help us answer lca queries and also find k th parent of a vertex in log(n).
The key strategy we will establish is to count the number of perfect paths passing through each of the vertex on the path from u to v and take their sum because all of them must be disjoint. They are disjoint because if some path passes through two or more vertices on u to v then its not perfect.
Lets us first develop some tools before getting into solving it.
1. How many different unordered paths are present in subtree of vertex u ?
Ans: Fixing both endpoints fixes the path. Now, both the endpoints must be inside the subtree. It is sufficient that both endpoints are inside the subtree because lca of any two vertices inside subtree is inside the subtree. Hence, number of paths in subtree of vertex u = subtree[u]*(subtree[u]+1)/2.
2. How to count number of paths passing through vertex u and are inside the subtree of vertex u?
Ans: We know how to count total number of paths in subtree of u. Now, we will need to subtract the number of paths which do not pass through u from total number of paths in its subtree. Now any path which does not pass through u remains in subtree of one of its children. From above, we can count number of paths in subtree of a child c which will be (subtree[c])* (subtree[c]+1)/2. So, summation of this across all the children will give paths which do not pass through u. So, we can calculate the number of paths passing through u and in its subtree. Lets call this ans[u]. We can compute this for every vertex in a single dfs because we need to iterate on each of the children once per vertex.
3. How to count number of paths passing through vertex u (not necessarily inside the subtree of vertex u)?
Ans: We can convert this to previous case. Assume you rooted the tree at vertex u. Now this question is same as previous one. Lets call this ans1[u]. We can compute this for every vertex in single dfs because we additionally need the subtree\_size of parent of u when we assumed the tree was rooted at u which will be (n-subtree[u]).
4. How to count number of paths passing through vertex u and are inside the subtree of vertex u and not passing through one of its child c_1 ?
Ans: We will count number of paths passing through u using above idea. Now we will subtract number of paths passing through both c_1 and u. For this to happen, one of the vertex must come from subtree of c_1 and other from subtree of u outside subtree of c_1 which will be (subtree[c_1])*(subtree[u]-subtree[c_1]). Hence answer will be ans[u] - (subtree[c_1])*(subtree[u]-subtree[c_1]).
5. How to count number of paths passing through vertex u (not necessarily inside the subtree of vertex u) and not through one of its child c_1 ?
Ans: We do the similar strategy of assuming tree is rooted tree at u and thus solving above question on it. Hence, answer will be ans1[u] - (subtree[c_1])*(n-subtree[c_1]).
6. How to count number of paths passing through vertex u (not necessarily inside the subtree of vertex u) and not passing through two of its child c_1,c_2 ?
Ans: We can do inclusion exclusion to get the answer.
It will be equal to
count of paths passing through u
- count of path passing through c_1 and u
- count of paths passing through c_2 and u
+ count of paths passing through c_1 and c_2 and u.
Now we need to know how to count the last term , count of paths passing through c_1 and c_2 and u. Now we need to note that paths passing through c_1 and c_2 must pass through u because u is lca of c_1 and c_2. Now, number of paths will be (subtree[c_1])*(subtree[c_2]).
I will give the paths for each of the above 6 questions on this tree. Let us do this exercise for u = 3, c_1=4 c_2=5.
I will represent paths using their endpoints.
-
(3,3),(4,4),(5,5),(6,6),(3,4),(3,5),(3,6),(4,5),(4,6),(5,6)
-
(3,4),(3,5),(3,6),(4,5),(4,6),(5,6)
-
(3,4),(3,5),(3,6),(4,5),(4,6),(5,6),(1,3),(1,4),(1,5),(1,6), (2,3),(2,4),(2,5),(2,6)
-
(3,5),(3,6),(5,6)
-
(3,5),(3,6),(5,6),(1,3),(1,5),(1,6),(2,3),(2,5),(2,6)
-
(3,6),(1,3),(1,6),(2,3),(2,6)
Now, lets answer the queries.
Case 1: u = v.
Then only the paths passing through u satisfy the property. So, now we want to count number of paths passing through u which is same as ans1[u] (question 3 answers this). This takes O(1) time
Case 2: u is an ancestor of v. (note if v is ancestor of u , we can just swap u and v)
To check if its this case, we can find lca and see if one of u or v is the lca.
Now for v we want to count number of paths in subtree of v passing through v because any path from outside to v must come from its parent which makes it not perfect. So ans[v] is what we need here. (question 2 answers this). This takes O(1) time.
For u, we need paths that pass through u and not through its child which is on path to v. (we can find that child using k th parent query). And now we have question 5 here. This takes O(logn) because we need k th parent and solving question 5 takes O(1) time from there.
For rest of the vertices on the path, we need to solve question 4 for them i.e we need to find number of paths in their subtree not passing through a specific child c_1. Answering for each of the vertex on path takes O(logn). But there can be many vertices on the path. So, we want to speed it up. Lets see how the answer looks for a vertex x with its child on path being c(x) = ans[x] - (subtree[c(x)])*(subtree[x]-subtree[c(x)]). Now, we need to do this summation across all vertices between u and v on the path.
Now, if we maintain a value called preans[x] = sum of ans[y] over all vertices y on path from root till x.
For this, lets say path from u to v is like u,x_1,x_2,...,x_k,v
We will try to get summation of ans[x] over all vertices between u and v fastly (i.e x_1,x_2,...,x_k). So now answer will be preans[x_k] - preans[u].
Let’s represent p(x) as parent of x.
We will be left with computing sum of (subtree[c(x)])*(subtree[x]-subtree[c(x)]) over \{x_1,x_2,...x_k\}. We can rewrite this as sum of (subtree[x])*(subtree[p(x)]-subtree[x]) over \{x_2,x_3....v\}. Now we can again borrow the idea of maintaining sums from root till x and find this sum in O(1) if we precompute those prefix sums. Precomputing will take one dfs over the tree.
Case 3: Let lca of u and v be g.
Now, for u and v we use question 2. It takes O(1) time.
For vertices between g and u , we use the last part of case 2. (similarly for vertices between g and v). It takes O(1) time.
For g, we use question 6. It takes O(log(n)) time because we need to find those children.
TIME COMPLEXITY
Computing powers of two parents take O(nlog(n)) time
Initially, we precomputed arrays ans, ans1, preans and prefix sum of (subtree[x])*(subtree[p(x)]-subtree[x]) from root till x for all vertices in the tree. Each takes one dfs call. Hence, complexity is O(1).
Case 1 takes O(1) time.
Case 2 takes O(log(n)) time because we need to find child of u on path to v.
Case 3 takes O(log(n)) to find children on path from g and u and v respectively.
Hence, total time complexity is O((n+q)log(n)).
SOLUTIONS:
Setter's Solution
import java.io.OutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.io.InputStream;
/**
* Built using CHelper plug-in
* Actual solution is at the top
*
* @author Aman Kumar Singh
*/
public class Main {
public static void main(String[] args) {
InputStream inputStream = System.in;
OutputStream outputStream = System.out;
InputReader in = new InputReader(inputStream);
PrintWriter out = new PrintWriter(outputStream);
IntersectingPaths solver = new IntersectingPaths();
int testCount = Integer.parseInt(in.next());
for (int i = 1; i <= testCount; i++)
solver.solve(i, in, out);
out.close();
}
static class IntersectingPaths {
int lgN = 20;
PrintWriter out;
InputReader in;
int n;
ArrayList<Integer>[] tree;
long[] sz;
long[] all_possible;
long[] cum;
int[][] anc;
int[] tin;
int[] tout;
int[] dist;
int timer = 0;
void dfs1(int v, int p) {
anc[0][v] = p;
tin[v] = timer++;
for (int i = 1; i < lgN; i++)
anc[i][v] = anc[i - 1][anc[i - 1][v]];
sz[v] = 1;
for (int u : tree[v]) {
if (u != p) {
dist[u] = dist[v] + 1;
dfs1(u, v);
sz[v] += sz[u];
}
}
for (int u : tree[v]) {
if (u != p)
all_possible[v] += sz[u] * (sz[v] - sz[u] - 1);
}
all_possible[v] /= 2;
all_possible[v] += sz[v];
for (int u : tree[v]) {
if (u != p) {
long to_be_excluded = (sz[v] - sz[u] - 1) * sz[u] + sz[u];
cum[u] = all_possible[v] - to_be_excluded;
}
}
tout[v] = timer++;
}
void dfs2(int v, int p) {
for (int u : tree[v]) {
if (u != p) {
cum[u] += cum[v];
dfs2(u, v);
}
}
}
boolean is_ancestor(int u, int v) {
return tin[u] <= tin[v] && tout[u] >= tout[v];
}
int lca_of(int u, int v) {
if (is_ancestor(u, v))
return u;
if (is_ancestor(v, u))
return v;
int i = 0;
for (i = lgN - 1; i >= 0; i--) {
if (!is_ancestor(anc[i][u], v))
u = anc[i][u];
}
return anc[0][u];
}
int k_th(int u, int k) {
int j = 0;
while (k > 0) {
if ((k & 1) == 1)
u = anc[j][u];
k = k >> 1;
j++;
}
return u;
}
public void solve(int testNumber, InputReader in, PrintWriter out) {
this.out = out;
this.in = in;
n = ni();
int q = ni();
tree = new ArrayList[n];
tin = new int[n];
tout = new int[n];
dist = new int[n];
int i = 0;
for (i = 0; i < n; i++)
tree[i] = new ArrayList<>();
for (i = 0; i < n - 1; i++) {
int u = ni() - 1;
int v = ni() - 1;
tree[u].add(v);
tree[v].add(u);
}
cum = new long[n];
sz = new long[n];
all_possible = new long[n];
anc = new int[lgN][n];
timer = 0;
dfs1(0, 0);
dfs2(0, 0);
while (q-- > 0) {
int u = ni() - 1;
int v = ni() - 1;
if (u == v) {
long ans = all_possible[u];
long rem = (long) n - sz[u];
ans += rem * sz[u];
pn(ans);
continue;
}
int lca = lca_of(u, v);
if (lca != u && lca != v) {
int dis1 = dist[v] - dist[lca];
int dis2 = dist[u] - dist[lca];
long ans = 0;
int child1_lca = k_th(v, dis1 - 1);
ans += cum[v] - cum[child1_lca];
int child2_lca = k_th(u, dis2 - 1);
ans += cum[u] - cum[child2_lca];
ans += all_possible[u];
ans += all_possible[v];
long rem = (long) n - sz[lca];
long tot_sz = sz[lca] - sz[child1_lca] - sz[child2_lca];
long to_include = all_possible[lca];
to_include -= (sz[lca] - sz[child1_lca] - 1) * sz[child1_lca];
to_include -= (sz[lca] - sz[child2_lca] - 1) * sz[child2_lca];
to_include += sz[child1_lca] * sz[child2_lca];
to_include += (tot_sz - 1) * rem;
to_include -= sz[child1_lca];
to_include -= sz[child2_lca];
to_include += rem;
ans += to_include;
pn(ans);
} else {
if (lca == u) {
int dis1 = dist[v] - dist[lca];
long ans = 0;
int child1_lca = k_th(v, dis1 - 1);
ans += cum[v] - cum[child1_lca];
ans += all_possible[v];
long rem = (long) n - sz[lca];
long tot_sz = sz[lca] - sz[child1_lca];
long to_include = all_possible[lca];
to_include -= (sz[lca] - sz[child1_lca] - 1) * sz[child1_lca];
to_include += (tot_sz - 1) * rem;
to_include -= sz[child1_lca];
to_include += rem;
ans += to_include;
pn(ans);
} else {
int dis2 = dist[u] - dist[lca];
long ans = 0;
int child2_lca = k_th(u, dis2 - 1);
ans += cum[u] - cum[child2_lca];
ans += all_possible[u];
long rem = (long) n - sz[lca];
long tot_sz = sz[lca] - sz[child2_lca];
long to_include = all_possible[lca];
to_include -= (sz[lca] - sz[child2_lca] - 1) * sz[child2_lca];
to_include += (tot_sz - 1) * rem;
to_include -= sz[child2_lca];
to_include += rem;
ans += to_include;
pn(ans);
}
}
}
}
int ni() {
return in.nextInt();
}
void pn(Object o) {
out.println(o);
}
}
static class InputReader {
private InputStream stream;
private byte[] buf = new byte[1024];
private int curChar;
private int numChars;
public InputReader(InputStream stream) {
this.stream = stream;
}
public int read() {
if (numChars == -1)
throw new UnknownError();
if (curChar >= numChars) {
curChar = 0;
try {
numChars = stream.read(buf);
} catch (IOException e) {
throw new UnknownError();
}
if (numChars <= 0)
return -1;
}
return buf[curChar++];
}
public int nextInt() {
return Integer.parseInt(next());
}
public String next() {
int c = read();
while (isSpaceChar(c))
c = read();
StringBuffer res = new StringBuffer();
do {
res.appendCodePoint(c);
c = read();
} while (!isSpaceChar(c));
return res.toString();
}
private boolean isSpaceChar(int c) {
return c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == -1;
}
}
}
Tester's Solution
#include <bits/stdc++.h>
#define endl '\n'
//#pragma GCC optimize ("O3")
//#pragma GCC target ("sse4")
#define SZ(x) ((int)x.size())
#define ALL(V) V.begin(), V.end()
#define L_B lower_bound
#define U_B upper_bound
#define pb push_back
using namespace std;
template<class T, class T2> inline int chkmax(T &x, const T2 &y) { return x < y ? x = y, 1 : 0; }
template<class T, class T2> inline int chkmin(T &x, const T2 &y) { return x > y ? x = y, 1 : 0; }
const int MAXN = (1 << 19);
// We will solve the problem with HLD and partial sums. The complexity will be O(N log N). The idea is the same as the one for the O(N^2) solution,
// except for the way we will compute the contribution of every chain. If we do partial sums on the chains, it can be easily seen that a sub-chain's contribution
// can be computed in O(1). For more information check the function "solve_fast(l, r)" which gives the answer for the vertices with dfs order in the range [l; r].
int read_int();
int n, q;
vector<int> adj[MAXN];
void read()
{
cin >> n >> q;
for(int i = 1; i <= n; i++) adj[i].clear();
for(int i = 0; i < n - 1; i++)
{
int u, v;
cin >> u >> v;
adj[u].pb(v);
adj[v].pb(u);
}
}
int head[MAXN], par[MAXN], tr_sz[MAXN];
int st[MAXN], en[MAXN], dfs_time;
long long through_ver[MAXN];
void pre_hld(int u, int pr = -1)
{
par[u] = pr;
tr_sz[u] = 1;
for(int v: adj[u])
if(v != pr)
{
pre_hld(v, u);
tr_sz[u] += tr_sz[v];
}
}
int ver[MAXN];
int memo_sz[MAXN];
long long psum[MAXN];
void hld(int u, int chead, int pr = -1)
{
pair<int, int> mx = {-1, -1};
for(int v: adj[u])
if(v != pr)
chkmax(mx, make_pair(tr_sz[v], v));
st[u] = ++dfs_time;
ver[st[u]] = u;
memo_sz[st[u]] = tr_sz[u];
head[u] = chead;
int sum = 1;
through_ver[st[u]] = 1;
if(mx.second != -1)
{
int v = mx.second;
hld(v, chead, u);
through_ver[st[u]] += sum * 1ll * tr_sz[v];
sum += tr_sz[v];
}
for(int v: adj[u])
if(v != pr && v != mx.second)
{
hld(v, v, u);
through_ver[st[u]] += sum * 1ll * tr_sz[v];
sum += tr_sz[v];
}
int down = 0;
if(st[u] != n && head[u] == head[ver[st[u] + 1]])
down = memo_sz[st[u] + 1];
psum[st[u]] = through_ver[st[u]] - (down * 1ll * (memo_sz[st[u]] - down));
en[u] = dfs_time;
}
void compute_down(int u, int pr = -1)
{
if(pr != -1 && head[u] == head[par[u]])
psum[st[u]] += psum[st[u] - 1];
for(int v: adj[u])
if(v != pr)
compute_down(v, u);
}
// Contribution of [l; r] subsegment. The lowest vertex is ver[l].
inline void solve_fast(int l, int r, int &prv, long long &answer)
{
answer += through_ver[r] - (prv * 1ll * (memo_sz[r] - prv));
if(l <= r - 1)
{
if(ver[l] == head[ver[l]]) answer += psum[r - 1];
else answer += psum[r - 1] - psum[l - 1];
}
prv = memo_sz[l];
}
int solve_up(int u, int x, long long &answer)
{
int prv = 0;
while(st[x] < st[u])
{
int l = max(st[x] + 1, st[head[u]]), r = st[u];
solve_fast(l, r, prv, answer);
if(l == st[x] + 1) return ver[st[x] + 1];
if(par[head[u]] == x) return head[u];
u = par[head[u]];
}
return MAXN - 1;
}
int lca(int u, int v)
{
while(true)
{
if(st[u] > st[v]) swap(u, v);
if(head[u] == head[v]) return u;
v = par[head[v]];
}
}
long long solve(int u, int v)
{
int x = lca(u, v);
long long answer = 0;
int up1 = solve_up(u, x, answer);
int up2 = solve_up(v, x, answer);
answer += (n - tr_sz[x]) * 1ll * (tr_sz[x] - tr_sz[up1] - tr_sz[up2]);
answer += through_ver[st[x]];
answer -= (tr_sz[up1] * 1ll * (memo_sz[st[x]] - tr_sz[up1] - tr_sz[up2]));
answer -= (tr_sz[up2] * 1ll * (memo_sz[st[x]] - tr_sz[up2] - tr_sz[up1]));
answer -= tr_sz[up1] * 1ll * tr_sz[up2];
return answer;
}
void solve()
{
dfs_time = 0;
pre_hld(1);
hld(1, 1);
compute_down(1, 1);
while(q--)
{
int u, v;
cin >> u >> v;
cout << solve(u, v) << endl;
}
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int T;
cin >> T;
while(T--)
{
read();
solve();
}
return 0;
}
Editorialist's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill - cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x) cout<<fixed<<val; // prints x digits after decimal in val
using namespace std;
using namespace __gnu_pbds;
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout)
#define primeDEN 727999983
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
// find_by_order() // order_of_key
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;
#define int ll
int par[312345][20];
int subtree[312345],ans1[312345],ans2[312345],ans3[312345];
int preans2[312345],preans3[312345];
int dep[312345];
int n;
vector<vi> adj(312345);
int getlca(int u,int v){
int i;
if(dep[u]>dep[v])
swap(u,v);
fd(i,19,0){
if(dep[v]-(1<<i)>=dep[u]){
v=par[v][i];
}
}
if(u==v)
return u;
fd(i,19,0){
if(par[u][i]!=par[v][i]){
u=par[u][i];
v=par[v][i];
}
}
return par[u][0];
}
int getpar(int u,int deep){
int i;
fd(i,19,0){
if(dep[u]-(1<<i)>=deep){
u=par[u][i];
}
}
return u;
}
int solve(int u,int v){
int foo,wow=0;
foo=getpar(v,dep[u]+1);
wow=preans2[v]-preans2[u];
wow-=preans3[v]-preans3[foo];
//return wow;
wow += ans1[u];
wow-= (subtree[foo])*(n-subtree[foo]);
return wow;
}
int dfs(int cur,int paren){
int i;
par[cur][0]=paren;
ans3[cur]=0;
if(paren==-1){
dep[cur]=0;
}
else{
dep[cur]=dep[paren]+1;
}
subtree[cur]=1;
rep(i,adj[cur].size()){
if(adj[cur][i]!=paren){
dfs(adj[cur][i],cur);
subtree[cur]+=subtree[adj[cur][i]];
}
}
ans2[cur] = subtree[cur]*(subtree[cur]+1);
rep(i,adj[cur].size()){
if(adj[cur][i]!=paren){
ans2[cur]-=(subtree[adj[cur][i]])*(subtree[adj[cur][i]]+1);
}
}
ans2[cur]/=2;
ans1[cur]=ans2[cur]+subtree[cur]*(n-subtree[cur]);
return 0;
}
int dfs1(int cur,int paren){
int i;
if(paren==-1){
preans2[cur]=ans2[cur];
preans3[cur]=ans3[cur];
}
else{
preans2[cur]=preans2[paren]+ans2[cur];
preans3[cur]=preans3[paren]+ans3[cur];
}
rep(i,adj[cur].size()){
if(adj[cur][i]!=paren){
ans3[adj[cur][i]]=(subtree[cur]-subtree[adj[cur][i]])*(subtree[adj[cur][i]]);
dfs1(adj[cur][i],cur);
}
}
return 0;
}
main(){
//std::ios::sync_with_stdio(false); cin.tie(NULL);
int t;
cin>>t;
while(t--){
int q;
//cin>>n>>q;
scanf("%lld",&n);
scanf("%lld",&q);
int i;
int u,v;
rep(i,n+10){
adj[i].clear();
}
rep(i,n-1){
//cin>>u>>v;
scanf("%lld",&u);
scanf("%lld",&v);
u--;
v--;
adj[u].pb(v);
adj[v].pb(u);
}
dfs(0,-1);
int j;
f(j,1,20){
rep(i,n){
if(par[i][j-1]==-1)
par[i][j]=-1;
else
par[i][j]=par[par[i][j-1]][j-1];
}
}
dfs1(0,-1);
int gg;
rep(i,q){
//cin>>u>>v;
scanf("%lld",&u);
scanf("%lld",&v);
u--;
v--;
if(dep[u]>dep[v]){
swap(u,v);
}
gg=getlca(u,v);
int how;
if(u==v){
how = ans1[u];
}
else if(gg==u){
//cout<<"Dsa"<<endl;
how = solve(u,v);
}
else{
int foo1 = getpar(v,dep[gg]+1);
int foo2 = getpar(u,dep[gg]+1);
how = solve(gg,v) + solve(gg,u) + subtree[foo1]*subtree[foo2]-ans1[gg];
}
//cout<<how<<endl;
printf("%lld\n",how);
}
}
return 0;
}
Feel free to Share your approach, If it differs. Suggestions are always welcomed.