# TREESAREFUN - Editorial

Author: satyam_343
Tester: yash_daga
Editorialist: iceknight1093

3060

# PREREQUISITES:

Euler tour of a tree, Segment trees

# PROBLEM:

You’re given a tree on N vertices. Vertex i has color C_i.
Let F_x denote the frequency of color x in the entire tree, and S_u denote the set of all colors appearing in the subtree of u.

For each vertex u, compute

\prod_{x\in S_1 \setminus S_u} F_x

modulo M.

# EXPLANATION:

If M were prime, this would be a rather easy exercise in small-to-large merging.
Unfortunately, M isn’t guaranteed to be prime — this throws any approaches involving division out of the window.

Let’s perform an Euler tour of the tree, so that subtrees now correspond to subarrays.
Let [L_u, R_u] be the subarray corresponding to the subtree of vertex u.
This makes the set of vertices outside the subtree correspond to the union of the ranges [1, L_u-1] and [R_u+1, N].

In particular, suppose that we are able to ensure the following:

• For any color c that appears in S_u, F_c doesn’t appear outside [L_u, R_u]
• For any color c that doesn’t appear in S_u, F_c appears exactly once outside [L_u, R_u]

Then we can simply take the product of the values on ranges [1, L_u-1] and [R_u+1, N] to obtain our answer!
Finding the product of a range can be done in \mathcal{O}(\log N) using a segment tree, so we just need to figure out how to maintain the two invariants above.

That can be achieved with a bit of cleverness when performing a DFS.
Consider an array A of length N, such that A_i = 1 for all i initially.
Next, for each color x, select exactly one vertex u such that C_u =x and set A_{L_u} = F_x.

This ensures that the frequency of each color appears exactly once throughout the whole array.
We only need to focus on moving it around appropriately.

Let’s perform a DFS. When at vertex u with color x,

• First, recursively compute the answer for all the children of u.
• Then, suppose \text{pos}_x is the current position of F_x in the array.
Set A_{\text{pos}_x} = 1, and A_{L_u} = F_x.
Finally, query for the product of ranges [1, L_u-1] and [R_u+1, N]; the product of these two values is the answer for u.

This works because:

• First, we already ensured that each F_x value occurs exactly once in the tree.
• Next, if color x doesn’t appear in the subtree of u, then F_x will surely lie outside the range [L_u, R_u] and hence be included in the product.
• Finally, if color x does appear in the subtree of u, F_x will be set at the position of some occurrence of color x in the subtree of u — in particular, it will be inside the range [L_u, R_u], and hence will not be included in the product of the outside.

Point updates and range queries can be handled in \mathcal{O}(\log N) using a segment tree, and so we’re done.

# TIME COMPLEXITY

\mathcal{O}(N\log N) per test case.

# CODE:

Setter's code (C++)
#pragma GCC optimization("O3")
#pragma GCC optimization("Ofast,unroll-loops")

#include <bits/stdc++.h>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
#define ll long long
const ll INF_MUL=1e13;
#define pb push_back
#define mp make_pair
#define nline "\n"
#define f first
#define s second
#define pll pair<ll,ll>
#define all(x) x.begin(),x.end()
#define vl vector<ll>
#define vvl vector<vector<ll>>
#define vvvl vector<vector<vector<ll>>>
#ifndef ONLINE_JUDGE
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);
#endif
void _print(ll x){cerr<<x;}
void _print(int x){cerr<<x;}
void _print(char x){cerr<<x;}
void _print(string x){cerr<<x;}
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";}
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=998244353;
const ll MAX=2002000;
ll n,mod;
class ST{
public:
vector<ll> segs;
ll size=0;
ll ID=1;

ST(ll sz) {
segs.assign(2*sz,ID);
size=sz;
}

ll comb(ll a,ll b) {
a=(a*b)%mod;
return a;
}

void upd(ll idx, ll val) {
segs[idx+=size]=val;
for(idx/=2;idx;idx/=2){
segs[idx]=comb(segs[2*idx],segs[2*idx+1]);
}
}

ll query(ll l,ll r) {
ll lans=ID,rans=ID;
for(l+=size,r+=size+1;l<r;l/=2,r/=2) {
if(l&1) {
lans=comb(lans,segs[l++]);
}
if(r&1){
rans=comb(segs[--r],rans);
}
}
return comb(lans,rans);
}
};
ST dp(MAX);
ll now;
vector<ll> color(MAX),freq(MAX);
vector<ll> ans(MAX),tin(MAX),last(MAX);
vector<ll> visited(MAX,0);
void dfs(ll cur,ll par){
tin[cur]=now++;
dp.upd(last[color[cur]],1);
dp.upd(tin[cur],freq[color[cur]]);
last[color[cur]]=tin[cur];
debug(cur);
visited[cur]=1;
if(visited[chld]){
continue;
}
debug(mp(cur,chld));
dfs(chld,cur);
}
ans[cur]=dp.query(1,tin[cur]-1);
}
void solve(){
cin>>n>>mod;
for(ll i=1;i<=n;i++){
freq[i]=0;
visited[i]=0;
}
for(ll i=1;i<=n;i++){
cin>>color[i];
freq[color[i]]++;
}
for(ll i=1;i<n;i++){
ll u,v; cin>>u>>v;
}
for(ll i=1;i<=n;i++){
last[i]=i;
freq[i]=max(freq[i],1ll);
dp.upd(i,freq[i]);
}
now=n+1;
dfs(1,-1);
for(ll i=1;i<=n;i++){
cout<<ans[i]<<" \n"[i==n];
}
return;
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
freopen("error.txt", "w", stderr);
#endif
ll test_cases=1;
cin>>test_cases;
while(test_cases--){
solve();
}
cout<<fixed<<setprecision(10);
cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n";
}

Tester's code (C++)
//clear adj and visited vector declared globally after each test case
//check for long long overflow
//Incase of close mle change language to c++17 or c++14
//Check ans for n=1
#pragma GCC target ("avx2")
#pragma GCC optimize ("O3")
#pragma GCC optimize ("unroll-loops")
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#define int long long
#define IOS std::ios::sync_with_stdio(false); cin.tie(NULL);cout.tie(NULL);cout.precision(dbl::max_digits10);
#define pb push_back
#define lld long double
#define mii map<int, int>
#define pii pair<int, int>
#define ll long long
#define ff first
#define ss second
#define all(x) (x).begin(), (x).end()
#define rep(i,x,y) for(int i=x; i<y; i++)
#define fill(a,b) memset(a, b, sizeof(a))
#define vi vector<int>
#define setbits(x) __builtin_popcountll(x)
#define print2d(dp,n,m) for(int i=0;i<=n;i++){for(int j=0;j<=m;j++)cout<<dp[i][j]<<" ";cout<<"\n";}
typedef std::numeric_limits< double > dbl;
using namespace __gnu_pbds;
using namespace std;
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> indexed_set;
//member functions :
//1. order_of_key(k) : number of elements strictly lesser than k
//2. find_by_order(k) : k-th element in the set
#define prev prev2
const long long N=2000005, INF=2000000000000000000;
const int inf=2e9 + 5;
lld pi=3.1415926535897932;
int lcm(int a, int b)
{
int g=__gcd(a, b);
return a/g*b;
}
int power(int a, int b, int p)
{
if(a==0)
return 0;
int res=1;
a%=p;
while(b>0)
{
if(b&1)
res=(1ll*res*a)%p;
b>>=1;
a=(1ll*a*a)%p;
}
return res;
}

int getRand(int l, int r)
{
uniform_int_distribution<int> uid(l, r);
return uid(rng);
}

int ar[N], st[4*N], mod, n, n2, ans[N], tim, tin[N], col[N], prev[N], co[N];
vi v[N];
int combine(int a, int b)
{
return (a*b)%mod;
}
void build(int v, int l, int r)
{
if(l==r)
{
st[v]=1;
return;
}
int m=(l+r)/2;
build(v*2, l, m);
build((v*2)+1, m+1, r);
st[v]=combine(st[v*2], st[(v*2)+1]);
}
void update(int v, int l, int r, int pos, int val)
{
if(l==r)
{
st[v]=val;
return;
}
int m=(l+r)/2;
if(pos<=m)
update(v*2, l, m, pos, val);
else
update((v*2)+1, m+1, r, pos, val);
st[v]=combine(st[v*2], st[(v*2)+1]);
}
int query(int v, int tl, int tr, int l, int r)
{
if(l>r)
return 1;
if(tl==l&&tr==r)
return st[v];
int tm=(tl+tr)/2;
return combine(query((2*v), tl, tm, l, min(tm, r)), query((2*v)+1, tm+1, tr, max(tm+1, l), r));
}

void dfs(int u, int p)
{
tin[u]=tim++;
update(1, 1, n2, prev[col[u]], 1);
update(1, 1, n2, tin[u], co[col[u]]);
prev[col[u]]=tin[u];
for(auto to:v[u])
{
if(to==p)
continue;
dfs(to, u);
}
ans[u]=query(1, 1, n2, 1, tin[u]-1);
}
int32_t main()
{
IOS;
int t;
cin>>t;
while(t--)
{
cin>>n>>mod;
n2=2*n;
rep(i,1,n+1)
{
v[i].clear();
co[i]=0;
}
rep(i,1,n+1)
{
cin>>col[i];
co[col[i]]++;
}
rep(i,0,n-1)
{
int a, b;
cin>>a>>b;
v[a].pb(b);
v[b].pb(a);
}
build(1, 1, n2);
rep(i,1,n+1)
{
prev[i]=i;
co[i]=max(co[i], 1ll);
update(1, 1, n2, i, co[i]);
}
tim=n+1;
dfs(1, 0);
rep(i,1,n+1)
cout<<ans[i]<<" ";
cout<<"\n";
}
}

Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

typedef unsigned long long ull;
struct FastMod {
ull b, m;
FastMod(ull b) : b(b), m(-1ULL / b) {}
ull reduce(ull a) { // a % b + (0 or b)
return a - (ull)((__uint128_t(m) * a) >> 64) * b;
}
};

template<class T, T unit = T()>
struct SegTree {
vector<T> s; int n;
FastMod F;
T f(T a, T b) { return F.reduce(1LL * a * b); }
SegTree(int _n, int _mod, T def = unit) : s(2*_n, def), n(_n), F(_mod) {}
void update(int pos, T val) {
for (s[pos += n] = val; pos /= 2;)
s[pos] = f(s[pos * 2], s[pos * 2 + 1]);
}
T query(int b, int e) {
T ra = unit, rb = unit;
for (b += n, e += n; b < e; b /= 2, e /= 2) {
if (b % 2) ra = f(ra, s[b++]);
if (e % 2) rb = f(s[--e], rb);
}
return f(ra, rb);
}
};

int main()
{
ios::sync_with_stdio(false); cin.tie(0);

int t; cin >> t;
while (t--) {
int n, mod; cin >> n >> mod;
vector<int> col(n), freq(n+1);
for (int &x : col) {
cin >> x;
++freq[x];
}
for (int i = 0; i < n-1; ++i) {
int u, v; cin >> u >> v;
}
vector<int> in(n), out(n);
int timer = 0;
auto dfs = [&] (const auto &self, int u, int p) -> void {
in[u] = timer++;
for (int v : adj[u]) if (v != p) self(self, v, u);
out[u] = timer;
};

dfs(dfs, 0, 0);
vector pos(n+1, -1), ans(n, 1);
SegTree<int, 1> seg(n, mod);
for (int i = 0; i < n; ++i) {
if (pos[col[i]] >= 0) continue;
pos[col[i]] = in[i];
seg.update(in[i], freq[col[i]]);
}
auto solve = [&] (const auto &self, int u, int p) -> void {
for (int v : adj[u]) {
if (v == p) continue;
self(self, v, u);
}

seg.update(pos[col[u]], 1);
seg.update(in[u], freq[col[u]]);
pos[col[u]] = in[u];
int L = in[u], R = out[u];
ans[u] = (1LL * seg.query(0, L) * seg.query(R, n)) % mod;
};
solve(solve, 0, 0);
for (auto x : ans) cout << x << ' ';
cout << '\n';
}
}