# PROBLEM LINK:

*Author:* naveen19991124

*Editorialist:* naveen19991124

*Tester:* valiant_vidit,hellb0y_suru

# PROBLEM:

Given a Tree rooted at Node 1 where all nodes are initially uncolored (color/value - 0).You are given Q queries,where for each query you need to color all the nodes which are within a distance D from the given node with a given color.After processing all the queries, for each node in a tree you need to find the number of distinct colors in the subtree rooted at node i.

# PREREQUISITES:

- DP
- Euler Tour
- Range Data Structures(Segment, Fenwick)

# DIFFICULTY:

EASY-MEDIUM

# EXPLANATION:

**Brute Force Approach :**

According to the given problem , we need to color all the nodes with the given color which are within distance d from the given node as a query.

So for getting the final tree from all the given queries , we can do a simple DFS starting from node as per a given query and color all the nodes with the given color which are within the given distance.

Hence, for each query it takes O(V+E) in worst case time to color all the nodes within the given distance with given color.

We get our final tree in Q*(V+E) which is not fast enough!!! (It’s good enough for **1st Subtask**)

Now for getting the number of distinct colors in the subtree rooted at node i, we can simply maintain sets over the nodes in the subtree to calculate

the number of distinct colors in the given subtree rooted at node i.

Hence for each query it takes O(N*LogN) time , where N is the number of nodes in the subtree rooted at given node which is fast enough to pass the **1st subtask** for the given queries.

Or you can use other Data Structure to count the distinct values in the given subtree , but in either case you need to perform a DFS.

**Efficient Approach :**

**One of the key observations here is to reverse the queries.**

The idea is that if we color the nodes from the Nth query, the nodes which are already colored will never be colored even if they fall within distance of 1…(N-1)th queries.

But even if you don’t need to repeatedly recolor the nodes , still you need to perform DFS to check whether there exist some uncolored nodes or not.

**Now, we need to observe that the maximum distance given is only 20.**

Here , we can clearly observe that if we have already colored the nodes with a given color that are at a distance D, then we don’t need to further explore it’s children, to check whether there further exists some uncolored nodes or not, because they must have been colored.

So , we can maintain a dp[node][distance] , where if a node is already colored when we reach it at distance distance then , we can simply return.

So , in order to obtain the final tree, for each node we can reach it at 20 different distances, so in worst case, it will take 20*N time to get the final tree from the given queries.

**How to get distinct colors in the subtree rooted at node i :**

This is a standard problem of counting distinct values in a given subarray as a query.

But we have a tree?

The tree can be visualised as an array where each node corresponds to some subarray denoting it’s subtree i.e Euler Tour.

So now, after Euler Tour problem reduces to find the distinct values in a given subarray for each query , which can be easily answered using any range data structure + Mo’s algorithm / Square root decomposition.

# SOLUTION:

## Setter's Solution in C++

```
//AUTHOR - NAVEEN KUMAR
//The Game Is ON!!!
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp> // Common file
#include <ext/pb_ds/tree_policy.hpp> // Including tree_order_statistics_node_update
//#include<boost/multiprecision/cpp_int.hpp>
//#inlcude<boost/multiprecision/cpp_dec_float.hpp>
#define ll long long
#define ld long double
#define F first
#define S second
#define nl "\n"
#define mem(v, t) memset(v, t, sizeof(v))
#define all(v) v.begin(), v.end()
#define srt(v) sort(all(v))
#define rsrt(v) sort(v.rbegin(), v.rend())
#define pb push_back
#define f(a) for (ll i = 0; i < a; i++)
#define rep(i, a, b) for (ll i = a; i < b; i++)
#define rrep(i, a, b) for (ll i = a; i > b; i--)
#define vll vector<ll>
#define pll pair<ll, ll>
#define vpll vector<pair<ll, ll>>
#define mp make_pair
using namespace std;
vector<vector<int>> vec(110000);
vector<int> color(110000, 0);
int dp[110000][21];
bool visited[110000];int intime[110000];int outtime[110000];
int linear[110000];int ans[110000];int val[110000];
int curtime = 0;
struct query
{
int v;
int d;
int c;
};
int n;
void flatten(int v)
{
visited[v] = true;
intime[v] = ++curtime;
for (int i = 0; i < vec[v].size(); i++)
{
if (!visited[vec[v][i]])
{
flatten(vec[v][i]);
}
}
outtime[v] = curtime;
}
struct BIT
{
vector<ll> bit;
void init(ll N)
{
N = n;
bit.assign(N + 1, 0);
}
void update(ll idx, ll val)
{
while (idx <= n)
{
bit[idx] += val;
idx += idx & (-idx);
}
}
ll pref(ll idx)
{
ll ans = 0;
while (idx > 0)
{
ans += bit[idx];
idx -= idx & (-idx);
}
return ans;
}
ll rsum(ll l, ll r)
{
return pref(r) - pref(l - 1);
}
};
void dfs(int v, int d, int c)
{
if (dp[v][d] != -1)
{
return;
}
dp[v][d]++;
if (color[v] == 0)
{
color[v] = c;
}
if (d == 0)
{
return;
}
for (int next : vec[v])
{
if (dp[next][d - 1] == -1)
{
dfs(next, d - 1, c);
}
}
}
//MAIN CODE
int main()
{
mem(dp, -1);
cin >> n;
for (int i = 0; i < n - 1; i++)
{
int a, b;
cin >> a >> b;
a--;
b--;
vec[a].push_back(b);
vec[b].push_back(a);
}
int q;
cin >> q;
vector<query> op(q);
for (int i = 0; i < q; i++)
{
int v, d, c;
cin >> v >> d >> c;
v--;
op[i] = query{v, d, c};
}
reverse(op.begin(), op.end());
for (int i = 0; i < q; i++)
{
dfs(op[i].v, op[i].d, op[i].c);
}
memset(visited, false, sizeof(visited));
flatten(0);
for (int i = 1; i <= n; i++)
{
linear[intime[i - 1]] = color[i - 1];
}
vector<pair<pair<ll, ll>, ll>> p;
for (int i = 1; i <= n; i++)
{
p.push_back(mp(mp(outtime[i - 1], intime[i - 1]), i));
}
BIT t1;
t1.init(n);
sort(all(p));
ll ind = 0;
map<ll, ll> last;
ll k = 0;
for (ll i = 1; i <= n; i++)
{
if (last.find(linear[i]) == last.end())
{
last[linear[i]] = i;
t1.update(i, 1);
}
else
{
t1.update(last[linear[i]], -1);
last[linear[i]] = i;
t1.update(i, 1);
}
if (i == p[ind].first.first && ind <= n)
{
ll j = ind;
while (p[j].first.first == p[ind].first.first && ind <= n)
{
ans[p[ind].second] = t1.rsum(p[ind].first.second, p[ind].first.first);
ind++;
}
}
}
for (ll i = 1; i <= n - 1; i++)
{
cout << ans[i] << " ";
}
cout << ans[n] << "\n";
return 0;
}
//The Game Is OVER!!!
```