JMPFVR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2

Setter: Md. Mahamudur Rahaman Sajib
Tester: Teja Vardhan Reddy
Editorialist: Taranpreet Singh

DIFFICULTY:

Medium-Hard

PREREQUISITES:

Dynamic Convex Hull Trick, DSU on Tree

PROBLEM:

Given a tree with N nodes where node i has value V_i, let’s define f(u, v) be the value of path, given by \sum_{i = 1}^L V_{p_i}*i (After converting suffix sums to normal sums), where path from u to v contains L nodes, first node being u and last node being v and p_i is i-th node on path.

Now, consider choosing two distinct nodes u and v. The simple path from u to v be u = a_1, a_2, \ldots a_L = v. Then chef may choose any non-negative integer m and any sequence k_1, k_2 \ldots k_m such that 2 \leq k1 < k2 < \ldots < k_m < L. Then the score of trip is f(u, a_{k_1})+f(a_{k_1}, a_{k_2})+ \ldots + f(a_{k_m}, v). You may also choose m = 0, then the score of trip is f(u, v)

Find the maximum possible score of the trip by choosing u, v, and the sequence k optimally.

QUICK EXPLANATION

  • Decompose single path f(u, v) to \displaystyle f(u, v) = up(u, a_k) + k*V_{a_k}+down(v, a_k) + k*s(a_{k+1}, v) where up(u, l) and down(v, l) are upward and downward path costs and s(u, v) is the sum of value of nodes and a_k is the LCA of u and v.
  • Rewrite equations to fix start point (and end point) and LCA node and apply convex hull trick to find best value path starting (ending) at some node.
  • Use dsu on tree trick to reduce time complexity from O(N^2*log(N)) to O(N*log^2(N))
  • In order to handle trip consisting of multiple paths, all paths except the once containing LCA node are either moving upward or downward.
  • So compute up_u being the best trip value moving upward ending at u and down_u being the best trip value starting at u and moving downwards. Only the constant term gets affected in above convex hulls.

EXPLANATION

Let’s consider a simpler problem, where we just need to find the max score path. We just need to find the maximum value of f(u, v) for some u, v

Assuming u = a_1, a_2, \ldots a_L = v be the required path. Then the required path score becomes \sum_{i = 1}^L V_{a_i}*i. Suppose a_k be the lowest common ancestor of u and v.

Writing \displaystyle f(u, v) = \sum_{i = 1}^L i*V_{a_i} = \sum_{i = 1}^{k-1} i*V_{a_i}+k*V_{a_k}+\sum_{i = k+1}^L i*V_{a_i} (Just split the summation by k-th node.)

\displaystyle f(u, v) = \sum_{i = 1}^{k-1} i*V_{a_i}+k*V_{a_k}+\sum_{i = 1}^{L-k} i*V_{a_{i+k}} + k*\sum_{i = 1}^{L-k} V_{a_{i+k}}

Above can be written as \displaystyle f(u, v) = up(u, a_k) + k*V_{a_k}+down(v, a_k) + k*s(a_{k+1}, v)

where up(u, l) is the score of path from u inclusive to l exclusive and down(u, l) is the score of the path from l exclusive to u inclusive and s(u, v) denote the sum of values on path from u to v both inclusive.

Hence, we have decomposed f(u, v) into sum of an upward path (since a_k is ancestor of u), a downward path (Since a_k is an ancestor of v) and some other terms.

Now, let’s use some pre-computation to allow us to compute the values of an upward/downward path in O(1). Try writing the score of a path from root to some node, and from some node to root. Try if you can use prefix arrays to compute these values efficiently.

It is worth a shot to try solving this problem first, to understand pre-computation.

Pre-computation

Let’s compute three values for each node. In all three, we consider root = a_1, a_2 \ldots a_m = u be the path from root to node u.
d_u be depth of node u
\displaystyle sum_u = \sum_{i = 1}^m V_{a_i}
\displaystyle sum1_u = \sum_{i = 1}^m (m-i+1)*V_{a_i}
\displaystyle sum2_u = \sum_{i = 1}^m i*V_{a_i}

Figure out how we can write up(u, l) and down(u, l) in terms of above terms where l is the ancestor of node u.

Here's how

Note that we compute following, we assume l is excluded, but u is included in path.
up(u, l) = sum1_u-sum1_l-sum_l*(d_u-d_l)
down(u, l) = sum2_u-sum2_l-d_l*(sum_u-sum_l)

The proof is left as an exercise.

Now that we know how to write f(u, v) as a combination of an upward path and a downward path, we need to group all pairs (u, v) by their lowest common ancestor, and efficiently find the best score path.

Let us expand \displaystyle f(u, v) = up(u, a_k) + k*V_{a_k}+down(v, a_k) + k*s(a_{k+1}, v) as follows.
\displaystyle f(u, v) = sum1_u-sum1_l-sum_l*(d_u-d_l) + k*V_{a_k}+sum2_u-sum2_l-d_l*(sum_u-sum_l) + k*s(a_{k+1}, v)

Also, we can observe k = d_u-d_l+1 (As a_k = l is the LCA node), so we get

\displaystyle f(u, v) = sum1_u-sum1_l-sum_l*(d_u-d_l) + (d_u-d_l+1)*V_l+sum2_u-sum2_l-d_l*(sum_u-sum_l) + (d_u-d_l+1)*(s(l, v)-V_l)

This is the most important expression of this problem. Given u and v, we can compute this in O(1).

But, what if we are given l and v, but not u. Rearranging terms, we get

\displaystyle f(u, l, v) = \bigg(sum1_u\bigg) + d_u*\bigg(-sum_l+s(l, v)\bigg) + \bigg(-sum1_l+sum_l*d_l + (-d_l+1)*s(l, v)+down(v, l)\bigg)

Similarly, if we have u and l but not v, then we get following
\displaystyle f(u, l, v) = \bigg( sum2_v\bigg)+sum_v*\bigg(-2*d_l+d_u+1\bigg)+ \bigg( -sum2_l+d_l*sum_l+(d_u-d_l+1)*(V_l-sum_l) + up(u, l) \bigg)

We can see, the third term is independent of missing term, and first and second term are a standard application of convex hull trick, with first term being the intecept and second term being slope.

So, let’s maintain two convex hulls based on above equations.

Now, for a fixed LCA l, all we need to do is to iterate over all nodes in subtree of l, when reached a node w, try it as both start point (and query in first convex hull) and end point (and query in second convex hull) of required path.

But we need to assure that lca of fixed node and all nodes in hull is l. Say the children of l are w_1, w_2 \ldots w_m We first try to fix all nodes in sub-tree of w_1, then add all nodes in this sub-tree to both hulls, then repeat for each child. This way, LCA of current node, and all nodes in hull is guaranteed to be l

The above procedure takes O(N*log(N)) operations for each node, which shall time out. But, we can notice that the lines added to hull remain same, hence we can merge small hull to large (or just reuse large hull). This means, that we reuse the hull built for the child with largest sub-tree size. This article called dsu on tree would be helpful to understand. This reduces the time complexity to O(N*log^2N)

You can try above problem here. It’s editorial present a different solution based on centroid decomposition (which won’t be useful here).

The final thing is to return to original problem (if you remember :stuck_out_tongue: ) where we have to consider trip from u to v into several paths. The crucial observation is that all paths not including LCA of u and v are either moving upward only, or moving downward only.

So, let’s introduce two other values,

  • up_u being the largest trip value which started in sub-tree of u and ended at u
  • down_u being the largest trip value which started at v and ended in sub-tree of v

Notice that

  • For some path f(u, v), the best trip score becomes up_u+f(u, v)+down_v.
  • up_u isn’t dependent upon v, nor down_v depended upon u.

Suppose we have these two values computed for all nodes in sub-tree of l. We can prove that we can still use the same line equations, only the intercept term becomes sum1_u+up_u and sum2_v+down_v respectively.

Finally, to compute up_{lca} and down_{lca}, we can fix end point and start point as lca.

Do refer my implementation which uses exactly same variables for convenience.

TIME COMPLEXITY

The time complexity is O(N*log^2N) per test case.

SOLUTIONS:

Setter's Solution
#include<bits/stdc++.h>
//#include <ext/pb_ds/assoc_container.hpp>
//#include <ext/pb_ds/tree_policy.hpp>
#include <cstring>
#include <iostream>
#define pie acos(-1)
#define si(a) scanf("%d",&a)
#define sii(a,b) scanf("%d %d",&a,&b)
#define siii(a,b,c) scanf("%d %d %d",&a,&b,&c)
#define sl(a) scanf("%lld",&a)
#define sll(a,b) scanf("%lld %lld",&a,&b)
#define slll(a,b,c) scanf("%lld %lld %lld",&a,&b,&c)
#define ss(st) scanf("%s",st)
#define sch(ch) scanf("%ch",&ch)
#define ps(a) printf("%s",a)
#define newLine() printf("\n")
#define pi(a) printf("%d",a)
#define pii(a,b) printf("%d %d",a,b)
#define piii(a,b,c) printf("%d %d %d",a,b,c)
#define pl(a) printf("%lld",a)
#define pll(a,b) printf("%lld %lld",a,b)
#define plll(a,b,c) printf("%lld %lld %lld",a,b,c)
#define pch(c) printf("%ch",c)
#define debug1(str,a) printf("%s=%d\n",str,a)
#define debug2(str1,str2,a,b) printf("%s=%d %s=%d\n",str1,a,str2,b)
#define debug3(str1,str2,str3,a,b,c) printf("%s=%d %s=%d %s=%d\n",str1,a,str2,b,str3,c)
#define debug4(str1,str2,str3,str4,a,b,c,d) printf("%s=%d %s=%d %s=%d %s=%d\n",str1,a,str2,b,str3,c,str4,d)
#define for0(i,n) for(i=0;i<n;i++)
#define for1(i,n) for(i=1;i<=n;i++)
#define forab(i,a,b) for(i=a;i<=b;i++)
#define forstl(i, s) for (__typeof ((s).end ()) i = (s).begin (); i != (s).end (); ++i)
#define nl puts("")
#define sd(a) scanf("%lf",&a)
#define sdd(a,b) scanf("%lf %lf",&a,&b)
#define sddd(a,b,c) scanf("%lf %lf %lf",&a,&b,&c)
#define sp printf(" ")
#define ll long long int
#define ull unsigned long long int
#define MOD 1000000007
#define mpr make_pair
#define pub(x) push_back(x)
#define pob(x) pop_back(x)
#define mem(ara,value) memset(ara,value,sizeof(ara))
#define INF INT_MAX
#define eps 1e-9
#define checkbit(n, pos) (n & (1<<pos))
#define setbit(n, pos) (n  (1<<pos))
#define para(i,a,b,ara)\
for(i=a;i<=b;i++){\
    if(i!=0){printf(" ");}\
    cout<<ara[i];\
}\
printf("\n");
#define pvec(i,vec)\
for(i=0;i<vec.size();i++){\
    if(i!=0){printf(" ");}\
    cout<<vec[i];\
}\
printf("\n");
#define ppara(i,j,n,m,ara)\
for(i=0;i<n;i++){\
    for(j=0;j<m;j++){\
        if(j!=0){printf(" ");}\
        cout<<ara[i][j];\
    }\
    printf("\n");\
}
#define ppstructara(i,j,n,m,ara)\
for(i=0;i<n;i++){\
    printf("%d:\n",i);\
    for(j=0;j<m;j++){\
        cout<<ara[i][j];printf("\n");\
    }\
}
#define ppvec(i,j,n,vec)\
for(i=0;i<n;i++){\
    printf("%d:",i);\
    for(j=0;j<vec[i].size();j++){\
        if(j!=0){printf(" ");}\
        cout<<vec[i][j];\
    }\
    printf("\n");\
}
#define ppstructvec(i,j,n,vec)\
for(i=0;i<n;i++){\
    printf("%d:",i);\
    for(j=0;j<vec[i].size();j++){\
        cout<<vec[i][j];printf("\n");\
    }\
}
#define sara(i,a,b,ara)\
for(i=a;i<=b;i++){\
    scanf("%d",&ara[i]);\
}
#define pstructara(i,a,b,ara)\
for(i=a;i<=b;i++){\
    cout<<ara[i];nl;\
}
#define pstructvec(i,vec)\
for(i=0;i<vec.size();i++){\
    cout<<vec[i];nl;\
}
#define pstructstl(stl,x)\
for(__typeof(stl.begin()) it=stl.begin();it!=stl.end();++it){\
    x=*it;\
    cout<<x;nl;\
}\
nl;
#define pstl(stl)\
for(__typeof(stl.begin()) it=stl.begin();it!=stl.end();++it){\
    if(it!=stl.begin()){sp;}\
    pi(*it);\
}\
nl;
#define ppairvec(i,vec)\
for(i=0;i<vec.size();i++){\
    cout<<vec[i].first;sp;cout<<vec[i].second;printf("\n");\
}
#define ppairara(i,a,b,ara)\
for(i=a;i<=b;i++){\
    cout<<ara[i].first;sp;cout<<ara[i].second;printf("\n");\
}
#define pppairvec(i,j,n,vec)\
for(i=0;i<n;i++){\
    printf("%d:\n",i);\
    for(j=0;j<vec[i].size();j++){\
        cout<<vec[i][j].first;sp;cout<<vec[i][j].second;nl;\
    }\
}
#define pppairara(i,j,n,m,ara)\
for(i=0;i<n;i++){\
    printf("%d:\n",i);\
    for(j=0;j<m;j++){\
        cout<<ara[i][j].first;printf(" ");cout<<ara[i][j].second;nl;\
    }\
}
#define SZ 2 * 100010
#define xx first
#define yy second
using namespace std;
//using namespace __gnu_pbds;
//bool status[100010];
//vector <int> prime;
//void siv(){
//    int N = 100005, i, j; prime.clear();
//    int sq = sqrt(N);
//    for(i = 4; i <= N; i += 2){ status[i] = true; }
//    for(i = 3; i <= sq; i+= 2){
//        if(status[i] == false){
//            for(j = i * i; j <= N; j += i){ status[j] = true; }
//        }
//    }
//    status[1] = true;
//    for1(i, N){ if(!status[i]){ prime.pub(i); } }
//}
//mt19937_64 mt(chrono::steady_clock::now().time_since_epoch().count());
//auto seed = chrono::high_resolution_clock::now().time_since_epoch().count();
//std::mt19937 mt(seed);
inline int add(int _a, int _b){
    if(_a < 0){ _a += MOD; }
    if(_b < 0){ _b += MOD; }
    if(_a + _b >= MOD){ return _a + _b - MOD; }
    return _a + _b;
}
inline int mul(int _a, int _b){
    if(_a < 0){ _a += MOD; }
    if(_b < 0){ _b += MOD; }
    return ((ll)((ll)_a * (ll)_b)) % MOD;
}
const long long LL_INF = (long long) 2e18 + 5;
struct point {
    long long x, y;
    point() : x(0), y(0) {}
    point(long long _x, long long _y) : x(_x), y(_y) {}
};
// dp_hull enables you to do the following two operations in amortized O(log n) time:
// 1. Insert a pair (a_i, b_i) into the structure
// 2. For any value of x, query the maximum value of a_i * x + b_i
// All values a_i, b_i, and x can be positive or negative.

//max_cost(u, v) = pd[u] + jump_cost(u, v) + dp[v]
//= pd[u] + (ur[u] - ur[p) - (dpt[u] - dpt[lca] + 1) * sum[p + (sum[v] - sum[lca]) * (dpt[u] - dpt[lca] + 1)
//+ (ru[v] - ru[lca]) - (sum[v] - sum[lca]) * dpt[lca] + dp[v]
//= pd[u] + (ur[u] - ur[p) - (dpt[u] - dpt[lca] + 1) * sum[p - sum[lca] * (dpt[u] - dpt[lca] + 1) - ru[lca] + sum[lca] * dpt[lca]
//+ ru[v] + dp[v] + sum[v] * (dpt[u] - 2 * dpt[lca] + 1){c = ru[v] + dp[v], m = sum[v], x = (dpt[u] - 2 * dpt[lca] + 1)}

//max_cost(u, lca, v) = pd[u] + jump_cost(u, lca) + jump_cost(lca, v) + dp[v]
//= pd[u] + (ur[u] - ur[p) - (dpt[u] - dpt[lca] + 1) * sum[p]
//+ (ru[v] - ru[p) - (sum[v] - sum[p) * (dpt[lca] - 1) + dp[v]
//= pd[u] + (ur[u] - ur[p) - (dpt[u] - dpt[lca] + 1) * sum[p - ru[p + sum[p * (dpt[lca] - 1)
//+ ru[v] + dp[v] + sum[v] * (-dpt[lca] + 1)

//max_cost(v, u) = pd[v] + jump_cost(v, u) + dp[u]
//= pd[v] + (ur[v] - ur[p) - (dpt[v] - dpt[lca] + 1) * sum[p + (sum[u] - sum[lca]) * (dpt[v] - dpt[lca] + 1)
//+ (ru[u] - ru[lca]) - (sum[u] - sum[lca]) * dpt[lca] + dp[u]
//= (ru[u] - ru[lca]) - (sum[u] - sum[lca]) * dpt[lca] + dp[u] - ur[p + (dpt[lca] - 1) * sum[p + (sum[u] - sum[lca]) * (- dpt[lca] + 1)
//+ pd[v] + ur[v] + dpt[v] * (sum[u] - sum[lca] - sum[p){c = pd[v] + ur[v], m = dpt[v], x = (sum[u] - sum[lca] - sum[p)}

//max_cost(v, lca, u) = pd[v] + jump_cost(v, lca) + jump_cost(lca, u) + dp[u]
//= pd[v] + (ur[v] - ur[p) - (dpt[v] - dpt[lca] + 1) * sum[p
//+ (ru[u] - ru[p) - (sum[u] - sum[p) * (dpt[lca] - 1) + dp[u]
//= (ru[u] - ru[p) - (sum[u] - sum[p) * (dpt[lca] - 1) + dp[u] - ur[p + (dpt[lca] - 1) * sum[p
// + pd[v] + ur[v] + dpt[v] * (-sum[p)
struct dp_hull {
    struct segment {
        point p;
        mutable point next_p;
        segment(point _p = {0, 0}, point _next_p = {0, 0}) : p(_p), next_p(_next_p) {}
        bool operator<(const segment &other) const {
            // Sentinel value indicating we should binary search the set for a single x-value.
            if (p.y == LL_INF)
                return p.x * (other.next_p.x - other.p.x) <= other.p.y - other.next_p.y;
            return make_pair(p.x, p.y) < make_pair(other.p.x, other.p.y);
        }
    };
    set<segment> segments;
    int size() const {
        return segments.size();
    }
    set<segment>::iterator prev(set<segment>::iterator it) const {
        return it == segments.begin() ? it : --it;
    }
    set<segment>::iterator next(set<segment>::iterator it) const {
        return it == segments.end() ? it : ++it;
    }
    static long long floor_div(long long a, long long b) {
        return a / b - ((a ^ b) < 0 && a % b != 0);
    }
    static bool bad_middle(const point &a, const point &b, const point &c) {
        // This checks whether the x-value where b beats a comes after the x-value where c beats b. It's fine to round
        // down here if we will only query integer x-values. (Note: plain C++ division rounds toward zero)
        return floor_div(a.y - b.y, b.x - a.x) >= floor_div(b.y - c.y, c.x - b.x);
    }
    bool bad(set<segment>::iterator it) const {
        return it != segments.begin() && next(it) != segments.end() && bad_middle(prev(it)->p, it->p, next(it)->p);
    }
    void insert(const point &p) {
        set<segment>::iterator next_it = segments.lower_bound(segment(p));
        if (next_it != segments.end() && p.x == next_it->p.x)
            return;
        if (next_it != segments.begin()) {
            set<segment>::iterator prev_it = prev(next_it);
            if (p.x == prev_it->p.x)
                segments.erase(prev_it);
            else if (next_it != segments.end() && bad_middle(prev_it->p, p, next_it->p))
                return;
        }
        // Note we need the segment(p, p) here for the single x-value binary search.
        set<segment>::iterator it = segments.insert(next_it, segment(p, p));
        while (bad(prev(it)))
            segments.erase(prev(it));
        while (bad(next(it)))
            segments.erase(next(it));
        if (it != segments.begin())
            prev(it)->next_p = it->p;
        if (next(it) != segments.end())
            it->next_p = next(it)->p;
    }
    void insert(long long a, long long b) {
        insert(point(a, b));
    }
    // Queries the maximum value of ax + b.
    long long query(long long x) const {
        assert(size() > 0);
        set<segment>::iterator it = segments.upper_bound(segment(point(x, LL_INF)));
        return it->p.x * x + it->p.y;
    }
};
dp_hull convex, xevnoc;
const int N = 3e5;
int n, ara[N + 5], sbtr[N + 5], dpt[N + 5], vrtx[N + 5], order[N + 5], t = 0;
vector <int> adj[N + 5];
ll sum[N + 5], ru[N + 5], ur[N + 5], dp[N + 5], pd[N + 5];
ll global, val = 0;
void go(int u, int p, int d){
	int i, j;
    // pii(p, u); nl;
	for0(i, adj[u].size()){
		int v = adj[u][i];
		if(v == p) break;
	} if(i != adj[u].size()) adj[u].erase(adj[u].begin() + i);
	sbtr[u] = 1, dpt[u] = d;
	sum[u] = sum[p] + ara[u];
	ru[u] = (ll)ara[u] * (ll)dpt[u] + ru[p];
	ur[u] = ur[p] + sum[u];
    vrtx[t++] = u, order[u] = t - 1;
	for(int v : adj[u]) go(v, u, d + 1), sbtr[u] += sbtr[v];
}
//jump_cost(u, v) = (ru[v] - ru[p[u]]) - (sum[v] - sum[p[u]]) * (dpt[u] - 1)
//max_cost(u, v) = (ru[v] - ru[p[u]]) - (sum[v] - sum[p[u]]) * (dpt[u] - 1) + dp[v]
//               = -ru[p[u]] + sum[p[u]] * (dpt[u] - 1) + ru[v] + dp[v] - sum[v] * (dpt[u] - 1)
//               = -ru[p[u]] + sum[p[u]] * (dpt[u] - 1) + mx + c {m = -sum[v], x = (dpt[u] - 1), c = ru[v] + dp[v]}
//jump_cost(v, u) = (ur[v] - ur[p[u]]) - (dpt[v] - dpt[u] + 1) * sum[p[u]]
//max_cost(v, u)  = (ur[v] - ur[p[u]]) - (dpt[v] - dpt[u] + 1) * sum[p[u]] + pd[v]
//                = -ur[p[u]] + (dpt[u] - 1) * sum[p[u]] + ur[v] + pd[v] - dpt[v] * sum[p[u]]
//                =  -ur[p[u]] + (dpt[u] - 1) * sum[p[u]] + mx + c{m = -dpt[v], x = sum[p[u]], c = ur[v] + pd[v]}
void dsu_prep(int u, int p, bool keep){
    int i, j, mx = -1, bg = -1;
    for(int v : adj[u]) if(sbtr[v] > mx) mx = sbtr[v], bg = v;
    for(int v : adj[u]) if(v != bg) dsu_prep(v, u, 0);
    if(bg != -1) dsu_prep(bg, u, 1);
    for(int v : adj[u]){
        if(v == bg) continue;
        for(i = order[v]; i <= order[v] + sbtr[v] - 1; ++i){
            int x = vrtx[i];
            convex.insert(-sum[x], ru[x] + dp[x]);
            xevnoc.insert(-(ll)dpt[x], ur[x] + pd[x]);
        }
    }
    dp[u] = pd[u] = 0;
    if(adj[u].size()){
        dp[u] = -ru[p] + sum[p] * (ll)(dpt[u] - 1) + convex.query((ll)(dpt[u] - 1));
        pd[u] = -ur[p] + sum[p] * (ll)(dpt[u] - 1) + xevnoc.query(sum[p]);
        global = max(global, dp[u]), global = max(global, pd[u]);
        dp[u] = max(dp[u], 0ll);
        pd[u] = max(pd[u], 0ll);  
    }
    convex.insert(-sum[u], ru[u] + dp[u]);
    xevnoc.insert(-(ll)dpt[u], ur[u] + pd[u]);
    if(!keep) convex.segments.clear(), xevnoc.segments.clear();
}
//max_cost(u, v) = pd[u] + jump_cost(u, v) + dp[v]
//= pd[u] + (ur[u] - ur[p]) - (dpt[u] - dpt[lca] + 1) * sum[p] + (sum[v] - sum[lca]) * (dpt[u] - dpt[lca] + 1)
//+ (ru[v] - ru[lca]) - (sum[v] - sum[lca]) * dpt[lca] + dp[v]
//= pd[u] + (ur[u] - ur[p]) - (dpt[u] - dpt[lca] + 1) * sum[p] - sum[lca] * (dpt[u] - dpt[lca] + 1) - ru[lca] + sum[lca] * dpt[lca]
//+ ru[v] + dp[v] + sum[v] * (dpt[u] - 2 * dpt[lca] + 1){c = ru[v] + dp[v], m = sum[v], x = (dpt[u] - 2 * dpt[lca] + 1)}

//max_cost(u, lca, v) = pd[u] + jump_cost(u, lca) + jump_cost(lca, v) + dp[v]
//= pd[u] + (ur[u] - ur[p]) - (dpt[u] - dpt[lca] + 1) * sum[p]
//+ (ru[v] - ru[p]) - (sum[v] - sum[p]) * (dpt[lca] - 1) + dp[v]
//= pd[u] + (ur[u] - ur[p]) - (dpt[u] - dpt[lca] + 1) * sum[p] - ru[p] + sum[p] * (dpt[lca] - 1)
//+ ru[v] + dp[v] + sum[v] * (-dpt[lca] + 1)

//max_cost(v, u) = pd[v] + jump_cost(v, u) + dp[u]
//= pd[v] + (ur[v] - ur[p]) - (dpt[v] - dpt[lca] + 1) * sum[p] + (sum[u] - sum[lca]) * (dpt[v] - dpt[lca] + 1)
//+ (ru[u] - ru[lca]) - (sum[u] - sum[lca]) * dpt[lca] + dp[u]
//= (ru[u] - ru[lca]) - (sum[u] - sum[lca]) * dpt[lca] + dp[u] - ur[p] + (dpt[lca] - 1) * sum[p] + (sum[u] - sum[lca]) * (- dpt[lca] + 1)
//+ pd[v] + ur[v] + dpt[v] * (sum[u] - sum[lca] - sum[p]){c = pd[v] + ur[v], m = dpt[v], x = (sum[u] - sum[lca] - sum[p])}

//max_cost(v, lca, u) = pd[v] + jump_cost(v, lca) + jump_cost(lca, u) + dp[u]
//= pd[v] + (ur[v] - ur[p]) - (dpt[v] - dpt[lca] + 1) * sum[p]
//+ (ru[u] - ru[p]) - (sum[u] - sum[p]) * (dpt[lca] - 1) + dp[u]
//= (ru[u] - ru[p]) - (sum[u] - sum[p]) * (dpt[lca] - 1) + dp[u] - ur[p] + (dpt[lca] - 1) * sum[p]
// + pd[v] + ur[v] + dpt[v] * (-sum[p])
void dsu(int lca, int p, bool keep){
    int i, j, mx = -1, bg = -1;
    ll c_u_v, c_u_lca_v, c_v_u, c_v_lca_u, q_lca, q_acl;
    for(int x : adj[lca]) if(sbtr[x] > mx) mx = sbtr[x], bg = x;
    for(int x : adj[lca]) if(x != bg) dsu(x, lca, 0);
    if(bg != -1) dsu(bg, lca, 1);
    for(int x : adj[lca]){
        if(x == bg) continue;
        q_lca = convex.query(-(ll)dpt[lca] + 1);
        q_acl = xevnoc.query(-sum[p]);
        for(i = order[x]; i <= order[x] + sbtr[x] - 1; ++i){
            int u = vrtx[i];
            c_u_v = pd[u] + (ur[u] - ur[p]) - (ll)(dpt[u] - dpt[lca] + 1) * sum[p] - sum[lca] * (ll)(dpt[u] - dpt[lca] + 1) - ru[lca] + sum[lca] * (ll)dpt[lca];
            c_u_lca_v =  pd[u] + (ur[u] - ur[p]) - (ll)(dpt[u] - dpt[lca] + 1) * sum[p] - ru[p] + sum[p] * (ll)(dpt[lca] - 1);
            c_v_u = (ru[u] - ru[lca]) - (sum[u] - sum[lca]) * (ll)dpt[lca] + dp[u] - ur[p] + (ll)(dpt[lca] - 1) * sum[p] + (sum[u] - sum[lca]) * (ll)(- dpt[lca] + 1);
            c_v_lca_u = (ru[u] - ru[p]) - (sum[u] - sum[p]) * (ll)(dpt[lca] - 1) + dp[u] - ur[p] + (ll)(dpt[lca] - 1) * sum[p];
            c_u_v += convex.query((ll)dpt[u] - 2 * (ll)dpt[lca] + 1);
            c_v_u += xevnoc.query(sum[u] - sum[lca] - sum[p]);
            c_u_lca_v += q_lca;
            c_v_lca_u += q_acl;
            global = max(global, max(c_u_v, c_v_u));
            global = max(global, max(c_u_lca_v, c_v_lca_u));
        }
        for(i = order[x]; i <= order[x] + sbtr[x] - 1; ++i){
            int u = vrtx[i];
            convex.insert(sum[u], ru[u] + dp[u]);
            xevnoc.insert((ll)dpt[u], pd[u] + ur[u]);
        }
    }
    convex.insert(sum[lca], ru[lca] + dp[lca]);
    xevnoc.insert((ll)dpt[lca], pd[lca] + ur[lca]);
    if(!keep) convex.segments.clear(), xevnoc.segments.clear();
}
void solve(){
	int i, j;
    dp[n] = pd[n] = sum[n] = ru[n] = ur[n] = 0, dpt[n] = -1;
	t = 0, go(0, n, 0);
    assert(t == n);
    // for0(i, n) val += sbtr[i];
    // int mxht = -1;
    // for0(i, n) mxht = max(mxht, dpt[i]);
    global = -LLONG_MAX;
    dsu_prep(0, n, 0);
    dsu(0, n, 0);
    pl(global), nl;
    // pl(global), sp, pl(val), sp, pi(mxht), nl; 
}
int main(){
    // freopen("input.txt", "r", stdin); 
    // freopen("output.txt", "w", stdout); 
    // freopen("0.in", "r", stdin);
    // freopen("0.out", "wb", stdout);

    // freopen("1.in", "r", stdin);
    // freopen("1.out", "wb", stdout);

    // freopen("2.in", "r", stdin);
    // freopen("2.out", "wb", stdout);

    // freopen("3.in", "r", stdin);
    // freopen("3.out", "wb", stdout);

    // freopen("4.in", "r", stdin);
    // freopen("4.out", "wb", stdout);

    // freopen("5.in", "r", stdin);
    // freopen("5.out", "wb", stdout);

    // freopen("6.in", "r", stdin);
    // freopen("6.out", "wb", stdout);

    // freopen("7.in", "r", stdin);
    // freopen("7.out", "wb", stdout);

    // freopen("8.in", "r", stdin);
    //   freopen("8.out", "wb", stdout);

   //  freopen("9.in", "r", stdin);
   // freopen("9.out", "wb", stdout);
    int cs, ts, total = 0;
    si(ts);
    assert(ts >= 1 && ts <= 100); 
    for0(cs, ts){
    	int i, j;
    	si(n);
        total += n; 
        assert(n >= 2 && n <= 300000); 
    	for0(i, n){
            si(ara[i]); 
            assert(ara[i] >= -100000 && ara[i] <= 100000); 
        }
        for0(i, n) adj[i].clear();
    	for0(i, n - 1){
    		int u, v;
    		sii(u, v);
            assert(u >= 1 && u <= n); 
            assert(v >= 1 && v <= n); 
            --u, --v;
    		adj[u].push_back(v), adj[v].push_back(u);
    	}
    	solve();
    }
    assert(total <= 300000); 
}
Tester's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val
 
using namespace std; 
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout) 
#define primeDEN 727999983
#define int ll
 
 
/*
 
Copied from rajat1603 (nice usage of comparator).
 
Dynamic convex Hull 
 
Cool!!!
increasing slopes (lower convex HULL of points(intuitive) , max)
/////////   also change for equality case of slopes **
*/
struct line{
	long long a , b;
	double xleft;
	bool type;
	line(long long _a , long long _b){
		a = _a;
		b = _b;
		type = 0;
	}
	bool operator < (const line &other) const{
		if(other.type){
			return xleft < other.xleft;
		}
		return a > other.a;
	}
};
double meet(line x , line y){
	return 1.0 * (y.b - x.b) / (x.a - y.a);
}
struct cht{
	set < line > hull;
	cht(){
		hull.clear();
	}
	void resethull(){
		hull.clear();
	}
	typedef set < line > :: iterator ite;
	bool hasleft(ite node){
		return node != hull.begin();
	}
	bool hasright(ite node){
		return node != prev(hull.end());
	}
	void updateborder(ite node){
		if(hasright(node)){
			line temp = *next(node);
			hull.erase(temp);
			temp.xleft = meet(*node , temp);
			hull.insert(temp);
		}
		if(hasleft(node)){
			line temp = *node;
			temp.xleft = meet(*prev(node) , temp);
			hull.erase(node);
			hull.insert(temp);
		}
		else{
			line temp = *node;
			hull.erase(node);
			temp.xleft = -1e18;
			hull.insert(temp);
		}
	}
	bool useless(line left , line middle , line right){
		return meet(left , middle) > meet(middle , right);
	}
	bool useless(ite node){
		if(hasleft(node) && hasright(node)){
			return useless(*prev(node) , *node , *next(node));
		}
		return 0;
	}
	void addline(long long a , long long b){
		a*=-1;
		b*=-1;
		line temp = line(a , b);
		auto it = hull.lower_bound(temp);
		if(it != hull.end() && it -> a == a){
			if(it -> b > b){
				hull.erase(it);
			}
			else{
				return;
			}
		}
		hull.insert(temp);
		it = hull.find(temp);
		if(useless(it)){
			hull.erase(it);
			return;
		}
		while(hasleft(it) && useless(prev(it))){
			hull.erase(prev(it));
		}
		while(hasright(it) && useless(next(it))){
			hull.erase(next(it));
		}
		updateborder(it);
	}
	long long getbest(long long x){
		if(hull.empty()){
			return -1e18;
		}
		line query(0 , 0);
		query.xleft = x;
		query.type = 1;
		auto it = hull.lower_bound(query);
		it = prev(it);
		return -1*(it -> a * x + it -> b);
	}
};
 
cht hull1;
cht hull2;
int offseta1,offseta2,offsetb1,offsetb2,adderx1,adderx2;
 
int resethull(){
	hull1.resethull();
	hull2.resethull();
	offseta1 = 0;
	offseta2 = 0;
	offsetb1 = 0;
	offsetb2 = 0;
	adderx1 = 0;
	adderx2 = 0;
	return 0;
}
 
void insertline1(int a,int b){
	hull1.addline(a - offseta1, b - offsetb1 - adderx1*a);
}
 
void insertline2(int a,int b){
	hull2.addline(a - offseta2, b - offsetb2 - adderx2*a);
}
 
int foo[312345];
int moveup(int cur){
	offsetb2-=adderx2;
	offseta2++;
	adderx2+=foo[cur];
	
	adderx1+=1;
	offseta1+=foo[cur];
	offsetb1-=adderx1*foo[cur];
	return 0;
}
 
int query2(int x){
	int val=hull2.getbest(x+adderx2);
	val+=offsetb2;
	val+=offseta2*(x+adderx2);
	//max2=max(max2,val);
	return val;
}
 
int query1(int x){
	int val=hull1.getbest(x+adderx1);
	val+=offsetb1;
	val+=offseta1*(x+adderx1);
	//max1=max(max1,val);
	return val;
}
 
vector<vi> adj(312345);
int par[312345],subtree[312345];
int dep[312345];
int sum1[312345],sum2[312345],sumn[312345];
int dfs1(int cur,int paren){
	dep[cur]=dep[paren]+1;
	par[cur]=paren;
	sumn[cur]=sumn[paren]+foo[cur];
	sum2[cur]=sum2[paren]+sumn[cur];
	sum1[cur]=dep[cur]*foo[cur]+sum1[paren];
	subtree[cur]=1;
	int i;
	rep(i,adj[cur].size()){
		if(adj[cur][i]!=paren){
			dfs1(adj[cur][i],cur);
			subtree[cur]+=subtree[adj[cur][i]];
		}
	}
	return 0;
}
 
int dp1[312345],dp2[312345];
 
int maxi;
int dfs2(int cur,int ances){
	int i;
	int val=query1(dep[cur]-dep[ances]+1)+dp2[cur];
	val+=sum2[cur]-sum2[ances];
	val-=sumn[ances]*(dep[cur]-dep[ances]);
	maxi=max(maxi,val);
	val=query2(sumn[cur]-sumn[ances])+dp1[cur];
	val+=sum1[cur]-sum1[ances];
	val-=(dep[ances])*(sumn[cur]-sumn[ances]);
	maxi=max(maxi,val);
	rep(i,adj[cur].size()){
		if(adj[cur][i]==par[cur])
			continue;
		dfs2(adj[cur][i],ances);
	}
	return 0;
}
 
int adddfs(int cur,int ances){
	int i;
	int val=sum1[cur]-sum1[ances] - (dep[ances]+1)*(sumn[cur]-sumn[ances]);
	insertline1(sumn[cur]-sumn[ances],dp1[cur]+val);
 
	val=sum2[cur]-sum2[ances]-(dep[cur]-dep[ances])*sumn[ances];
	insertline2(dep[cur]-dep[ances],dp2[cur]+val);
	rep(i,adj[cur].size()){
		if(adj[cur][i]==par[cur])
			continue;
		adddfs(adj[cur][i],ances);
	}
	return 0;
}
// dp1 starts after i and goes down
// dp2 ends after i and goes up from down. 
int dfs(int cur){
	int i;
	int maxim=0,child=-1;
	rep(i,adj[cur].size()){
		if(adj[cur][i]!=par[cur]){
			if(subtree[adj[cur][i]]>maxim){
				maxim=subtree[adj[cur][i]];
				child=adj[cur][i];
			}
		}
	}
	rep(i,adj[cur].size()){
		if(adj[cur][i]!=par[cur] && adj[cur][i]!=child){
			dfs(adj[cur][i]);
			resethull();
		}
	}
	if(child==-1){
		dp1[cur]=0;
		dp2[cur]=0;
		insertline1(0,dp1[cur]);
		insertline2(0,dp2[cur]);
		moveup(cur);
		return 0;
	}
	dfs(child);
	rep(i,adj[cur].size()){
		if(adj[cur][i]!=par[cur] && adj[cur][i]!=child){
			dfs2(adj[cur][i],par[cur]);
			adddfs(adj[cur][i],cur);
		}
	}
	dp1[cur]=query1(2)+foo[cur];
	dp2[cur]=query2(foo[cur])+foo[cur];
	maxi=max(maxi,dp1[cur]);
	maxi=max(maxi,dp2[cur]);
	dp1[cur]=max(dp1[cur],0LL);
	dp2[cur]=max(dp2[cur],0LL);
	insertline1(0,dp1[cur]);
	insertline2(0,dp2[cur]);
	moveup(cur);
	return 0;
} 
signed main(){
	std::ios::sync_with_stdio(false); cin.tie(NULL);
	int t;
	cin>>t;
	while(t--){
		maxi=inf;
		maxi*=inf;
		maxi*=-1;
		int n;
		cin>>n;
		resethull();
		int i;
		f(i,1,n+1){
			cin>>foo[i];
			adj[i].clear();
		}
		int u,v;
		rep(i,n-1){
			cin>>u>>v;
			adj[u].pb(v);
			adj[v].pb(u);
		}
		dfs1(1,0);
		dfs(1);
		cout<<maxi<<endl;
	}
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class JMPFEVER{
    //SOLUTION BEGIN
    long IINF = (long)1e18, ans = -IINF;
    int ti = -1, n;
    int[][] g;
    int[] sub, par, dep, st, en, eu;
    long[] A, sum, sum1, sum2, up, down;
    CHT downToUp, upToDown;
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        n = ni();
        A = new long[1+n];
        for(int i = 1; i<= n; i++)A[i] = nl();
        int[] from = new int[n-1], to = new int[n-1];
        for(int i = 0; i< n-1; i++){from[i] = ni();to[i] = ni();}
        g = make(1+n, n-1, from, to);
        sub = new int[1+n];par = new int[1+n];dep = new int[1+n];
        st = new int[1+n];en = new int[1+n];eu = new int[1+n];
        sum = new long[1+n]; sum2 = new long[1+n]; sum1 = new long[1+n];
        up = new long[1+n];down = new long[1+n];
        ti = -1;
        prepare(1, 0);
        ans = -IINF;
        resetHulls();
        dfs(1, true);
        pn(ans);
    }
    void resetHulls(){
        downToUp = new CHT(CHT.MAX);
        upToDown = new CHT(CHT.MAX);
    }
    void dfs(int lca, boolean keep) throws Exception{
        int hc = -1;
        for(int v:g[lca]){
            if(v == par[lca])continue;
            if(hc == -1 || sub[v] > sub[hc])hc = v;
        }
        for(int v:g[lca])
            if(v != par[lca] && v != hc)
                dfs(v, false);
        
        if(hc != -1){
            dfs(hc, true);
            for(int v:g[lca]){
                if(v == par[lca] || v == hc)continue;
                
                for(int t = st[v]; t <= en[v]; t++){
                    int w = eu[t];
                    ans = Math.max(ans, downToUp.query(-sum[lca]+A[lca]+sum[w]-sum[lca])
                            -sum1[lca]+dep[lca]*sum[lca]+(-dep[lca]+1)*(A[lca]+sum[w]-sum[lca])+down(w, lca));
                    ans = Math.max(ans, upToDown.query(-dep[lca]*2+1+dep[w])
                            -sum2[lca]+dep[lca]*sum[lca]+up(w, lca)+(dep[w]-dep[lca]+1)*(A[lca]-sum[lca]));
                }
                for(int t = st[v]; t <= en[v]; t++){
                    int w = eu[t];
                    downToUp.add(dep[w], up[w]+sum1[w]);
                    upToDown.add(sum[w], down[w]+sum2[w]);
                }
            }
            
            up[lca] = downToUp.query(-sum[lca]+A[lca]+sum[lca]-sum[lca])
                            -sum1[lca]+dep[lca]*sum[lca]+(-dep[lca]+1)*(A[lca]+sum[lca]-sum[lca]);
            down[lca] = upToDown.query(-dep[lca]*2+1+dep[lca])
                            -sum2[lca]+dep[lca]*sum[lca]+(dep[lca]-dep[lca]+1)*(A[lca]-sum[lca]);
            
            ans = Math.max(ans, Math.max(up[lca], down[lca]));
        }
        up[lca] = Math.max(up[lca], 0);
        down[lca] = Math.max(down[lca], 0);
        
        downToUp.add(dep[lca], up[lca]+sum1[lca]);
        upToDown.add(sum[lca], down[lca]+sum2[lca]);
        
        if(!keep)resetHulls();
    }
    void prepare(int u, int p){
        eu[++ti] = u;
        st[u] = ti;
        sub[u] = 1;
        par[u] = p;
        dep[u] = dep[p]+1;
        sum[u] = sum[p]+A[u];
        
        sum1[u] = sum1[p]+sum[u];
        sum2[u] = sum2[p]+dep[u]*A[u];
        
        for(int v:g[u]){
            if(v == p)continue;
            prepare(v, u);
            sub[u] += sub[v];
        }
        en[u] = ti;
    }
    //u inclusive, l exclusive paths
    long up(int u, int l){
        return up[u]+sum1[u]-sum1[l]-sum[l]*(dep[u]-dep[l]);
    }
    long down(int u, int l){
        return down[u]+sum2[u]-sum2[l]-dep[l]*(sum[u]-sum[l]);
    }
    int[][] make(int n, int e, int[] from, int[] to){
        int[] c = new int[n];int[][] g = new int[n][];
        for(int i = 0; i< e; i++){c[from[i]]++;c[to[i]]++;}
        for(int i = 0; i< n; i++)g[i] = new int[c[i]];
        for(int i = 0; i< e; i++){
            g[from[i]][--c[from[i]]] = to[i];
            g[to[i]][--c[to[i]]] = from[i];
        }
        return g;
    }
    class CHT {
        //http://codeforces.com/contest/932/submission/35323630
        static final int MIN = -1, MAX = 1;
        long IINF = (long)1e19;
        public TreeSet<CHT.Line> hull;
        int type;
        boolean query = false;

        Comparator<CHT.Line> comp = new Comparator<CHT.Line>() {
            public int compare(CHT.Line a, CHT.Line b) {
                if (!query) return type * Long.compare(a.m, b.m);
                if (a.left == b.left)
                    return Long.compare(a.m, b.m);
                return Double.compare(a.left, b.left);
            }
        };
        public CHT(final int type) {
            this.type = type;
            hull = new TreeSet<>(comp);
        }
        public void add(long m, long c){
            add(new Line(m, c));
        }
        public void add(CHT.Line a) {
            CHT.Line[] LR = {hull.lower(a), hull.ceiling(a)};
            for (int i = 0; i < 2; i++)
                if (LR[i] != null && LR[i].m == a.m) {
                    if (type == 1 && LR[i].b >= a.b)
                        return;
                    if (type == -1 && LR[i].b <= a.b)
                        return;
                    remove(LR[i]);
                }
            hull.add(a);
            CHT.Line L = hull.lower(a), R = hull.higher(a);
            if (L != null && R != null && a.inter(R) <= R.left) {
                hull.remove(a);
                return;
            }
            CHT.Line LL = (L != null) ? hull.lower(L) : null;
            CHT.Line RR = (R != null) ? hull.higher(R) : null;
            if (L != null) a.left = a.inter(L);
            if (R != null) R.left = a.inter(R);
            while (LL != null && L.left >= a.inter(LL)) {
                remove(L);
                a.left = a.inter(L = LL);
                LL = hull.lower(L);
            }
            while (RR != null && R.inter(RR) <= a.inter(RR)) {
                remove(R);
                RR.left = a.inter(R = RR);
                RR = hull.higher(R);
            }
        }
        public long query(long x){
            CHT.Line temp = new CHT.Line(0, 0, 0);
            temp.left = x;
            query = true;
            Line line = hull.floor(temp);
            if(line == null)return IINF*-this.type;
            long ans = line.eval(x);
            query = false;
            return ans;
        }
        private void remove(CHT.Line x) {
            hull.remove(x);
        }
        public int size() {
            return hull.size();
        }
        public class Line {
            long m;
            long b;
            double left = Long.MIN_VALUE;
            public Line(long m, long x, long y) {
                this.m = m;
                this.b = -m * x + y;
            }
            public Line(long m, long b) {
                this.m = m;
                this.b = b;
            }
            public long eval(long x) {
                return m * x + b;
            }
            public double inter(CHT.Line x) {
                return (double) (x.b - this.b) / (double) (this.m - x.m);
            }
        }
    }
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    DecimalFormat df = new DecimalFormat("0.00000000000");
    static boolean multipleTC = true;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
        new JMPFEVER().run();
    }
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
                }
            }
            return st.nextToken();
        }

        String nextLine() throws Exception{
            String str = "";
            try{   
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            }  
            return str;
        }
    }
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile: