Editorial - GGANBU

PROBLEM LINK: LINK

Author: Nishit Sharma

Tester: Abhishek Jugdar

Editorialist: Nishit Sharma

DIFFICULTY:

Medium.

PREREQUISITES:

DFS, Euler tour, Segment tree

PROBLEM:

Given a tree with N nodes rooted at 1, where the ith node has the value A[i] . Given a set of special edges S containing some edges of the tree. You have to answer Q queries. There are 3 different types of queries:

  • 1 u K - Divide the subtree of u (denoted by T_u) into K disconnected components by removing K-1 non special edges such that the value $$ \sum_{v \in T_u} Z_v\cdot A_v$$ is maximum possible. Where Z_v denotes 1 plus the number of non special edges removed on the shortest path from u to v
  • 2 u v - Add edge u-v to set S.
  • 3 u v - Remove edge u-v from the set S

OBSERVATION 1:

Tap to view

It is always optimal to break the edge between the highest subtree sum node and its parent.

EXPLANATION:

Tap to view

Let us first try to solve the problem without considering the set S we’ll modify the solution later according to set S.

Let’s denote the parent of any node X as par(X) and the sum of values of nodes in the subtree of X as subtreeSum(X).

Initially consider K = 2, then let’s consider any arbitrary node R' in the subtree of R and break the edge between par(R’) and R'. The nodes in the subtree of R get divided into two sets which are \{nodes(R)-nodes(R’)\} and \{nodes(R')\} in which all the nodes in the first set have Z value equal to 1 and all nodes in the second set have a Z value equal to 2.
We then compute the value of F as : 1*(subtreeSum(R) - subtreeSum(R’)) + 2*(subtreeSum(R’))

This equation is rewritten as subtreeSum(R) + subtreeSum(R'). After the above operation, we have 2 separate subtrees one of which is rooted at node R(excluding the subtree at R') and the other at R'. Now if K is greater than 2 we can simply extend the above operation on either of the two subtrees and get a similar equation.

Overall the equation for any general K will be:
subtreeSum(R) + subtreeSum(R_1) + subtreeSum(R_2)...... +subtreeSum(R_{k-1})

Where R_1, R_2,.... R_{k-1} are the nodes whose edges with their respective parents have been broken.

Hence the problem simplifies down to choosing the nodes with K-1 greatest subtree sums from the subtree of R.

Implementation:
We can do a euler tour of the tree and store the subtree sum of each node at its corresponding index of each node, let’s denote that array by E. Then we can build a max segment tree on E . For each query of type 1 we can query K-1 times, the subtree of u for the maximum subtree sum, let’s say the index at which the maximum value occurs in the subtree of u is j then update the value at index j to -\infty.
If we are unable to get K-1 values from the subtree which are not equal to -\infty then the answer is IMPOSSIBLE. Let’s say the values returned by K-1 queries are P = \{V_1, V_2 .... V_{K-1}\}.
Then the answer to the query is subtreeSum(u) + V_1 + V_2 .... V_{K+1}

Now to accommodate the set S into the question, for any edge u-v to avoid breaking this edge we can simply set the value of subtreeSum(v) as -\infty and if and when we remove this edge from the set we can again update the value of subtreeSum(v) with the correct subtree sum.

Time Complexity: O(N+QKlogN)

Space Complexity: O(N)

CODE:

Setter's Solution(C++)
#include<bits/stdc++.h>
#define ll long long int
#define fab(a,b,i) for(int i=a;i<b;i++)
#define pb push_back
#define db double
#define mp make_pair
#define endl "\n"
#define f first
#define se second
#define all(x) x.begin(),x.end()
#define vll vector<ll>
#define vi vector<int>
#define pii pair<int,int>
#define pll pair<ll,ll>
#define quick ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL)

using namespace std;

const int MOD = 1e9 + 7;

ll add(ll x, ll y) {ll res = x + y; return (res >= MOD ? res - MOD : res);}
ll mul(ll x, ll y) {ll res = x * y; return (res >= MOD ? res % MOD : res);}
ll sub(ll x, ll y) {ll res = x - y; return (res < 0 ? res + MOD : res);}
ll power(ll x, ll y) {ll res = 1; x %= MOD; while (y) {if (y & 1)res = mul(res, x); y >>= 1; x = mul(x, x);} return res;}
ll mod_inv(ll x) {return power(x, MOD - 2);}
ll lcm(ll x, ll y) { ll res = x / __gcd(x, y); return (res * y);}


#define int ll
class segtree
{
public:
    vector<pair<int, int>> seg;
    vector<int> a;
    int n;
    int placeHolder;
    segtree(vector<int> &v)
    {
        n = v.size();
        a = v;
        placeHolder = -1e18;
        seg.resize(2 * n);
    }

    pair<int, int> merge(pair<int, int> a, pair<int, int> b)
    {
        return (a.first >= b.first ? a : b);
    }

    void build()
    {
        for (int i = 0; i < n; i++)
        {
            seg[i + n] = {a[i], i};
        }
        for ( int i = n - 1; i > 0; i--)
        {
            seg[i] = merge(seg[2 * i] , seg[2 * i + 1]);
        }

    }

    void update( int ind , int val)
    {
        a[ind] = val;
        ind += n;
        seg[ind] = {val, ind - n};

        for ( ; ind > 1; ind >>= 1)
        {
            seg[ind >> 1] = merge(seg[ind] , seg[ind ^ 1]);
        }
    }

    pair<int, int> query(int l, int r)
    {
        l += n;
        r += n;
        pair<int, int> ans = {placeHolder, -1};
        while (l < r)
        {

            if (l % 2)
            {
                ans = merge(ans, seg[l]);
                l++;
            }

            if (r % 2)
            {
                --r;
                ans =  merge(ans, seg[r]);
            }
            l >>= 1;
            r >>= 1;
        }
        return ans;

    }

};


void dfs(int src, vector<vector<int>> &v, vector<int> &subtree, vector<int> &euler, vector<int> &indexInEuler, int &tim, vector<int> &a, vector<int> &subtreeSum) {

    assert(subtree[src] == 0);
    subtree[src] = 1;
    assert(euler[tim] == 0);
    euler[tim] = src;
    subtreeSum[src] = a[src];
    assert(indexInEuler[src] == 0);
    indexInEuler[src] = tim;
    tim++;

    for (int &i : v[src]) {
        dfs(i, v, subtree, euler, indexInEuler, tim, a, subtreeSum);
        subtree[src] += subtree[i];
        subtreeSum[src] += subtreeSum[i];
    }
}


int32_t main()
{
    quick;
    int t = 1;
    cin >> t;
    while (t--)
    {
        int n, q;
        cin >> n >> q;
        vector<vector<int>> v(n);
        vector<int> parent(n);
        fab(1, n, i)
        {
            int parentNode;
            cin >> parentNode;
            parentNode--;
            parent[i] = parentNode;
            v[parentNode].push_back(i);
        }

        vector<int> a(n);
        fab(0, n, i)
        {
            cin >> a[i];
        }

        int notAllowedSz;
        cin >> notAllowedSz;
        vector<pair<int, int>> cancel;
        set<pair<int, int>> s;

        for (int i = 0; i < notAllowedSz; i++)
        {
            int x, y;
            cin >> x >> y;
            x--, y--;
            cancel.pb({x, y});
        }

        vector<int> subtree(n), euler(n), indexInEuler(n), subtreeSum(n);
        int tim = 0;
        dfs(0, v, subtree, euler, indexInEuler, tim, a, subtreeSum);
        vector<int> alter = euler;

        for (int i = 0; i < n; i++)
        {
            assert(subtree[i] > 0);
            euler[i] = subtreeSum[euler[i]];

        }
        segtree seg(euler);
        seg.build();
        const int inf = 1e18;
        const int compareVal = -1e17;
        auto breakEdges = [&](int ind, int k) {

            int eulerIndex = indexInEuler[ind];
            int subSize = subtree[ind];
            vector<int> indicesUpdated;
            bool ok = 1;
            int sum = subtreeSum[ind];
            for (int i = 0; i < k - 1; i++) {

                int leftIndex = eulerIndex + 1;
                int rightIndex = eulerIndex + subSize;
                auto currMax = seg.query(leftIndex, rightIndex);

                if (currMax.first < compareVal)
                {
                    ok = 0;
                    break;
                }
                int index = currMax.second;
                indicesUpdated.push_back(index);
                sum += currMax.first;
                seg.update(index, -inf);
            }


            for (int &i : indicesUpdated) {
                seg.update(i, subtreeSum[alter[i]]);
            }

            if (!ok)
            {
                cout << "IMPOSSIBLE" << endl;
                return;
            }
            cout << sum << endl;
        };

        auto addEdge = [&](int x, int y) {
            s.insert({x, y});
            int index = indexInEuler[y];
            seg.update(index, -inf);
        };

        auto removeEdge = [&](int x, int y) {
            s.erase(s.find({x, y}));
            seg.update(indexInEuler[y], subtreeSum[y]);
        };


        for (auto &i : cancel)
        {
            addEdge(i.f, i.se);
        }
        while (q--) {

            int type;
            cin >> type;

            if (type == 1) {

                int index, k;
                cin >> index >> k;
                assert(index >= 1 and index <= n);
                index--;
                breakEdges(index, k);
            }
            else if (type == 2) {
                int x, y;
                cin >> x >> y;
                assert(x >= 1 and x <= n and y >= 1 and y <= n);
                x--, y--;
                addEdge(x, y);
            } else if (type == 3) {
                int x, y;
                cin >> x >> y;
                assert(x >= 1 and x <= n and y >= 1 and y <= n);
                x--, y--;
                removeEdge(x, y);
            }
        }

    }
    cerr << "time taken : " << (float)clock() / CLOCKS_PER_SEC << " secs" << endl;
    return 0;
}
Tester's Solution
#include <bits/stdc++.h>

using namespace std;

const int64_t INF = 4e15;

class segtree {
private:
	vector<pair<int64_t, int>> value;
	int n;

public:
	segtree(int _n) : n(_n) {
		value.assign(2 * n, make_pair(-INF, -1));
	}

	void build(vector<int64_t>& v) {
		for (int i = 0; i < n; i++) {
			value[n + i] = make_pair(v[i], i);
		}

		for (int i = n; i > 0; i--) {
			value[i] = max(value[i << 1], value[i << 1 | 1]);
		}
	}

	void upd(int ind, int64_t val) {
		for (value[ind += n] = make_pair(val, ind); ind > 1; ind >>= 1) {
			value[ind >> 1] = max(value[ind], value[ind ^ 1]);
		}
	}

	pair<int64_t, int> qry(int l, int r) {
		pair<int64_t, int> res = make_pair(-INF, -1);
		for (l += n, r += n; l < r; l >>= 1, r >>= 1) {
			if (l & 1) {
				res = max(res, value[l]);
				l++;
			}

			if (r & 1) {
				r--;
				res = max(res, value[r]);
			}
		}

		return res;
	}
};

const int maxN = 1e5 + 5;
vector<vector<int>> adj(maxN);
vector<int> tin(maxN), tout(maxN), a(maxN), par(maxN);
vector<int64_t> sub(maxN); 
int curr_time = 0;

void init(int n) {
	for (int i = 1; i <= n; i++) {
		adj[i].clear();
	}
	curr_time = 0;
}

void dfs(int x, int p = 0) {
	tin[x] = curr_time++;
	sub[x] = a[x];

	for (int v : adj[x]) {
		if (v != p) {
			dfs(v, x);
			sub[x] += sub[v];
		}
	}
	tout[x] = curr_time++;
}

int main() {
	ios::sync_with_stdio(false); cin.tie(0);
	int t;
	cin >> t;
	while (t--) {
		int n, q;
		cin >> n >> q;
		init(n);

		for (int i = 2; i <= n; i++) {
			int a;
			cin >> a;

			par[i] = a;
			adj[i].push_back(a);
			adj[a].push_back(i);
		}

		for (int i = 1; i <= n; i++) {
			cin >> a[i];
		}

		dfs(1);

		vector<int64_t> v(2 * n, 0);
		for (int i = 1; i <= n; i++) {
			v[tin[i]] = v[tout[i]] = sub[i];
		}

		segtree st(2 * n);
		st.build(v);

		vector<int> mp(2 * n, 0);
		for (int i = 1; i <= n; i++) {
			mp[tin[i]] = mp[tout[i]] = i;
		}

		int m;
		cin >> m;

		for (int i = 0; i < m; i++) {
			int u, v;
			cin >> u >> v;
			if (par[u] == v) swap(u, v);
			st.upd(tin[v], -INF); st.upd(tout[v], -INF);
		}

		while (q--) {
			int type;
			cin >> type;

			if (type == 1) {
				int r, k;
				cin >> r >> k;

				int64_t sum = sub[r];
				vector<pair<int, int64_t>> v;

				for (int i = 0; i < k - 1; i++) {
					auto [val, ind] = st.qry(tin[r] + 1, tout[r]);
					if (val == -INF) {
						sum = -INF;
						break;
					}

					sum += val;
					ind = mp[ind];
					v.emplace_back(ind, val);
					st.upd(tin[ind], -INF); st.upd(tout[ind], -INF);
				}

				if (sum == -INF) cout << "IMPOSSIBLE\n";
				else cout << sum << '\n';

				for (const auto& [ind, val] : v) {
					st.upd(tin[ind], val);
					st.upd(tout[ind], val);
				}
			}
			else if (type == 2) {
				int u, v;
				cin >> u >> v;

				if (par[u] == v) swap(u, v);
				st.upd(tin[v], -INF); st.upd(tout[v], -INF);
			}
			else {
				int u, v;
				cin >> u >> v;

				if (par[u] == v) swap(u, v);
				st.upd(tin[v], sub[v]); st.upd(tout[v], sub[v]);
			}
		}
	}
}

If anything is unclear please let me know in the comments, it will help me improve.

2 Likes