MARKTREEHD - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: snow_29
Tester: iceknight1093
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

DFS, dynamic programming

PROBLEM:

You’re given a tree, with two markers: one each at vertices 1 and N.

Consider a set S of vertices of the tree marked as special.

In one move, you can choose a marker and move it to an adjacent vertex.
Define f(S) to be the minimum number of moves needed to visit every special vertex with at least one marker, and end up in a state where both markers are back at their initial positions.

Compute the sum of f(S) across all choices of S.

EXPLANATION:

Recall from the easy version of the problem that the solution to a fixed subset S was as follows:

  • Consider the path from 1 to N, say 1 = x_1 \to x_2 \to\ldots\to x_k = N
  • Try all possibilities of x_i as the highest visited vertex from N.
  • This splits the tree into two smaller trees: the subtree rooted at x_i and everything else.
  • For each smaller tree, there’s only one marker in it, so solve for that (the answer is twice the count of important edges), and add up the values for both trees.
  • The overall answer is then the minimum of this across all choices of x_i.

Let’s now gain a deeper understanding of the above solution.
Split the edges of the tree into two classes: edges that lie on the 1 \to N path, and those that don’t.

First, we look at non-path edges.
Let (u, \text{par}_u) be one such edge.
Then, observe that:

  • If there exists an element of S in the subtree of u, then this edge will definitely be visited no matter what the choice of x_i is (i.e. this edge must always be visited by one of the markers).
  • If there’s no element of S in the subtree of u, then this edge never needs to be visited.

This is really just an extension of the idea of solving when there’s only a single marker; where we treat the entire 1\to N path as being compressed into a single node.

In particular, note that for every non-path edge, its contribution can be directly computed independently of the choice of x_i.

Next, let’s look at path edges.
Since we know the markers will never meet, the set of path edges traversed will be some prefix and some suffix.
That is, there will exist some indices i \lt j such that one marker will visit x_1, x_2, \ldots, x_i, the other marker will visit x_j, x_{j+1}, \ldots, x_k, and all of x_{i+1}, \ldots, x_{j-1} will be unvisited.

Clearly, it’s optimal for us to maximize the number of unvisited vertices - so let’s figure out how to do that.

Suppose x_i is unvisited on the path.
Then, that also means any non-path vertices that must be reached from x_i cannot be visited.
That is, let’s define C_i to be the set of vertices u such that if we repeatedly move upwards from u, the first vertex on the 1\to N path we touch is x_i (this is equivalent to \text{lca}(N, u) = x_i).

Then, if we decide not to visit x_i with either marker, we also cannot visit any vertex in C_i with any marker.
This is clearly only possible if no element of C_i is in S, i.e. C_i \cap S = \emptyset.

With this in mind, let’s compute for each x_i on the 1\to N path whether it can be possibly skipped or not (which is doable with a straightforward DFS).
Note that in particular, we also ensure that 1 and N themselves cannot be skipped.

Suppose the path vertices that cannot be skipped are at indices i_1, i_2, \ldots, i_r.
Then, because of the prefix/suffix visitation condition, we can only skip visiting vertices of the form x_{i_j + 1}, \ldots, x_{i_{j+1}-1} for some j.
That is, some prefix of the ‘important’ vertices must be visited from 1, the remaining suffix is then forced to be visited by N, and the best we can do is to skip everything inbetween.

Note that if the first j important path vertices are visited from 1, then the number of skipped edges equals exactly i_{j+1} - i_j.
We’re looking to maximize this, so it’s clearly optimal to just take the maximum adjacent difference here.

This is how we obtain a solution in \mathcal{O}(N) for the fixed-subset version.

  • Edges not on the 1\to N path have their contributions counted independently.
  • As for edges on the 1\to N path, we first find all ‘important’ vertices on this path and then take the length of the path minus the maximum adjacent distance between important vertices.
  • All of this can be implemented in \mathcal{O}(N) with a single DFS, since it only needs a little subtree information.

We now adapt the above solution to counting across all subsets.

Since the answer is twice the number of chosen edges, we’ll count the contribution of each edge.
Let’s look at path edges and non-path edges separately.

First, consider non-path edges.
Say we’re looking at edge (u, \text{par}_u) which is not on the 1\to N path.
As noted above, this edge will be traversed (twice) if and only if some element of S lies in the subtree of u.

Let s_u denote the number of vertices in the subtree of u.
There are then (2^{s_u} - 1) ways to choose a non-empty subset of vertices in the subtree of u, and 2^{N-s_u} choices of vertices outside the subtree (the outside subtree can be empty; that’s not an issue.)

So, this edge has an overall contribution of 2\cdot 2^{N-s_u} \cdot (2^{s_u}-1).
The values of s_u can be precomputed in linear time with a DFS, after which the above quantity is easy to compute.

Next, let’s look at path edges.
As we saw, for the 1\to N path, what really matters is the maximum distance between ‘important’ vertices on it.

So, a natural idea would be: for each integer r, compute the number of configurations such that the maximum distance between some adjacent pair of vertices is exactly r, and then add this count multiplied by 2\cdot (k-1-r) to the answer (where k is the number of vertices on the 1\to N path, endpoints included.)

However, it’s somewhat hard to ensure that the maximum distance is exactly r.
Instead, let’s relax the criterion a bit and allow for the maximum distance to be at most r.

This is much easier to compute, and can be done with the help of dynamic programming.

Specifically, we can do the following.
First, define y_i to be the count of vertices u such that \text{lca}(N, u) = x_i.
Essentially, y_i is the size of the set C_i we talked about above - this is important because as noted previously, x_i can potentially be skipped only if C_i \cap S = \emptyset.

Now, define dp(i, r) to be the number of choices of subsets such that:

  • We only consider the sets C_1, C_2, \ldots, C_i (i.e. the path till x_i and the non-path vertices connected to them),
  • Every adjacent distance between important path vertices is \le r; and
  • i is an important vertex.

To compute dp(i, r),

  • We need to ensure that i is important.
    That means there must be some intersection between C_i and S.
    There are 2^{y_i}-1 choices for a non-empty subset of these elements.
  • Next, we need to ensure that the previous important vertex is no further than r away.
    Further, if we fix j \lt i as the previous important vertex, then for everything in [j+1, i-1] we must choose nothing from C_j.
    So, if j is fixed, we can simply take dp(j, r).
  • Thus, we obtain dp(i, r) = (2^{y_i}-1)\cdot (dp(i-1, r) + dp(i-2, r) + \ldots + dp(i-r, r)).

Now, dp(k, r) gives us the number of subsets with maximum distance at most r.
To obtain the number of subsets with maximum distance exactly r, simply subtract dp(k, r-1) from it.


The above DP has \mathcal{O}(N^2) states, and an \mathcal{O}(N) transition from each state.
However, the transitions are trivially optimized to constant time since they’re just range sums over the dp(r, \cdot) array, and we this obtain a solution that’s \mathcal{O}(N^2) overall, which is fast enough.

TIME COMPLEXITY:

\mathcal{O}(N^2) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

const int mod = 998244353;

struct mint{
    int x;

    mint (){ x = 0;}
    mint (int32_t xx){ x = xx % mod; if (x < 0) x += mod;}
    mint (long long xx){ x = xx % mod; if (x < 0) x += mod;}

    int val(){
        return x;
    }
    mint &operator++(){
        x++;
        if (x == mod) x = 0;
        return *this;
    }
    mint &operator--(){
        if (x == 0) x = mod;
        x--;
        return *this;
    }
    mint operator++(int32_t){
        mint result = *this;
        ++*this;
        return result;
    }
    
    mint operator--(int32_t){
        mint result = *this;
        --*this;
        return result;
    }
    mint& operator+=(const mint &b){
        x += b.x;
        if (x >= mod) x -= mod;
        return *this;
    }
    mint& operator-=(const mint &b){
        x -= b.x;
        if (x < 0) x += mod;
        return *this;
    }
    mint& operator*=(const mint &b){
        long long z = x;
        z *= b.x;
        z %= mod;
        x = (int)z;
        return *this;
    }
    mint operator+() const {
        return *this;
    }
    mint operator-() const {
        return mint() - *this;
    }
    mint operator/=(const mint &b){
        return *this = *this * b.inv();
    }
    mint power(long long n) const {
        mint ok = *this, r = 1;
        while (n){
            if (n & 1){
                r *= ok;
            }
            ok *= ok;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        return power(mod - 2);
    }
    friend mint operator+(const mint& a, const mint& b){ return mint(a) += b;}
    friend mint operator-(const mint& a, const mint& b){ return mint(a) -= b;}
    friend mint operator*(const mint& a, const mint& b){ return mint(a) *= b;}
    friend mint operator/(const mint& a, const mint& b){ return mint(a) /= b;}
    friend bool operator==(const mint& a, const mint& b){ return a.x == b.x;}
    friend bool operator!=(const mint& a, const mint& b){ return a.x != b.x;}
    mint power(mint a, long long n){
        return a.power(n);
    }
    friend ostream &operator<<(ostream &os, const mint &m) {
        os << m.x;
        return os;
    }
    explicit operator bool() const {
        return x != 0;
    }
};

// Remember to check MOD

void Solve() 
{
    int n; cin >> n;
    
    vector <int> par(n + 1);
    vector<vector<int>> adj(n + 1);
    for (int i = 2; i <= n; i++){
        cin >> par[i];
        adj[i].push_back(par[i]);
        adj[par[i]].push_back(i);
    }
    
    vector <bool> on(n + 1, false);
    int s = n;
    on[s] = true;
    
    mint ans = 0;
    while (s != 1){
        ans++;
        s = par[s];
        on[s] = true;
    }
    ans--;
    
    vector <int> sub(n + 1, 0);
    vector <int> vec;
    vector <mint> p2(n + 1);
    p2[0] = 1;
    for (int i = 1; i <= n; i++) p2[i] = p2[i - 1] * 2;
    vector <mint> i2(n + 1);
    for (int i = 0; i <= n; i++) i2[i] = 1 / p2[i];
    
    ans *= p2[n];
    
    auto dfs = [&](auto self, int u, int par) -> void{
        sub[u] += 1;
        
        for (int v : adj[u]) if (v != par && !on[v]){
            self(self, v, u);
            sub[u] += sub[v];
            ans += (p2[sub[v]] - 1) * p2[n - sub[v]];
        }
    };
    
    for (int i = 1; i <= n; i++){
        if (on[i]){
            dfs(dfs, i, -1);
            
            if (i != 1 && i != n){
                vec.push_back(sub[i]);
            }
        }
    }
    
    int m = vec.size();
    vector <mint> p(m + 2, 0);
    for (int i = 0; i < vec.size(); i++){
        p[i + 1] = i2[vec[i]];
    }
    
    vector <mint> prob(n + 1, 0);
    vector <mint> pp(m + 2, 1);
    for (int i = 1; i <= m + 1; i++) pp[i] = pp[i - 1] * p[i];
    vector <mint> ipp(m + 2);
    for (int i = 0; i <= m + 1; i++){
        ipp[i] = 1 / pp[i];
    }
    
    for (int r = 0; r <= n; r++){
        mint sum = 0;
        vector <mint> dp(m + 2, 0);
        dp[0] = 1;
        
        sum += dp[0] * ipp[0];
        
        for (int i = 1; i <= m + 1; i++){
            dp[i] = sum * (1 - p[i]) * pp[i - 1];
            sum += dp[i] * ipp[i];
            if (i > r){
                sum -= dp[i - r - 1] * ipp[i - r - 1];
            }
        }
        
        prob[r] = dp[m + 1];
    }
    
    for (int i = n; i >= 1; i--){
        prob[i] -= prob[i - 1];
    }
    
    for (int r = 0; r <= n; r++){
        ans -= prob[r] * r * p2[n];
    }
    
    ans *= 2;
    
    cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}