SCP - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: kingmessi
Tester: kingmessi
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

Small-to-large merging

PROBLEM:

For a tree with colors and values on its nodes, we define the following:

  • To compute the score of a leaf:
    • For each color c, compute the sum of values of nodes with color c in each of the two components obtained by removing the edge connecting the leaf to its unique neighbor
    • Then, add the difference in these two values to the score.
  • The score of a tree is the maximum score across any of its leaves.
    If the tree has no leaves, then its score is the value of its single vertex.

You’re given a tree. Compute the score of each of its subtrees.

EXPLANATION:

Let’s analyze the problem for a single tree first.

Define S_c to be the sum of values of all vertices with color c.
Let S = \sum_c S_c be the sum of all values of vertices in the tree.

If we choose a leaf v with value x and color c, observe that only the contribution of color c to the total score changes: the resulting score is

S - S_c + |x - (S_v - x)| = S - S_c + |2x - S_c|

In particular, this is either S + 2x - 2S_c or S - 2x depending on whether 2x - S_c \geq 0 or not.
We want to maximize this, so let’s look at both cases.

  1. 2x - S_c \geq 0.
    Here, the final score is S + 2x - 2S_c.
    If c is fixed, clearly it’s best to make x as large as possible.
    Since making x larger maintains the inequality 2x - S_c \geq 0, it’s enough to store the largest leaf value corresponding to this color.
  2. 2x - S_c \leq 0.
    Here, the score resolves to S - 2x.
    To maximize this, of course x should be as small as possible.
    Again, making x smaller preserves the inequality 2x - S_c \leq 0, so it’s best to take the minimum.

So, suppose we define mn_c and mx_c to be the minimum and maximum leaf values of some leaf vertex with color c.
The answer is then the maximum value of

\max(S - 2mn_c, S - 2S_c + 2mx_c)

across all colors c.

Note that we didn’t check whether 2mn_c \leq S_c or 2mx_c \geq S_c (which were technically required to use those expressions), but it doesn’t matter since at least one of them will be valid; and if one is invalid you can verify that the other expression will give a larger value anyway.
This allows us to check both expressions, which is quite nice: in particular, we can just maximize each of them across all c, and then take the maximum of both values obtained.


Now, let’s extend this to every subtree.
As seen above, for every subtree we need to know a few pieces of information:

  1. The sum of all values (S).
  2. The sum of values for each color (S_c)
  3. The minimum/maximum leaf values corresponding to each color (mn_c and mx_c).

And then of course figure out a way to put these together quickly enough.
This means we also need:

  1. The minimum of mn_c across all c.
  2. The maximum of mx_c - S_c across all c.

The first one is trivial: it’s simply the minimal leaf value present in the subtree.
The second one is not as obvious, so we’ll get back to it in a bit.

Let’s go back to computing the information we need.
The sum of all values in each subtree is trivially found using a DFS.

Next, consider the sum of values of each color.
Let S[u][c] denote the sum of values of all vertices with color c, in the subtree of u.
Computing this for all u, c is a classical application of small-to-large merging.

That is, start with S[u][C_u] = V_u. Then, for each child of u, compute the sum for it recursively; then merge the two tables together.
The merging process is commutative, so by always making sure to merge the smaller table into the larger one, the total number of merges is bounded by \mathcal{O}(N\log N) across all vertices.

In fact, this works not just for the sums, but also for the maximums.
That is, if mx[u][c] denotes the maximum value of a leaf with color c in the subtree of u, this can also be computed quickly for all u, c using small-to-large merging.

Since we’re using small-to-large for both the maximums and the sums, the other value we want to maintain: that is, the maximum of mx[u][c] - S[u][c] across all c, can also be stored while doing the merges.
Specifically, store a (multi)set of all the mx[u][c] - S[u][c] values corresponding to u.
When merging a child of u into it, note that only the values corresponding to colors existing in the child subtree can possibly change; so whenever one of them is being processed, just erase its corresponding value from the set, update the value, and insert the new value back in.
In the end, the maximum element of the set is the value we’re looking for.

Once the merges are done we have all the information we need to answer subtree queries.


Note that there’s one edge case to take care of.
If a vertex u has a single child, then it’ll be a leaf in its own subtree. So, don’t forget to process it separately when computing the answer for it.

TIME COMPLEXITY:

\mathcal{O}(N\log^2 N) per testcase.

CODE:

Author's code (C++)
// Har Har Mahadev
#include<bits/stdc++.h>
#define int long long
using namespace std;

const int N = 200005;
int par[N];
int sz[N];
map<int,int> m[N];//stores min score for each color in leaf in the group
set<array<int,2>> s[N];// stores {score,color} for each leaf in the group
map<int,int> sc[N];//stores sum of score for each color in the group
map<int,int> msc[N];//stores abs(2*m[find_set(nd)][c[nd]]-sc[find_set(nd)][c[nd]]) - sc[find_set(nd)][c[nd]] for each color in the group


void make_set(int v) {
    par[v] = v;
    sz[v] = 1;
    m[v].clear();
    s[v].clear();
    sc[v].clear();
    msc[v].clear();
}

int find_set(int v) {
    if (v == par[v])
        return v;
    return par[v] = find_set(par[v]);
}

void union_sets(int a, int b) {
    
    a = find_set(a);
    b = find_set(b);
    if (a != b) {
        if (sz[a] < sz[b])
            swap(a, b);
        par[b] = a;
        sz[a] += sz[b];
        for(auto &[x,y] : sc[b]){
            if(msc[a].find(x) != msc[a].end())s[a].erase({msc[a][x],x});
            if(m[b].find(x) != m[b].end()){
                if(m[a].find(x) == m[a].end())m[a][x] = m[b][x];
                else m[a][x] = min(m[a][x],m[b][x]);
            }
            
            sc[a][x] += sc[b][x];

            if(m[a].find(x) != m[a].end()){
                msc[a][x] = abs(2*m[a][x]-sc[a][x]) - sc[a][x];
                s[a].insert({msc[a][x],x});
            }
            
        }
    }
}

vector<int> adj[N];
int c[N],a[N];
int ans[N],sm[N];
 
void dfs(int cur,int par){
    sm[cur] = a[cur];
    int child = 0;
    for(auto &x : adj[cur]){
        if(x == par)continue;
        child++;
        dfs(x,cur);
        union_sets(x,cur);
        sm[cur] += sm[x];
    }
    auto it = s[find_set(cur)].end();it--;

    ans[cur] = sm[cur] + (*it)[0];
    if(child == 1){
        ans[cur] = max(ans[cur],sm[cur] + abs(2*a[cur]-sc[find_set(cur)][c[cur]]) - sc[find_set(cur)][c[cur]]);
    }
}
 
void solve()
{
    int n;
    cin >> n;
    for(int i = 0;i < n-1;i++){
        int u,v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    for(int i = 1;i <= n;i++){
        cin >> c[i];
    }
    for(int i = 1;i <= n;i++){
        cin >> a[i];
    }

    for(int i = 1;i <= n;i++){
        make_set(i);
        sc[i][c[i]] = a[i];
        if(adj[i].size() > 1)continue;
        m[i][c[i]] = a[i];
        s[i].insert({0,c[i]});
        msc[i][c[i]] = 0;
    }

    dfs(1,-1);

    for(int i = 1;i <= n;i++){
        cout << ans[i] << "\n";
    }

    for(int i = 1;i <= n;i++){
        adj[i].clear();
    }

}

signed main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    int t;
    cin >> t;
    while(t--)
        solve();
    return 0;
}
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector adj(n, vector<int>());
        for (int i = 0; i < n-1; ++i) {
            int u, v; cin >> u >> v;
            adj[--u].push_back(--v);
            adj[v].push_back(u);
        }

        vector col(n, 0), val(n, 0);
        for (int &x : col) cin >> x;
        for (int &x : val) cin >> x;

        vector<ll> subsum(n), minleaf(n);
        vector<map<int, ll>> colsum(n), mxleaf(n), mxval(n);
        vector<multiset<ll>> curdifs(n);
        vector<ll> ans(n);
        auto merge = [&] (int u, int v) {
            if (colsum[u].size() < colsum[v].size()) {
                swap(colsum[u], colsum[v]);
                swap(mxleaf[u], mxleaf[v]);
                swap(mxval[u], mxval[v]);
                swap(curdifs[u], curdifs[v]);
            }

            for (auto [c, x] : colsum[v]) {
                colsum[u][c] += x;
                
                if (mxleaf[u].find(c) != mxleaf[u].end()) {
                    curdifs[u].erase(curdifs[u].find(mxval[u][c]));
                }

                if (mxleaf[v].find(c) != mxleaf[v].end()) {
                    mxleaf[u][c] = max(mxleaf[u][c], mxleaf[v][c]);
                    mxval[u][c] = mxleaf[u][c] - colsum[u][c];
                }
                else if (mxleaf[u].find(c) != mxleaf[u].end()) {
                    mxval[u][c] = mxleaf[u][c] - colsum[u][c];
                }
                
                if (mxleaf[u].find(c) != mxleaf[u].end()) {
                    curdifs[u].insert(mxval[u][c]);
                }
            }
        };
        
        auto dfs = [&] (const auto &self, int u, int p) -> void {
            subsum[u] = val[u];
            colsum[u][col[u]] = val[u];
            minleaf[u] = 1e18;
            int children = 0;
            for (int v : adj[u]) if (v != p) {
                self(self, v, u);
                ++children;
                merge(u, v);
                subsum[u] += subsum[v];
                minleaf[u] = min(minleaf[u], minleaf[v]);
            }

            if (!children) {
                ans[u] = val[u];
                minleaf[u] = val[u];
                mxleaf[u][col[u]] = val[u];
                mxval[u][col[u]] = 0;
                curdifs[u].insert(0);
            }
            else {
                // 1. S - 2*minleaf
                ans[u] = subsum[u] - 2*minleaf[u];

                // 2. S - 2S_c + 2max_c
                ans[u] = max(ans[u], subsum[u] + 2*(*rbegin(curdifs[u])));

                // 3. u is a leaf
                if (children == 1) {
                    ans[u] = max(ans[u], subsum[u] - colsum[u][col[u]] + abs(2*val[u] - colsum[u][col[u]]));
                }
            }
        };
        dfs(dfs, 0, 0);
        for (auto x : ans) cout << x << '\n';
    }
}

Correct me if I am wrong, but I am pretty sure it’s sufficient to only keep the minimum. Here is the proof:

Firstly, let’s assume that there are at least two leaves with the color c. Let’s take the minimum value of the leaves with color c and call it x, and the maximum value y (note that because we are dealing with absolute values, only the minimum or the maximum — or both, if x = y — can give us the maximum value).

Now, we want to prove that:
S - Sc + |Sc - 2x| >= S - Sc + |Sc - 2y|
=> |Sc - 2x| >= |Sc - 2y|

Obviously, x <= y, so we only need to consider two cases.

Let’s define:

Sc’ = Sc - x - y >= 0

Then:

Case 1: Sc - 2y >= 0

|Sc - 2x| = Sc’ + x + y - 2x = Sc’ + y - x
|Sc - 2y| = Sc’ + x + y - 2y = Sc’ + x - y

So,
Sc’ + y - x >= Sc’ + x - y
=> 2y >= 2x
=> y >= x

Which is always true.

Case 2: Sc - 2y < 0

|Sc - 2x| = Sc’ + y + x - 2x = Sc’ + y - x
|Sc - 2y| = 2y - (Sc’ + x + y) = y - Sc’ - x

So,

Sc’ + y - x >= y - Sc’ - x => 2Sc’ >= 0

Which is also always true by definition.

Therefore, |Sc - 2x| >= |Sc - 2y|, and it is always sufficient to only consider the minimum.

Also a small nitpick — it’s supposed to be Sc, not Sv.