STSMDM - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: mrmadness
Tester: raysh07
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

Data structures, Euler tour of a tree

PROBLEM:

For a tree, define f(u, v) to be the number of roots that result in u having a larger subtree size than v.
For each u, compute \sum f(u, v) across all v.

EXPLANATION:

Let’s first to compute only all the values f(1, v).

To do this, let’s first root the tree at vertex 1, and compute all the subtree sizes - let s_v be the subtree size of v when the root is 1.

Now, let’s look at some vertex v \ne 1 and figure out what f(1, v) is.
Let c be the first vertex (other than 1) on the 1 \to v path (equivalently, c is the child of 1 whose subtree contains v.)

Observe that:

  • For any root r outside the subtree of c, we’ll surely have \text{subsize}_r(1) \gt \text{subsize}_r(v), because 1 will be an ancestor of v.
    There are N - s_c such roots.
  • For any root r inside the subtree of v, we’ll surely have \text{subsize}_r(1) \lt \text{subsize}_r(v) instead, because 1 will be a descendent of v.
    So, all such r can be ignored.
  • That leaves choices of r that lie within the subtree of c but not within the subtree of v.
    For every such r, the subtree sizes of 1 and v will be N - s_c and s_v respectively.
    So, either all of these r are valid roots, or none of them are - it depends on how the subtree sizes compare.
    There are s_c - s_v possible roots here.

Turning this around, for any vertex v in the subtree of c we have a “base” value of N - s_c, and then if N - s_c \gt s_v we have an additional value of s_c - s_v.


Now, consider an arbitrary vertex u.
Let c be a child of u (we’re still treating the tree as being rooted at 1.)

Extending the above discussion, it can be seen that for any v in the subtree of c, we can compute f(u, v) as follows:

  • There’s a base value of N - s_c.
  • Further, if N - s_c \gt s_v, there’s an additional value of s_c - s_v.

So, summing up across all such v:

  • There are s_c choices of v, each with a base value of N - s_c.
    We can thus add (N - s_c) \cdot s_c to the answer of u.
  • Let there be x vertices in this subtree whose subtree size is \lt N - s_c, and let the sum of their subtree sizes be y.
    We then need to add x\cdot s_c - y to the answer of u.

To find x we’re essentially asking the question “how many vertices in this subtree have size smaller than a given value”.
Building the Euler tour of a tree converts this to a subarray query; and queries of the form “how many values in this subarray are smaller than a given value” are very classical (see SPOJ KQUERY for an old instance.)

Perhaps the simplest way to solve this is to do it offline - build a segment tree/fenwick tree on the array and insert values into it in ascending order; process each query immediately after all values \le its threshold have been inserted.
This works in \mathcal{O}(N\log N) time and quite quickly at that.
There are, of course, many other solutions - with varying complexity and constant factor. Most reasonable ones should work.

Finding the value of y can be done similarly.


Observe that this handles computing \sum f(u, v) for all v that lie in the subtree of u.

That leaves us with vertices not in the subtree of u.
For these, the basic computation is still the same:

  • There are N-s_u vertices outside the subtree of u, and for each of them the s_u vertices in the subtree of u will always count as valid roots.
    That gives a base cost of s_u \cdot (N - s_u).
  • Then, any “small enough” subtree will incur an additional count.

The issue here is that for a vertex v outside the subtree of u, its subtree size might not be s_v anymore.
Specifically, all the ancestors of u are the ones whose subtree sizes will have changed - all non-ancestors will still have a size of s_v when rooted at u.

Let’s analyze non-ancestors and ancestors separately.

First, consider a non-ancestor v.
Here, the condition that needs to be satisfied is simply s_u \gt s_v.
If it is satisfied, we need to add N - s_u - s_v to the answer.

Importantly, note that s_u \gt s_v can never be satisfied if v is an ancestor of u anyway; so we can simply pretend we’re querying for the entirety of the outside of u’s subtree!
So, what we want now is, outside a certain subtree:

  • The number of subtree sizes that are \lt s_u, and
  • The sum of subtree sizes of vertices satisfying the above.

Each “outside subtree” query becomes a prefix+suffix in the Euler tour flattened array, so again this part can be handled offline with a fenwick tree.

Next, we look at the ancestors of u.
Let v be an ancestor of u, and y be the child of v that contains u.
Then, for additional roots to be valid, the necessary and sufficient condition is s_u \gt N - s_y; and there are s_y - s_u such roots.

This information can also be computed in \mathcal{O}(N\log N) time (for all u simultaneously) by maintaining a couple of fenwick trees of ancestor information (keyed by value) as we DFS through the tree.

This takes care of all computations for u, so we’re done!

TIME COMPLEXITY:

\mathcal{O}(N \log N) per testcase.

CODE:

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

struct Tree {
    vector<vector<int>> adj, lift;    
    vector<int> d, tin, tout, par, sub, head, ord;
    int n, timer;
    bool initialized = false;
    bool dfsed = false;
 
    void init(int nn){
        n = nn;
        adj.resize(n + 1);
        d.resize(n + 1);
        lift.resize(n + 1);
        tin.resize(n + 1);
        tout.resize(n + 1);
        par.resize(n + 1);
        sub.resize(n + 1);
        head.resize(n + 1);
        for (int i = 1; i <= n; i++) adj[i].clear();
        for (int i = 0; i <= n; i++) lift[i].resize(20, 0);
        initialized = true;
        ord.resize(1);
    }
 
    void addEdge(int u, int v){
        if (!initialized){ cout << "STUPID INITIALIZE\n"; exit(0);}
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
 
    void build(){
        for (int j = 1; j < 20; j++){
            for (int i = 1; i <= n; i++){
                lift[i][j] = lift[lift[i][j - 1]][j - 1];
            }
        }
    }
    
    void dfs1(int u, int par1){
        sub[u] = 1;
        for (int v : adj[u]) if (v != par1){
            dfs1(v, u);
            sub[u] += sub[v];
        }
    }
 
    void dfs(int u, int par1, int h){
        par[u] = par1;
        tin[u] = ++timer;
        ord.push_back(u);
        head[u] = h;
        bool fir = true;
        for (int v : adj[u]){
            if (v != par1){
                d[v] = d[u] + 1;
                lift[v][0] = u;
                int hh;
                if (fir) hh = h;
                else hh = v;
                fir = false;
                dfs(v, u, hh);
            }
        }
        tout[u] = timer;
    }
 
    void dfs(int root = 1){
        if (!initialized){ cout << "STUPID INITIALIZE\n"; exit(0);}
        d[root] = 0;
        timer = 0;
        dfs1(root, 0);
        for (int i = 1; i <= n; i++){
            sort(adj[i].begin(), adj[i].end(), [&](int x, int y){
                return sub[x] > sub[y]; 
            });
        }
        dfs(root, 0, root);
        build();
        dfsed = true;
    }
 
    int jump(int x, int depth){
        for (int i = 0; i < 20; i++) if (depth >> i & 1){
            x = lift[x][i];
        }
        return x;
    }
 
    int lca(int a, int b){
        if (!dfsed){ cout << "STUPID DFS\n"; exit(0);}
        if (d[a] < d[b]) swap(a, b);
        int del = d[a] - d[b];
        for (int i = 0; i < 20; i++) if (del >> i & 1) a = lift[a][i];
 
        if (a == b) return a;
        for (int i = 19; i >= 0; i--) if (lift[a][i] != lift[b][i]){
            a = lift[a][i];
            b = lift[b][i];
        }
        return lift[a][0];
    }
 
    int dist(int a, int b){
        return d[a] + d[b] - 2 * d[lca(a, b)];
    }
 
    bool anc(int x, int y){
        return tin[x] <= tin[y] && tout[x] >= tout[y];
    }
};
 
struct FenwickTree{
    int n;
    vector <int> f;
    vector <int> b;
 
    inline void add(int i, int x){
        b[i] += x;
        for (int j = i; j <= n; j += j & (-j)){
            f[j] += x;
        }
    }
 
    inline void modify(int i, int x){
        add(i, x - b[i]);
    }
 
    inline void init(int nn, vector <int> a){
        n = nn;
        if (a.size() == n){
            vector <int> a2;
            a2.push_back(0);
            for (auto x : a) a2.push_back(x);
            a = a2;
        }
 
        f.resize(n + 1);
        b.resize(n + 1);
 
        for (int i = 0; i <= n; i++) f[i] = 0, b[i] = 0;
 
        for (int i = 1; i <= n; i++){
            modify(i, a[i]);
        }
    }
 
    inline int query(int x){
        int ans = 0;
        for (int i = x; i; i -= i & (-i)){
            ans += f[i];
        }
        return ans;
    }
 
    inline int query(int l, int r){
        return query(r) - query(l - 1);
    }
};

void Solve() 
{
    int n; cin >> n;
    
    Tree T;
    T.init(n);
    
    for (int i = 1; i < n; i++){
        int u, v; cin >> u >> v;
        
        T.addEdge(u, v);
    }
    
    T.dfs();
    
    // in your subtree, check if n - sub[child] > sub[v], n - sub[child] or n - sub[v] 
    // outside your subtree, check if sub[u] > sub[v], sub[u] or n - sub[v] 
    // in your subtree handled
    // do out of subtree for everyone
    // then subtract it out for within subtree 
    
    vector <int> ans(n + 1, 0);
    
    vector<vector<array<int, 3>>> at(n + 1);
    
    for (int i = 2; i <= n; i++){
        // this is the child 
        // query range : tin[i], tout[i] 
        // query value : n - sub[i] 
        at[T.tin[i] - 1].push_back({n - T.sub[i], -1, T.par[i]});
        at[T.tout[i]].push_back({n - T.sub[i], +1, T.par[i]});
        
        at[T.tin[i] - 1].push_back({T.sub[i], +1, i});
        at[T.tout[i]].push_back({T.sub[i], -1, i});
    }
    
    FenwickTree sum, cnt;
    vector <int> vv(n);
    sum.init(n, vv);
    cnt.init(n, vv);
    
    auto upd = [&](int x, int v){
        sum.add(x, v * x);
        cnt.add(x, v);
    };
    
    for (int i = 1; i <= n; i++){
        int v = T.ord[i];
        upd(T.sub[v], 1);
        
        for (auto [x, y, u] : at[i]){
            // < x => n - sub[v], otherwise x
            int count_ge = cnt.query(x, n);
            int count_le = cnt.query(1, x - 1);
            int sum_le = sum.query(1, x - 1);
            
            ans[u] += y * (count_le * n - sum_le + count_ge * x);
        }
    }
    
    auto dfs = [&](auto self, int u, int par) -> void{
        if (u > 1){
            int x = T.sub[u];
            int count_ge = cnt.query(x, n);
            int count_le = cnt.query(1, x - 1);
            int sum_le = sum.query(1, x - 1);
            
            ans[u] += count_le * n - sum_le + count_ge * x;
        }
        
        for (int v : T.adj[u]) if (v != par){
            upd(T.sub[u], -1);
            upd(n - T.sub[v], +1);
            
            self(self, v, u);
            
            upd(T.sub[u], +1);
            upd(n - T.sub[v], -1);
        }
    };
    
    dfs(dfs, 1, -1);
    
    for (int i = 1; i <= n; i++){
        cout << ans[i] << " \n"[i == n];
    }
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}