CSED - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: nirbhaypaliwal
Tester: wasd2401
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

DFS, dynamic programming

PROBLEM:

There’s a tree with N nodes rooted at node 1. The i-th initially has value P_i.
You start at node 1, and repeatedly do the following as many times as you like:

  • Move from the current node to one of its ancestors
  • If you’re visiting some node u for the first time, add P_u to your score.

However, when moving from a vertex to its child, the value of P_u of every vertex in the subtree of the child decreases by 1.

This curse can be relaxed for exactly one node.
For each u = 2, 3, 4, \ldots, N, find the maximum possible score if the curse is relaxed for node u.

EXPLANATION:

First, let’s solve a slightly easier version of this task: suppose the curse can’t be removed at all. What’s the maximum score possible?

This can be computed with the help of dynamic programming.
First, observe that the effective value of node u is P_u - \text{dep}_u, where \text{dep}_u denotes the number of edges on the 1\to u path.
It should be easy to see why: every time one such edge is traversed downwards, P_u decreases by 1.

Now, let dp[u] denote the answer for the subtree of u (that is, we start at u and only move within its subtree).
It can be seen that:

  • We definitely get a value of P_u - \text{dep}_u for starting at u.
  • Then, for each child v of u, we can get a value of dp[v] by going into it.
    However, if dp[v] is negative, it’s optimal to just not go into the subtree of v at all.
  • So, dp[u] = P_u - \text{dep}_u + \sum_v \max(0, dp[v]).
    The summation is across all children v of u.

This can be computed in \mathcal{O}(N) time for all vertices, with a DFS.


Now, suppose vertex u is uncursed.
This effectively increases the value of every node within the subtree of u by 1, which of course might change some values.

Let’s store this information also with dynamic programming.
Let dp_2[u] denote the answer for the subtree of u, if u is uncursed.
Then, similar reasoning as the first part tells us that

dp_2[u] = P_u - \text{dep}_u + 1 + \sum_v \max(0, dp_2[v])

Note that we take dp_2[v] in the summation because uncursing u modifies its entire subtree; so when moving to a child we can pretend the child is itself uncursed (since the effect on its subtree is the same).


Finally, let’s compute the actual answer for a vertex u, given that we start at 1.
Let the path from 1 to u be 1 \to x_1 \to x_2 \to\ldots\to x_k \to u.
There are two possibilities for the optimal solution when u is uncursed: either we visit u at some point, or we never visit it.

If we never visit u, the value obtained is just dp[1] — the answer for the whole tree assuming every node is still cursed.
That’s because:

  • If dp[1] was obtained by never visiting u, clearly it is still the answer.
  • If dp[1] did visit u, then if u is uncursed we can only get an even higher value (meaning we move to the case where u is visited, and take that as the answer).

That leaves us with the case when the optimal solution includes visiting u.
Notice that the optimal sequence of visits will look as follows:

  • Start at 1, getting a value of P_1.
  • For each child v of 1 other than x_1, get a value of \max(0, dp[v]) by visiting them.
  • Then, move to x_1, getting a value of P_{x_1} - 1.
  • Again, for each child of x_1 other than x_2, get \max(0, dp[v]).
  • Then move to x_2, and so on.
    \vdots
  • Finally, when we reach u, the best we can do is dp_2[u].

This is fairly easy to compute in \mathcal{O}(N) time for a single fixed u, but that’s too slow.

To optimize it, we can simply send information downwards when performing our DFS.
That is, let S be the ‘extra’ value obtained by going off of our path.
Then, when at a vertex u, do the following:

  • Set \text{ans}_u = \max(dp[1], dp_2[u] + S).
  • For each child v of u, when recursing into v, update S as follows:
    • Add P_u - \text{dep}_u to S.
    • Also add \sum_{x\neq v} \max(0, dp[x]) to S, to account for everything that’s ‘away’ from the path.
    • To do this quickly, notice that the total added value simply equals dp[u] - \max(0, dp[v]): after all, that’s the formula dp[u] was calculated with in the first place.
      This allows us to transition to a child of u in constant time.

A single DFS now lets us find the answer for all N vertices, and we’re done!

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Tester's code (C++)
/*

*       *  *  ***       *       *****
 *     *   *  *  *     * *        *
  *   *    *  ***     *****       *
   * *     *  * *    *     *   *  *
    *      *  *  *  *       *   **

                                 *
                                * *
                               *****
                              *     *
        *****                *       *
      _*     *_
     | * * * * |                ***
     |_*  _  *_|               *   *
       *     *                 *  
        *****                  *  **
       *     *                  ***
  {===*       *===}
      *  IS   *                 ***
      *  IT   *                *   *
      * RATED?*                *  
      *       *                *  **
      *       *                 ***
       *     *
        *****                  *   *
                               *   *
                               *   *
                               *   *
                                ***   

*/

#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace __gnu_pbds;
using namespace std;

#define osl tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update>
#define ll long long
#define ld long double
#define forl(i, a, b) for(ll i = a; i < b; i++)
#define rofl(i, a, b) for(ll i = a; i > b; i--)
#define fors(i, a, b, c) for(ll i = a; i < b; i += c)
#define fora(x, v) for(auto x : v)
#define vl vector<ll>
#define vb vector<bool>
#define pub push_back
#define pob pop_back
#define fbo find_by_order
#define ook order_of_key
#define yesno(x) cout << ((x) ? "YES" : "NO")
#define all(v) v.begin(), v.end()

const ll N = 3e5 + 4;
const ll mod = 1e9 + 7;
// const ll mod = 998244353;

vl v[N];
vl depth(N);
vl a(N),a1(N),a2(N);
vl b(N,1);
vl v1(N);
ll modinverse(ll a) {
	ll m = mod, y = 0, x = 1;
	while (a > 1) {
		ll q = a / m;
		ll t = m;
		m = a % m;
		a = t;
		t = y;
		y = x - q * y;
		x = t;
	}
	if (x < 0) x += mod;
	return x;
}
ll gcd(ll a, ll b) {
	if (b == 0)
		return a;
	return gcd(b, a % b);
}
ll lcm(ll a, ll b) {
	return (a / gcd(a, b)) * b;
}
bool poweroftwo(ll n) {
	return !(n & (n - 1));
}
ll power(ll a, ll b, ll md = mod) {
	ll product = 1;
	a %= md;
	while (b) {
		if (b & 1) product = (product * a) % md;
		a = (a * a) % md;
		b /= 2;
	}
	return product % md;
}
void barfi(ll n){
	a1[n]=a[n]-depth[n];
	fora(x,v[n]){
		if(depth[x]==-1){
			depth[x]=depth[n]+1;
			barfi(x);
			if(a1[x]>0) a1[n]+=a1[x];
		}
	}
}
ll cnt;
void kulfi(ll n){
	b[n]=0;
	if(n>1){
		if(a1[n]<0) cnt+=a1[n];
	}
	a2[n]=a[n]-depth[n]+1;
	fora(x,v[n]){
		if(b[x]){
			kulfi(x);
			if(a2[x]>0) a2[n]+=a2[x];
		}
	}
	if(n>1){
		ll x=a2[n]-a1[n];
		if(cnt+x<=0) v1[n]=a1[1];
		else v1[n]=a1[1]+cnt+x;
		if(a1[n]<0) cnt-=a1[n];
	}
}
void panipuri() {
	ll n, m = 0, k = -1, c = 0, sum = 0, q = 0, ans = 0, p = 1;
	string s;
	bool ch = true;
	cin >> n;
	forl(i, 1, n+1) {
		cin >> a[i];
		v[i].clear();
		depth[i]=-1;
		a1[i]=-1e18;
		a2[i]=-1e18;
		b[i]=1;
	}
	forl(i,0,n-1){
		ll x,y;
		cin>>x>>y;
		v[x].pub(y);
		v[y].pub(x);
	}
	cnt=0;
	depth[1]=0;
	barfi(1);
	kulfi(1);
	forl(i,2,n+1) cout<<v1[i]<<' ';
	return;
}
int main() {
	ios::sync_with_stdio(false);
	cin.tie(NULL);
	#ifndef ONLINE_JUDGE
	freopen("input.txt", "r", stdin);
	freopen("output.txt", "w", stdout);
	#endif
	int laddu = 1;
	cin >> laddu;
	forl(i, 1, laddu + 1) {
		// cout << "Case #" << i << ": ";
		panipuri();
		cout << '\n';
	}
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector<int> a(n);
        for (int &x : a) cin >> x;
        vector adj(n, vector<int>());
        for (int i = 0; i < n-1; ++i) {
            int u, v; cin >> u >> v;
            adj[--u].push_back(--v);
            adj[v].push_back(u);
        }

        vector<ll> dp1(n), dp2(n), ans(n), sm(n);
        auto dfs = [&] (const auto &self, int u, int p, int dep = 0) -> void {
            for (int v : adj[u]) if (v != p) {
                self(self, v, u, dep + 1);
                dp1[u] += max(0ll, dp1[v]);
                dp2[u] += max(0ll, dp2[v]);
                sm[u] += max(0ll, dp1[v]);
            }
            dp1[u] += a[u] - dep;
            dp2[u] += a[u] - dep;
            if (dep) ++dp2[u];
        };
        auto fix = [&] (const auto &self, int u, int p, ll up = 0, int dep = 0) -> void {
            ans[u] = up + dp2[u];
            for (int v : adj[u]) if (v != p) {
                self(self, v, u, up + sm[u] + a[u] - dep - max(0ll, dp1[v]), dep + 1);
            }
        };
        dfs(dfs, 0, 0);
        fix(fix, 0, 0);

        for (int i = 1; i < n; ++i) cout << max(dp1[0], ans[i]) << ' ';
        cout << '\n';
    }
}