# PROBLEM LINK

**Setter:** Abdullah Aslam

**Tester:** Michael Nematollahi

**Editorialist:** Taranpreet Singh

# DIFFICULTY

Easy-Medium.

# PREREQUISITES

Both Greedy, Dynamic-Programming.

# PROBLEM

Given a map of Tolland with N cities connected with bi-directional roads in the form of a tree. Abdullah lives in city X. For the next 2*N days, He made a plan to visit city i from city X on the day 2*i-1 and then to return from city i to the city X on the day 2*i. Each city has an associated toll to be paid if Abdullah visits that city, given in array TL. So, for traveling from city i to city j, the toll for all cities on the path from i to j (including city i and the city j) has to be paid.

By using Coupons, He can reduce the tolls for any city to any non-negative value. Determine the minimum amount he should spend on coupons so that on any day, his total toll paid does not exceed K.

# EXPLANATION

Let us root the tree at node X.

In problem with so long statement, basically, we need to buy coupons at minimum cost to reduce the maximum toll to at most K for any day.

**Lemma:** Maximum toll paid on any day is determined by the toll paid to reach any of the leaf nodes.

**Proof:** For any non-leaf node, if he needs to pay toll A units to visit this node, then for visiting any child of this node, he needs to pay A+x units for some non-negative x which can never be less than A. Hence, the maximum number of coins needed to reach any node is the same as the number of coins needed to reach any node in the tree.

Hence, if we manage to use coupons in such a manner so that all leaves are reachable using at most K units of money, it shall ensure that we can reach any city using at most K units of money.

Let us calculate for each node, the minimum number of coins using which Abdullah can reach every leaf in the subtree of that node. Now, it can be very easily computed using bottom-up dynamic programming.

For a leaf, the minimum number of coins needed is the toll of that leaf node itself.

For a non-leaf node, the minimum number of coins needed is the sum of the toll of current node u, plus the maximum of the coins needed to reach from any direct child v of u to leaf in their subtrees.

Recursively, minimum number of coins is given by V_u = mx[u] = cost[u]+max(mx[v]) for v being direct children of u.

Now, we can see that toll of node X is included in every path, So reducing the toll of node X reduces the toll of all paths. After reducing it to zero, we cannot reduce it further. If the number of coins needed still exceeds k, we can move one step towards the leaf which still needs more than k coins and so on. The key idea is, that reducing the toll of nodes near to root is better than their descendants since by reducing the toll of city u, the total toll for all node in the subtree of v reduces by the same amount.

So, let us visit every city, and if the toll required to be paid to reach leaf from current node V_u exceed K, We try to reduce the toll of the current city by V_u-K. But we cannot reduce the toll to some negative value. So, we reduce the toll for each city u by min(toll_u, V_u-K). But some cities already may have V_u < K, so we do not reduce toll at all for those cities. Hence, for each city u, we can reduce the toll by min(toll_u, max(V-K, 0)).

The sum of this over all the cities gives the minimum coupons to be used. But problem involve going and then coming back, requiring each coupon to be bought twice. Hence the answer is doubled.

## You Believe I'm wrong here?

Some of you guys might think I’m wrong here, as seemingly I didn’t take into account the reduction in the toll of the parent node of any node. But the reason is, that V_child never included toll_u in the first place, which is why we didn’t need to subtract it.

Editorialist solution is based on the same implementation, while the other two implementations differ slightly in the way they approach, though key ideas are the same.

Still, believe I’m wrong? drop in comments.

Surprise, the problem is solved! Will meet again in the journey of any other problem. xD

# TIME COMPLEXITY

Time complexity is O(N) per test case.

# SOLUTIONS:

## Setter's Solution

```
#include <bits/stdc++.h>
using namespace std;
#define LL long long
#define PB push_back
#define N 300101
#define SZ 3001
#define LB lower_bound
#define M 1000000007
#define UB upper_bound
#define MP make_pair
#define LD long double
#define F first
#define S second
#define endl "\n"
vector<LL> adj[N];
LL val[N],cst[N],mx[N],k,rs;
LL vis[N];
void chk0(LL x)
{
vis[x]=1;
for(auto ch:adj[x])
if(!vis[ch])
chk0(ch);
}
LL dfs(LL x,LL p,LL c)
{
// cout<<x<<" "<<p<<" "<<c<<endl;
c+=val[x];
cst[x]=c;
mx[x]=c;
for(auto ch:adj[x])
if(ch!=p)
mx[x]=max(mx[x],dfs(ch,x,c));
// cout<<x<<" "<<mx[x]<<endl;
return mx[x];
}
void dfs1(LL x,LL p,LL df)
{
LL cr=max(mx[x]-k-df,0ll);
cr=min(cr,val[x]);
rs+=cr;
df+=cr;
for(auto ch:adj[x])
if(ch!=p)
dfs1(ch,x,df);
}
int main()
{
LL i,j,lt,tc,d,r,q,y,z,v,x,m,n,u;
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
cin>>lt;
while(lt--)
{
rs=0;
cin>>n>>x>>k;
for(i=1;i<=n;i++)
{
adj[i].clear();
cin>>val[i];
}
for(i=1;i<n;i++)
{
cin>>u>>v;
adj[u].PB(v);
adj[v].PB(u);
}
fill(vis,vis+n+1,0);
chk0(x);
for(i=1;i<=n;i++)
assert(vis[x]==1);
dfs(x,0,0);
dfs1(x,0,0);
cout<<rs*2<<endl;
if(rs==0)
y++;
d+=rs*2;
}
}
```

## Tester's Solution

```
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define F first
#define S second
const int MAXN = 1e4 + 10;
int n, root, cost[MAXN];
ll bud, ans;
deque<int> deq;
vector<int> adj[MAXN];
ll dfs(int v, int p = -1, ll sm = 0){
deq.push_back(v);
sm += cost[v];
ll ret = 0;
while (sm > bud){
int u = deq.front();
ll g = min((ll)cost[u], sm-bud);
ans += g<<1;
ret += g;
cost[u] -= g;
sm -= g;
if (!cost[u])
deq.pop_front();
}
for (int u:adj[v])
if (u^p){
ll temp = dfs(u, v, sm);
ret += temp;
sm -= temp;
sm = max(0ll, sm);
}
if (deq.size() && deq.back() == v)
deq.pop_back();
return ret;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
int te; cin >> te;
while (te--){
cin >> n >> root >> bud, root--;
for (int i = 0; i < n; i++) adj[i].clear(), cin >> cost[i];
for (int i = 0; i < n-1; i++){
int a, b; cin >> a >> b, a--, b--;
adj[a].push_back(b);
adj[b].push_back(a);
}
ans = 0;
deq.clear();
dfs(root);
cout << ans << "\n";
}
return 0;
}
```

## Editorialist's Solution

```
import java.util.*;
import java.io.*;
class ABDTOLL{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
int n = ni(), x = ni()-1;long k = nl();
long[] cost = new long[n];
for(int i = 0; i< n; i++)cost[i] = nl();
int[][] e = new int[n-1][];
for(int i = 0; i< n-1; i++)e[i] = new int[]{ni()-1, ni()-1};
int[][] g = makeU(n, e);
long[] sum = new long[n];
fillSum(g, sum, cost, x, -1);
pn(2*calc(g, sum, cost, x, -1, k));
}
long calc(int[][] g, long[] sum, long[] cost, int u, int p, long k){
long cur = Math.min(Math.max(0, sum[u]-k), cost[u]);
long ans = cur;
for(int v:g[u]){
if(v==p)continue;
ans+=calc(g, sum, cost, v, u, k);
}
return ans;
}
void fillSum(int[][] g, long[] sum, long[] cost, int u, int p){
for(int v:g[u]){
if(v==p)continue;
fillSum(g, sum, cost, v, u);
sum[u] = Math.max(sum[u], sum[v]);
}
sum[u] += cost[u];
}
int[][] makeU(int n, int[][] edge){
int[][] g = new int[n][];int[] cnt = new int[n];
for(int i = 0; i< edge.length; i++){cnt[edge[i][0]]++;cnt[edge[i][1]]++;}
for(int i = 0; i< n; i++)g[i] = new int[cnt[i]];
for(int i = 0; i< edge.length; i++){
g[edge[i][0]][--cnt[edge[i][0]]] = edge[i][1];
g[edge[i][1]][--cnt[edge[i][1]]] = edge[i][0];
}
return g;
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new ABDTOLL().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
```

Feel free to Share your approach, if you want to. (even if its same ) . Suggestions are welcomed as always had been.