PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: satyam_343
Tester: sushil2006
Editorialist: iceknight1093
DIFFICULTY:
Hard
PREREQUISITES:
DFS, Sorting, Segment tree/fenwick tree, binary search
PROBLEM:
You’re given a tree with N vertices. Vertex i has value A_i.
The score of a tree is defined to be the minimum number of following operations required to make every A_i equal to 0:
- Choose a subset S of vertices of the tree such that for any x, y\in S such that x\neq y, x is not an ancestor of y.
- Then, reduce A_x by 1 for each x \in S.
Answer Q queries.
In each query, you’re given a value K. You want the score of the tree to be \leq K.
Find the minimum number of operations needed to achieve this, where in one operation you can choose two nodes u, v such that A_u \gt 0, then reduce A_u by 1 and increase A_v by 1.
EXPLANATION:
First, a recap of the solution to the easy version, which solves a single query.
- Initially, for each vertex u, compute the following quantities with a DFS:
- S_u, the sum of values on the 1\to u path.
- \text{sub}_u, which is the maximum of S_v across all v in the subtree of u.
- L_u, the number of leaves in the subtree of u.
- If L_1\cdot K \lt \text{sum}(A), the answer is -1.
- First, the subtraction phase.
Define X to be the sum of \max(0, \min(A_u, \text{sub}_u - (S_u - A_u) - K)) across all u.
This is the number of subtractions necessary to make all root → node paths have a sum that’s \leq K. - Next, the addition phase.
Compute F to be the amount of free space in the leaves, after subtractions.- If F \geq X, the answer is X.
- Otherwise, process vertices y in descending order of their L_y values.
- For each y, perform as many subtractions as needed to make (F-X) become non-negative. Each subtraction increases the difference by L_y - 1 so this is some simple math.
Of course, A_y cannot fall below 0.
The answer is the number of subtraction operations done.
This will take \mathcal{O}(N) per query (if the vertices are sorted by leaf count beforehand), which is of course too slow.
Some parts of the above solution don’t depend on the value of K, and so can be precomputed.
These include the values of S_u, L_u, \text{sub}_u, as well as sorting vertices based on leaf count.
Everything else depends on K though, so we’ll need to look into speeding that up.
First, let’s look at computing the initial value of X, which is
Notice that for a vertex u, there are three possible “states”:
- When A_u \leq \text{sub}_u - (S_u - A_u) - K, this vertex will add A_u to the sum.
This can be rearranged to just the inequality K \leq \text{sub}_u - S_u. - When \text{sub}_u - (S_u - A_u) - K \leq 0, this vertex will add 0 to the sum.
This rearranges to \text{sub}_u - (S_u - A_u) \leq K. - In every other case, this vertex adds \text{sub}_u - (S_u - A_u) - K to the sum.
To compute the sum across all u, observe that we only need to know the sum of A_u across all state 1 vertices, the sum of \text{sub}_u - (S_u - A_u) across all state 3 vertices, and the number of state 3 vertices.
If these quantities are x, y, z respectively, the overall sum will be x+y-K\cdot z.
We want to compute these quantities for various values of K.
One way to do this quickly, is to solve the problem offline.
That is, suppose you start with K = 0, and then keep increasing K.
Then, every vertex starts out in state 1, and for some u:
- When K reaches \text{sub}_u - S_u + 1, the vertex shifts to state 3.
- When K reaches \text{sub}_u - S_u + A_u, the vertex shifts to state 2, after which it will always remain there.
So, there are only 2N “important” points where quantities can change.
Whenever something changes, recomputations are easy in constant time: only one vertex contributes to a different summation.
This allows us to compute the relevant values for all queries, quickly enough.
We looked at X, the number of always-necessary subtractions.
Let’s now look at F, the amount of free space available to us.
One way to compute F is as follows:
- Start with F equal to the sum of S_u values of all leaves.
- Then, for each operation performed in the subtraction phase, say on vertex y, subtract L_y from F.
At the end of all operations, F will equal the amount of space taken up by the leaves. - Finally, replace F by K\cdot L_1 - F, since that’s the empty space available to us.
K\cdot L_1 and the initial sum of all S_u values of leaves are both constants, so the only hard part is the second step:
However, observe that the second step is in fact quite similar to what we wanted when computing the value of X earlier.
Specifically, the value we want to compute now is just
The exact same analysis shows that each vertex will always be in one of three states, and there are only 2N total transitions between states, so once again all relevant values can be computed by processing queries offline.
Now, let’s look at a single query once again.
We know the initial values of X and F for it, computed offline as above.
If X \leq F then the answer is just X, of course.
This leaves the case of X\gt F, when we needed to perform some additional subtraction operations to obtain more free space.
Recall that this was done in decreasing order of subtree-leaf-count.
In particular, let c_u be the initial number of subtraction operations done on vertex u.
We can then perform at most (A_u - c_u) more operations on u, each giving us a “profit” of L_u - 1; and we want a total profit of at least X - F.
Unsurprisingly, this is more of the same of what we did previously.
Let u_1, u_2, \ldots, u_N be the (non-increasing) sorted order of vertices in terms of their number of leaves.
Our goal is then to find the smallest index p such that
If this p is known, the number of extra operations is then \sum_{i=1}^{p-1} (A_{u_i} - c_{u_i}), along with enough operations at u_p itself to make up the difference (which is just simple math).
Our goal is hence to just find p.
Once again, each vertex has three states of contribution:
- For K \leq \text{sub}_u - S_u, it contributes 0.
- For K \geq \text{sub}_u - S_u + A_u it contributes A_u\cdot (L_u - 1).
- For everything in between, it contributes (A_u - \text{sub}_u + (S_u - A_u) + K) \cdot (L_u - 1) to the sum.
This is once again of the form r_u + K\cdot s_u where r_u and s_u are constants that depend only on u.
So, we’re still able to process K offline, with at most 2N state switches.
For each query, we’re now able to binary search to find the smallest valid prefix p: simply build a segment tree that stores appropriate information about state-1 and state-3 vertices in each prefix so that it can be queried quickly.
Depending on whether the binary search is done separately or built into the segment tree (“segment tree walking”), the complexity is either \mathcal{O}(\log^2 N) or \mathcal{O}(\log N) per query, along with a general \mathcal{O}((N+Q)\log(N+Q)) from all the sorting and offline processing.
TIME COMPLEXITY:
\mathcal{O}((N+Q)\log(N+Q) + Q\log^2 N) per testcase.
CODE:
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());
/**
* Point-update Segment Tree
* Source: kactl
* Description: Iterative point-update segment tree, ranges are half-open i.e [L, R).
* f is any associative function.
* Time: O(logn) update/query
*/
template<class T, T unit = T()>
struct SegTree {
T f(T a, T b) {
for (int i = 0; i < 4; ++i)
a[i] += b[i];
return a;
}
vector<T> s; int n;
SegTree(int _n = 0, T def = unit) : s(2*_n, def), n(_n) {}
void update(int pos, T val) {
for (s[pos += n] = val; pos /= 2;)
s[pos] = f(s[pos * 2], s[pos * 2 + 1]);
}
T query(int b, int e) {
T ra = unit, rb = unit;
for (b += n, e += n; b < e; b /= 2, e /= 2) {
if (b % 2) ra = f(ra, s[b++]);
if (e % 2) rb = f(s[--e], rb);
}
return f(ra, rb);
}
};
int main()
{
ios::sync_with_stdio(false); cin.tie(0);
int test = 0;
int t; cin >> t;
while (t--) {
++test;
int n, q; cin >> n >> q;
vector a(n, 0ll);
for (ll &x : a) cin >> x;
vector adj(n, vector<int>());
for (int i = 0; i < n-1; ++i) {
int u, v; cin >> u >> v;
adj[--u].push_back(--v);
adj[v].push_back(u);
}
/**
* Each node will be "active" for some interval of time, say [Li, Ri] for i
* Initial number of ops: sum of values of all "finished" nodes, plus sum(k - Li) across all "active" i
* Can find easily if I know count and sum of L, of active nodes
*
* Similarly for extra values, only sum of R/count of active nodes in each prefix matter, so can build segtree on nodes in sorted order of leaf count
* Process k in descending order, this becomes O(qlogn + nlogn)
*/
vector<ll> sm(n), mx(n);
vector<array<ll, 2>> starts, ends;
vector<int> leaves(n);
ll leafct = 0, leafsum = 0, totsum = 0;
vector<array<ll, 2>> interval(n);
auto dfs = [&] (const auto &self, int u, int p, ll up) -> void {
leaves[u] = 0;
bool leaf = true;
totsum += a[u];
for (int v : adj[u]) if (v != p) {
sm[v] = sm[u] + a[v];
self(self, v, u, up + a[u]);
leaf = false;
leaves[u] += leaves[v];
mx[u] = max(mx[u], mx[v]);
}
mx[u] = max(mx[u], sm[u]);
leaves[u] += leaf;
starts.push_back({mx[u] - up, u});
ends.push_back({mx[u] - up - a[u], u});
interval[u] = {mx[u] - up - a[u], mx[u] - up};
if (leaf) {
++leafct;
leafsum += sm[u];
}
};
sm[0] = a[0];
dfs(dfs, 0, 0, 0);
ranges::sort(starts); ranges::reverse(starts);
ranges::sort(ends); ranges::reverse(ends);
vector ord(n, 0), pos(n, 0);
iota(begin(ord), end(ord), 0);
ranges::sort(ord, [&] (int i, int j) {return leaves[i] > leaves[j];});
for (int i = 0; i < n; ++i)
pos[ord[i]] = i;
vector<array<ll, 3>> queries;
for (int i = 0; i < q; ++i) {
ll k; cin >> k;
queries.push_back({k, i, 0});
}
ranges::sort(queries); ranges::reverse(queries);
SegTree<array<ll, 4>> seg1(n), seg2(n);
for (int i = 0; i < n; ++i)
seg1.update(pos[i], {1ll*a[i]*(leaves[i] - 1), a[i] * (leaves[i] > 1), 0, 0});
int p1 = 0, p2 = 0;
ll activesum = 0, activect = 0, done = 0;
ll val1 = 0, val2 = 0, done2 = 0;
for (auto &[k, id, ans] : queries) {
if (k*leafct < totsum) {
ans = -1;
continue;
}
while (p1 < n) {
auto [L, i] = starts[p1];
if (L >= k) {
// Activate
// cerr << "Activate " << i << '\n';
activesum += L;
++activect;
seg1.update(pos[i], {0, 0, 0, 0});
if (leaves[i] > 1) seg2.update(pos[i], {1ll*(leaves[i]-1)*(L - a[i]), leaves[i]-1, L - a[i], 1});
val1 += 1ll*L*leaves[i];
val2 += leaves[i];
}
else break;
++p1;
}
while (p2 < n) {
auto [R, i] = ends[p2];
if (R >= k) {
// Deactivate
// cerr << "Deactivate " << i << '\n';
activesum -= R + a[i];
--activect;
done += a[i];
seg2.update(pos[i], {0, 0, 0, 0});
done2 += 1ll*a[i]*leaves[i];
val2 -= leaves[i];
val1 -= 1ll*(R+a[i])*leaves[i];
}
else break;
++p2;
}
ll cur = done + activesum - k*activect;
// How much free space do I have?
// k*leafcount - sum(leaf values)
// leaf values change by a[i]*leaf[i] when i is operated on
// store sum(a[i]*leaf[i]) across all completed operations
// and also sum(leaf[i] * (L[i] - k)) across active ones
ll have = k*leafct - (leafsum - done2 - (val1 - k*val2));
if (cur <= have) { // No extra moves needed
ans = cur;
continue;
}
ll req = cur - have;
int lo = -1, hi = n-1;
while (lo < hi) {
int mid = (lo + hi + 1) / 2;
auto res1 = seg1.query(0, mid+1), res2 = seg2.query(0, mid+1);
ll usage = res1[0] + k*res2[1] - res2[0];
if (usage <= req) lo = mid;
else hi = mid - 1;
}
// Everything that's <= lo, and then maybe something from lo+1
auto res1 = seg1.query(0, lo+1), res2 = seg2.query(0, lo+1);
ll usage = res1[0] + k*res2[1] - res2[0];
ll ops = res1[1] + k*res2[3] - res2[2];
req -= usage;
if (req > 0) {
int i = ord[lo+1];
ops += (req + leaves[i] - 2) / (leaves[i] - 1);
}
ans = cur + ops;
}
ranges::sort(queries, [&] (auto A, auto B) {return A[1] < B[1];});
for (auto [k, id, ans] : queries) cout << ans << '\n';
}
}
Tester's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl
#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)
template<typename T>
void amin(T &a, T b) {
a = min(a,b);
}
template<typename T>
void amax(T &a, T b) {
a = max(a,b);
}
#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif
/*
*/
const int MOD = 1e9 + 7;
const int N = 2e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;
template<typename T>
struct fenwick {
int n;
vector<T> tr;
int LOG = 0;
fenwick() {
}
fenwick(int n_) {
n = n_;
tr = vector<T>(n + 1);
while((1<<LOG) <= n) LOG++;
}
int lsb(int x) {
return x & -x;
}
void pupd(int i, T v) {
for(; i <= n; i += lsb(i)){
tr[i] += v;
}
}
T sum(int i) {
T res = 0;
for(; i; i ^= lsb(i)){
res += tr[i];
}
return res;
}
T query(int l, int r) {
if (l > r) return 0;
T res = sum(r) - sum(l - 1);
return res;
}
int lower_bound(T s){
// first pos with sum >= s
if(sum(n) < s) return n+1;
int i = 0;
rev(bit,LOG-1,0){
int j = i+(1<<bit);
if(j > n) conts;
if(tr[j] < s){
s -= tr[j];
i = j;
}
}
return i+1;
}
int upper_bound(T s){
return lower_bound(s+1);
}
};
vector<ll> adj[N];
vector<ll> a(N), leaves(N);
vector<ll> dp(N);
vector<array<ll,3>> node_ranges;
void dfs1(ll u, ll p){
ll c = 0;
ll mx = 0;
leaves[u] = 0;
trav(v,adj[u]){
if(v == p) conts;
dfs1(v,u);
c++;
amax(mx,dp[v]);
leaves[u] += leaves[v];
}
leaves[u] += !c;
node_ranges.pb({mx,mx+a[u],u});
dp[u] = mx+a[u];
}
void solve(int test_case){
ll n,q; cin >> n >> q;
rep1(i,n) cin >> a[i];
rep1(i,n){
adj[i].clear();
}
rep1(i,n-1){
ll u,v; cin >> u >> v;
adj[u].pb(v), adj[v].pb(u);
}
node_ranges.clear();
dfs1(1,-1);
ll sum = accumulate(a.begin()+1,a.begin()+n+1,0ll);
ll mnv = ceil2(sum,leaves[1]);
map<ll,vector<pll>> mp;
for(auto [l,r,u] : node_ranges){
mp[l].pb({1,u});
mp[r+1].pb({2,u});
}
rep1(i,q){
ll k; cin >> k;
mp[k].pb({3,i});
}
vector<pll> range_u(n+5);
for(auto [l,r,u] : node_ranges){
range_u[u] = {l,r};
}
vector<ll> ans(q+5,-1);
ll untouched = 0;
rep1(i,n) untouched += a[i];
ll active_sumr = 0, active_countr = 0;
ll active_leaf_sum = 0, active_leaf_contrib = 0;
ll done_leaf_sum = 0;
fenwick<ll> fenw_active_cnt(n+5), fenw_active_sum2(n+5), fenw_active_sum(n+5), fenw_active_sum3(n+5);
fenwick<ll> fenw_done_cnt(n+5), fenw_done_sum(n+5);
for(auto [k,events] : mp){
for(auto [t,u] : events){
if(t == 1){
untouched -= a[u];
auto [l,r] = range_u[u];
active_sumr += r;
active_countr++;
active_leaf_sum += leaves[u];
active_leaf_contrib += l*leaves[u];
fenw_active_cnt.pupd(leaves[u],1);
fenw_active_sum.pupd(leaves[u],l*leaves[u]);
fenw_active_sum2.pupd(leaves[u],l);
fenw_active_sum3.pupd(leaves[u],leaves[u]);
}
else if(t == 2){
auto [l,r] = range_u[u];
active_sumr -= r;
active_countr--;
done_leaf_sum += a[u]*leaves[u];
active_leaf_sum -= leaves[u];
active_leaf_contrib -= l*leaves[u];
fenw_active_cnt.pupd(leaves[u],-1);
fenw_active_sum.pupd(leaves[u],-l*leaves[u]);
fenw_active_sum2.pupd(leaves[u],-l);
fenw_active_sum3.pupd(leaves[u],-leaves[u]);
fenw_done_cnt.pupd(leaves[u],a[u]);
fenw_done_sum.pupd(leaves[u],a[u]*leaves[u]);
}
else{
ll id = u;
ll removed = untouched+active_sumr-k*active_countr;
ll avail = done_leaf_sum+k*active_leaf_sum-active_leaf_contrib;
avail = leaves[1]*k-avail;
// debug(removed,avail);
if(avail >= removed){
ans[id] = removed;
conts;
}
if(k < mnv){
conts;
}
// avail < removed
// find min pos s.t avail < removed if suff is considered
ll lo = 1, hi = n;
ll mnp = -1;
ll fa = 0, fr = 0;
while(lo <= hi){
ll mid = (lo+hi)>>1;
ll curr_avail = avail, curr_removed = removed;
ll cnt_active = fenw_active_cnt.query(mid,n);
ll sum_active = fenw_active_sum.query(mid,n);
ll cnt_done = fenw_done_cnt.query(mid,n);
ll sum_done = fenw_done_sum.query(mid,n);
ll sum2 = fenw_active_sum2.query(mid,n);
ll sum3 = fenw_active_sum3.query(mid,n);
curr_removed += cnt_done+k*cnt_active-sum2;
curr_avail += sum_done+k*sum3-sum_active;
if(curr_avail < curr_removed){
mnp = mid;
fa = curr_avail, fr = curr_removed;
hi = mid-1;
}
else{
lo = mid+1;
}
}
ll res = fr+ceil2(fr-fa,mnp-2);
ans[id] = res;
}
}
}
rep1(id,q) cout << ans[id] << endl;
}
int main()
{
fastio;
int t = 1;
cin >> t;
rep1(i, t) {
solve(i);
}
return 0;
}