# PROBLEM LINK:

Contest Division 1

Contest Division 2

Contest Division 3

Practice

**Setter:** Manan Grover

**Tester:** Istvan Nagy

**Editorialist:** Taranpreet Singh

# DIFFICULTY

Easy-Medium

# PREREQUISITES

Sack on Tree, Basic Combinatorics.

# PROBLEM

Given a tree with N nodes where each node is assigned a non-zero value, given by array A of length N.

You randomly choose a node (say node u) and change its value to zero. Now, if A_u = 0 for some node u and u is not a leaf node, a node v in the subtree of u shall be chosen such that u \neq v and A_u and A_v are swapped. This process is then repeated until the node with value 0 is a leaf node. Letâ€™s call this a final tree.

In the end, there is exactly one node with the value 0 which is a leaf in the tree.

Determine the number of different trees possible. Two trees are considered distinct if thereâ€™s a node u having different values in trees.

# QUICK EXPLANATION

- If node u is chosen, it can be swapped with any node v in itâ€™s subtree having a different value to obtain a unique tree. So \displaystyle ways_u = \sum_{v \in S, A_u \neq A_v} ways_v where ways_u denote the number of final trees where value of node u is changed and S denote nodes in subtree of u excluding node u. Leaf nodes have ways_u =1
- To quickly sum up such values, we can either use sack on tree or use Stack for each distinct value.

# EXPLANATION

### Simplified Problem

Letâ€™s assume all values are distinct. What is the number of distinct trees? It is easy to see that we can group the number of final trees by the node chosen, whose value is replaced with 0. Let ways_u denote the number of such trees. It is easy to see that ways_u = 1 if node u is leaf node.

Now, assuming non-leaf node u, Since all values are distinct, we can swap with any node v in the subtree of node u. The number of trees where u is chosen, and then swapped with node v is equivalent to choosing node v. So total ways_v distinct trees arise in case node u is chosen, and then swapped with node v.

Since node u can be swapped with any node in its subtree independently, then by addition rule of counting, we can simply add them up to obtain distinct trees.

So, we can write \displaystyle ways_u = \sum_{v \in S} ways_v where S is the set of nodes in subtree of node u excluding node u.

Hence, this simplified problem is solved in O(N)

For example, consider a test case

```
3
1 2
1 3
1 2 3
```

Node 2 and 3 are leaf nodes, so ways_2 = ways_3 = 1. For node 1, we have ways_1 = ways_2+ways_3 = 2.

Hence, the total number of distinct trees is \displaystyle \sum_{u} ways_u = 4. We can check that there are 4 trees, node values represented as 103, 120, 203, 320 (String 103 means node 1 has value 1, node 2 has value 0 and node 3 has value 3).

### Original problem

Letâ€™s apply the same solution as above here. Now, some trees are double-counted. Letâ€™s figure out which ones.

Considering the same example, with the label of node 1 being 2.

```
3
1 2
1 3
2 2 3
```

Here, if we proceed like the previous solution, the node values in trees are represented by 203, 220, 203, 320. See that node labels 203 appeared twice. It happened, because this tree is counted twice, first when node 2 is chosen, and secondly when node 1 is chosen and swapped immediately with node 2.

It happened because after node u was chosen, it was immediately swapped with a node having the same value.

### Observation

Considering two nodes u and v where v is in the subtree of node u and A_u = A_v, then choosing node u and swapping with node v results in trees same as when node v is chosen in the first operation.

It happens because, after the swap, node u contains its original value, and node v is 0, which is what happens when node v is chosen.

Hence, Letâ€™s add another constraint, that after a node u is chosen, it must be swapped with node v in the subtree of node u excluding node u, if and only if A_u \neq A_v. This way, we know that the value of node u will change, and we can group all different trees by the topmost node, whose value is changed.

Letâ€™s redefine ways_u a bit. ways_u now represents the number of distinct trees, such that node u is the highest node in the tree whose value is changed. We can see that now, for leaf nodes ways_u = 1 for leaf nodes.

For non-leaf nodes, \displaystyle ways_u = \sum_{v \in S, A_u \neq A_v} ways_v.

This is what we need now. To compute ways_u efficiently.

### Approach 1

Letâ€™s use sack on tree. This problem is one of the standard applications of this technique, to maintain some value in the subtree of node u. Here, letâ€™s maintain \displaystyle F_x = \sum_{v \in S, A_v = x} ways_v. That is, the sum of ways_v where node v has value x. Also, maintain \displaystyle T = \sum F_x

This way, ways_u = T - F_{A_u} can be computed quickly.

Whenever we add or remove a node v, we update both F_{A_v} and T.

Have a look at the editorialistâ€™s solution for this. We can optionally compress values in A so that an array can be used for storing F. Or use a map.

So this approach has time complexity O(N*log(N)). See editorialistâ€™s solution for this approach.

### Approach 2

In this approach, we shall maintain a stack for each distinct value, and whenever we exit a node, we insert node u into stack A_u.

While processing node u, it shall contain some nodes v such that v in subtree of node u, A_v = A_u and thereâ€™s no node on path from u to v with same value. After these values, the stack may also contain some nodes visited before entering node u.

Hence, when computing ways_u, we shall remove the contribution of nodes at top of the stack, which lies in the subtree of node u, and remove them from the stack as well.

We can prove that each node is inserted and remove only once from the stack. Hence, the time complexity of this approach is O(N).

Refer to the setterâ€™s solution for this approach.

# TIME COMPLEXITY

The time complexity is O(N) or O(N*log(N)) per test case.

# SOLUTIONS

## Setter's Solution

```
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define md 1000000007
void dfs(ll x, ll pr, vector<ll> tr[], ll dp[], ll cor[], map<ll, vector<ll>> &mpp, ll a[]){
auto it = mpp.find(a[x]);
(*it).second.push_back(x);
for(ll i = 0; i < (ll)tr[x].size(); i++){
ll y = tr[x][i];
if(y == pr){
continue;
}
dfs(y, x, tr, dp, cor, mpp, a);
dp[x] += dp[y];
if((ll)tr[y].size() > 1){
dp[x] += dp[y];
}
dp[x] -= cor[y];
dp[x] %= md;
}
if((ll)tr[x].size() == 1 && x != 0){
dp[x] = 1;
}
(*it).second.pop_back();
if((*it).second.size()){
cor[(*it).second.back()] += dp[x];
cor[(*it).second.back()] %= md;
}
}
int main(){
ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
ll t;
cin>>t;
while(t--){
ll n;
cin>>n;
ll dp[n+1] = {};
ll cor[n+1] = {};
ll a[n+1] = {};
vector<ll> tr[n+1];
tr[0].push_back(1);
tr[1].push_back(0);
for(ll i = 0; i<n-1; i++){
ll u,v;
cin>>u>>v;
tr[u].push_back(v);
tr[v].push_back(u);
}
map<ll, vector<ll>> mpp;
vector<ll> temp;
for(ll i = 0; i < n + 1; i++){
if(i){
cin>>a[i];
}else{
a[i] = 0;
}
mpp.insert(make_pair(a[i], temp));
}
dfs(0, -1, tr, dp, cor, mpp, a);
ll ans = dp[0];
ans %= md;
ans += md;
ans %= md;
cout<<ans<<"\n";
}
return 0;
}
```

## Tester's Solution

```
#include <iostream>
#include <cassert>
#include <vector>
#include <set>
#include <map>
#include <algorithm>
#include <random>
#ifdef HOME
#include <windows.h>
#endif
#define all(x) (x).begin(), (x).end()
#define rall(x) (x).rbegin(), (x).rend()
#define forn(i, n) for (int i = 0; i < (int)(n); ++i)
#define for1(i, n) for (int i = 1; i <= (int)(n); ++i)
#define ford(i, n) for (int i = (int)(n) - 1; i >= 0; --i)
#define fore(i, a, b) for (int i = (int)(a); i <= (int)(b); ++i)
template<class T> bool umin(T &a, T b) { return a > b ? (a = b, true) : false; }
template<class T> bool umax(T &a, T b) { return a < b ? (a = b, true) : false; }
using namespace std;
long long readInt(long long l, long long r, char endd) {
long long x = 0;
int cnt = 0;
int fi = -1;
bool is_neg = false;
while (true) {
char g = getchar();
if (g == '-') {
assert(fi == -1);
is_neg = true;
continue;
}
if ('0' <= g && g <= '9') {
x *= 10;
x += g - '0';
if (cnt == 0) {
fi = g - '0';
}
cnt++;
assert(fi != 0 || cnt == 1);
assert(fi != 0 || is_neg == false);
assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
}
else if (g == endd) {
assert(cnt > 0);
if (is_neg) {
x = -x;
}
assert(l <= x && x <= r);
return x;
}
else {
assert(false);
}
}
}
string readString(int l, int r, char endd) {
string ret = "";
int cnt = 0;
while (true) {
char g = getchar();
assert(g != -1);
if (g == endd) {
break;
}
cnt++;
ret += g;
}
assert(l <= cnt && cnt <= r);
return ret;
}
long long readIntSp(long long l, long long r) {
return readInt(l, r, ' ');
}
long long readIntLn(long long l, long long r) {
return readInt(l, r, '\n');
}
string readStringLn(int l, int r) {
return readString(l, r, '\n');
}
string readStringSp(int l, int r) {
return readString(l, r, ' ');
}
int main(int argc, char** argv)
{
#ifdef HOME
if(IsDebuggerPresent())
{
freopen("../FLGZRO_0.in", "rb", stdin);
freopen("../out.txt", "wb", stdout);
}
#endif
int T = readIntLn(1, 100'000);
int sumN = 0;
forn(tc, T)
{
int N = readIntLn(1, 100'000);
sumN += N;
vector<pair<int, int >> edges(N - 1);
vector<vector<int> > neighb(N);
for (auto& e : edges)
{
e.first = readIntSp(1, N);
e.second = readIntLn(1, N);
--e.first;
--e.second;
neighb[e.first].push_back(e.second);
neighb[e.second].push_back(e.first);
}
vector<int> v(N);
int cc = 0;
for (auto& vi : v)
{
if (++cc == N)
vi = readIntLn(1, 1'000'000'000);
else
vi = readIntSp(1, 1'000'000'000);
}
vector<vector<int>> childs(N);
vector<int> parent(N, -1);
vector<int> q(1);
vector<bool> used(N);
used[0] = true;
forn(i, N)
{
int actN = q[i];
for (auto ne : neighb[actN])
{
if (!used[ne])
{
used[ne] = true;
parent[ne] = actN;
childs[actN].push_back(ne);
q.push_back(ne);
}
}
}
int64_t MOD = 1'000'000'007;
vector<int64_t> dp(N);
vector<map<int, int64_t> > ex(N);
reverse(q.begin(), q.end());
for(auto qi : q)
{
int curV = v[qi];
if (childs[qi].empty())
{
dp[qi] = 1;
ex[qi][curV] = 1;
continue;
}
int largestChildVal = 0, largestChildSize = 0;
int64_t newV = 0, sumEx = 0;
for (auto ci : childs[qi])
{
newV += dp[ci];
if (ex[ci].count(curV))
sumEx += ex[ci][curV];
if (largestChildSize < ex[ci].size())
{
largestChildSize = ex[ci].size();
largestChildVal = ci;
}
}
newV %= MOD;
sumEx %= MOD;
int64_t newEx = newV - sumEx;
if (newEx < 0)
newEx += MOD;
newV += newV;
newV += MOD- sumEx;
newV %= MOD;
dp[qi] = newV;
ex[qi].swap(ex[largestChildVal]);
ex[qi][curV] += newEx;
ex[qi][curV] %= MOD;
for (auto ci : childs[qi])
{
if(ci == largestChildVal)
continue;
for (auto mi : ex[ci])
{
ex[qi][mi.first] += mi.second;
ex[qi][mi.first] %= MOD;
}
}
}
printf("%lld\n", dp[0]);
}
assert(sumN <= 1'000'000);
return 0;
}
```

## Editorialist's Solution

```
import java.util.*;
import java.io.*;
class Main{
//SOLUTION BEGIN
final long MOD = (long)1e9+7;
long[] ways, valueMap;
int[] sub, col;
int[][] tree;
long total = 0;
void pre() throws Exception{}
void solve(int TC) throws Exception{
int N = ni();
int[] from = new int[N-1], to = new int[N-1];
for(int i = 0; i< N-1; i++){
from[i] = ni()-1;
to[i] = ni()-1;
}
col = new int[N];
int[] tmp = new int[N];
for(int i = 0; i< N; i++)tmp[i] = col[i] = ni();
Arrays.sort(tmp);
int C = 1;
for(int i = 1; i< N; i++)if(tmp[i] != tmp[C-1])tmp[C++] = tmp[i];
for(int i = 0; i< N; i++)col[i] = Arrays.binarySearch(tmp, 0, C, col[i]);
valueMap = new long[C];
tree = make(N, N-1, from, to, true);
sub = new int[N];
sub(0, -1);
total = 0;
ways = new long[N];
dfs(0, -1, true);
long ans = 0;
for(int i = 0; i< N; i++){
ans += ways[i];
if(ans >= MOD)ans -= MOD;
}
pn(ans);
}
void sub(int u, int p){
sub[u]++;
for(int v:tree[u]){
if(v == p)continue;
sub(v, u);
sub[u] += sub[v];
}
}
void dfs(int u, int p, boolean keep){
int hc = -1;
for(int v:tree[u]){
if(v == p)continue;
if(hc == -1 || sub[v] > sub[hc])
hc = v;
}
for(int v:tree[u]){
if(v == hc || v == p)continue;
dfs(v, u, false);
}
if(hc != -1)dfs(hc, u, true);
for(int v:tree[u]){
if(v != hc && v != p)
add(v, u, 1);
}
if(hc == -1)
ways[u] = 1;
else
ways[u] = (total+MOD-valueMap[col[u]])%MOD;
total += ways[u];
if(total >= MOD)total -= MOD;
valueMap[col[u]] += ways[u];
if(valueMap[col[u]] >= MOD)valueMap[col[u]] -= MOD;
if(!keep)add(u, p, -1);
}
void add(int u, int p, long mul){
valueMap[col[u]] = (valueMap[col[u]]+MOD+ways[u]*mul)%MOD;
total = (total+MOD+mul*ways[u])%MOD;
for(int v:tree[u])
if(v != p)
add(v, u, mul);
}
int[][] make(int n, int e, int[] from, int[] to, boolean f){
int[][] g = new int[n][];int[]cnt = new int[n];
for(int i = 0; i< e; i++){
cnt[from[i]]++;
if(f)cnt[to[i]]++;
}
for(int i = 0; i< n; i++)g[i] = new int[cnt[i]];
for(int i = 0; i< e; i++){
g[from[i]][--cnt[from[i]]] = to[i];
if(f)g[to[i]][--cnt[to[i]]] = from[i];
}
return g;
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new Main().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
```

Feel free to share your approach. Suggestions are welcomed as always.