POWTREE - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: poetic_soul
Testers: nishant403, satyam_343
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

DFS, binary search

PROBLEM:

You have a tree on N vertices rooted at vertex 1, the i-th vertex has value A_i.
The power of a vertex is the maximum value present in its subtree.

In one operation, you can pick i and increase A_i by 1.
Find the minimum number of operations required to make the sum of powers of all the vertices at least X.

EXPLANATION

Let p_i denote the parent of node i and B_i denote the power of node i.
Computing the initial values of B_i can be done with a single dfs.
Notice that B_{p_i} \geq B_i for every i \gt 1.

Suppose we decide that we’re going to increase the value of node u.
It can only affect the powers of its ancestors.

In particular, if the ancestors of u are (in order from the root) 1 = v_1, v_2, \ldots, v_k = u, then the powers of some suffix of these ancestors will increase, while the rest will remain the same.
More specifically, suppose we increase A_u by m. Then,

  • Let i be the smallest index such that B_{v_i} \lt A_u + m.
    • Since the B_{v_i} are sorted (in decreasing order), this i can be found in \mathcal{O}(\log N) with binary search.
  • Then, B_{v_i}, B_{v_{i+1}}, \ldots, B_{v_k} will all be set to A_u+m, while the other vertices’ powers will remain unchanged.
  • The increase in the sum of powers can be computed in \mathcal{O}(1) if we had prefix sums of the B_{v_i}.

This analysis gives us one important observation required to solve this problem: the fact that it’s enough to perform all our operations on a single node.

Proof

Consider a solution where we increase both A_u and A_v.
Suppose their current values are C_u = A_u + r and C_v = A_v + s.

Now,

  • Let x be the increase in the sum of powers if we do C_u \to C_u+1
  • Let y be the increase in the sum of powers if we do C_v \to C_v+1.

Without loss of generality, let x \geq y.
Notice that the s moves done on v so far increased the sum of powers by at most s\cdot y each; because each move affected the power of at most y vertices.

Let’s undo these s moves and perform s moves on u instead.
Each such move will increase the power by at least x each; giving us a total increase of at least x\cdot s.
Since x\geq y, we have at least the same sum of powers, using exactly the same number of moves.

Repeating this process over and over again allows us to reach a solution where only one vertex has its value increased, as claimed.

Now that we know only one vertex needs its value to be increased, let’s see if we can figure out how much to increase it by.
Clearly, the larger m is the more the increase in sums of powers is going to be.
If we fix m and u, as noted above the increase can be computed in \mathcal{O}(\log N) (if we had appropriate prefix sums present).
So, for a fixed u, we can find the smallest possible m that allows the sum of powers to exceed X by binary searching on m.

This allows us to solve for a single vertex in \mathcal{O}(\log N\log X), provided we have a list of its ancestors and the corresponding prefix sums.

Creating this list for each vertex would take \mathcal{O}(N) time, which is too slow.
Instead, we can be a bit more clever.

Let’s maintain two stacks: the list of ancestors from the root, and the corresponding prefix sums.
Then, start a DFS from vertex 1, doing the following:

  • When you enter vertex u, push u onto the stack. When doing this, also push the appropriate prefix sum onto the stack (that is, B_u plus the previous top of the stack).
  • Now the two stacks contain exactly the lists we need to solve for u, so do that in \mathcal{O}(\log N\log X) as described above.
  • Then, DFS into all the children of u, thereby solving for those as well.
  • Finally, when leaving u, pop both stacks once.

This procedure ensures that whenever we’re at a vertex u, the two stacks contain exactly the information we need.
We do a total of N pushes and pops each, so this part is \mathcal{O}(N).

The algorithm as a whole thus takes \mathcal{O}(N\log N\log X) time.

TIME COMPLEXITY:

\mathcal{O}(N\log N\log X) per testcase.

CODE:

Setter's code (C++)
#include <bits/stdc++.h>
using namespace std;

#define ll                   long long
#define vi vector<int>
#define pb push_back
#define allrev(v)            v.rbegin(), v.rend()
#define lb                   lower_bound
const int mod = 1e9+7;
const int N = 1e5+5;

vi g[N];
void dfs(int u,int p,vi &ass, vi&a)
{
	ass[u] = a[u];
	for(int v:g[u])
	{
		if(v!=p)
		{
			dfs(v,u,ass,a);
			ass[u] = max(ass[u],ass[v]);
		}
	}
}

void dfs2(int u,int p,vi &ass,vector<ll> &b, vector<ll> &cur,vector<ll> &pref,ll rem)
{
	cur.pb(ass[u]);
	pref.pb(ass[u]);
	pref.back() += pref[pref.size()-2];
	ll l = ass[u],r = 1e12;
	while(l+1<r)
	{
		ll mid = (l+r)/2;
		ll ind = lb(allrev(cur),mid)-cur.rbegin();
		ll sc = ind*mid-(pref[pref.size()-1]-pref[pref.size()-1-ind]);
		if(sc>=rem)r = mid;
		else l = mid;
	}
	b[u] = r;
	for(int v:g[u])
	{
		if(v!=p)
		{
			dfs2(v,u,ass,b,cur,pref,rem);
			ass[u] = max(ass[u],ass[v]);
		}
	}
	cur.pop_back();
	pref.pop_back();
}

void solve()
{
    ll n,x;
    cin >> n >> x;
    for(int i = 0; i<=n; i++)g[i].clear();
    vi a(n+1);
    
    for(int i = 1; i<=n; i++)cin >> a[i];
    
    for(int i = 1; i<n; i++)
    {
		int u,v;
		cin >> u >> v;
		g[u].pb(v);
		g[v].pb(u);
	}
	vi ass(n+1,0);
	dfs(1,-1,ass,a);
	ll sum = 0;
	for(int i:ass)sum += i;
	if(sum>=x)
	{
		cout << "0\n";
		return;
	}
	vector<ll> b(n+1,-1);
	vector<ll> cur,pref;
	pref.pb(0);
	dfs2(1,-1,ass,b,cur,pref,-sum+x);
	ll ans = 1e12;
	for(int i = 1; i<=n; i++)
	{
	    ans = min(ans,b[i]-a[i]);
	    assert(b[i]!=-1);
	}
	cout << ans << "\n";
}


int main()
{
    ios_base::sync_with_stdio(0); cin.tie(0);cout.tie(0);
    int t = 1;
     cin>> t;
    for(int i = 1; i<=t; i++) {
       solve();



   }
   return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;

/*
------------------------Input Checker----------------------------------
*/

long long readInt(long long l, long long r, char endd)
{
    long long x = 0;
    int cnt = 0;
    int fi = -1;
    bool is_neg = false;
    while (true)
    {
        char g = getchar();
        if (g == '-')
        {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if ('0' <= g && g <= '9')
        {
            x *= 10;
            x += g - '0';
            if (cnt == 0)
            {
                fi = g - '0';
            }
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);

            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if (g == endd)
        {
            if (is_neg)
            {
                x = -x;
            }

            if (!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        }
        else
        {
            assert(false);
        }
    }
}
string readString(int l, int r, char endd)
{
    string ret = "";
    int cnt = 0;
    while (true)
    {
        char g = getchar();
        assert(g != -1);
        if (g == endd)
        {
            break;
        }
        cnt++;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}
long long readIntSp(long long l, long long r)
{
    return readInt(l, r, ' ');
}
long long readIntLn(long long l, long long r)
{
    return readInt(l, r, '\n');
}
string readStringLn(int l, int r)
{
    return readString(l, r, '\n');
}
string readStringSp(int l, int r)
{
    return readString(l, r, ' ');
}

/*
------------------------Main code starts here----------------------------------
*/

#define int long long

struct FenwickTree {
    vector<int> bit;  // binary indexed tree
    int n;

    FenwickTree(int n) {
        this->n = n;
        bit.assign(n, 0);
    }

    FenwickTree(vector<int> a) : FenwickTree(a.size()) {
        for (size_t i = 0; i < a.size(); i++)
            add(i, a[i]);
    }

    int sum(int r) {
        int ret = 0;
        for (; r >= 0; r = (r & (r + 1)) - 1)
            ret += bit[r];
        return ret;
    }

    int sum(int l, int r) {
        if(l > r) return 0;
        return sum(r) - sum(l - 1);
    }

    void add(int idx, int delta) {
        for (; idx < n; idx = idx | (idx + 1))
            bit[idx] += delta;
    }
};

const int MAX_T = 1e4;
const int MAX_N = 1e5;
const int MAX_SUM_N = 1e5;
const int MAX_X = 1e12;
const int MAX_VAL = 1e6;

#define fast                      \
    ios_base::sync_with_stdio(0); \
    cin.tie(0);                   \
    cout.tie(0)

int sum_n = 0;
int max_n = 0;
int max_x = 0;
int max_ans = 0;

int n,x;
int a[MAX_N + 1];
vector<int> g[MAX_N + 1];
int vis[MAX_N + 1];
int mx[MAX_N + 1];
int ans;

FenwickTree a_sum(MAX_VAL + 1),a_cnt(MAX_VAL + 1);

void dfs(int node,int par = -1) {
    vis[node] = 1;
    mx[node] = a[node];
    
    for(auto x : g[node]) {
        if(x != par) {
            assert(vis[x] == false);
            dfs(x,node);
            mx[node] = max(mx[node],mx[x]);
        }
    }
    
    x -= mx[node];
}

void dfs2(int node,int par = -1) {
    
    a_cnt.add(mx[node],1);
    a_sum.add(mx[node],mx[node]);
    
    //calculate result for increasing current node
    int l = 0;
    int r = max(l,x);
    int res = r;
    
    while(l <= r) {
        int mid = (l + r)/2;
        
        //if we perform mid operations on current node, what will be final sum of value
        int sum_val = 0;
        
        sum_val += a_cnt.sum(1,min(MAX_VAL,a[node] + mid - 1)) * (a[node] + mid);
        sum_val -= a_sum.sum(1,min(MAX_VAL,a[node] + mid - 1));
        
        if(sum_val >= x) {
            res = mid;
            r = mid - 1;
        } else {
            l = mid + 1;
        }
    }
    
    ans = min(ans , res);
    
    for(auto x : g[node]) {
        if(x != par) {
            dfs2(x,node);
        }
    }
    
    a_cnt.add(mx[node],-1);
    a_sum.add(mx[node],-mx[node]);
}

void solve()
{
    n = readIntSp(1, MAX_N);
    max_n = max(max_n, n);
    sum_n += n;
    assert(sum_n <= MAX_SUM_N);

    x = readIntLn(1,MAX_X);
    ans = x;
    max_x = max(max_x, x);

    for (int i = 1; i <= n; i++)
    {
        if (i != n)
        {
            a[i] = readIntSp(1, MAX_VAL);
        }
        else
        {
            a[i] = readIntLn(1, MAX_VAL);
        }
    }
    
    for(int i = 1; i < n; i++) {
        int u,v;
        u = readIntSp(1 , n);
        v = readIntLn(1 , n);
        assert(u != v);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    
    // Observations : 
    // It is optimal to increase only one node
    // More operations we perform in this node , More answer will be , More will be the value of this node , we can bsearch on these

    dfs(1);
    dfs2(1); 
    
    max_ans = max(max_ans, ans);

    cout << ans << '\n';
    
    for(int i=1;i<=n;i++) {
        vis[i] = 0;
        g[i].clear();
        mx[i] = 0;
    }
}

signed main()
{
    int t = 1;
    t = readIntLn(1, MAX_T);

    for (int i = 1; i <= t; i++)
    {
        solve();
    }

    assert(getchar() == -1);

    cerr << "SUCCESS\n";
    cerr << "Tests : " << t << '\n';
    cerr << "Sum of N : " << sum_n << '\n';
    cerr << "Maximum N : " << max_n << '\n';
    cerr << "Maximum X : " << max_x << '\n';
    cerr << "Maximum answer : " << max_ans << '\n';
}
Editorialist's code (Python)
def calc(tree, a):
	stack = [0]
	dp = [0]*len(a)
	visited = [0]*len(a)
	while stack:
		u = stack[-1]
		if visited[u] == 0:
			visited[u] = 1
			for v in tree[u]:
				if visited[v]: continue
				stack.append(v)
		else:
			dp[u] = a[u]
			for v in tree[u]:
				if visited[v] == 2:
					dp[u] = max(dp[u], dp[v])
			visited[u] = 2
			stack.pop()
	return dp

for _ in range(int(input())):
	n, x = map(int, input().split())
	a = list(map(int, input().split()))
	tree = [[] for _ in range(n)]
	for i in range(n-1):
		u, v = map(int, input().split())
		tree[u - 1].append(v - 1)
		tree[v - 1].append(u - 1)
	vals = calc(tree, a)
	req = max(0, x - sum(vals))
	
	stack = [0]
	pref = [vals[0]]
	ans = 10**18
	while stack:
		u = stack[-1]
		# Push children onto stack
		if tree[u]: 
			v = tree[u][-1]
			tree[u].pop()
			if u > 0 and stack[-2] == v: continue
			stack.append(v)
			pref.append(pref[-1] + vals[v])
		# No children left, stack contains root -> u path
		else:
			sz = len(stack)
			def find_index(x):
				# Find smallest i such that vals[stack[i]] <= x
				lo, hi = 0, sz
				while lo < hi:
					mid = (lo + hi)//2
					if vals[stack[mid]] <= x: hi = mid
					else: lo = mid+1
				return lo
			
			lo, hi = 0, vals[0] + req
			while lo < hi:
				mid = (lo + hi)//2
				
				# Find first element on stack that's <= a[u] + mid
				# Use prefix sums to figure out how much is added
				idx = find_index(a[u] + mid)
				if idx == sz:
					lo = mid+1
				else:
					increase = pref[-1]
					if idx > 0: increase -= pref[idx-1]
					increase = (sz-idx) * (a[u] + mid) - increase
					if increase >= req: hi = mid
					else: lo = mid+1
			ans = min(ans, lo)
			stack.pop()
			pref.pop()
	print(ans)
2 Likes