What is f(u,v). Itâ€™s just \lceil dist(u,v)/k \rceil. This can also be written as \frac{dist(u,v ) + c_{u,v}}{k}, where c_{u,v} is the smallest non-negative integer such that k divides dist(u,v) + c_{u,v}. Since we need \sum \frac{dist(u,v ) + c_{u,v}}{k}, we can split this into two parts. \sum dist(u,v ) + \sum c_{u,v}

As we can see, c_{u,v} is only dependent on dist(u,v)\%k.

How do we compute the distance between 2 nodes in a tree?. We can find it by finding the depths of each node, and the depth of itâ€™s LCA(Lowest common ancestor). Let L_{u,v} denote the LCA of u and v.

Then dist(u,v) = depth(u) + depth(v) - 2\times depth(L_{u,v}). This is because it has to go up to its LCA and then come back down. Letâ€™s define dp_{i,j} The number of nodes in the subtree of i at a depth d such that d \equiv j\mod k.

Now letâ€™s choose some node a, and consider \sum_{L_{u,v}=a} c_{u,v}. This is the sum of c over all pairs of nodes whose LCA is a. Since all pairs of nodes have some LCA, this includes all of them.

We can see that the LCA of two nodes will be a only of the two nodes lie in different subtrees of a. The two for loops are counting the number of pairs of each \mod k

At each point `count_subtree[a][i]`

(dp_{a,i}) stores the number of nodes at a depth of i \mod k from the subtrees already accounted for. So for each subtree, we are calculating the number of pairs of nodes at depth i,j \mod k. That is all the information needed to find c.

Therefore i + j - 2\times depth gives us the distance \mod k. Using this information we can compute `needed`

(c), and multiply it by the number of pairs.