SUDH - EDITORIAL

PROBLEM LINK:https://www.codechef.com/COOW2019/problems/SUDH
DIFFICULTY: Medium-Hard
PREREQUISITES: Dynamic-Programming, DFS,Basic Mathematics,LCA in Log(n) (https://www.geeksforgeeks.org/lca-for-general-or-n-ary-trees-sparse-matrix-dp-approach-onlogn-ologn/)
EXPLANATION:
We Just have to find the Lowest Common Ancestor of two nodes.
If Both Sudhanshu and Harsh can reach LCA by their own so we just need to check whether they have total stamina value >= (h[x])(h[x] + 1)/2 + (h[y])(h[y] + 1)/2 or not.
If they can’t reach LCA by their own or total stamina value is less then the calculated value then print -1.
SOLUTION:
(Python 3)

def par_at_hei(node, hei, par, h):
    cur_hei = h[node]
    if hei == 0:
        return 1
    for i in range(20, -1, -1):
        if h[par[node][i]] >= hei:
            node = par[node][i]
    return node


def find_lca(node, node2, h, par):
    if h[node] > h[node2]:
        node = par_at_hei(node, h[node2], par, h)
    elif h[node] < h[node2]:
        node2 = par_at_hei(node2, h[node], par, h)

    if node == node2:
        return node
    for i in range(20, -1, -1):
        if par[node][i] != par[node2][i]:
            node = par[node][i]
            node2 = par[node2][i]

    return par[node][0]

test = int(input())
for _ in range(test):
    n, q = map(int, input().split())
    a = [[] for i in range(n + 1)]
    for i in range(n - 1):
        x, y = map(int, input().split())
        a[x].append(y)
        ind += 1
        a[y].append(x)
    par = [[0 for i in range(21)] for j in range(n + 1)]
    st = [1]
    h = [0] * (n + 1)
    vis = [0] * (n + 1)
    vis[1] = 1
    while st:
        p = st.pop()
        for i in a[p]:
            if vis[i] == 0:
                vis[i] = 1
                par[i][0] = p
                h[i] = h[p] + 1
                st.append(i)
                for j in range(1, 21):

                    if par[par[i][j - 1]][j - 1] == 0:
                        break
                    par[i][j] = par[par[i][j - 1]][j - 1]

    for i in range(q):
        x, y, s, t = map(int, input().split())
        ind += 1
        ss, tt = s, t
        comm = find_lca(x, y, h, par)
        s -= ((h[x] - h[comm]) * (h[x] - h[comm] + 1)) // 2
        t -= ((h[y] - h[comm]) * (h[y] - h[comm] + 1)) // 2

        if s < 0 or t < 0 or h[x] * (h[x] + 1) + h[y] * (h[y] + 1) > 2 * (ss + tt):
            print(-1)
        else:
            print(-  (h[x] * (h[x] + 1) + h[y] * (h[y] + 1)) // 2 + (ss + tt))