DIVDIST Editorial

Root the tree from some node. Let’s denote by d_v the depth of the node v, glob[v] the answer for the subtree of node v, and by dp[v][depth] the maximum number of nodes that we can select from subtree of node v, if we can select only nodes with d_v \equiv depth \bmod K (so we want this for each depth from 0 to K-1). Let’s calculate these values for each node.

Suppose that we know these values for all nodes in the subtree of v except v, let’s find them for v.

  • dp[v][x]. If d_v \equiv x \bmod K, it’s equal to 1 plus the sum of dp[son][x] over all children of v. Otherwise, we can’t select v, so we can only select nodes from subtrees of its children.

Then, if 2(d_v - x) is divisible by K, it’s just the sum of dp[son][x] over all children of v. Otherwise, it’s the maximum of these values over its children.

  • glob[v]. If we select v, it’s dp[v][d_v \bmod K]. Otherwise, there are several subcases:
  1. If we select nodes from subtrees of at least 3 different sons of v, then they all have to have the same depth, so it’s just the largest of dp[v][x].

  2. If we select nodes from subtree of only one son, it’s some glob[son].

  3. If we select nodes from subtrees of two of its sons, then all selected nodes from the first subtree must have the same depth modulo K (say a), and the same for the second (say b). In addition, a + b - d_v must be divisible by K. This gives dp[son_1][a] + dp[son_2][b].

Now we just have to take the maximum of all obtained values.

Now, how to calculate all these values fast? It’s quite clear how to calculate everything in O(NK), the only not immediately trivial thing is calculating largest value of dp[son_1][a] + dp[son_2][b] over different sons of v and pairs (a, b) with a + b - d_v divisible by K, for this we can iterate over all valid pairs (a, b), and calculate for each pair (son, a) the largest value of dp[son_1][a] over son_1 \neq son.

It turns out that this can be optimized to O(N)! Indeed, note that we don’t need to save dp[v][x] for a node if there aren’t any nodes of depth x mod K in the subtree of v. Surprisingly, this optimization turns out to be sufficient to reduce complexity to O(N)! You can prove this as follows: for each node, and for each x such that there is some node with depth equal to x mod K, choose one of such nodes and place a token on it. When recalculating the dp for given node v, we will take into account each subtree except the one with the largest depth at most (its depth) times, so let’s just remove all tokens from each subtree except one (with the largest depth). For each node, the token will be placed/removed at least once from it, so the amortized complexity will be O(N).

We had a subtask for a solution with centroid decomposition in O(N\log{N}), but I won’t describe it here.