PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: raysh07
Tester: sushil2006
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Dynamic programming, dfs
PROBLEM:
You’re given a tree with N vertices, such that \text{par}_i \lt i for each 2 \le i \le N.
For a subset S of vertices, define an array A as A_i = \min_{x\in S} (\text{lca}(x, i))
Compute the sum of \text{sum}(A) across all distinct arrays A that can possibly be formed this way.
EXPLANATION:
The \text{par}_i \lt i condition here is key, and is what we’ll heavily utilize in order to solve the problem.
First, let’s try to figure out what arrays A are possible to achieve.
Suppose 1 \in S.
Then, it’s easy to see that \text{lca}(1, i) = 1 for every i, so we simply have A_i = 1 for all i.
This gives the array [1, 1, \ldots, 1] with a sum of N.
From now, we assume 1 \not \in S.
Let the children of vertex 1 be c_1, c_2, \ldots, c_k.
Observe that if S contains vertices from the subtrees of at least two different c_i’s, then we’ll just end up with A = [1, 1, \ldots, 1] again.
This is because for any vertex u, we’ll be able to find some vertex x \in S such that \text{lca}(u, x) = 1 (by just choosing some vertex that lies in a different subtree from u.)
Because we only care about distinct arrays A, such a case is pretty useless to us (since [1, 1, \ldots, 1] has already been considered.)
So, the only “interesting” case is when all the elements of S lie within a single child subtree.
Now, suppose all the elements of S lie within the subtree of c_1.
Observe that for any vertex i that does not belong to this subtree, we’ll surely have A_i = 1, because their LCA with vertices within c_i is just 1 itself.
So, if there are s_{c_1} vertices in the subtree of c_1, we know for sure that every array A under consideration is going to have (N - s_{c_1}) of its elements be equal to 1.
On the other hand, for vertices within c_1, we are simply solving the problem independently for that subtree!
This gives us the idea of using dynamic programming.
Define dp_u to be the answer for the subtree rooted at u.
Also define ct_u to be the number of distinct possible arrays when considering the subtree rooted at u (we’ll need this for transitions.)
Recall that s_u denotes the number of vertices present in the subtree of u.
Let’s compute dp_u.
As seen earlier, there are really only two options: either every element of the array equals u, or all elements of the chosen subset lie within a single child subtree.
The first case is trivial: all the elements equal u, and there are s_u of them, so the overall sum is s_u \cdot u.
Otherwise, let c be a child of u and suppose all elements of the set lie within the subtree of c.
Then,
- There are s_u - s_c vertices outside the subtree.
All of them will have their minimum LCA be just u itself.
So, for each possible array formed from within c, we get an additional (s_u - s_c) copies of u. - There are, by definition, ct_c distinct arrays possible when considering only the subtree of c.
- By definition, we also know the sum of their sums equals dp_c.
- So, when lifting these arrays to u, each of them has their sum increased by (s_u - s_c) \cdot u.
This gives an overall contribution of dp_c + ct_c\cdot (s_u - s_c) \cdot u.
Simply sum up the last quantity across all children c to obtain dp_u.
ct_u can be similarly computed to be just the sum of all ct_c across the children, plus 1 for the array [u, u, \ldots, u].
(In fact, you may notice that we always have ct_u = s_u so we don’t even need to explicitly maintain them as separate arrays.)
This gives us a dynamic programming solution in \mathcal{O}(N) time.
TIME COMPLEXITY:
\mathcal{O}(N) per testcase.
CODE:
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());
int main()
{
ios::sync_with_stdio(false); cin.tie(0);
const int mod = 998244353;
int t; cin >> t;
while (t--) {
int n; cin >> n;
vector par(n+1, 0);
for (int i = 2; i <= n; ++i) {
cin >> par[i];
}
vector dp(n+1, 0), ct(n+1, 0), subsz(n+1, 1);
for (int i = n; i > 1; --i) {
subsz[par[i]] += subsz[i];
}
for (int i = 1; i <= n; ++i) {
dp[i] = (1ll * i * subsz[i]) % mod;
ct[i] = 1;
}
for (int i = n; i > 1; --i) {
// go into i from par[i]
// everything outside has value par[i]
int out = subsz[par[i]] - subsz[i];
ct[par[i]] += ct[i];
dp[par[i]] = (dp[par[i]] + 1ll*out*par[i]*ct[i] + dp[i]) % mod;
}
cout << dp[1] << '\n';
}
}