TREESAREFUN - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: satyam_343
Tester: yash_daga
Editorialist: iceknight1093

DIFFICULTY:

3060

PREREQUISITES:

Euler tour of a tree, Segment trees

PROBLEM:

You’re given a tree on N vertices. Vertex i has color C_i.
Let F_x denote the frequency of color x in the entire tree, and S_u denote the set of all colors appearing in the subtree of u.

For each vertex u, compute

\prod_{x\in S_1 \setminus S_u} F_x

modulo M.

EXPLANATION:

If M were prime, this would be a rather easy exercise in small-to-large merging.
Unfortunately, M isn’t guaranteed to be prime — this throws any approaches involving division out of the window.

Let’s perform an Euler tour of the tree, so that subtrees now correspond to subarrays.
Let [L_u, R_u] be the subarray corresponding to the subtree of vertex u.
This makes the set of vertices outside the subtree correspond to the union of the ranges [1, L_u-1] and [R_u+1, N].

In particular, suppose that we are able to ensure the following:

  • For any color c that appears in S_u, F_c doesn’t appear outside [L_u, R_u]
  • For any color c that doesn’t appear in S_u, F_c appears exactly once outside [L_u, R_u]

Then we can simply take the product of the values on ranges [1, L_u-1] and [R_u+1, N] to obtain our answer!
Finding the product of a range can be done in \mathcal{O}(\log N) using a segment tree, so we just need to figure out how to maintain the two invariants above.

That can be achieved with a bit of cleverness when performing a DFS.
Consider an array A of length N, such that A_i = 1 for all i initially.
Next, for each color x, select exactly one vertex u such that C_u =x and set A_{L_u} = F_x.

This ensures that the frequency of each color appears exactly once throughout the whole array.
We only need to focus on moving it around appropriately.

Let’s perform a DFS. When at vertex u with color x,

  • First, recursively compute the answer for all the children of u.
  • Then, suppose \text{pos}_x is the current position of F_x in the array.
    Set A_{\text{pos}_x} = 1, and A_{L_u} = F_x.
    Finally, query for the product of ranges [1, L_u-1] and [R_u+1, N]; the product of these two values is the answer for u.

This works because:

  • First, we already ensured that each F_x value occurs exactly once in the tree.
  • Next, if color x doesn’t appear in the subtree of u, then F_x will surely lie outside the range [L_u, R_u] and hence be included in the product.
  • Finally, if color x does appear in the subtree of u, F_x will be set at the position of some occurrence of color x in the subtree of u — in particular, it will be inside the range [L_u, R_u], and hence will not be included in the product of the outside.

Point updates and range queries can be handled in \mathcal{O}(\log N) using a segment tree, and so we’re done.

TIME COMPLEXITY

\mathcal{O}(N\log N) per test case.

CODE:

Setter's code (C++)
#pragma GCC optimization("O3")
#pragma GCC optimization("Ofast,unroll-loops")

#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;   
using namespace std;  
#define ll long long
const ll INF_MUL=1e13;
const ll INF_ADD=1e18;  
#define pb push_back                 
#define mp make_pair            
#define nline "\n"                           
#define f first                                          
#define s second                                             
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()     
#define vl vector<ll>             
#define vvl vector<vector<ll>>    
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);    
#endif 
void _print(ll x){cerr<<x;}         
void _print(int x){cerr<<x;}  
void _print(char x){cerr<<x;}     
void _print(string x){cerr<<x;}    
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());   
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=998244353;    
const ll MAX=2002000;
ll n,mod;
class ST{
public:
    vector<ll> segs;
    ll size=0;                       
    ll ID=1;
 
    ST(ll sz) {
        segs.assign(2*sz,ID);
        size=sz;  
    }   
   
    ll comb(ll a,ll b) {
        a=(a*b)%mod;
        return a;  
    }
 
    void upd(ll idx, ll val) {
        segs[idx+=size]=val;
        for(idx/=2;idx;idx/=2){
            segs[idx]=comb(segs[2*idx],segs[2*idx+1]);
        }
    }
 
    ll query(ll l,ll r) {
        ll lans=ID,rans=ID;
        for(l+=size,r+=size+1;l<r;l/=2,r/=2) {
            if(l&1) {
                lans=comb(lans,segs[l++]);
            }
            if(r&1){
                rans=comb(segs[--r],rans);
            }
        }
        return comb(lans,rans);
    }
};
ST dp(MAX);
ll now;
vector<vector<ll>> adj;
vector<ll> color(MAX),freq(MAX);
vector<ll> ans(MAX),tin(MAX),last(MAX);
vector<ll> visited(MAX,0);
void dfs(ll cur,ll par){
    tin[cur]=now++;   
    dp.upd(last[color[cur]],1);
    dp.upd(tin[cur],freq[color[cur]]);
    last[color[cur]]=tin[cur]; 
    debug(cur); 
    visited[cur]=1;
    for(auto chld:adj[cur]){
        if(visited[chld]){
            continue;
        }
        debug(mp(cur,chld)); 
        dfs(chld,cur); 
    }
    ans[cur]=dp.query(1,tin[cur]-1); 
}
void solve(){
    cin>>n>>mod;
    adj.clear(); adj.resize(n+5);
    for(ll i=1;i<=n;i++){
        freq[i]=0;
        visited[i]=0;
    }
    for(ll i=1;i<=n;i++){
        cin>>color[i];
        freq[color[i]]++;
    }
    for(ll i=1;i<n;i++){  
        ll u,v; cin>>u>>v;
        adj[u].push_back(v);  
        adj[v].push_back(u);  
    }
    for(ll i=1;i<=n;i++){    
        last[i]=i;  
        freq[i]=max(freq[i],1ll);
        dp.upd(i,freq[i]);   
    }  
    now=n+1;   
    dfs(1,-1);
    for(ll i=1;i<=n;i++){
        cout<<ans[i]<<" \n"[i==n];
    }
    return;     
}                                                     
int main()                                                                                            
{                
    ios_base::sync_with_stdio(false);                             
    cin.tie(NULL);          
    #ifndef ONLINE_JUDGE                        
    freopen("input.txt", "r", stdin);                                                  
    freopen("output.txt", "w", stdout);  
    freopen("error.txt", "w", stderr);                          
    #endif                            
    ll test_cases=1;                   
    cin>>test_cases;
    while(test_cases--){  
        solve();
    } 
    cout<<fixed<<setprecision(10);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
}  
Tester's code (C++)
//clear adj and visited vector declared globally after each test case
//check for long long overflow   
//Incase of close mle change language to c++17 or c++14  
//Check ans for n=1 
#pragma GCC target ("avx2")    
#pragma GCC optimize ("O3")  
#pragma GCC optimize ("unroll-loops")
#include <bits/stdc++.h>                   
#include <ext/pb_ds/assoc_container.hpp>  
#define int long long      
#define IOS std::ios::sync_with_stdio(false); cin.tie(NULL);cout.tie(NULL);cout.precision(dbl::max_digits10);
#define pb push_back 
#define lld long double
#define mii map<int, int> 
#define pii pair<int, int>
#define ll long long 
#define ff first
#define ss second 
#define all(x) (x).begin(), (x).end()
#define rep(i,x,y) for(int i=x; i<y; i++)    
#define fill(a,b) memset(a, b, sizeof(a))
#define vi vector<int>
#define setbits(x) __builtin_popcountll(x)
#define print2d(dp,n,m) for(int i=0;i<=n;i++){for(int j=0;j<=m;j++)cout<<dp[i][j]<<" ";cout<<"\n";}
typedef std::numeric_limits< double > dbl;
using namespace __gnu_pbds;
using namespace std;
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> indexed_set;
//member functions :
//1. order_of_key(k) : number of elements strictly lesser than k
//2. find_by_order(k) : k-th element in the set
#define prev prev2
const long long N=2000005, INF=2000000000000000000;
const int inf=2e9 + 5;
lld pi=3.1415926535897932;
int lcm(int a, int b)
{
    int g=__gcd(a, b);
    return a/g*b;
}
int power(int a, int b, int p)
    {
        if(a==0)
        return 0;
        int res=1;
        a%=p;
        while(b>0)
        {
            if(b&1)
            res=(1ll*res*a)%p;
            b>>=1;
            a=(1ll*a*a)%p;
        }
        return res;
    }
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

int getRand(int l, int r)
{
    uniform_int_distribution<int> uid(l, r);
    return uid(rng);
}

int ar[N], st[4*N], mod, n, n2, ans[N], tim, tin[N], col[N], prev[N], co[N];
vi v[N];
int combine(int a, int b)
{
    return (a*b)%mod;
}
void build(int v, int l, int r)
{
    if(l==r)
    {
        st[v]=1;
        return;
    }
    int m=(l+r)/2;
    build(v*2, l, m);
    build((v*2)+1, m+1, r);
    st[v]=combine(st[v*2], st[(v*2)+1]);
}
void update(int v, int l, int r, int pos, int val)
{
    if(l==r)
    {
        st[v]=val;
        return;
    }
    int m=(l+r)/2;
    if(pos<=m)
    update(v*2, l, m, pos, val);
    else
    update((v*2)+1, m+1, r, pos, val);
    st[v]=combine(st[v*2], st[(v*2)+1]);
}
int query(int v, int tl, int tr, int l, int r)
{
    if(l>r)
    return 1;
    if(tl==l&&tr==r)
    return st[v];
    int tm=(tl+tr)/2;
    return combine(query((2*v), tl, tm, l, min(tm, r)), query((2*v)+1, tm+1, tr, max(tm+1, l), r));
}

void dfs(int u, int p)
{
    tin[u]=tim++;
    update(1, 1, n2, prev[col[u]], 1);
    update(1, 1, n2, tin[u], co[col[u]]);
    prev[col[u]]=tin[u];
    for(auto to:v[u])
    {
        if(to==p)
            continue;
        dfs(to, u);
    }
    ans[u]=query(1, 1, n2, 1, tin[u]-1);
}
int32_t main()
{
    IOS;
    int t;
    cin>>t;
    while(t--)
    {
        cin>>n>>mod;
        n2=2*n;
        rep(i,1,n+1)
        {
            v[i].clear();
            co[i]=0;
        }
        rep(i,1,n+1)
        {
            cin>>col[i];
            co[col[i]]++;
        }
        rep(i,0,n-1)
        {
            int a, b;
            cin>>a>>b;
            v[a].pb(b);
            v[b].pb(a);
        }
        build(1, 1, n2);
        rep(i,1,n+1)
        {
            prev[i]=i;
            co[i]=max(co[i], 1ll);
            update(1, 1, n2, i, co[i]);
        }
        tim=n+1;
        dfs(1, 0);
        rep(i,1,n+1)
        cout<<ans[i]<<" ";
        cout<<"\n";
    }
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

typedef unsigned long long ull;
struct FastMod {
	ull b, m;
	FastMod(ull b) : b(b), m(-1ULL / b) {}
	ull reduce(ull a) { // a % b + (0 or b)
		return a - (ull)((__uint128_t(m) * a) >> 64) * b;
	}
};

template<class T, T unit = T()>
struct SegTree {
	vector<T> s; int n;
	FastMod F;
	T f(T a, T b) { return F.reduce(1LL * a * b); }
	SegTree(int _n, int _mod, T def = unit) : s(2*_n, def), n(_n), F(_mod) {}
	void update(int pos, T val) {
		for (s[pos += n] = val; pos /= 2;)
			s[pos] = f(s[pos * 2], s[pos * 2 + 1]);
	}
	T query(int b, int e) {
		T ra = unit, rb = unit;
		for (b += n, e += n; b < e; b /= 2, e /= 2) {
			if (b % 2) ra = f(ra, s[b++]);
			if (e % 2) rb = f(s[--e], rb);
		}
		return f(ra, rb);
	}
};

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	int t; cin >> t;
	while (t--) {
		int n, mod; cin >> n >> mod;
		vector<int> col(n), freq(n+1);
		for (int &x : col) {
		    cin >> x;
		    ++freq[x];
		}
		vector adj(n, vector<int>());
		for (int i = 0; i < n-1; ++i) {
			int u, v; cin >> u >> v;
			adj[--u].push_back(--v);
			adj[v].push_back(u);
		}
		vector<int> in(n), out(n);
		int timer = 0;
		auto dfs = [&] (const auto &self, int u, int p) -> void {
			in[u] = timer++;
			for (int v : adj[u]) if (v != p) self(self, v, u);
			out[u] = timer;
		};

		dfs(dfs, 0, 0);
		vector pos(n+1, -1), ans(n, 1);
		SegTree<int, 1> seg(n, mod);
		for (int i = 0; i < n; ++i) {
		    if (pos[col[i]] >= 0) continue;
		    pos[col[i]] = in[i];
		    seg.update(in[i], freq[col[i]]);
		}
		auto solve = [&] (const auto &self, int u, int p) -> void {
			for (int v : adj[u]) {
				if (v == p) continue;
				self(self, v, u);
			}
			
			seg.update(pos[col[u]], 1);
			seg.update(in[u], freq[col[u]]);
			pos[col[u]] = in[u];
			int L = in[u], R = out[u];
			ans[u] = (1LL * seg.query(0, L) * seg.query(R, n)) % mod;
		};
		solve(solve, 0, 0);
		for (auto x : ans) cout << x << ' ';
		cout << '\n';
	}
}