TREEDEST - Editorial


Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: lawliet_p
Testers: tabr and errorgorn
Editorialist: iceknight1093




Euler tour, fenwick trees (or segment trees), binary lifting


The score of a tree T with N vertices is obtained as follows:

  • Pick a leaf vertex u, and add the distance between u and the farthest node from u to the score.
    Then, delete u from T.
  • Repeat the above process |T| times, till T is empty.

You are given N trees. The first tree is a single vertex, and tree (i+1) is obtained from tree i by attaching a single new vertex to it as a leaf.
Find the maximum possible score of each tree.


Let T_i denote the i-th tree created.

First, we need to figure out how to compute the maximum score for a single tree.

Maximum score

Consider some tree T.
Its maximum score can be obtained as follows:

  • Fix the endpoints (x, y) of a diameter of T.
  • Then, remove all vertices that don’t lie on this diameter, one at a time.
  • Finally, we’re left with a path, which can be removed in any order, one endpoint at a time.

A proof is fairly straightforward, using the fact that for any vertex u\in T, one farthest vertex from u is either x or y.

From here, we know that the diameter of each T_i is important, so we need to be able to maintain them as we go.
Let D_i = (x_i, y_i) denote the endpoints of a diameter of T_i.
Then, if T_{i+1} is obtained by adding vertex u to T_i, D_{i+1} will be some two out of \{x_i, y_i, u\}; whichever pair has the maximum distance.
For convenience, we won’t pick u as a diameter endpoint unless it results in a strict increase.

Now, suppose we know the score of T_i. Let’s see how we can compute the score of T_{i+1} from it; with u being the newly added vertex.

  • First, if D_{i+1} = D_i, then almost nothing changes.
    The newly added node isn’t part of the diameter, so it contributes its farthest distance and then vanishes, leaving us with T_i again (whose score we already know).
  • Otherwise, u is the endpoint of a diameter of T_{i+1}, and diam(T_{i+1}) \gt diam(T_i).
    Let x be the other endpoint of the diameter.
    Then, if p_u is the vertex to which u is attached, we know (x, p_u) was itself a diameter of T_i (do you see why?)

The second point means that all we need to be able to do, is compute the change in the answer when the diameter is extended by 1.

For that, let’s go back to analyzing the score-finding process itself.
Recall that for the fixed diameter (x, p_u), we did the following:

  • For each vertex v not on the diameter, add \max(d(v, x), d(v, p_u)) to the score.
  • Then, remove the diameter one endpoint at a time.

In particular, what matters is, for each v, which of x and p_u is further away from it.

For now, assume d(x, p_u) is even; though the odd case is similar (though not exactly the same, a little careful implementation is necessary).
Let S_1 be the set of all vertices that are farther away from x then p_u, and S_2 be the set of all other vertices.
Note that S_1 and S_2 will both form connected components.

When u is appended to the diameter, note that:

  • All the elements of S_1 will still be farther away from x; and
  • All the elements of S_2 will have u as their new farthest vertex.

So, there are only two real changes:

  • Elements of S_2 that aren’t on the diameter will contribute +1 to the score.
  • The diameter is increased to d(x, u), so that value gets added to the score.
    Everything else remains exactly the same.

d(x, u) is easy to compute; so if we’re able to quickly find the number of elements of S_2 that aren’t on the diameter, we’ll be done.
Note that exactly \left \lceil \frac{d(x, u) + 1}{2}\right\rceil elements of S_2 lie on the diameter, so we can instead just find the size of S_2 and subtract this number from it.

Since S_1 and S_2 partition T_i, we could also just find the size of S_1 instead.
Recall that S_1 and S_2 were both connected components.
In particular, at least one of them will be a subtree of T_N
That is, either S_1 or S_2 will satisfy the following condition:

  • There exists a unique vertex w, such that all elements of the chosen set lie in the subtree of w and none of the elements of the other set lie in the subtree of w.
    “Subtree” here refers to a subtree in T_N, the final tree.

Finding this “splitting vertex” w is possible using binary lifting, after which we only need to query the number of active vertices in its subtree, which is a standard task that can be done using euler tour + fenwick/segment trees; since it’s essentially a subtree sum query.
After adding each new vertex, update the data structure accordingly to activate it.

Note that we technically needed to compute distances in a “dynamic” tree, but really we can just compute distances in T_N instead since those are the same.


\mathcal{O}(N\log N) per testcase.


Author's code (C++)
#include <iostream>
#include <vector>
#include <bits/stdc++.h>

using namespace std;
typedef long long ll;

const int lg = 20;
const int maxn = 1000010;

struct DiamInfo
    int a, b;
    int size, lca;

class FenwickTree

        int query(int v)
            int ans = 0;

            for( ; v > 0 ; v -= v & -v)
                ans += bit[v];

            return ans;

        int query(int l, int r) { return query(r) - query(l - 1); }
        void update(int v) { for( ; v <= n ; v += v & -v) bit[v]++; }
        void init(int N) { n = N; memset( bit , 0 , sizeof(int)*(N + 1) ); }


        int n;
        int bit[maxn];

int n;

int depth[maxn];
int tab[lg][maxn];
int tin[maxn], tout[maxn];
int lcaA[maxn], lcaB[maxn];

DiamInfo diam[maxn];

vector<int> adj[maxn];

FenwickTree BIT;

void dfs(int node, int d, int& curTime)
    depth[node] = d;
    tin[node] = ++curTime;

    for(int neighbor: adj[node])
        dfs( neighbor , d + 1 , curTime );

    tout[node] = curTime;

int upPath(int node, int up)
    for(int k = 0 ; k < lg ; k++)
        if( up & (1 << k) ) node = tab[k][node];

    return node;

int lca(int u, int v)
    auto in = [&](int u, int v) -> bool {
        return tin[u] <= tin[v] && tout[v] <= tout[u];

    for(int k = lg - 1 ; k >= 0 ; k--)
        if( !in( tab[k][u] , v ) ) u = tab[k][u];

    return in(u, v) ? u : tab[0][u];

int dist(int u, int v, int l) { return depth[u] + depth[v] - 2*depth[l]; }

void calculateDiameters()
    diam[1].size = 0;
    diam[1].a = diam[1].b = diam[1].lca = 1;

    for(int i = 2 ; i <= n ; i++)
        diam[i] = diam[i - 1];

        lcaA[i] = lca( i , diam[i].a );
        lcaB[i] = lca( i , diam[i].b );

        if( diam[i].size < dist( i , diam[i].b , lcaB[i] ) )
            diam[i].lca = lcaB[i];
            diam[i].a = lcaA[i] = i;

            diam[i].size = dist( i , diam[i].b , lcaB[i] );

        if( diam[i].size < dist( diam[i].a , i , lcaA[i] ) )
            diam[i].lca = lcaA[i];
            diam[i].b = lcaB[i] = i;

            diam[i].size = dist( diam[i].a , i , lcaA[i] );

ll updateDiameter(int node)
    int a = diam[node].a, b = diam[node].b, l = diam[node].lca;
    int da = depth[a] - depth[l], db = depth[b] - depth[l], up = (diam[node].size - 1)/2;

    if( up < da )
        int root = upPath( a , up );
        return BIT.query( tin[root] , tout[root] ) - (up + 1);

    int upB = diam[node].size - up - 1;
    int root = upPath( b , upB );

    return node - BIT.query( tin[root] , tout[root] ) - (up + 1);

void init(int n)
    BIT.init( n );

    for(int i = 1 ; i <= n ; i++)

void solve()
    cin >> n;
    init( n );

    tab[0][1] = 1;

    for(int i = 2 ; i <= n ; i++)
        cin >> tab[0][i];

    for(int k = 1 ; k < lg ; k++)
        for(int i = 1 ; i <= n ; i++)
            tab[k][i] = tab[k - 1][ tab[k - 1][i] ];

    for(int i = 2 ; i <= n ; i++)
        adj[ tab[0][i] ].push_back( i );

    int curTime = 0;
    dfs( 1 , 0 , curTime );


    ll curAns = 0;
    BIT.update( tin[1] );

    cout << 0 << " ";

    for(int i = 2 ; i <= n ; i++)
        BIT.update( tin[i] );
        curAns += max( dist( diam[i].a , i , lcaA[i] ) , dist( i , diam[i].b , lcaB[i] ) );

        if( diam[i].a == i )
            swap( diam[i].a , diam[i].b );

        if( diam[i].b == i )
            curAns += updateDiameter( i );

        cout << curAns << " ";

    cout << endl;

int main()

    int t;
    cin >> t;

    while( t-- )
Tester (errorgorn)'s code
#include <bits/stdc++.h>
using namespace std;

#define int long long
#define ll long long
#define ii pair<int,int>
#define iii tuple<int,int,int>
#define fi first
#define se second
#define endl '\n'
#define debug(x) cout << #x << ": " << x << endl

#define pub push_back
#define pob pop_back
#define puf push_front
#define pof pop_front
#define lb lower_bound
#define ub upper_bound

#define rep(x,start,end) for(int x=(start)-((start)>(end));x!=(end)-((start)>(end));((start)<(end)?x++:x--))
#define all(x) (x).begin(),(x).end()
#define sz(x) (int)(x).size()

mt19937 rng(chrono::system_clock::now().time_since_epoch().count());

struct node{
	ii fen[1000005];
	void reset(int n){ rep(x,1,n+1) fen[x]={0,0};}
	ii add(ii i,ii j){ return {,}; }
	ii sub(ii i,ii j){ return {,}; }
	void update(int i,ii k){
		while (i<1000005){
	ii query(int i){
		ii res={0,0};
		while (i){
		return res;
	ii query(int i,int j){
		return sub(query(j),query(i-1));
} root;

int n;
vector<int> al[1000005];

int d[1000005];
int tkd[1000005][22];
int in[1000005];
int out[1000005];
int _TIME;

void dfs(int i){
	for (auto it:al[i]){
		int curr=tkd[it][0]=i;
			if (curr==-1) break;

int mov(int i,int j){
	rep(x,0,22) if (j&(1<<x)) i=tkd[i][x];
	return i;

int lca(int i,int j){
	if (d[i]<d[j]) swap(i,j);
	if (i==j) return i;
	rep(x,22,0) if (tkd[i][x]!=tkd[j][x]) i=tkd[i][x],j=tkd[j][x];
	return tkd[i][0];

int dist(int i,int j){
	return d[i]+d[j]-2*d[lca(i,j)];

int onpath(int i,int j,int k){
	int g=lca(i,j);
	if (d[i]-k>=d[g]) return mov(i,k);
	else return mov(j,dist(i,j)-k);

signed main(){
	cin.exceptions(ios::badbit | ios::failbit);
	int TC;
	while (TC--){
		rep(x,1,n+1) al[x].clear();
			rep(y,0,22) tkd[x][y]=-1;
			int a; cin>>a;
		// rep(x,1,n+1) cout<<d[x]<<" "; cout<<endl;
		// rep(x,1,n+1) cout<<tkd[x][0]<<" "; cout<<endl;
		// rep(x,1,n+1) cout<<tkd[x][1]<<" "; cout<<endl;
		// rep(x,1,n+1) cout<<in[x]<<" "; cout<<endl;
		// rep(x,1,n+1) cout<<out[x]<<" "; cout<<endl;
		cout<<0<<" "<<1<<" ";
		int a=1,b=2,D=1; //diam
		vector<int> id={1,2};
		vector<int> cnt0={1,1};
		vector<int> cnt1={0,0};
			int mn=1e9;
			for (auto it:id) mn=min(mn,dist(u,it));
			rep(x,0,sz(id)) if (mn==dist(u,id[x])){
			int c=-1;
			if (dist(a,u)>D) c=a;
			if (dist(b,u)>D) c=b;
			if (c!=-1){
				if (D%2==0){
					if (id[0]!=onpath(a,b,D/2-1)){
					if (tkd[id[2]][0]==id[1]){ //this guy is below
						auto val=root.query(in[id[2]],out[id[2]]);
					else{ //this guy is ontop (stupid case)
						auto val0=root.query(in[id[0]],out[id[0]]);
						auto val1=root.query(in[id[1]],out[id[1]]);
				else{ //there are only 2 guys
					if (id[0]==onpath(a,b,D/2)),b,D/2+1));
					if (tkd[id[1]][0]==id[0]){
						auto val=root.query(in[id[1]],out[id[1]]);
						auto val=root.query(in[id[0]],out[id[0]]);
			int tot;
			if (D%2==0) tot=cnt1[0]+cnt1[1]+cnt1[2]+u*(D/2+1)-cnt0[1];
			else tot=cnt1[0]+cnt1[1]+u*(D/2+1);
			cout<<tot-((D+1)/2)*(D/2+1)<<" ";