# TREEREQ2 - Editorial

Author: sushil2006
Tester: tabr
Editorialist: iceknight1093

Easy-Medium

# PREREQUISITES:

DFS, binary search/binary lifting

# PROBLEM:

You’re given a tree on N vertices, vertex i has value A_i.
There are M constraints of the form (u, r, k).

Find the minimum possible sum of values of set S of vertices such that, for every constraint (u, r, k):

• The number of elements of S in the subtree of u, when rooted at r, is exactly k.

# EXPLANATION:

This editorial will continue from the solution to the easy version.

Recall that our solution was as follows: after rooting at 1, each (u, r, k) constraint converts to either “exactly k vertices from some subtree should be chosen”, or “exactly k vertices outside some subtree should be chosen.”
Also recall that for every constrained vertex, there was a certain subset of vertices it was allowed to choose from.

Now, rather than fixing the size X of the chosen set, let’s leave it as a variable.
Consider what happens during our DFS to compute the answer.
When we’re at vertex u, we can have constraints of the form “pick k vertices inside S_u” and/or “pick k vertices outside S_u”.
If neither constraint is present, we can just ignore u.
If both constraints are present, that uniquely fixes the value of X after which the problem can be solved in linear time using the easy version.

Now, suppose u has the constraint “exactly k vertices from S_u should be chosen.”
Note that while performing the DFS, we’ve already selected some vertices from S_u while processing its descendants.
The key observation here, is that the number of already selected vertices will be of the form p_u\cdot X + q_u for some constants p_u and q_u.
This is simply because the number of already chosen vertices is the sum of either constants, or expressions of the form (X-k) for some constant k, across certain descendants of u.

So, if we pick exactly y vertices from the ones available to satisfy u, we want

p_u\cdot X + q_u + y = k

to hold.
Note that if we fix y, everything other than X is a constant, so X is uniquely determined (unless p_u = 0 in which case X can be anything at all.)
Also note that if the constraint was instead “choose k vertices outside S_u”, the right side would be X-k instead, but we can still compute X uniquely upon fixing y (unless the coefficient of X is 0, in which case again any X is valid.)

So, for each valid y, compute the appropriate X (if it exists), and add the value of the selected vertices to \text{ans}_X.
When the coefficient of X is 0 we instead have y uniquely determined, and want to add the value to every \text{ans}_X - so just store the sum of all such things in a separate variable and add it in at the end.

Note that if for any X no valid y exists, there exists no solution for that X.
y must be \geq 0, and is bounded above by the size of the vertices available to u (recall that in particular, the vertices available to u are exactly those for which u is the nearest constrained ancestor).
This gives us both lower and upper bounds for X, make sure to compute those.

Once this is done for every u, we’ll have the final values of every one of \text{ans}_1, \text{ans}_2, \ldots, \text{ans}_N and also the lower and upper bounds on X.
The final answer is then the minimum value of \text{ans}_X across all X between the lower and upper bounds.

A couple of final notes:

1. For a vertex u, iterating over all possible values of y is fine since it sums up to N across all vertices - after all, a vertex is available to exactly one of its ancestors.
2. One step of the solution is to find, given u and r such that u is an ancestor of r, the first vertex on the u\to r path.
In the easy version, this was doable with brute force.
Here, we must instead use binary lifting or binary search (on the in-times of children of u) to find the appropriate vertex in \mathcal{O}(\log N) time.

# TIME COMPLEXITY:

\mathcal{O}((N+M)\log N) per testcase.

# CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

struct Tree {
vector<int> d, tin, tout, par;
vector<vector<int>> ch;
int n, timer;
bool initialized = false;
bool dfsed = false;

void init(int nn){
n = nn;
d.resize(n + 1);
lift.resize(n + 1);
tin.resize(n + 1);
tout.resize(n + 1);
par.resize(n + 1);
for (int i = 1; i <= n; i++) adj[i].clear();
for (int i = 0; i <= n; i++) lift[i].resize(20, 0);
ch.resize(n + 1);
initialized = true;
}

void addEdge(int u, int v){
if (!initialized){
cout << "STUPID INITIALIZE\n";
exit(0);
}
}

void build(){
for (int j = 1; j < 20; j++){
for (int i = 1; i <= n; i++){
lift[i][j] = lift[lift[i][j - 1]][j - 1];
}
}
}

void dfs(int u, int par1){
par[u] = par1;
tin[u] = ++timer;
for (int v : adj[u]){
if (v != par1){
d[v] = d[u] + 1;
lift[v][0] = u;
dfs(v, u);

ch[u].push_back(v);
}
}
tout[u] = timer;
}

void dfs(int root = 1){
if (!initialized){
cout << "STUPID INITIALIZE\n";
exit(0);
}
d[root] = 0;
timer = 0;
dfs(root, 0);
build();
dfsed = true;
}

int jump(int x, int depth){
for (int i = 0; i < 20; i++) if (depth >> i & 1){
x = lift[x][i];
}
return x;
}

int lca(int a, int b){
if (!dfsed){
cout << "STUPID DFS\n";
exit(0);
}
if (d[a] < d[b]) swap(a, b);
int del = d[a] - d[b];
for (int i = 0; i < 20; i++) if (del >> i & 1) a = lift[a][i];

if (a == b) return a;

for (int i = 19; i >= 0; i--) if (lift[a][i] != lift[b][i]){
a = lift[a][i];
b = lift[b][i];
}
return lift[a][0];
}

int dist(int a, int b){
return d[a] + d[b] - 2 * d[lca(a, b)];
}

bool anc(int x, int y){
return tin[x] <= tin[y] && tout[x] >= tout[y];
}
};

void Solve()
{
int n, m; cin >> n >> m;

vector <int> a(n + 1);
vector <bool> b(n + 1, false);

for (int i = 1; i <= n; i++){
cin >> a[i];
}

Tree T;
T.init(n);

for (int i = 1; i < n; i++){
int u, v; cin >> u >> v;

}

T.dfs();

auto getnextindir = [&](int u, int v){
if (u == v){
return u;
}

// cout << "QUERYING " << u << " " << v << "\n";

if (T.tin[u] <= T.tin[v] && T.tout[u] >= T.tout[v]){
//   cout << "CASE 1\n";
int lo = 0, hi = (int)T.ch[u].size() - 1;

while (lo != hi){
int mid = (lo + hi + 1) / 2;

int got = T.ch[u][mid];

// cout << mid << " " << got << " " << T.tin[v] << " " << T.tin[got] << "\n";
if (T.tin[v] >= T.tin[got]){
lo = mid;
} else {
hi = mid - 1;
}
}

return T.ch[u][lo];
}
return T.par[u];
};

int tot = -1;

vector <int> f1(n + 1, -1), f2(n + 1, -1);
bool bad = false;

for (int i = 1; i <= m; i++){
int u, r, k; cin >> u >> r >> k;

int v = getnextindir(u, r);
if (v == u){
if (tot != k && tot != -1){
}
tot = k;
continue;
}
if (v == T.par[u]){
if (f1[u] != k && f1[u] != -1){
}
f1[u] = k;
b[u] = true;
} else {
if (f2[v] != k && f2[v] != -1){
}
f2[v] = k;
b[v] = true;
}
}

cout << "IMPOSSIBLE\n";
return;
}

for (int i = 1; i <= n; i++){
if (f1[i] != -1 && f2[i] != -1){
if (tot != f1[i] + f2[i] && tot != -1){
cout << "IMPOSSIBLE\n";
return;
}
tot = f1[i] + f2[i];
}
}

vector<vector<int>> groups(n + 1);
vector<pair<int, int>> coeff_groups(n + 1), coeff_sub(n + 1);
vector<int> spare;

if (tot != -1){
for (int i = 1; i <= n; i++){
if (f1[i] == -1 && f2[i] != -1){
f1[i] = tot - f2[i];
} else if (f1[i] != -1 && f2[i] != -1){
if (tot != f1[i] + f2[i]){
cout << "IMPOSSIBLE\n";
return;
}
}
}
}

// for (int i = 1; i <= n; i++){
//     cout << f1[i] << " " << f2[i] << "\n";
// }

auto dfs = [&](auto self, int u, int par, int last) -> void{
if (b[u]){
last = u;
}
if (last != -1){
groups[last].push_back(a[u]);
} else {
spare.push_back(a[u]);
}

for (int v : T.adj[u]){
if (v != par){
self(self, v, u, last);
coeff_sub[u].first += coeff_sub[v].first;
coeff_sub[u].second += coeff_sub[v].second;
}
}

if (b[u]){
pair <int, int> pi = {0, 0};
if (f1[u] != -1){
pi.first = f1[u];
} else if (f2[u] != -1) {
pi.first = -f2[u];
pi.second = 1;
}

coeff_groups[u].first = pi.first - coeff_sub[u].first;
coeff_groups[u].second = pi.second - coeff_sub[u].second;

coeff_sub[u] = pi;
}
};

dfs(dfs, 1, -1, -1);

vector <int> picked_by_tot(n + 1, 0), sum_by_tot(n + 1, 0), poss(n + 1, 0);
for (int i = 1; i <= n; i++){
sort(groups[i].begin(), groups[i].end());
}
sort(spare.begin(), spare.end());

int forced = 0, forced_sum = 0;
int need_to_good = 0;

// for (int i = 1; i <= n; i++){
//     cout << f1[i] << " " << f2[i] << "\n";
// }

for (int i = 1; i <= n; i++){
if (groups[i].size() == 0) continue;
if (coeff_groups[i].second == 0){
int req = coeff_groups[i].first;
if (req > groups[i].size()){
cout << "IMPOSSIBLE\n";
return;
}
for (int j = 0; j < req; j++){
forced++;
forced_sum += groups[i][j];
}
} else {
int sz = groups[i].size();
int sum_till_now = 0;
int x = coeff_groups[i].first;
int y = coeff_groups[i].second;

need_to_good++;

// cout << i << "\n";

for (int j = 0; j <= sz; j++){
// solve equation
// x + y * T = j
// T = (j - x) / y

if ((j - x) % y != 0){
if (j < sz){
sum_till_now += groups[i][j];
}
continue;
}
int tt = (j - x) / y;
if (tt < 1 || tt > n){
if (j < sz){
sum_till_now += groups[i][j];
}
continue;
}
poss[tt]++;
picked_by_tot[tt] += j;
sum_by_tot[tt] += sum_till_now;

// cout << tt << " " << j << "\n";

if (j < sz){
sum_till_now += groups[i][j];
}
}

// cout << "\n";
}
}

int ans = INF;
vector <int> ps(n + 1, 0);
for (int i = 0; i < (int)spare.size(); i++){
ps[i + 1] = ps[i] + spare[i];
}

// for (int i = 1; i <= n; i++){
//     cout << f1[i] << " " << f2[i] << "\n";
// }

// cout << forced << " " << forced_sum << "\n";

for (int i = 1; i <= n; i++){
if (poss[i] == need_to_good){
if (tot != -1 && i != tot){
continue;
}
// cout << i << "\n";
picked_by_tot[i] += forced;
sum_by_tot[i] += forced_sum;

int left = i - picked_by_tot[i];
// cout << left << "\n";

if (left < 0 || left > (int)spare.size()){
continue;
}

//  cout << i << "\n";

ans = min(ans, sum_by_tot[i] + ps[left]);
}
}

if (ans == INF){
cout << "IMPOSSIBLE\n";
return;
}

cout << ans << "\n";
}

int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in",  "r", stdin);
// freopen("out", "w", stdout);

cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;

void solve(istringstream cin) {
int n, m;
cin >> n >> m;
vector<long long> a(n);
for (int i = 0; i < n; i++) {
cin >> a[i];
}
vector<vector<int>> g(n);
for (int i = 0; i < n - 1; i++) {
int x, y;
cin >> x >> y;
x--;
y--;
g[x].emplace_back(y);
g[y].emplace_back(x);
}

vector<int> depth(n);
vector<vector<int>> go_up(20, vector<int>(n + 1, n));
function<void(int, int)> dfs = [&](int v, int p) {
for (int to : g[v]) {
if (to == p) {
continue;
}
depth[to] = depth[v] + 1;
go_up[0][to] = v;
dfs(to, v);
}
};
dfs(0, -1);
for (int i = 1; i < 20; i++) {
for (int j = 0; j < n; j++) {
go_up[i][j] = go_up[i - 1][go_up[i - 1][j]];
}
}

int s_size = -1;
map<pair<int, int>, int> cut;
for (int i = 0; i < m; i++) {
int u, r, k;
cin >> u >> r >> k;
u--;
r--;
if (u == r) {
s_size = k;
continue;
}

if (depth[r] <= depth[u]) {
r = go_up[0][u];
} else {
int diff = depth[r] - depth[u] - 1;
for (int j = 0; j < 20; j++) {
if (diff & (1 << j)) {
r = go_up[j][r];
}
}
if (go_up[0][r] != u) {
r = go_up[0][u];
}
}

assert((cut.count({u, r}) == 0 || cut[{u, r}] == k));
cut[{u, r}] = k;
if (cut.count({r, u})) {
s_size = cut[{r, u}] + k;
}
}

if (cut.empty()) {
sort(a.begin(), a.end());
cout << accumulate(a.begin(), a.begin() + s_size, 0LL) << '\n';
return;
}

vector<long long> c(n + 1);
long long d = 0;
vector<int> e(n + 2);
vector<bool> checked(n);
for (int i = 0; i < n; i++) {
if (checked[i]) {
continue;
}
// the size of this group -> s_size * x + y
int x = 1, y = 0;
vector<long long> b;
function<void(int, int)> search_group = [&](int v, int p) {
b.emplace_back(a[v]);
checked[v] = true;
for (int to : g[v]) {
if (to == p) {
continue;
}
if (cut.count({to, v})) {
y -= cut[{to, v}];
continue;
}
if (cut.count({v, to})) {
x -= 1;
y += cut[{v, to}];
continue;
}
search_group(to, v);
}
};
search_group(i, -1);
sort(b.begin(), b.end());

if (x == 0) {
d += accumulate(b.begin(), b.begin() + y, 0LL);
continue;
}

e[n + 1]++;
long long sum = 0;
for (int cnt = 0; cnt <= (int) b.size(); cnt++) {
// s_size * x + y == cnt
if ((cnt - y) % x == 0) {
int k = (cnt - y) / x;
if (0 <= k && k <= n) {
e[k]++;
c[k] += sum;
}
}
if (cnt < (int) b.size()) {
sum += b[cnt];
}
}
}

if (s_size != -1) {
e[s_size]++;
e[n + 1]++;
}

long long ans = 1e18;
for (int i = 0; i <= n; i++) {
if (e[i] == e[n + 1]) {
ans = min(ans, c[i] + d);
}
}
assert(ans < 1e18);
cout << ans << '\n';
}

////////////////////////////////////////

#define IGNORE_CR

struct input_checker {
string buffer;
int pos;

const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";

input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
#ifdef IGNORE_CR
if (c == '\r') {
continue;
}
#endif
buffer.push_back((char) c);
}
}

assert(pos < (int) buffer.size());
string res;
while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
assert(!isspace(buffer[pos]));
res += buffer[pos];
pos++;
}
return res;
}

string readString(int min_len, int max_len, const string& pattern = "") {
assert(min_len <= max_len);
string res = readOne();
assert(min_len <= (int) res.size());
assert((int) res.size() <= max_len);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}

int readInt(int min_val, int max_val) {
assert(min_val <= max_val);
int res = stoi(readOne());
assert(min_val <= res);
assert(res <= max_val);
return res;
}

long long readLong(long long min_val, long long max_val) {
assert(min_val <= max_val);
long long res = stoll(readOne());
assert(min_val <= res);
assert(res <= max_val);
return res;
}

vector<int> readInts(int size, int min_val, int max_val) {
assert(min_val <= max_val);
vector<int> res(size);
for (int i = 0; i < size; i++) {
res[i] = readInt(min_val, max_val);
if (i != size - 1) {
}
}
return res;
}

vector<long long> readLongs(int size, long long min_val, long long max_val) {
assert(min_val <= max_val);
vector<long long> res(size);
for (int i = 0; i < size; i++) {
res[i] = readLong(min_val, max_val);
if (i != size - 1) {
}
}
return res;
}

assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}

assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}

assert((int) buffer.size() == pos);
}
};

struct dsu {
int n;
vector<int> p;
vector<int> sz;

dsu(int _n) : n(_n) {
p = vector<int>(n);
iota(p.begin(), p.end(), 0);
sz = vector<int>(n, 1);
}

inline int get(int x) {
if (p[x] == x) {
return x;
} else {
return p[x] = get(p[x]);
}
}

inline bool unite(int x, int y) {
x = get(x);
y = get(y);
if (x == y) {
return false;
}
p[x] = y;
sz[y] += sz[x];
return true;
}

inline bool same(int x, int y) {
return (get(x) == get(y));
}

inline int size(int x) {
return sz[get(x)];
}

inline bool root(int x) {
return (x == get(x));
}
};

int main() {
input_checker in;
int tt = in.readInt(1, 1e5);
int sn = 0, sm = 0;
while (tt--) {
int n = in.readInt(3, 2e5);
int m = in.readInt(1, 2e5);
sn += n;
sm += m;
auto a = in.readInts(n, -1e9, 1e9);
vector<int> u1(n - 1), v1(n - 1);
for (int i = 0; i < n - 1; i++) {
u1[i] = in.readInt(1, n);
v1[i] = in.readInt(1, n);
}
dsu uf(n);
for (int i = 0; i < n - 1; i++) {
assert(uf.unite(u1[i] - 1, v1[i] - 1));
}
vector<int> u2(m), r2(m), k2(m);
for (int i = 0; i < m; i++) {
u2[i] = in.readInt(1, n);
r2[i] = in.readInt(1, n);
k2[i] = in.readInt(1, n);
}
ostringstream sout;
sout << n << " " << m << '\n';
for (int i = 0; i < n; i++) {
sout << a[i] << " \n"[i == n - 1];
}
for (int i = 0; i < n - 1; i++) {
sout << u1[i] << " " << v1[i] << '\n';
}
for (int i = 0; i < m; i++) {
sout << u2[i] << " " << r2[i] << " " << k2[i] << '\n';
}
solve(istringstream(sout.str()));
}
cerr << sn << " " << sm << endl;
assert(sn <= 2e5);
assert(sm <= 2e5);
return 0;
}

Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
ios::sync_with_stdio(false); cin.tie(0);

int t; cin >> t;
while (t--) {
int n, m; cin >> n >> m;
vector<int> a(n);
for (int &x : a) cin >> x;

for (int i = 0; i < n-1; ++i) {
int u, v; cin >> u >> v;
}

vector<int> in(n), out(n), ord;
int timer = 0;
auto dfs = [&] (const auto &self, int u, int p) -> void {
in[u] = timer++;
ord.push_back(u);
for (int v : adj[u]) if (v != p)
self(self, v, u);
out[u] = timer;

sort(begin(adj[u]), end(adj[u]), [&] (int x, int y) {
return in[x] < in[y];
});
};
dfs(dfs, 0, 0);
reverse(begin(ord), end(ord));

vector<int> inside(n, -1), outside(n, -1);
for (int i = 0; i < m; ++i) {
int root, u, k; cin >> u >> root >> k;
--root, --u;

if (u == root) {
inside[0] = k;
}
else {
if (in[u] <= in[root] and out[u] >= out[root]) {
int lo = (u != 0), hi = size(adj[u]) - 1;
while (lo < hi) {
int mid = (lo + hi + 1)/2;
int v = adj[u][mid];
if (in[v] > in[root]) hi = mid-1;
else lo = mid;
}
int v = adj[u][lo];
outside[v] = k;
}
else inside[u] = k;
}
}

vector val(n, vector<int>());
auto populate = [&] (const auto &self, int u, int p, int who) -> void {
if (inside[u] != -1 or outside[u] != -1) who = u;
val[who].push_back(a[u]);
for (int v : adj[u]) if (v != p)
self(self, v, u, who);
sort(begin(val[u]), end(val[u]));
};
populate(populate, 0, 0, 0);

{
ll cur = 0;
vector<int> req(n), used(n);
auto solve = [&] (const auto &self, int u, int p) -> void {
for (int v : adj[u]) if (v != p) {
self(self, v, u);
used[u] += used[v];
}

if (req[u] != -1) {
int take = req[u] - used[u];
if (take < 0 or take > val[u].size()) cur = 1e18;
else {
for (int i = 0; i < take; ++i) cur += val[u][i];
}
used[u] += take;
}
};
bool done = false;
for (int i = 0; i < n; ++i) {
if ((inside[i] != -1 and outside[i] != -1) or (i == 0 and inside[i] != -1)) {
int k = inside[i] + outside[i]*(i > 0);

req.assign(n, -1);
for (int u : ord) {
if (inside[u] != -1) req[u] = inside[u];
if (outside[u] != -1) {
req[u] = k - outside[u];
}
}
req[0] = k;

cur = 0;
used.assign(n, 0);
solve(solve, 0, 0);

done = true;
break;
}
}

if (done) {
cout << cur << '\n';
continue;
}
}

vector<ll> ans(n+1, 0);
ll lo = 1, hi = n;
vector<array<ll, 2>> coef(n);
ll always = 0;
outside[0] = 0;
for (int u : ord) {
if (inside[u] == -1 and outside[u] == -1) continue;

auto &[x, y] = coef[u];
if (inside[u] != -1) {
// xk + y are taken
// take z from here -> xk + y + z = inside[u]
// z = inside[u] - y - xk
x *= -1;
y = inside[u] - y;
}
else {
// xk + y are taken
// take z from here -> xk + y + z = k - outside[u]
// z = k*(1-x) - outside[u] - y
x = 1-x;
y = -y - outside[u];
}

int s = size(val[u]);
if (x > 0) {
// xk+y >= 0 -> k >= ceil(-y/x)
lo = max(lo, (-y+x-1) / x);
// xk+y <= size(val[u]) -> k <= floor((S-y)/x)
hi = min(hi, (s - y)/x);
}
if (x < 0) {
// xk+y >= 0 -> k <= floor(-y/x)
hi = min(hi, (-y)/x);
// xk+y <= S -> xk <= S-y -> k >= ceil((S-y)/x)))
lo = max(lo, (y-s + abs(x)-1)/-x);
}

ll sm = 0;
for (int i = 1; i <= s; ++i) {
// xk + y = i -> k = (i-y)/x
sm += val[u][i-1];
ll k = i-y;
if (x == 0) {
if (i == y) always += sm;
continue;
}
if (k%x) continue;
k /= x;
if (k < 1 or k > n) continue;
ans[k] += sm;
}

if (inside[u] == -1) {