PROBLEM LINK: LINK
Author: Nishit Sharma
Tester: Abhishek Jugdar
Editorialist: Nishit Sharma
DIFFICULTY:
Medium.
PREREQUISITES:
DFS, Euler tour, Segment tree
PROBLEM:
Given a tree with N nodes rooted at 1, where the ith node has the value A[i] . Given a set of special edges S containing some edges of the tree. You have to answer Q queries. There are 3 different types of queries:
- 1 u K - Divide the subtree of u (denoted by T_u) into K disconnected components by removing K-1 non special edges such that the value $$ \sum_{v \in T_u} Z_v\cdot A_v$$ is maximum possible. Where Z_v denotes 1 plus the number of non special edges removed on the shortest path from u to v
- 2 u v - Add edge u-v to set S.
- 3 u v - Remove edge u-v from the set S
OBSERVATION 1:
Tap to view
It is always optimal to break the edge between the highest subtree sum node and its parent.
EXPLANATION:
Tap to view
Let us first try to solve the problem without considering the set S we’ll modify the solution later according to set S.
Let’s denote the parent of any node X as par(X) and the sum of values of nodes in the subtree of X as subtreeSum(X).
Initially consider K = 2, then let’s consider any arbitrary node R' in the subtree of R and break the edge between par(R’) and R'. The nodes in the subtree of R get divided into two sets which are \{nodes(R)-nodes(R’)\} and \{nodes(R')\} in which all the nodes in the first set have Z value equal to 1 and all nodes in the second set have a Z value equal to 2.
We then compute the value of F as : 1*(subtreeSum(R) - subtreeSum(R’)) + 2*(subtreeSum(R’))
This equation is rewritten as subtreeSum(R) + subtreeSum(R'). After the above operation, we have 2 separate subtrees one of which is rooted at node R(excluding the subtree at R') and the other at R'. Now if K is greater than 2 we can simply extend the above operation on either of the two subtrees and get a similar equation.
Overall the equation for any general K will be:
subtreeSum(R) + subtreeSum(R_1) + subtreeSum(R_2)...... +subtreeSum(R_{k-1})
Where R_1, R_2,.... R_{k-1} are the nodes whose edges with their respective parents have been broken.
Hence the problem simplifies down to choosing the nodes with K-1 greatest subtree sums from the subtree of R.
Implementation:
We can do a euler tour of the tree and store the subtree sum of each node at its corresponding index of each node, let’s denote that array by E. Then we can build a max segment tree on E . For each query of type 1 we can query K-1 times, the subtree of u for the maximum subtree sum, let’s say the index at which the maximum value occurs in the subtree of u is j then update the value at index j to -\infty.
If we are unable to get K-1 values from the subtree which are not equal to -\infty then the answer is IMPOSSIBLE
. Let’s say the values returned by K-1 queries are P = \{V_1, V_2 .... V_{K-1}\}.
Then the answer to the query is subtreeSum(u) + V_1 + V_2 .... V_{K+1}
Now to accommodate the set S into the question, for any edge u-v to avoid breaking this edge we can simply set the value of subtreeSum(v) as -\infty and if and when we remove this edge from the set we can again update the value of subtreeSum(v) with the correct subtree sum.
Time Complexity: O(N+QKlogN)
Space Complexity: O(N)
CODE:
Setter's Solution(C++)
#include<bits/stdc++.h>
#define ll long long int
#define fab(a,b,i) for(int i=a;i<b;i++)
#define pb push_back
#define db double
#define mp make_pair
#define endl "\n"
#define f first
#define se second
#define all(x) x.begin(),x.end()
#define vll vector<ll>
#define vi vector<int>
#define pii pair<int,int>
#define pll pair<ll,ll>
#define quick ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL)
using namespace std;
const int MOD = 1e9 + 7;
ll add(ll x, ll y) {ll res = x + y; return (res >= MOD ? res - MOD : res);}
ll mul(ll x, ll y) {ll res = x * y; return (res >= MOD ? res % MOD : res);}
ll sub(ll x, ll y) {ll res = x - y; return (res < 0 ? res + MOD : res);}
ll power(ll x, ll y) {ll res = 1; x %= MOD; while (y) {if (y & 1)res = mul(res, x); y >>= 1; x = mul(x, x);} return res;}
ll mod_inv(ll x) {return power(x, MOD - 2);}
ll lcm(ll x, ll y) { ll res = x / __gcd(x, y); return (res * y);}
#define int ll
class segtree
{
public:
vector<pair<int, int>> seg;
vector<int> a;
int n;
int placeHolder;
segtree(vector<int> &v)
{
n = v.size();
a = v;
placeHolder = -1e18;
seg.resize(2 * n);
}
pair<int, int> merge(pair<int, int> a, pair<int, int> b)
{
return (a.first >= b.first ? a : b);
}
void build()
{
for (int i = 0; i < n; i++)
{
seg[i + n] = {a[i], i};
}
for ( int i = n - 1; i > 0; i--)
{
seg[i] = merge(seg[2 * i] , seg[2 * i + 1]);
}
}
void update( int ind , int val)
{
a[ind] = val;
ind += n;
seg[ind] = {val, ind - n};
for ( ; ind > 1; ind >>= 1)
{
seg[ind >> 1] = merge(seg[ind] , seg[ind ^ 1]);
}
}
pair<int, int> query(int l, int r)
{
l += n;
r += n;
pair<int, int> ans = {placeHolder, -1};
while (l < r)
{
if (l % 2)
{
ans = merge(ans, seg[l]);
l++;
}
if (r % 2)
{
--r;
ans = merge(ans, seg[r]);
}
l >>= 1;
r >>= 1;
}
return ans;
}
};
void dfs(int src, vector<vector<int>> &v, vector<int> &subtree, vector<int> &euler, vector<int> &indexInEuler, int &tim, vector<int> &a, vector<int> &subtreeSum) {
assert(subtree[src] == 0);
subtree[src] = 1;
assert(euler[tim] == 0);
euler[tim] = src;
subtreeSum[src] = a[src];
assert(indexInEuler[src] == 0);
indexInEuler[src] = tim;
tim++;
for (int &i : v[src]) {
dfs(i, v, subtree, euler, indexInEuler, tim, a, subtreeSum);
subtree[src] += subtree[i];
subtreeSum[src] += subtreeSum[i];
}
}
int32_t main()
{
quick;
int t = 1;
cin >> t;
while (t--)
{
int n, q;
cin >> n >> q;
vector<vector<int>> v(n);
vector<int> parent(n);
fab(1, n, i)
{
int parentNode;
cin >> parentNode;
parentNode--;
parent[i] = parentNode;
v[parentNode].push_back(i);
}
vector<int> a(n);
fab(0, n, i)
{
cin >> a[i];
}
int notAllowedSz;
cin >> notAllowedSz;
vector<pair<int, int>> cancel;
set<pair<int, int>> s;
for (int i = 0; i < notAllowedSz; i++)
{
int x, y;
cin >> x >> y;
x--, y--;
cancel.pb({x, y});
}
vector<int> subtree(n), euler(n), indexInEuler(n), subtreeSum(n);
int tim = 0;
dfs(0, v, subtree, euler, indexInEuler, tim, a, subtreeSum);
vector<int> alter = euler;
for (int i = 0; i < n; i++)
{
assert(subtree[i] > 0);
euler[i] = subtreeSum[euler[i]];
}
segtree seg(euler);
seg.build();
const int inf = 1e18;
const int compareVal = -1e17;
auto breakEdges = [&](int ind, int k) {
int eulerIndex = indexInEuler[ind];
int subSize = subtree[ind];
vector<int> indicesUpdated;
bool ok = 1;
int sum = subtreeSum[ind];
for (int i = 0; i < k - 1; i++) {
int leftIndex = eulerIndex + 1;
int rightIndex = eulerIndex + subSize;
auto currMax = seg.query(leftIndex, rightIndex);
if (currMax.first < compareVal)
{
ok = 0;
break;
}
int index = currMax.second;
indicesUpdated.push_back(index);
sum += currMax.first;
seg.update(index, -inf);
}
for (int &i : indicesUpdated) {
seg.update(i, subtreeSum[alter[i]]);
}
if (!ok)
{
cout << "IMPOSSIBLE" << endl;
return;
}
cout << sum << endl;
};
auto addEdge = [&](int x, int y) {
s.insert({x, y});
int index = indexInEuler[y];
seg.update(index, -inf);
};
auto removeEdge = [&](int x, int y) {
s.erase(s.find({x, y}));
seg.update(indexInEuler[y], subtreeSum[y]);
};
for (auto &i : cancel)
{
addEdge(i.f, i.se);
}
while (q--) {
int type;
cin >> type;
if (type == 1) {
int index, k;
cin >> index >> k;
assert(index >= 1 and index <= n);
index--;
breakEdges(index, k);
}
else if (type == 2) {
int x, y;
cin >> x >> y;
assert(x >= 1 and x <= n and y >= 1 and y <= n);
x--, y--;
addEdge(x, y);
} else if (type == 3) {
int x, y;
cin >> x >> y;
assert(x >= 1 and x <= n and y >= 1 and y <= n);
x--, y--;
removeEdge(x, y);
}
}
}
cerr << "time taken : " << (float)clock() / CLOCKS_PER_SEC << " secs" << endl;
return 0;
}
Tester's Solution
#include <bits/stdc++.h>
using namespace std;
const int64_t INF = 4e15;
class segtree {
private:
vector<pair<int64_t, int>> value;
int n;
public:
segtree(int _n) : n(_n) {
value.assign(2 * n, make_pair(-INF, -1));
}
void build(vector<int64_t>& v) {
for (int i = 0; i < n; i++) {
value[n + i] = make_pair(v[i], i);
}
for (int i = n; i > 0; i--) {
value[i] = max(value[i << 1], value[i << 1 | 1]);
}
}
void upd(int ind, int64_t val) {
for (value[ind += n] = make_pair(val, ind); ind > 1; ind >>= 1) {
value[ind >> 1] = max(value[ind], value[ind ^ 1]);
}
}
pair<int64_t, int> qry(int l, int r) {
pair<int64_t, int> res = make_pair(-INF, -1);
for (l += n, r += n; l < r; l >>= 1, r >>= 1) {
if (l & 1) {
res = max(res, value[l]);
l++;
}
if (r & 1) {
r--;
res = max(res, value[r]);
}
}
return res;
}
};
const int maxN = 1e5 + 5;
vector<vector<int>> adj(maxN);
vector<int> tin(maxN), tout(maxN), a(maxN), par(maxN);
vector<int64_t> sub(maxN);
int curr_time = 0;
void init(int n) {
for (int i = 1; i <= n; i++) {
adj[i].clear();
}
curr_time = 0;
}
void dfs(int x, int p = 0) {
tin[x] = curr_time++;
sub[x] = a[x];
for (int v : adj[x]) {
if (v != p) {
dfs(v, x);
sub[x] += sub[v];
}
}
tout[x] = curr_time++;
}
int main() {
ios::sync_with_stdio(false); cin.tie(0);
int t;
cin >> t;
while (t--) {
int n, q;
cin >> n >> q;
init(n);
for (int i = 2; i <= n; i++) {
int a;
cin >> a;
par[i] = a;
adj[i].push_back(a);
adj[a].push_back(i);
}
for (int i = 1; i <= n; i++) {
cin >> a[i];
}
dfs(1);
vector<int64_t> v(2 * n, 0);
for (int i = 1; i <= n; i++) {
v[tin[i]] = v[tout[i]] = sub[i];
}
segtree st(2 * n);
st.build(v);
vector<int> mp(2 * n, 0);
for (int i = 1; i <= n; i++) {
mp[tin[i]] = mp[tout[i]] = i;
}
int m;
cin >> m;
for (int i = 0; i < m; i++) {
int u, v;
cin >> u >> v;
if (par[u] == v) swap(u, v);
st.upd(tin[v], -INF); st.upd(tout[v], -INF);
}
while (q--) {
int type;
cin >> type;
if (type == 1) {
int r, k;
cin >> r >> k;
int64_t sum = sub[r];
vector<pair<int, int64_t>> v;
for (int i = 0; i < k - 1; i++) {
auto [val, ind] = st.qry(tin[r] + 1, tout[r]);
if (val == -INF) {
sum = -INF;
break;
}
sum += val;
ind = mp[ind];
v.emplace_back(ind, val);
st.upd(tin[ind], -INF); st.upd(tout[ind], -INF);
}
if (sum == -INF) cout << "IMPOSSIBLE\n";
else cout << sum << '\n';
for (const auto& [ind, val] : v) {
st.upd(tin[ind], val);
st.upd(tout[ind], val);
}
}
else if (type == 2) {
int u, v;
cin >> u >> v;
if (par[u] == v) swap(u, v);
st.upd(tin[v], -INF); st.upd(tout[v], -INF);
}
else {
int u, v;
cin >> u >> v;
if (par[u] == v) swap(u, v);
st.upd(tin[v], sub[v]); st.upd(tout[v], sub[v]);
}
}
}
}
If anything is unclear please let me know in the comments, it will help me improve.