FLGZRO - Editorial


Contest Division 1
Contest Division 2
Contest Division 3

Setter: Manan Grover
Tester: Istvan Nagy
Editorialist: Taranpreet Singh




Sack on Tree, Basic Combinatorics.


Given a tree with N nodes where each node is assigned a non-zero value, given by array A of length N.

You randomly choose a node (say node u) and change its value to zero. Now, if A_u = 0 for some node u and u is not a leaf node, a node v in the subtree of u shall be chosen such that u \neq v and A_u and A_v are swapped. This process is then repeated until the node with value 0 is a leaf node. Let’s call this a final tree.

In the end, there is exactly one node with the value 0 which is a leaf in the tree.

Determine the number of different trees possible. Two trees are considered distinct if there’s a node u having different values in trees.


  • If node u is chosen, it can be swapped with any node v in it’s subtree having a different value to obtain a unique tree. So \displaystyle ways_u = \sum_{v \in S, A_u \neq A_v} ways_v where ways_u denote the number of final trees where value of node u is changed and S denote nodes in subtree of u excluding node u. Leaf nodes have ways_u =1
  • To quickly sum up such values, we can either use sack on tree or use Stack for each distinct value.


Simplified Problem

Let’s assume all values are distinct. What is the number of distinct trees? It is easy to see that we can group the number of final trees by the node chosen, whose value is replaced with 0. Let ways_u denote the number of such trees. It is easy to see that ways_u = 1 if node u is leaf node.

Now, assuming non-leaf node u, Since all values are distinct, we can swap with any node v in the subtree of node u. The number of trees where u is chosen, and then swapped with node v is equivalent to choosing node v. So total ways_v distinct trees arise in case node u is chosen, and then swapped with node v.

Since node u can be swapped with any node in its subtree independently, then by addition rule of counting, we can simply add them up to obtain distinct trees.

So, we can write \displaystyle ways_u = \sum_{v \in S} ways_v where S is the set of nodes in subtree of node u excluding node u.

Hence, this simplified problem is solved in O(N)

For example, consider a test case

1 2
1 3
1 2 3

Node 2 and 3 are leaf nodes, so ways_2 = ways_3 = 1. For node 1, we have ways_1 = ways_2+ways_3 = 2.

Hence, the total number of distinct trees is \displaystyle \sum_{u} ways_u = 4. We can check that there are 4 trees, node values represented as 103, 120, 203, 320 (String 103 means node 1 has value 1, node 2 has value 0 and node 3 has value 3).

Original problem

Let’s apply the same solution as above here. Now, some trees are double-counted. Let’s figure out which ones.

Considering the same example, with the label of node 1 being 2.

1 2
1 3
2 2 3

Here, if we proceed like the previous solution, the node values in trees are represented by 203, 220, 203, 320. See that node labels 203 appeared twice. It happened, because this tree is counted twice, first when node 2 is chosen, and secondly when node 1 is chosen and swapped immediately with node 2.

It happened because after node u was chosen, it was immediately swapped with a node having the same value.


Considering two nodes u and v where v is in the subtree of node u and A_u = A_v, then choosing node u and swapping with node v results in trees same as when node v is chosen in the first operation.

It happens because, after the swap, node u contains its original value, and node v is 0, which is what happens when node v is chosen.

Hence, Let’s add another constraint, that after a node u is chosen, it must be swapped with node v in the subtree of node u excluding node u, if and only if A_u \neq A_v. This way, we know that the value of node u will change, and we can group all different trees by the topmost node, whose value is changed.

Let’s redefine ways_u a bit. ways_u now represents the number of distinct trees, such that node u is the highest node in the tree whose value is changed. We can see that now, for leaf nodes ways_u = 1 for leaf nodes.

For non-leaf nodes, \displaystyle ways_u = \sum_{v \in S, A_u \neq A_v} ways_v.

This is what we need now. To compute ways_u efficiently.

Approach 1

Let’s use sack on tree. This problem is one of the standard applications of this technique, to maintain some value in the subtree of node u. Here, let’s maintain \displaystyle F_x = \sum_{v \in S, A_v = x} ways_v. That is, the sum of ways_v where node v has value x. Also, maintain \displaystyle T = \sum F_x

This way, ways_u = T - F_{A_u} can be computed quickly.

Whenever we add or remove a node v, we update both F_{A_v} and T.

Have a look at the editorialist’s solution for this. We can optionally compress values in A so that an array can be used for storing F. Or use a map.

So this approach has time complexity O(N*log(N)). See editorialist’s solution for this approach.

Approach 2

In this approach, we shall maintain a stack for each distinct value, and whenever we exit a node, we insert node u into stack A_u.

While processing node u, it shall contain some nodes v such that v in subtree of node u, A_v = A_u and there’s no node on path from u to v with same value. After these values, the stack may also contain some nodes visited before entering node u.

Hence, when computing ways_u, we shall remove the contribution of nodes at top of the stack, which lies in the subtree of node u, and remove them from the stack as well.

We can prove that each node is inserted and remove only once from the stack. Hence, the time complexity of this approach is O(N).

Refer to the setter’s solution for this approach.


The time complexity is O(N) or O(N*log(N)) per test case.


Setter's Solution
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define md 1000000007
void dfs(ll x, ll pr, vector<ll> tr[], ll dp[], ll cor[], map<ll, vector<ll>> &mpp, ll a[]){
  auto it = mpp.find(a[x]);
  for(ll i = 0; i < (ll)tr[x].size(); i++){
    ll y = tr[x][i];
    if(y == pr){
    dfs(y, x, tr, dp, cor, mpp, a);
    dp[x] += dp[y];
    if((ll)tr[y].size() > 1){
      dp[x] += dp[y];
    dp[x] -= cor[y];
    dp[x] %= md;
  if((ll)tr[x].size() == 1 && x != 0){
    dp[x] = 1;
    cor[(*it).second.back()] += dp[x];
    cor[(*it).second.back()] %= md;
int main(){
  ll t;
    ll n;
    ll dp[n+1] = {};
    ll cor[n+1] = {};
    ll a[n+1] = {};
    vector<ll> tr[n+1];
    for(ll i = 0; i<n-1; i++){
      ll u,v;
    map<ll, vector<ll>> mpp;
    vector<ll> temp;
    for(ll i = 0; i < n + 1; i++){
        a[i] = 0;
      mpp.insert(make_pair(a[i], temp));
    dfs(0, -1, tr, dp, cor, mpp, a);
    ll ans = dp[0];
    ans %= md;
    ans += md;
    ans %= md;
  return 0;
Tester's Solution
#include <iostream>
#include <cassert>
#include <vector>
#include <set>
#include <map>
#include <algorithm>
#include <random>

#ifdef HOME
    #include <windows.h>

#define all(x) (x).begin(), (x).end()
#define rall(x) (x).rbegin(), (x).rend()
#define forn(i, n) for (int i = 0; i < (int)(n); ++i)
#define for1(i, n) for (int i = 1; i <= (int)(n); ++i)
#define ford(i, n) for (int i = (int)(n) - 1; i >= 0; --i)
#define fore(i, a, b) for (int i = (int)(a); i <= (int)(b); ++i)

template<class T> bool umin(T &a, T b) { return a > b ? (a = b, true) : false; }
template<class T> bool umax(T &a, T b) { return a < b ? (a = b, true) : false; }

using namespace std;

long long readInt(long long l, long long r, char endd) {
    long long x = 0;
    int cnt = 0;
    int fi = -1;
    bool is_neg = false;
    while (true) {
	    char g = getchar();
	    if (g == '-') {
		    assert(fi == -1);
		    is_neg = true;
	    if ('0' <= g && g <= '9') {
		    x *= 10;
		    x += g - '0';
		    if (cnt == 0) {
			    fi = g - '0';
		    assert(fi != 0 || cnt == 1);
		    assert(fi != 0 || is_neg == false);

		    assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
	    else if (g == endd) {
		    assert(cnt > 0);
		    if (is_neg) {
			    x = -x;
		    assert(l <= x && x <= r);
		    return x;
	    else {

string readString(int l, int r, char endd) {
    string ret = "";
    int cnt = 0;
    while (true) {
	    char g = getchar();
	    assert(g != -1);
	    if (g == endd) {
	    ret += g;
    assert(l <= cnt && cnt <= r);
    return ret;
long long readIntSp(long long l, long long r) {
    return readInt(l, r, ' ');
long long readIntLn(long long l, long long r) {
    return readInt(l, r, '\n');
string readStringLn(int l, int r) {
    return readString(l, r, '\n');
string readStringSp(int l, int r) {
    return readString(l, r, ' ');

int main(int argc, char** argv) 
#ifdef HOME
	    freopen("../FLGZRO_0.in", "rb", stdin);
	    freopen("../out.txt", "wb", stdout);
    int T = readIntLn(1, 100'000);
    int sumN = 0;

    forn(tc, T)
	    int N = readIntLn(1, 100'000);
	    sumN += N;
	    vector<pair<int, int >> edges(N - 1);
	    vector<vector<int> > neighb(N);
	    for (auto& e : edges)
		    e.first = readIntSp(1, N);
		    e.second = readIntLn(1, N);
	    vector<int> v(N);
	    int cc = 0;
	    for (auto& vi : v)
		    if (++cc == N)
			    vi = readIntLn(1, 1'000'000'000);
			    vi = readIntSp(1, 1'000'000'000);
	    vector<vector<int>> childs(N);
	    vector<int> parent(N, -1);

	    vector<int> q(1);
	    vector<bool> used(N);
	    used[0] = true;
	    forn(i, N)
		    int actN = q[i];
		    for (auto ne : neighb[actN])
			    if (!used[ne])
				    used[ne] = true;
				    parent[ne] = actN;

	    int64_t MOD = 1'000'000'007;
	    vector<int64_t> dp(N);
	    vector<map<int, int64_t> > ex(N);
	    reverse(q.begin(), q.end());
	    for(auto qi : q)
		    int curV = v[qi];
		    if (childs[qi].empty())
			    dp[qi] = 1;
			    ex[qi][curV] = 1;
		    int largestChildVal = 0, largestChildSize = 0;
		    int64_t newV = 0, sumEx = 0;
		    for (auto ci : childs[qi])
			    newV += dp[ci];
			    if (ex[ci].count(curV))
				    sumEx += ex[ci][curV];
			    if (largestChildSize < ex[ci].size())
				    largestChildSize = ex[ci].size();
				    largestChildVal = ci;

		    newV %= MOD;
		    sumEx %= MOD;
		    int64_t newEx = newV - sumEx;
		    if (newEx < 0)
			    newEx += MOD;

		    newV += newV;
		    newV += MOD- sumEx;
		    newV %= MOD;


		    dp[qi] = newV;

		    ex[qi][curV] += newEx;
		    ex[qi][curV] %= MOD;

		    for (auto ci : childs[qi])
			    if(ci == largestChildVal)
			    for (auto mi : ex[ci])
				    ex[qi][mi.first] += mi.second;
				    ex[qi][mi.first] %= MOD;
	    printf("%lld\n", dp[0]);
    assert(sumN <= 1'000'000);
    return 0;
Editorialist's Solution
import java.util.*;
import java.io.*;
class Main{
    final long MOD = (long)1e9+7;
    long[] ways, valueMap;
    int[] sub, col;
    int[][] tree;
    long total = 0;
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        int N = ni();
        int[] from = new int[N-1], to = new int[N-1];
        for(int i = 0; i< N-1; i++){
            from[i] = ni()-1;
            to[i] = ni()-1;
        col = new int[N];
        int[] tmp = new int[N];
        for(int i = 0; i< N; i++)tmp[i] = col[i] = ni();
        int C = 1;
        for(int i = 1; i< N; i++)if(tmp[i] != tmp[C-1])tmp[C++] = tmp[i];
        for(int i = 0; i< N; i++)col[i] = Arrays.binarySearch(tmp, 0, C, col[i]);
        valueMap = new long[C];
        tree = make(N, N-1, from, to, true);
        sub = new int[N];
        sub(0, -1);
        total = 0;
        ways = new long[N];
        dfs(0, -1, true);
        long ans = 0;
        for(int i = 0; i< N; i++){
            ans += ways[i];
            if(ans >= MOD)ans -= MOD;
    void sub(int u, int p){
        for(int v:tree[u]){
            if(v == p)continue;
            sub(v, u);
            sub[u] += sub[v];
    void dfs(int u, int p, boolean keep){
        int hc = -1;
        for(int v:tree[u]){
            if(v == p)continue;
            if(hc == -1 || sub[v] > sub[hc])
                hc = v;
        for(int v:tree[u]){
            if(v == hc || v == p)continue;
            dfs(v, u, false);
        if(hc != -1)dfs(hc, u, true);
        for(int v:tree[u]){
            if(v != hc && v != p)
                add(v, u, 1);
        if(hc == -1)
            ways[u] = 1;
            ways[u] = (total+MOD-valueMap[col[u]])%MOD;
        total += ways[u];
        if(total >= MOD)total -= MOD;
        valueMap[col[u]] += ways[u];
        if(valueMap[col[u]] >= MOD)valueMap[col[u]] -= MOD;
        if(!keep)add(u, p, -1);
    void add(int u, int p, long mul){
        valueMap[col[u]] = (valueMap[col[u]]+MOD+ways[u]*mul)%MOD;
        total = (total+MOD+mul*ways[u])%MOD;
        for(int v:tree[u])
            if(v != p)
                add(v, u, mul);
    int[][] make(int n, int e, int[] from, int[] to, boolean f){
        int[][] g = new int[n][];int[]cnt = new int[n];
        for(int i = 0; i< e; i++){
        for(int i = 0; i< n; i++)g[i] = new int[cnt[i]];
        for(int i = 0; i< e; i++){
            g[from[i]][--cnt[from[i]]] = to[i];
            if(f)g[to[i]][--cnt[to[i]]] = from[i];
        return g;
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = true;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
    public static void main(String[] args) throws Exception{
        new Main().run();
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
            return st.nextToken();

        String nextLine() throws Exception{
            String str = "";
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            return str;

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

1 Like

I won’t have been able to solve this if I wouldn’t have looked at this after the atcoder contest. This problem use exactly the same idea as in editorial. My code.


I solved that problem too using sack on tree.

I coded the N log^2 (N) dsu on trees and it passed. Coding the hld style one might have been a bit more time consuming. I guess if constraints would have been 1e6 on N, that might have called for only NlogN or N solutions.

Video editorialist has shared a solution with only dfs nothing fancy but that is O(n * n) in worst case.


Yeah I haven’t learnt DSU, but after seeing the Codeforces blog on DSU on trees, the problem seemed very similar to that of the recent AtCoder contest. Thanks for validating that this could be done by the in and out time approach too.

1 Like

can anyone help me out to solve this problem, I spend a lot of time solving this but I can’t, here is my solutioncode that is quite similar to video editorial

I used Euler touring(for traversal) and Fenwick tree(to store the answer). link to my submission.

I have used same approach as video editorial. But it is giving TLE can anybody help me point out my mistake Solution: 46859731 | CodeChef


1 Like

Instead of computing dp of a vertex from vertices in its subtree, we can also do reverse and compute dp of a vertex from its ancestors.

dp [x] = 1 + sum of dp[u] where u is an ancestor of x
dp[x] = dp[x] - sum of dp[u] where u is an ancestor of x and a[u] = a[x].

The answer to our problem would be sum of dp values of all the leaves.
My submission

1 Like

I have checked my code against the successful submissions with nodes value upto 20. Its working fine. I have difficulty for only modulo condition. I don’t know where i should mod for the right answer. Can anyone check and let me know where I should mod to get the correct answer.

this is where it sums up all the dp values and gives the answer.
long sum = 0;
for(long j = 0; j < N; j++)
sum = (sum % mod + dp[j] % mod) % mod;

this is the method I wrote to calculate the dp value of every node.
public static void Traverse(long[] dp, long currNode, Dictionary<long, List> tree, long[] value)
if(dp[currNode] == 0 && !tree.ContainsKey(currNode))
dp[currNode] = 1;
else if(dp[currNode] == 0 && tree.ContainsKey(currNode))
foreach(long subNode in tree[currNode])
if(dp[subNode] == 0)
Traverse(dp, subNode, tree, value);
if(value[currNode] != value[subNode])
dp[currNode] = (dp[currNode] % mod + dp[subNode] % mod) % mod;

That’s a nice approach.

Hey, sorry to be bringing this up this late after the contest, but I looked through your code, and I noticed that the vector of pairs in the map isn’t sorted, then how are you using lower bound?

The first entry inserted in the pair corresponds to the time of arrival which is inserted in the map in the sorted order. I would recommend you to read the atcoder editorial thoroughly.