TREELCA7 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Dynamic programming, dfs

PROBLEM:

You’re given a tree with N vertices, such that \text{par}_i \lt i for each 2 \le i \le N.
For a subset S of vertices, define an array A as A_i = \min_{x\in S} (\text{lca}(x, i))

Compute the sum of \text{sum}(A) across all distinct arrays A that can possibly be formed this way.

EXPLANATION:

The \text{par}_i \lt i condition here is key, and is what we’ll heavily utilize in order to solve the problem.

First, let’s try to figure out what arrays A are possible to achieve.

Suppose 1 \in S.
Then, it’s easy to see that \text{lca}(1, i) = 1 for every i, so we simply have A_i = 1 for all i.
This gives the array [1, 1, \ldots, 1] with a sum of N.

From now, we assume 1 \not \in S.
Let the children of vertex 1 be c_1, c_2, \ldots, c_k.
Observe that if S contains vertices from the subtrees of at least two different c_i’s, then we’ll just end up with A = [1, 1, \ldots, 1] again.
This is because for any vertex u, we’ll be able to find some vertex x \in S such that \text{lca}(u, x) = 1 (by just choosing some vertex that lies in a different subtree from u.)

Because we only care about distinct arrays A, such a case is pretty useless to us (since [1, 1, \ldots, 1] has already been considered.)
So, the only “interesting” case is when all the elements of S lie within a single child subtree.


Now, suppose all the elements of S lie within the subtree of c_1.
Observe that for any vertex i that does not belong to this subtree, we’ll surely have A_i = 1, because their LCA with vertices within c_i is just 1 itself.

So, if there are s_{c_1} vertices in the subtree of c_1, we know for sure that every array A under consideration is going to have (N - s_{c_1}) of its elements be equal to 1.
On the other hand, for vertices within c_1, we are simply solving the problem independently for that subtree!

This gives us the idea of using dynamic programming.

Define dp_u to be the answer for the subtree rooted at u.
Also define ct_u to be the number of distinct possible arrays when considering the subtree rooted at u (we’ll need this for transitions.)
Recall that s_u denotes the number of vertices present in the subtree of u.

Let’s compute dp_u.
As seen earlier, there are really only two options: either every element of the array equals u, or all elements of the chosen subset lie within a single child subtree.

The first case is trivial: all the elements equal u, and there are s_u of them, so the overall sum is s_u \cdot u.

Otherwise, let c be a child of u and suppose all elements of the set lie within the subtree of c.
Then,

  • There are s_u - s_c vertices outside the subtree.
    All of them will have their minimum LCA be just u itself.
    So, for each possible array formed from within c, we get an additional (s_u - s_c) copies of u.
  • There are, by definition, ct_c distinct arrays possible when considering only the subtree of c.
  • By definition, we also know the sum of their sums equals dp_c.
  • So, when lifting these arrays to u, each of them has their sum increased by (s_u - s_c) \cdot u.
    This gives an overall contribution of dp_c + ct_c\cdot (s_u - s_c) \cdot u.

Simply sum up the last quantity across all children c to obtain dp_u.
ct_u can be similarly computed to be just the sum of all ct_c across the children, plus 1 for the array [u, u, \ldots, u].
(In fact, you may notice that we always have ct_u = s_u so we don’t even need to explicitly maintain them as separate arrays.)

This gives us a dynamic programming solution in \mathcal{O}(N) time.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    const int mod = 998244353;
    
    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector par(n+1, 0);
        for (int i = 2; i <= n; ++i) {
            cin >> par[i];
        }

        vector dp(n+1, 0), ct(n+1, 0), subsz(n+1, 1);
        for (int i = n; i > 1; --i) {
            subsz[par[i]] += subsz[i];
        }

        for (int i = 1; i <= n; ++i) {
            dp[i] = (1ll * i * subsz[i]) % mod;
            ct[i] = 1;
        }

        for (int i = n; i > 1; --i) {
            // go into i from par[i]
            // everything outside has value par[i]
            int out = subsz[par[i]] - subsz[i];
            ct[par[i]] += ct[i];
            dp[par[i]] = (dp[par[i]] + 1ll*out*par[i]*ct[i] + dp[i]) % mod;
        }

        cout << dp[1] << '\n';
    }
}
1 Like

great problem. really enjoyed it