# TRERMB - Editorial

Author: Hasin Rayhan Dewan Dhruboo
Tester: Raja Vardhan Reddy
Editorialist: William Lin

Medium Hard

# PREREQUISITES:

Trees, Lowest Common Ancestor, Prefix sums, DFS order

# PROBLEM:

Given a tree with N nodes, the beauty of the tree is \sum_{i=1}^N A_i \cdot F_i, where A_i is given and F_i is the number of distinct colors in the subtree of i. You are given Q updates of the form (a, x), which sets the color of node a to x. Find the beauty of the tree after each update.

# QUICK EXPLANATION:

Whenever we color/uncolor a node u, all nodes on the path from u to one of its ancestors will have F_i increased by 1 or decreased by 1, so we update the beauty by adding or subtracting the sum of A_i on that path. To find that path, for each color, we maintain a set of the nodes with that color, sorted by dfs order.

# EXPLANATION:

Letâ€™s split each update into two operations: uncoloring a node (so that it has no color) and coloring it with a new color. Both operations are similar, so I will only focus on the second operation.

When we color a node u with a color x, some of the ancestors of u will now have color x in their subtrees, so their F_i will increase by 1. An example is shown in the picture below, where color 1 is added to u:

In this example, 3 ancestors (including u) have their F_i increased by 1. However, w (and its ancestors) do not have their F_i increased as they already contained a node with color 1 in their subtree.

Letâ€™s suppose that we somehow know how to find w. How does the answer change? Since all nodes on the path from u to w (not including w) have their F_i increased by 1, the beauty of the tree increases by the sum of A_i for all nodes on the path from u to w. We can find the sum of A_i on the path efficiently by precalculating prefix sums of A_i from the root to each node.

What remains is to find w, the lowest ancestor of u that does not have its F_i changed.

One inefficient approach is as follows: Letâ€™s consider all nodes with color x. For each of those nodes, we will find the lowest common ancestor of that node with u. That lowest common ancestor is the lowest node which contains both u and the node with color x in its subtree. The lowest of all such lowest common ancestors will be w (the lowest node which contains u with any node of color x in its subtree).

How do we optimize this? If we sort all nodes with color x by their dfs order, we only need to check the node which is before u in the dfs order and the node which is after u in the dfs order. This is kind of intuitive and is somewhat well-known, but unfortunately I donâ€™t know how to prove it rigorously. Update: You can see @kshitij_789â€™s proof in this comment below.

# SOLUTIONS:

Setter's Solution
#include<bits/stdc++.h>
using namespace std;
typedef vector<int> vi;
typedef vector<pair<int,int> > vpi;
const int maxn = 100009;
const int LOG = 19;
vector < int > edges[maxn];
pair < int , int > upds[maxn];

vector < int > perColorNode[maxn];

int shurutime[maxn], seshtime[maxn], Parmain[maxn], Dmain[maxn], Color[maxn], perNodeVal[maxn];
long long nodetorootsum[maxn];
int tmme;
long long curAns = 0;
map < int , int > initial_map[maxn];

pair < int , int > allupdates[maxn];

void Merge(int pos1, int pos2)
{
if(initial_map[pos1].size() < initial_map[pos2].size()) {
initial_map[pos1].swap(initial_map[pos2]);
}
for(auto e : initial_map[pos2]) initial_map[pos1][e.first] = 1;
}
void dfs1(int pos, int par, int lvl, long long val)
{
shurutime[pos] = ++tmme;
Parmain[pos] = par;
Dmain[pos] = lvl;
nodetorootsum[pos] = val;
initial_map[pos][Color[pos]] = 1;
//    cout << pos << endl;
for(int to : edges[pos]){
if(to == par) continue;
dfs1(to, pos, lvl + 1, val + (long long) perNodeVal[to]);
Merge(pos, to);
}
curAns += (long long)perNodeVal[pos] * (long long)initial_map[pos].size();
seshtime[pos] = tmme;
}

struct sparseTable{

int color;
vector < vector < int > > P, edges;
vector < int > Par, D, intime, outtime;
map < int , int > mapping, rev;
int ttme;

int n;
vector < int > tree;
void Clear(){
color = n = ttme = 0;
tree.clear(); P.clear(); edges.clear();
Par.clear(), D.clear(), intime.clear(), outtime.clear();
mapping.clear(), rev.clear();
}
///point update, range query
///initial elements at tree[n]....tree[2*n-1]
void build(int _n){
n = _n;
tree.resize(2 * n + 10, 0);
}

void update(int p,int value){
for(tree[p+=n]=value; p>1; p>>=1)
tree[p>>1]=(tree[p] + tree[p^1]);
}

//outputs max(l,r-1)
int query(int l,int r){
int res=0;
for(l+=n, r+=n; l<r; l>>=1, r>>=1) {
if(l&1) res=(res + tree[l++]);
if(r&1) res=(res + tree[--r]);
}
return res;
}

void sparseMain(int n)
{
for(int i=1; i<=n; i++) for(int j=0; j<LOG; j++) P[i][j] = 0;

for(int i=1; i<=n; i++) P[i][0] = Par[i];
for(int j=1; j<LOG; j++){
for(int i=1; i<=n; i++){
if(P[i][j-1] != 0){
int x = P[i][j-1];
P[i][j] = P[x][j-1];
}
}
}
}

void buildSparse(int n){
P.resize(n + 1), Par.resize(n + 1, 0), D.resize(n + 1, 0), mapping.clear(), rev.clear();
intime.resize(n + 1, 0), outtime.resize(n + 1, 0);
for(int i = 1; i <= n; i++) P[i].resize(LOG), Par[i] = Parmain[i], D[i] = Dmain[i], mapping[i] = rev[i] = i, intime[i] = shurutime[i], outtime[i] = seshtime[i];
sparseMain(n);
}

void dfs(int pos, int par, int lvl)
{
D[pos] = lvl, Par[pos] = par, intime[pos] = ++ttme;
for(int to : edges[pos]) dfs(to, pos, lvl + 1);
outtime[pos] = ttme;
}

void buildSparse(vpi &alledges, vector < int > &subset){
int cur = 0;
mapping.clear(), rev.clear();
for(auto val : subset) mapping[val] = ++cur, rev[cur] = val;
for(auto &e : alledges) {
if(mapping[e.first] == 0) assert(0);
if(mapping[e.second] == 0) assert(0);
e.first = mapping[e.first];
e.second = mapping[e.second];
}

edges.resize(cur + 1), P.resize(cur + 1), Par.resize(cur + 1, 0), D.resize(cur + 1, 0);
intime.resize(cur + 1, 0), outtime.resize(cur + 1, 0), ttme = 0;
for(int i = 1; i <= cur; i++) P[i].resize(LOG);

for(auto e : alledges){
if(e.first > cur || e.second > cur) assert(0);
edges[e.first].push_back(e.second);
}
dfs(1, 0, 1);

sparseMain(cur);
build(cur);
for(int i = 1; i <= cur; i++) if(Color[rev[i]] == color) update(i - 1, 1);
}

int LCA(int p,int q){
p = mapping[p];q = mapping[q];
if(p == 0 || q == 0) assert(0);
if(D[p]<D[q]) swap(p,q);

int Log = log2(D[p])+1;
for(int i=Log;i>=0;i--) if(D[p]-D[q] >= (1<<i)) p = P[p][i];
if(p==q) return p;

for(int i=Log;i>=0;i--) if(P[p][i]!=0 && P[p][i] != P[q][i]) {p = P[p][i]; q = P[q][i];}
int LCA = Par[p];
return LCA;
}

int FindClosestNonzeroParent(int pos)
{
for(int i = LOG - 1; i >= 0; i--){
if(P[pos][i] == 0) continue;
int tmp = P[pos][i];
int beg = intime[tmp], ed = outtime[tmp];
if(query(beg - 1, ed) == 0) pos = tmp;
}
return Par[pos];
}

void Set(int pos, int setBit){
int curpos = pos;
pos = mapping[pos];
if(pos == 0) assert(0);
if(setBit != 1 && setBit != -1) assert(0);

if(setBit == -1) update(intime[pos] - 1, 0);
//        cout << "yo : " << curpos << ' ' << color << ' ' << query(intime[pos] - 1, outtime[pos]) << endl;
///if the subtree-sum of pos is already non-zero, that means changing it won't affect any nodes distinct color count
if(query(intime[pos] - 1, outtime[pos]) == 0){
int closestPar = FindClosestNonzeroParent(pos);
closestPar = rev[closestPar];/// even if pos = 0, rev[pos] should be 0, so it's okay
curAns += (nodetorootsum[curpos] - nodetorootsum[closestPar]) * setBit;

//            cout << curpos << ' ' << setBit << ' ' << closestPar << ' ' << nodetorootsum[curpos] << ' ' << nodetorootsum[closestPar] << endl;

}

if(setBit == 1) update(intime[pos] - 1, 1);
return;
}

} trees[maxn];

void compressTree(vector<int>& li, vpi& ret) {

auto cmp = [&](int a, int b) {
return shurutime[a] < shurutime[b];
};
sort(li.begin(),li.end(), cmp);
int m = li.size()-1;
for(int i=0; i<m; i++) {
int a = li[i], b = li[i+1];
li.push_back(trees[0].LCA(a, b));
}
sort(li.begin(),li.end(), cmp);
li.erase(unique(li.begin(),li.end()), li.end());
ret.clear();
for(int i=0; i<li.size()-1; i++) {
int a = li[i], b = li[i+1];
ret.emplace_back(trees[0].LCA(a, b),b);
}
}

int main()
{
//    freopen("in01.txt", "r", stdin);
//freopen("outj01.txt", "w", stdout);
int t, cs = 1;
cin >> t;
while(t--){
int n, q;
scanf("%d %d", &n, &q);

tmme = curAns = 0;
for(int i = 1; i <= n; i++){
edges[i].clear();
perColorNode[i].clear(); initial_map[i].clear();
trees[i].Clear();
}
for(int i = 1; i < n; i++){
int x, y;
scanf("%d %d", &x, &y);
edges[x].push_back(y);
edges[y].push_back(x);
}
for(int i = 1; i <= n; i++) {
int x;
scanf("%d", &x);
Color[i] = x;
perColorNode[x].push_back(i);
}
for(int i = 1; i <= n; i++) scanf("%d", &perNodeVal[i]);

for(int i = 1; i <= q; i++){
int x, y;
scanf("%d %d", &x, &y);
upds[i] = {x, y};
perColorNode[y].push_back(x);
}
dfs1(1, 0, 1, perNodeVal[1]);
trees[0].buildSparse(n);

int totsize = 1;
for(int i = 1; i <= n; i++){
if(perColorNode[i].size() == 0) continue;
vpi tmpEdge;
compressTree(perColorNode[i], tmpEdge);
//            cout << "for i : " << i << endl;
//            for(auto e : tmpEdge) cout << e.first << ' ' << e.second << endl;
//            cout << endl << endl;
if(tmpEdge.size() + 1 != perColorNode[i].size()) assert(0);
trees[i].color = i;
trees[i].buildSparse(tmpEdge, perColorNode[i]);
totsize += trees[i].mapping.size();
//            cout << "done" << endl;
}

if(totsize > 2 * (n + q)) assert(0);

for(int i = 1; i <= q; i++){
int prvColor = Color[pos];

//            cout << "after : " << i << ' ' << pos << ' ' << prvColor << ' ' << nxtClr << endl;
trees[prvColor].Set(pos, -1);
trees[nxtClr].Set(pos, 1);
Color[pos] = nxtClr;
printf("%lld\n", curAns);
//            cout << endl;
}

}

return 0;
};

Tester's Solution
//raja1999

//#pragma GCC optimize("Ofast")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,avx,avx2")

#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
//setbase - cout << setbase (16)a; cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 <<endl;prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val

using namespace std;
using namespace __gnu_pbds;
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
//#define int ll

typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;

//std::ios::sync_with_stdio(false);
int tim=0;
int a[100005],in[100005],out[100005],par[100005][25],arr[100005],st[500005];
ll sum[100005],res[100005];
int n;
vector<vi>pos(100005);
vector<viii>op(100005);
vi col;
ll ans[100005];
int c[100005],q;
set<int>act;
set<int>::iterator it;
int dfs(int u,int p,ll s){
int i;
sum[u]=a[u]+s;
par[u][0]=p;
in[u]=tim;
arr[tim]=u;
tim++;
}
out[u]=tim;
}
int build(int s,int e,int node){
if(s==e){
st[node]=0;
return 0;
}
int mid=(s+e)/2;
build(s,mid,2*node);
build(mid+1,e,2*node+1);
st[node]=st[2*node]+st[2*node+1];
return 0;
}
int update(int s,int e,int node,int pos,int val){
if(s>pos||e<pos){
return 0;
}
if(s==e){
st[node]=val;
return 0;
}
int mid=(s+e)/2;
update(s,mid,2*node,pos,val);
update(mid+1,e,2*node+1,pos,val);
st[node]=st[2*node]+st[2*node+1];
return 0;
}
int query(int s,int e,int node,int l,int r){
if(s>r||e<l){
return 0;
}
if(l<=s&&r>=e){
return st[node];
}
int mid=(s+e)/2;
return query(s,mid,2*node,l,r)+query(mid+1,e,2*node+1,l,r);
}
if(query(0,n-1,1,in[u],out[u]-1)>=1){
update(0,n-1,1,in[u],1);
return 0;
}
int u1=u,i;
fd(i,20,0){
if(par[u][i]!=-1){
if(query(0,n-1,1,in[par[u][i]],out[par[u][i]]-1)==0){
u=par[u][i];
}
}
}
update(0,n-1,1,in[u1],1);
ll val=0;
if(u==0){
val+=sum[u1];
}
else{
val+=sum[u1]-sum[par[u][0]];
}
return val;
}
ll remove(int u){
update(0,n-1,1,in[u],0);
if(query(0,n-1,1,in[u],out[u]-1)>=1){
return 0;
}
int u1=u,i;
fd(i,20,0){
if(par[u][i]!=-1){
if(query(0,n-1,1,in[par[u][i]],out[par[u][i]]-1)==0){
u=par[u][i];
}
}
}
ll val=0;
if(u==0){
val-=sum[u1];
}
else{
val-=sum[u1]-sum[par[u][0]];
}
return val;
}
int reset(){
int i;
build(0,n-1,1);
rep(i,n){
}
rep(i,col.size()){
op[col[i]].clear();
pos[col[i]].clear();
}
col.clear();
rep(i,q+2){
ans[i]=0;
}
act.clear();
tim=0;
}
main(){
//std::ios::sync_with_stdio(false); cin.tie(NULL);
int t;
//cin>>t;
//t=1;
scanf("%d",&t);
while(t--){
int i,u,v,j,x,id;
ll val;
//cin>>n>>q;
scanf("%d %d",&n,&q);
reset();
rep(i,n-1){
//cin>>u>>v;
scanf("%d %d",&u,&v);
u--;
v--;
}
rep(i,n){
//cin>>c[i];
scanf("%d",&c[i]);
op[c[i]].pb(mp(i,mp(0,1)));
col.pb(c[i]);
}
rep(i,n){
//cin>>a[i];
scanf("%d",&a[i]);
}
dfs(0,-1,0);
f(i,1,21){
rep(j,n){
if(par[j][i-1]==-1){
par[j][i]=-1;
}
else{
par[j][i]=par[par[j][i-1]][i-1];
}
}
}
rep(i,q){
//cin>>x>>v;
scanf("%d %d",&x,&v);
x--;
if(c[x]==v){
continue;
}
pos[c[x]].pb(i+1);
op[c[x]].pb(mp(x,mp(i+1,-1)));
op[v].pb(mp(x,mp(i+1,1)));
c[x]=v;
col.pb(v);
pos[v].pb(i+1);
}
sort(all(col));
col.resize(unique(col.begin(),col.end())-col.begin());
int pre;
rep(i,col.size()){
for(it=act.begin();it!=act.end();it++){
update(0,n-1,1,in[*it],0);
}
act.clear();
val=0;
rep(j,op[col[i]].size()){
if(op[col[i]][j].ss.ff!=0){
id=j;
break;
}
act.insert(op[col[i]][j].ff);
ans[0]+=val;
}
f(j,id,op[col[i]].size()){
if(op[col[i]][j].ss.ss==-1){
val=remove(op[col[i]][j].ff);
act.erase(op[col[i]][j].ff);
ans[op[col[i]][j].ss.ff]+=val;
}
else{
act.insert(op[col[i]][j].ff);
ans[op[col[i]][j].ss.ff]+=val;
}
}
}
res[0]=ans[0];
rep(i,q){
res[i+1]=res[i]+ans[i+1];
//cout<<res[i+1]<<"\n";
printf("%lld\n",res[i+1]);
}

}
return 0;
}

Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define ar array

const int mxN=1e5;
int n, q, c[mxN], dt, ds[mxN], d[mxN], anc[mxN][17];
ll a[mxN], ans;
set<ar<int, 2>> s[mxN];

void dfs(int u=0, int p=-1) {
ds[u]=dt++;
anc[u][0]=p;
for(int i=1; i<17; ++i)
anc[u][i]=~anc[u][i-1]?anc[anc[u][i-1]][i-1]:-1;
if(v==p)
continue;
a[v]+=a[u];
d[v]=d[u]+1;
dfs(v, u);
}
}

int lca(int u, int v) {
if(d[u]<d[v])
swap(u, v);
for(int i=16; ~i; --i)
if(d[u]-(1<<i)>=d[v])
u=anc[u][i];
if(u==v)
return u;
for(int i=16; ~i; --i) {
if(anc[u][i]^anc[v][i]) {
u=anc[u][i];
v=anc[v][i];
}
}
return anc[u][0];
}

void upd(int i, int x) {
//find min ancestor which already has c[i]
int w=-1;
//before i
auto it=s[c[i]].lower_bound({ds[i]});
if(it!=s[c[i]].begin()) {
--it;
w=lca(i, (*it)[1]);
++it;
}
//after i
if(it!=s[c[i]].end()) {
int w2=lca(i, (*it)[1]);
if(w<0||d[w2]>d[w])
w=w2;
}
//update ans for path
ans+=x*(a[i]-(~w?a[w]:0));
}

void solve() {
//input
cin >> n >> q;
for(int i=0; i<n; ++i)
for(int i=1, u, v; i<n; ++i) {
cin >> u >> v, --u, --v;
}
for(int i=0; i<n; ++i)
cin >> c[i], --c[i];
for(int i=0; i<n; ++i)
cin >> a[i];

//dfs for info
dfs();

//prepare for queries
ans=0;
for(int i=0; i<n; ++i)
s[i].clear();
for(int i=0; i<n; ++i) {
upd(i, 1);
s[c[i]].insert({ds[i], i});
}

for(int u, x; q--; ) {
cin >> u >> x, --u, --x;
//remove color of u
s[c[u]].erase({ds[u], u});
upd(u, -1);
c[u]=x;
upd(u, 1);
s[x].insert({ds[u], u});
//output
cout << ans << "\n";
}
}

int main() {
ios::sync_with_stdio(0);
cin.tie(0);

int t;
cin >> t;
while(t--)
solve();
}


Please give me suggestions if anything is unclear so that I can improve. Thanks

7 Likes

@tmwilliamlin Setterâ€™s solution and Testerâ€™s solution are same by mistake. Please update those. Thanks

Thanks for letting me know, I have updated it

1 Like

Can you suggest some similar problem like this?

2 Likes

Can someone explain what is the need of finding the node which is after u in the dfs order ?

In this case the lowest lowest common ancestor comes from the node with color 1 after u in the dfs order.

2 Likes

I see. Thanks.

i think we can also solve this using persistent segment tree @tmwilliamlin?

I havenâ€™t thought about that, could you explain the persistent segtree solution?

Anyways, itâ€™s good to not use data structures when unnecessary.

1 Like

@tmwilliamlin could you please explain the anc[mxN][17] array in your solution, how you calculate it for each node in dfs and how you use it to find the lca?

2 Likes

LCA can be obtained by Binary lifting(Editorialist solution) as well as using seg tree(Testers solution). Among two methods, which method is best for LCA?

1 Like

@tmwilliamlin I got a proof

claim

let x be the node just before u in dfs order
there is no node y such that dfs_order(y)<dfs_order(x) and lca(y,u) has more depth than lca(x,u)

suppose such a node y exists such that depth_lca(y,u)>depth_lca(x,u)
let depth_lca(x,u) be the kth ancestor of u
then y,u lie in the subtree of the (k-1)th ancestor of u since lca has higher depth

now x doesnâ€™t lie in subtree of the (k-1)th ancestor of u so it is visited earlier than y in the dfs function.since in the dfs function dfs(x) will be called before dfs((k-1th ancestor of u).we get a contradiction that dfs_order(y)>dfs_order(x)

we can prove it similarily for the node just after u in dfs order

7 Likes

Both methods have the same complexity O(logn). Binary lifiting is better because the method has many other benifits like finding the xth parent of the a given node and is quite intuitive

I generally prefer binary lifting method, because the code is shorter and less likely to be buggy when your write it yourself. Besides, segment tree solutions have bigger hidden constants in terms of time complexity, and are slower due to their recursive nature, unless of course you might want to write their iterative implementation.

Well, the proof looks ok. One thing I would add here is some graphical notion so that itâ€™s easier for people to understand.

Letâ€™s say we have node u appearing somewhere at index - i in the DFS order which we will denote by order[..] henceforth. So, u = order[i], let v = order[i-1], w = order[i-2].
w is some node present somewhere in the tree and let L be the LCA of (u, w).

If you follow the properties of DFS tree, it will be clear that v is present in the subtree of L for sure. Otherwise, it wouldnâ€™t have been clubbed between w and u. So the LCA of (u, v) will be at least L itself. It could be any other node as well on the path from L to u in the tree, depending upon the location of node v.

Try drawing a tree by hand around this scenario.

What can be done if the tree in the question was not rooted?

If the tree was unrooted, "the subtree of i" would not be defined, and the problem would not be possible.

1 Like