INTRPATH - EDITORIAL

PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Setter: Aman Kumar Singh

Tester: Radoslav Dimitrov

Editorialist: Teja Vardhan Reddy

DIFFICULTY:

Medium

PREREQUISITES:

Math, fast lca queries.

PROBLEM:

Given a tree containing n vertices. You have to answer following type of queries on the tree

Query: given u,v. Find number of unordered pairs (a,b) which have exactly one vertex in common with path (u,v) (Also lets call a path which has only one common vertex with (u,v) is called perfect path).

EXPLANATION

Let us assume tree is rooted at node 1.

Let us maintain for all the vertices their parents at powers of two. It will help us answer lca queries and also find k th parent of a vertex in log(n).

The key strategy we will establish is to count the number of perfect paths passing through each of the vertex on the path from u to v and take their sum because all of them must be disjoint. They are disjoint because if some path passes through two or more vertices on u to v then its not perfect.

Lets us first develop some tools before getting into solving it.

1. How many different unordered paths are present in subtree of vertex u ?

Ans: Fixing both endpoints fixes the path. Now, both the endpoints must be inside the subtree. It is sufficient that both endpoints are inside the subtree because lca of any two vertices inside subtree is inside the subtree. Hence, number of paths in subtree of vertex u = subtree[u]*(subtree[u]+1)/2.

2. How to count number of paths passing through vertex u and are inside the subtree of vertex u?

Ans: We know how to count total number of paths in subtree of u. Now, we will need to subtract the number of paths which do not pass through u from total number of paths in its subtree. Now any path which does not pass through u remains in subtree of one of its children. From above, we can count number of paths in subtree of a child c which will be (subtree[c])* (subtree[c]+1)/2. So, summation of this across all the children will give paths which do not pass through u. So, we can calculate the number of paths passing through u and in its subtree. Lets call this ans[u]. We can compute this for every vertex in a single dfs because we need to iterate on each of the children once per vertex.

3. How to count number of paths passing through vertex u (not necessarily inside the subtree of vertex u)?

Ans: We can convert this to previous case. Assume you rooted the tree at vertex u. Now this question is same as previous one. Lets call this ans1[u]. We can compute this for every vertex in single dfs because we additionally need the subtree\_size of parent of u when we assumed the tree was rooted at u which will be (n-subtree[u]).

4. How to count number of paths passing through vertex u and are inside the subtree of vertex u and not passing through one of its child c_1 ?

Ans: We will count number of paths passing through u using above idea. Now we will subtract number of paths passing through both c_1 and u. For this to happen, one of the vertex must come from subtree of c_1 and other from subtree of u outside subtree of c_1 which will be (subtree[c_1])*(subtree[u]-subtree[c_1]). Hence answer will be ans[u] - (subtree[c_1])*(subtree[u]-subtree[c_1]).

5. How to count number of paths passing through vertex u (not necessarily inside the subtree of vertex u) and not through one of its child c_1 ?

Ans: We do the similar strategy of assuming tree is rooted tree at u and thus solving above question on it. Hence, answer will be ans1[u] - (subtree[c_1])*(n-subtree[c_1]).

6. How to count number of paths passing through vertex u (not necessarily inside the subtree of vertex u) and not passing through two of its child c_1,c_2 ?

Ans: We can do inclusion exclusion to get the answer.
It will be equal to
count of paths passing through u
- count of path passing through c_1 and u
- count of paths passing through c_2 and u
+ count of paths passing through c_1 and c_2 and u.

Now we need to know how to count the last term , count of paths passing through c_1 and c_2 and u. Now we need to note that paths passing through c_1 and c_2 must pass through u because u is lca of c_1 and c_2. Now, number of paths will be (subtree[c_1])*(subtree[c_2]).

I will give the paths for each of the above 6 questions on this tree. Let us do this exercise for u = 3, c_1=4 c_2=5.

I will represent paths using their endpoints.

  1. (3,3),(4,4),(5,5),(6,6),(3,4),(3,5),(3,6),(4,5),(4,6),(5,6)

  2. (3,4),(3,5),(3,6),(4,5),(4,6),(5,6)

  3. (3,4),(3,5),(3,6),(4,5),(4,6),(5,6),(1,3),(1,4),(1,5),(1,6), (2,3),(2,4),(2,5),(2,6)

  4. (3,5),(3,6),(5,6)

  5. (3,5),(3,6),(5,6),(1,3),(1,5),(1,6),(2,3),(2,5),(2,6)

  6. (3,6),(1,3),(1,6),(2,3),(2,6)

Now, lets answer the queries.

Case 1: u = v.

Then only the paths passing through u satisfy the property. So, now we want to count number of paths passing through u which is same as ans1[u] (question 3 answers this). This takes O(1) time

Case 2: u is an ancestor of v. (note if v is ancestor of u , we can just swap u and v)

To check if its this case, we can find lca and see if one of u or v is the lca.

Now for v we want to count number of paths in subtree of v passing through v because any path from outside to v must come from its parent which makes it not perfect. So ans[v] is what we need here. (question 2 answers this). This takes O(1) time.

For u, we need paths that pass through u and not through its child which is on path to v. (we can find that child using k th parent query). And now we have question 5 here. This takes O(logn) because we need k th parent and solving question 5 takes O(1) time from there.

For rest of the vertices on the path, we need to solve question 4 for them i.e we need to find number of paths in their subtree not passing through a specific child c_1. Answering for each of the vertex on path takes O(logn). But there can be many vertices on the path. So, we want to speed it up. Lets see how the answer looks for a vertex x with its child on path being c(x) = ans[x] - (subtree[c(x)])*(subtree[x]-subtree[c(x)]). Now, we need to do this summation across all vertices between u and v on the path.

Now, if we maintain a value called preans[x] = sum of ans[y] over all vertices y on path from root till x.

For this, lets say path from u to v is like u,x_1,x_2,...,x_k,v

We will try to get summation of ans[x] over all vertices between u and v fastly (i.e x_1,x_2,...,x_k). So now answer will be preans[x_k] - preans[u].

Let’s represent p(x) as parent of x.

We will be left with computing sum of (subtree[c(x)])*(subtree[x]-subtree[c(x)]) over \{x_1,x_2,...x_k\}. We can rewrite this as sum of (subtree[x])*(subtree[p(x)]-subtree[x]) over \{x_2,x_3....v\}. Now we can again borrow the idea of maintaining sums from root till x and find this sum in O(1) if we precompute those prefix sums. Precomputing will take one dfs over the tree.

Case 3: Let lca of u and v be g.

Now, for u and v we use question 2. It takes O(1) time.

For vertices between g and u , we use the last part of case 2. (similarly for vertices between g and v). It takes O(1) time.

For g, we use question 6. It takes O(log(n)) time because we need to find those children.

TIME COMPLEXITY

Computing powers of two parents take O(nlog(n)) time

Initially, we precomputed arrays ans, ans1, preans and prefix sum of (subtree[x])*(subtree[p(x)]-subtree[x]) from root till x for all vertices in the tree. Each takes one dfs call. Hence, complexity is O(1).

Case 1 takes O(1) time.

Case 2 takes O(log(n)) time because we need to find child of u on path to v.

Case 3 takes O(log(n)) to find children on path from g and u and v respectively.

Hence, total time complexity is O((n+q)log(n)).

SOLUTIONS:

Setter's Solution
import java.io.OutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.io.InputStream;
 
/**
 * Built using CHelper plug-in
 * Actual solution is at the top
 *
 * @author Aman Kumar Singh
 */
public class Main {
    public static void main(String[] args) {
        InputStream inputStream = System.in;
        OutputStream outputStream = System.out;
        InputReader in = new InputReader(inputStream);
        PrintWriter out = new PrintWriter(outputStream);
        IntersectingPaths solver = new IntersectingPaths();
        int testCount = Integer.parseInt(in.next());
        for (int i = 1; i <= testCount; i++)
            solver.solve(i, in, out);
        out.close();
    }
 
    static class IntersectingPaths {
        int lgN = 20;
        PrintWriter out;
        InputReader in;
        int n;
        ArrayList<Integer>[] tree;
        long[] sz;
        long[] all_possible;
        long[] cum;
        int[][] anc;
        int[] tin;
        int[] tout;
        int[] dist;
        int timer = 0;
 
        void dfs1(int v, int p) {
            anc[0][v] = p;
            tin[v] = timer++;
            for (int i = 1; i < lgN; i++)
                anc[i][v] = anc[i - 1][anc[i - 1][v]];
            sz[v] = 1;
            for (int u : tree[v]) {
                if (u != p) {
                    dist[u] = dist[v] + 1;
                    dfs1(u, v);
                    sz[v] += sz[u];
                }
            }
 
            for (int u : tree[v]) {
                if (u != p)
                    all_possible[v] += sz[u] * (sz[v] - sz[u] - 1);
            }
            all_possible[v] /= 2;
            all_possible[v] += sz[v];
            for (int u : tree[v]) {
                if (u != p) {
                    long to_be_excluded = (sz[v] - sz[u] - 1) * sz[u] + sz[u];
                    cum[u] = all_possible[v] - to_be_excluded;
                }
            }
            tout[v] = timer++;
        }
 
        void dfs2(int v, int p) {
            for (int u : tree[v]) {
                if (u != p) {
                    cum[u] += cum[v];
                    dfs2(u, v);
                }
            }
        }
 
        boolean is_ancestor(int u, int v) {
            return tin[u] <= tin[v] && tout[u] >= tout[v];
        }
 
        int lca_of(int u, int v) {
            if (is_ancestor(u, v))
                return u;
            if (is_ancestor(v, u))
                return v;
            int i = 0;
            for (i = lgN - 1; i >= 0; i--) {
                if (!is_ancestor(anc[i][u], v))
                    u = anc[i][u];
            }
            return anc[0][u];
        }
 
        int k_th(int u, int k) {
            int j = 0;
            while (k > 0) {
                if ((k & 1) == 1)
                    u = anc[j][u];
                k = k >> 1;
                j++;
            }
            return u;
        }
 
        public void solve(int testNumber, InputReader in, PrintWriter out) {
            this.out = out;
            this.in = in;
            n = ni();
            int q = ni();
            tree = new ArrayList[n];
            tin = new int[n];
            tout = new int[n];
            dist = new int[n];
            int i = 0;
            for (i = 0; i < n; i++)
                tree[i] = new ArrayList<>();
            for (i = 0; i < n - 1; i++) {
                int u = ni() - 1;
                int v = ni() - 1;
                tree[u].add(v);
                tree[v].add(u);
            }
            cum = new long[n];
            sz = new long[n];
            all_possible = new long[n];
            anc = new int[lgN][n];
            timer = 0;
            dfs1(0, 0);
            dfs2(0, 0);
            while (q-- > 0) {
                int u = ni() - 1;
                int v = ni() - 1;
                if (u == v) {
                    long ans = all_possible[u];
                    long rem = (long) n - sz[u];
                    ans += rem * sz[u];
                    pn(ans);
                    continue;
                }
                int lca = lca_of(u, v);
                if (lca != u && lca != v) {
                    int dis1 = dist[v] - dist[lca];
                    int dis2 = dist[u] - dist[lca];
                    long ans = 0;
                    int child1_lca = k_th(v, dis1 - 1);
                    ans += cum[v] - cum[child1_lca];
                    int child2_lca = k_th(u, dis2 - 1);
                    ans += cum[u] - cum[child2_lca];
                    ans += all_possible[u];
                    ans += all_possible[v];
                    long rem = (long) n - sz[lca];
                    long tot_sz = sz[lca] - sz[child1_lca] - sz[child2_lca];
                    long to_include = all_possible[lca];
                    to_include -= (sz[lca] - sz[child1_lca] - 1) * sz[child1_lca];
                    to_include -= (sz[lca] - sz[child2_lca] - 1) * sz[child2_lca];
                    to_include += sz[child1_lca] * sz[child2_lca];
                    to_include += (tot_sz - 1) * rem;
                    to_include -= sz[child1_lca];
                    to_include -= sz[child2_lca];
                    to_include += rem;
                    ans += to_include;
                    pn(ans);
                } else {
                    if (lca == u) {
                        int dis1 = dist[v] - dist[lca];
                        long ans = 0;
                        int child1_lca = k_th(v, dis1 - 1);
                        ans += cum[v] - cum[child1_lca];
                        ans += all_possible[v];
                        long rem = (long) n - sz[lca];
                        long tot_sz = sz[lca] - sz[child1_lca];
                        long to_include = all_possible[lca];
                        to_include -= (sz[lca] - sz[child1_lca] - 1) * sz[child1_lca];
                        to_include += (tot_sz - 1) * rem;
                        to_include -= sz[child1_lca];
                        to_include += rem;
                        ans += to_include;
                        pn(ans);
                    } else {
                        int dis2 = dist[u] - dist[lca];
                        long ans = 0;
                        int child2_lca = k_th(u, dis2 - 1);
                        ans += cum[u] - cum[child2_lca];
                        ans += all_possible[u];
                        long rem = (long) n - sz[lca];
                        long tot_sz = sz[lca] - sz[child2_lca];
                        long to_include = all_possible[lca];
                        to_include -= (sz[lca] - sz[child2_lca] - 1) * sz[child2_lca];
                        to_include += (tot_sz - 1) * rem;
                        to_include -= sz[child2_lca];
                        to_include += rem;
                        ans += to_include;
                        pn(ans);
                    }
                }
            }
        }
 
        int ni() {
            return in.nextInt();
        }
 
        void pn(Object o) {
            out.println(o);
        }
 
    }
 
    static class InputReader {
        private InputStream stream;
        private byte[] buf = new byte[1024];
        private int curChar;
        private int numChars;
 
        public InputReader(InputStream stream) {
            this.stream = stream;
        }
 
        public int read() {
            if (numChars == -1)
                throw new UnknownError();
            if (curChar >= numChars) {
                curChar = 0;
                try {
                    numChars = stream.read(buf);
                } catch (IOException e) {
                    throw new UnknownError();
                }
                if (numChars <= 0)
                    return -1;
            }
            return buf[curChar++];
        }
 
        public int nextInt() {
            return Integer.parseInt(next());
        }
 
        public String next() {
            int c = read();
            while (isSpaceChar(c))
                c = read();
            StringBuffer res = new StringBuffer();
            do {
                res.appendCodePoint(c);
                c = read();
            } while (!isSpaceChar(c));
 
            return res.toString();
        }
 
        private boolean isSpaceChar(int c) {
            return c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == -1;
        }
 
    }
}
Tester's Solution
#include <bits/stdc++.h>
#define endl '\n'
 
//#pragma GCC optimize ("O3")
//#pragma GCC target ("sse4")
 
#define SZ(x) ((int)x.size())
#define ALL(V) V.begin(), V.end()
#define L_B lower_bound
#define U_B upper_bound
#define pb push_back
 
using namespace std;
template<class T, class T2> inline int chkmax(T &x, const T2 &y) { return x < y ? x = y, 1 : 0; }
template<class T, class T2> inline int chkmin(T &x, const T2 &y) { return x > y ? x = y, 1 : 0; }
const int MAXN = (1 << 19);
 
// We will solve the problem with HLD and partial sums. The complexity will be O(N log N). The idea is the same as the one for the O(N^2) solution,
// except for the way we will compute the contribution of every chain. If we do partial sums on the chains, it can be easily seen that a sub-chain's contribution
// can be computed in O(1). For more information check the function "solve_fast(l, r)" which gives the answer for the vertices with dfs order in the range [l; r].
 
int read_int();
 
int n, q;
vector<int> adj[MAXN];
 
void read()
{
    cin >> n >> q;
    for(int i = 1; i <= n; i++) adj[i].clear();
 
    for(int i = 0; i < n - 1; i++)
    {
        int u, v;
        cin >> u >> v;
        adj[u].pb(v);
        adj[v].pb(u);
    }
}
 
int head[MAXN], par[MAXN], tr_sz[MAXN];
int st[MAXN], en[MAXN], dfs_time;
 
long long through_ver[MAXN];
 
void pre_hld(int u, int pr = -1)
{
    par[u] = pr;
    tr_sz[u] = 1;
    for(int v: adj[u])
        if(v != pr)
        {
            pre_hld(v, u);
            tr_sz[u] += tr_sz[v];
        }
}
 
int ver[MAXN];
int memo_sz[MAXN];
 
long long psum[MAXN];
 
void hld(int u, int chead, int pr = -1)
{
    pair<int, int> mx = {-1, -1};
    for(int v: adj[u])
        if(v != pr) 
            chkmax(mx, make_pair(tr_sz[v], v));
 
    st[u] = ++dfs_time;
    
    ver[st[u]] = u;
    memo_sz[st[u]] = tr_sz[u];
    head[u] = chead;
 
    int sum = 1;
    through_ver[st[u]] = 1;
    
    if(mx.second != -1)
    {
        int v = mx.second;
        hld(v, chead, u);
        through_ver[st[u]] += sum * 1ll * tr_sz[v];
        sum += tr_sz[v];
    }
 
    for(int v: adj[u])
        if(v != pr && v != mx.second)
        {
            hld(v, v, u);
            through_ver[st[u]] += sum * 1ll * tr_sz[v];
            sum += tr_sz[v];
        }
 
    int down = 0;
    if(st[u] != n && head[u] == head[ver[st[u] + 1]])
        down = memo_sz[st[u] + 1];
 
    psum[st[u]] = through_ver[st[u]] - (down * 1ll * (memo_sz[st[u]] - down));  
    en[u] = dfs_time;
}
 
void compute_down(int u, int pr = -1)
{
    if(pr != -1 && head[u] == head[par[u]])
        psum[st[u]] += psum[st[u] - 1];
 
    for(int v: adj[u])
        if(v != pr)
            compute_down(v, u);
}
 
// Contribution of [l; r] subsegment. The lowest vertex is ver[l].
inline void solve_fast(int l, int r, int &prv, long long &answer)
{
    answer += through_ver[r] - (prv * 1ll * (memo_sz[r] - prv));    
    if(l <= r - 1) 
    {
        if(ver[l] == head[ver[l]]) answer += psum[r - 1];
        else answer += psum[r - 1] - psum[l - 1];
    }
 
    prv = memo_sz[l];
}
 
int solve_up(int u, int x, long long &answer)
{
    int prv = 0;
    while(st[x] < st[u])
    {
        int l = max(st[x] + 1, st[head[u]]), r = st[u];
 
        solve_fast(l, r, prv, answer);
        
        if(l == st[x] + 1) return ver[st[x] + 1];
        if(par[head[u]] == x) return head[u];
        u = par[head[u]];
    }
 
    return MAXN - 1;
}
 
int lca(int u, int v)
{
    while(true)
    {
        if(st[u] > st[v]) swap(u, v);
        if(head[u] == head[v]) return u;
        v = par[head[v]];
    }
}
 
long long solve(int u, int v)
{
    int x = lca(u, v);
    long long answer = 0;
    int up1 = solve_up(u, x, answer);
    int up2 = solve_up(v, x, answer);
 
    answer += (n - tr_sz[x]) * 1ll * (tr_sz[x] - tr_sz[up1] - tr_sz[up2]); 
    answer += through_ver[st[x]];
    answer -= (tr_sz[up1] * 1ll * (memo_sz[st[x]] - tr_sz[up1] - tr_sz[up2]));  
    answer -= (tr_sz[up2] * 1ll * (memo_sz[st[x]] - tr_sz[up2] - tr_sz[up1]));  
    answer -= tr_sz[up1] * 1ll * tr_sz[up2];
    return answer;
}
 
void solve()
{
    dfs_time = 0;
    pre_hld(1);
    hld(1, 1);
    compute_down(1, 1);
 
    while(q--)
    {
        int u, v;
        cin >> u >> v;
        cout << solve(u, v) << endl;
    }
}
 
int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
 
    int T;
    cin >> T;
    while(T--)
    {
        read();
        solve();
    }
 
    return 0;
}
Editorialist's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp> 
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val
 
using namespace std;
using namespace __gnu_pbds;
 
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout) 
#define primeDEN 727999983
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
 
// find_by_order()  // order_of_key
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;
#define int ll
int par[312345][20];
int subtree[312345],ans1[312345],ans2[312345],ans3[312345];
int preans2[312345],preans3[312345];
int dep[312345];
int n;
vector<vi> adj(312345);
int getlca(int u,int v){
    int i;
    if(dep[u]>dep[v])
        swap(u,v);
    fd(i,19,0){
        if(dep[v]-(1<<i)>=dep[u]){
            v=par[v][i];
        }
    }
    if(u==v)
        return u;
    fd(i,19,0){
        if(par[u][i]!=par[v][i]){
            u=par[u][i];
            v=par[v][i];
        }
    }
    return par[u][0];
}
int getpar(int u,int deep){
    int i;
    fd(i,19,0){
        if(dep[u]-(1<<i)>=deep){
            u=par[u][i];
        }
    }
    return u;
}
int solve(int u,int v){
    int foo,wow=0;
    foo=getpar(v,dep[u]+1);
    wow=preans2[v]-preans2[u];
    wow-=preans3[v]-preans3[foo];
    //return wow;
    wow += ans1[u];
    wow-= (subtree[foo])*(n-subtree[foo]);
    
    return wow;
}
 
int dfs(int cur,int paren){
    int i;
    par[cur][0]=paren;
    ans3[cur]=0;
    if(paren==-1){
        dep[cur]=0;
    }
    else{
        dep[cur]=dep[paren]+1;
    }
    subtree[cur]=1;
    rep(i,adj[cur].size()){
        if(adj[cur][i]!=paren){
            dfs(adj[cur][i],cur);
            subtree[cur]+=subtree[adj[cur][i]];
        }
    }
    ans2[cur] = subtree[cur]*(subtree[cur]+1);
    rep(i,adj[cur].size()){
        if(adj[cur][i]!=paren){
            ans2[cur]-=(subtree[adj[cur][i]])*(subtree[adj[cur][i]]+1);
        }
    }
    ans2[cur]/=2;
    ans1[cur]=ans2[cur]+subtree[cur]*(n-subtree[cur]);
    return 0;
}
 
int dfs1(int cur,int paren){
    int i;
    if(paren==-1){
        preans2[cur]=ans2[cur];
        preans3[cur]=ans3[cur];
    }
    else{
        preans2[cur]=preans2[paren]+ans2[cur];
        preans3[cur]=preans3[paren]+ans3[cur];
    }
 
    rep(i,adj[cur].size()){
        if(adj[cur][i]!=paren){
            ans3[adj[cur][i]]=(subtree[cur]-subtree[adj[cur][i]])*(subtree[adj[cur][i]]);
            dfs1(adj[cur][i],cur);
        }
    }
    return 0;
}
main(){
    //std::ios::sync_with_stdio(false); cin.tie(NULL);
    int t;
    cin>>t;
    while(t--){
        int q;
        //cin>>n>>q;
        scanf("%lld",&n);
        scanf("%lld",&q);
        
        int i;
        int u,v;
        rep(i,n+10){
            adj[i].clear();
        }
        rep(i,n-1){
            //cin>>u>>v;
            scanf("%lld",&u);
            scanf("%lld",&v);
            u--;
            v--;
            adj[u].pb(v);
            adj[v].pb(u);
        }
        dfs(0,-1);
        int j;
        f(j,1,20){
            rep(i,n){
                if(par[i][j-1]==-1)
                    par[i][j]=-1;
                else
                    par[i][j]=par[par[i][j-1]][j-1];
            }
        }
        dfs1(0,-1);
        int gg;
        rep(i,q){
            //cin>>u>>v;
            scanf("%lld",&u);
            scanf("%lld",&v);
            u--;
            v--;
            if(dep[u]>dep[v]){
                swap(u,v);
            }
            gg=getlca(u,v);
            int how;
            if(u==v){
                how = ans1[u];
            }
            else if(gg==u){
                //cout<<"Dsa"<<endl;
                how = solve(u,v);
            }
            else{
                int foo1 = getpar(v,dep[gg]+1);
                int foo2 = getpar(u,dep[gg]+1);
                how = solve(gg,v) + solve(gg,u) + subtree[foo1]*subtree[foo2]-ans1[gg];
            }
 
            //cout<<how<<endl;
            printf("%lld\n",how);
        }
 
    }
    return 0;   
}

Feel free to Share your approach, If it differs. Suggestions are always welcomed. :slight_smile:

12 Likes

Thanks for fast editorial
some more LCA problem

208E - Blood Cousins
191C - Fools and Roads
519E - A and B and Lecture Rooms
587C - Duff in the Army
609E - Minimum spanning tree for each edge
178B3 - Greedy Merchants
176E - Archaeology

27 Likes

Can you explain what do you mean by this ?

Interesting observation:

Kind of random, but i have a query.
If i use only cin.tie(NULL) instead of both cin.tie(NULL) and cout.tie(NULL), will that make any difference ? I mean considering that the above operation unties the streams, wouldn’t using one untie both ? (input and output)

@carnage17 I think it means the total paths passing through vertex u include its parent one also…

where are the other editorials
@teja349 didn’t u prepare editorials for all probs even b4 the contest started
where are they ?

4 Likes

Can you add some visualization to make the explanation easier ? You really need to reword most of your sentences, its very vague and not sure what you are referring to.

Remember you are trying to explain it to someone that dosent understand how to solve. Not someone who already knows how to solve it

17 Likes

lets consider node u have a,b,c,d adjecent subtree
so,
ans = a b + a c + a d + b c + b* d + c*d + a+b+c+d +1
this equation reduce to,
ans = ((a + b + c + d )^2 - (a^2 + b^2 + c^2 + d^2) )/2 + (a + b+c +d +1)
this can be computed in O(1).

for that ,dfs travelles in tree (consider 1 is root) ,and
for every vertex u store (a+b+c+d) and (a^2 + b^2 + c^2 + d^2) .

f(1,u) = ans of node u based on root 1.
now for query u,v
L= LCA(u,v)
ans = f(1,u) + f(1,v) -f(1,L)
lets, consider cu is child of L (lca) of path L->u and cv is child of L of path L->v
ans = ans + subSize[cu]* subeSize[cv]
here ,subSize[cu] is size of subtree of cu , which is precomputed.

for find cu and cv,

lets index of L in euler tour is X,
cu = query(u,X-1)
cv = query(x+1,v)

This is my solution:
https://www.codechef.com/viewsolution/24801321

1 Like

I was thinking, it would use dp to store some stuff :joy::joy::joy::joy:

my solution didn’t pass with even printf/scanf…then i used fastscan and it passed in 3.94 sec…too close…

1 Like

scanf/printf is a common optimization in many problems when the input and output files are huge. The best optimization was to use "\n" instead of endl as specifically the output file were as huge as 15 MB.

2 Likes

Is it faster than ios_base::sync_with_stdio(0); then cin.tie(0);?

Yes obviously !!!

Hey I used the same fast lca approach with the same formula, But got WA except just 1 testcase. Can I get some help regarding it.

My Submission : https://www.codechef.com/viewsolution/24800370

I messaged him on CF if he could add a couple of images, he didn’t reply anything.

Hi
Currently I am working on the last problem editorial. I will add pictures in this editorial for visualisation. I will add the pictures by weekend. Sorry for the delay

5 Likes

I have submitted all except one to codechef team. I will check with them.

yep same happened to me , i think fast IO was a big part of the problem.

this problem tested contestants to extremes :sweat_smile:

1 Like