# PROBLEM LINK:

Practice

Div-3 Contest

Div-2 Contest

Div-1 Contest

**Author:** Vishesh Saraswat

**Tester:** Istvan Nagy

**Editorialist:** Vishesh Saraswat

# DIFFICULTY:

Easy

# PREREQUISITES:

Graph, DFS, Prefix Sums

# PROBLEM:

There is a tree with N nodes. Every second new raindrops fall on its leaves and old raindrops move 1 node closer to the root. The tree is shook M times, each shake resulting in all raindrops except the ones at the root falling off. Find the number of raindrops at the root after the $M$th shake.

# SOLUTION AND EXPLANATION

We can see that each shake essentially resets the tree to the initial version. So, we can try to solve this problem for M independent time intervals and then add up the answer.

What we need now is an efficient method to calculate the answer for each interval. We know that a raindrop from a leaf which is at a distance of d nodes from the root, will take d+1 seconds to reach the root. After a raindrop reaches the root for the first time, more raindrops will follow from that leaf and we will get a raindrop from that leaf indefinitely. Building upon this, let’s now take D to be the distance of the leaf which is farthest from the root. Therefore after D seconds, we can be sure that the root is receiving X (the number of leaves) raindrops every second.

All we need to do now is to calculate how many raindrops the root receives for \leq D seconds after the tree is reset. To do so let’s define an array R such that R_i is the number of raindrops the root has after i seconds from a reset. We’ll also define an array L such that L_i is the number of leaves at a distance of i from the root. R_1 will be 0 because no drops can reach the root in the first second. Now, for all 1 \lt i \leq D we can calculate R_i as R_{i-1} + \sum_{j=1}^{i-1}L_j, the sum of raindrops that have already reached by i-1 seconds and the number of new raindrops that will reach this second.

So, we will calculate our final answer by going over all M intervals. Let’s say Y is the duration of current interval, we will add R_Y to the answer if Y \leq D, otherwise we will add R_D + (X \cdot (Y-D)) to the answer. X here is the number of total leaves in the tree.

# SOLUTIONS

## Setter's / Editorialist's Solution

```
#include "bits/stdc++.h"
using namespace std;
/*
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using ordered_set = tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update>;
*/
#define all(x) begin(x), end(x)
#define rall(x) rbegin(x), rend(x)
#define sz(x) (int)(x).size()
using ll = long long;
const int mod = 1e9+7;
vector<vector<int>> adj;
vector<int> L;
vector<int> dist;
ll X;
void dfs(int u, int p = -1) {
for (int v : adj[u]) {
if (v != p) {
dist[v] = dist[u] + 1;
dfs(v, u);
}
}
if (sz(adj[u]) == 1 and u != 0) {
X++;
L[dist[u]]++;
}
}
void solve(int tc) {
int n, m;
cin >> n >> m;
X = 0;
adj.clear(); L.clear(); dist.clear();
adj.resize(n); L.resize(n); dist.resize(n);
for (int i = 0; i < n-1; ++i) {
int u, v;
cin >> u >> v;
--u, --v;
adj[u].push_back(v);
adj[v].push_back(u);
}
vector<int> a(m+1);
for (int i = 1; i <= m; ++i)
cin >> a[i];
dfs(0);
ll D = n-1;
while (L[D] == 0)
--D;
vector<ll> R(D+1);
ll cursum = 0;
for (int i = 1; i <= D; ++i) {
R[i] = R[i-1] + (cursum + L[i-1]);
cursum += L[i-1];
}
ll ans = 0;
for (int i = 1; i <= m; ++i) {
ll Y = a[i] - a[i-1];
if (Y <= D)
ans += R[Y];
else
ans += R[D] + (Y - D) * X;
}
cout << ans << '\n';
}
signed main() {
cin.tie(0)->sync_with_stdio(0);
int tc = 1;
cin >> tc;
for (int i = 1; i <= tc; ++i) solve(i);
return 0;
}
```

## Tester's Solution

```
#include <iostream>
#include <algorithm>
#include <string>
#include <cassert>
#include <vector>
#include <numeric>
using namespace std;
#ifdef HOME
#define NOMINMAX
#include <windows.h>
#endif
long long readInt(long long l, long long r, char endd) {
long long x = 0;
int cnt = 0;
int fi = -1;
bool is_neg = false;
while (true) {
char g = getchar();
if (g == '-') {
assert(fi == -1);
is_neg = true;
continue;
}
if ('0' <= g && g <= '9') {
x *= 10;
x += g - '0';
if (cnt == 0) {
fi = g - '0';
}
cnt++;
assert(fi != 0 || cnt == 1);
assert(fi != 0 || is_neg == false);
assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
}
else if (g == endd) {
assert(cnt > 0);
if (is_neg) {
x = -x;
}
assert(l <= x && x <= r);
return x;
}
else {
//assert(false);
}
}
}
string readString(int l, int r, char endd) {
string ret = "";
int cnt = 0;
while (true) {
char g = getchar();
assert(g != -1);
if (g == endd) {
break;
}
cnt++;
ret += g;
}
assert(l <= cnt && cnt <= r);
return ret;
}
long long readIntSp(long long l, long long r) {
return readInt(l, r, ' ');
}
long long readIntLn(long long l, long long r) {
return readInt(l, r, '\n');
}
string readStringLn(int l, int r) {
return readString(l, r, '\n');
}
string readStringSp(int l, int r) {
return readString(l, r, ' ');
}
int main() {
#ifdef HOME
if (IsDebuggerPresent())
{
freopen("../in.txt", "rb", stdin);
freopen("../out.txt", "wb", stdout);
}
#endif
int T = readIntLn(1, 1000);
uint32_t sumN = 0;
for (int tc = 0; tc < T; ++tc)
{
uint32_t N = readIntSp(2, 100'000);
sumN += N;
uint32_t M = readIntLn(1, 100'000);
vector<vector<uint32_t>> neighb(N);
for (uint32_t i = 0; i < N - 1; ++i)
{
uint32_t u = readIntSp(1, N);
uint32_t v = readIntLn(1, N);
assert(u != v);
--u; --v;
neighb[u].push_back(v);
neighb[v].push_back(u);
}
vector<uint32_t> A(M);
vector<uint32_t> B;
uint32_t actr = 0;
uint32_t prevA = 0;
for (auto&& ai : A)
{
++actr;
if (actr == M)
ai = readIntLn(1, 1'000'000'000);
else
ai = readIntSp(1, 1'000'000'000);
assert(ai > prevA);
prevA = ai;
}
vector<int> d;
d.resize(N);
vector<int> nodes({0});
vector<bool> used(N);
used[0] = true;
for (uint32_t i = 0; i < nodes.size(); ++i)
{
auto actNode = nodes[i];
for (auto cand : neighb[actNode])
{
if(used[cand])
continue;
used[cand] = true;
nodes.push_back(cand);
d[cand] = d[actNode] + 1;
}
}
assert(nodes.size() == N);
uint32_t maxDepth = 0;
uint64_t leafCount = 0;
for (uint32_t i = 1; i < N; ++i)
{
if (neighb[i].size() == 1)
{
maxDepth = std::max<uint32_t>(maxDepth, d[i]);
++leafCount;
}
}
vector<uint32_t> vRainDrops(maxDepth + 1);
for (uint32_t i = 1; i < N; ++i)
{
if (neighb[i].size() == 1)
vRainDrops[d[i]]++;
}
uint32_t dropCtr = 0;
for (uint32_t i = 1; i < vRainDrops.size(); ++i)
{
dropCtr += vRainDrops[i];
vRainDrops[i] = vRainDrops[i-1] + dropCtr;
}
uint64_t res = 0;
uint32_t prev = 0;
for (auto ai : A)
{
uint64_t dist = ai - prev;
--dist;
if (dist < vRainDrops.size())
res += vRainDrops[dist];
else
res += vRainDrops[maxDepth] + (dist - maxDepth) * leafCount;
prev = ai;
}
printf("%llu\n", res);
}
assert(sumN <= 500'000);
assert(getchar() == -1);
}
```