TRTOKENS - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Practice

Setter: Jatin Yadav
Tester: Riley Borgard
Editorialist: Aman Dwivedi

DIFFICULTY

Medium

PREREQUISITES

Tree, DFS, RMQ

PROBLEM

You are given a rooted tree with N nodes numbered 1, 2, \ldots, N. Node 1 is the root node. Some of the nodes have a token in them. In one move, you can choose a non-root node that has a token, but its parent doesn’t, and move the token from this node to its parent. What is the maximum number of moves you can make?

Note: When a token is moved out of a node, the node becomes empty, and other tokens will be able to move there.

QUICK EXPLANATION:

We can run DFS on tree as, DFS(s): First call DFS for all children v of s. Now If node s initially had a token, then we can do nothing otherwise find the deepest token in its subtree and shift it to node s.

EXPLANATION:

The idea is that we can take the deepest token that has a free ancestor and move it to the closest (deepest) free ancestor.

Proof

Let us prove it by contradiction:

Let’s look at an optimal solution that doesn’t move this token. This token had more ancestor vertices than ancestor tokens.

So in the end:

  • It has a free ancestor, which is definitely not optimal we can just move this token there as our goal is to maximize the number of moves.

Hence, to do so we will go through all the vertices from the deepest. If there is a token in it we will try to find the closest free ancestor and move this token there.

To simplify the implementation, we can do DFS on the tree as

DFS(s): Such that all the nodes of subtree which are rooted at node s have been explored. Now If node s initially had a token, then we can do nothing otherwise find the deepest token in its subtree and shift it to node s.

Subtask 1:

T\le10,N\le17

Since the value of N is so small we can try every possible combination to shift which token to which node etc and find such a combination that maximizes the number of moves.

But yes definitely it is an overkill solution.

Solution
#include <bits/stdc++.h>
using namespace std;
 
#define ll long long
#define pii pair<int, int>
#define F first
#define S second
#define all(c) ((c).begin()), ((c).end())
#define sz(x) ((int)(x).size())
#define ld double
 
template<class T,class U>
ostream& operator<<(ostream& os,const pair<T,U>& p){
	os<<"("<<p.first<<", "<<p.second<<")";
	return os;
}
 
template<class T>
ostream& operator <<(ostream& os,const vector<T>& v){
	os<<"{";
	for(int i = 0;i < (int)v.size(); i++){
		if(i)os<<", ";
		os<<v[i];
	}
	os<<"}";
	return os;
}
 
#ifdef LOCAL
#define cerr cout
#else
#endif
 
#define TRACE
 
#ifdef TRACE
#define trace(...) __f(#__VA_ARGS__, __VA_ARGS__)
template <typename Arg1>
void __f(const char* name, Arg1&& arg1){
	cerr << name << " : " << arg1 << std::endl;
}
template <typename Arg1, typename... Args>
void __f(const char* names, Arg1&& arg1, Args&&... args){
	const char* comma = strchr(names + 1, ',');cerr.write(names, comma - names) << " : " << arg1<<" | ";__f(comma+1, args...);
}
#else
#define trace(...)
#endif
 
long long readInt(long long l,long long r,char endd){
	long long x=0;
	int cnt=0;
	int fi=-1;
	bool is_neg=false;
	while(true){
		char g=getchar();
		if(g=='-'){
			assert(fi==-1);
			is_neg=true;
			continue;
		}
		if('0'<=g && g<='9'){
			x*=10;
			x+=g-'0';
			if(cnt==0){
				fi=g-'0';
			}
			cnt++;
			assert(fi!=0 || cnt==1);
			assert(fi!=0 || is_neg==false);
 
			assert(!(cnt>19 || ( cnt==19 && fi>1) ));
		} else if(g==endd){
			if(is_neg){
				x= -x;
			}
			if(!(l<=x && x<=r))cerr<<l<<"<="<<x<<"<="<<r<<endl;
			assert(l<=x && x<=r);
			return x;
		} else {
			assert(false);
		}
	}
}
string readString(int l,int r,char endd){
	string ret="";
	int cnt=0;
	while(true){
		char g=getchar();
		assert(g!=-1);
		if(g==endd){
			break;
		}
		cnt++;
		ret+=g;
	}
	assert(l<=cnt && cnt<=r);
	return ret;
}
long long readIntSp(long long l,long long r){
	return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
	return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
	return readString(l,r,'\n');
}
string readStringSp(int l,int r){
	return readString(l,r,' ');
}
template<class T>
vector<T> readVector(int n, long long l, long long r){
    vector<T> ret(n);
    for(int i = 0; i < n; i++){
        ret[i] = i == n - 1 ? readIntLn(l, r) : readIntSp(l, r);
    }
    return ret;
}
 
const int SN = 1000;
// O(N 2^N)
int get(int n, vector<int> par,  string type){
    vector<vector<int>> children(n);
    for(int i = 1; i < n; i++){
        assert(par[i] >= 0 && par[i] < i);
        children[par[i]].push_back(i);
    }
    vector<int> depth(n);
    vector<int> masks(n);
    int sum = 0;
    vector<vector<bool>> isAncestor(n, vector<bool>(n, 0));
    vector<vector<bool>> isOK(1 << n, vector<bool>(n, 0));
    for(int i = 1; i < n; i++){
        depth[i] = depth[par[i]] + 1;
        sum += (type[i] - '0') * depth[i];
        isAncestor[0][i] = true;
        int u = i;
        while(u != 0){
            isAncestor[u][i] = true;
            u = par[u];
        }
    }
    for(int mask = 0; mask < (1 << n); mask++){
        for(int i = 0; i < n; i++) if(mask >> i & 1){
            isOK[mask][i] = true;
            for(int j = 0; j < i; j++) if((mask >> j & 1) && isAncestor[j][i]) isOK[mask][i] = false;
        }
    }
    const int INF = 1 << 29;
    vector<vector<int>> dp(n, vector<int>(1 << n, INF));
    for(int s = n - 1; s >= 0; s--){
        if(type[s] == '1') masks[s] |= 1<<s;
        for(int v : children[s]){
            masks[s] |= masks[v];
        }
        dp[s][0] = 0;
        for(int submask = masks[s]; submask; submask = (submask - 1) & masks[s]){
            for(int i = 0; i < n; i++) if(isOK[submask][i]){
                int cost = depth[s];
                int mask = submask ^ (1 << i);
                for(int v : children[s]){
                    cost += dp[v][mask & masks[v]];
                    cost = min(cost, INF);
                }
                dp[s][submask] = min(dp[s][submask], cost);
            }
        }
    }
    return sum - dp[0][masks[0]];
}
 
int main(){
	int t; cin >> t;
	int sn = 0;
	while(t--){
		int n; cin >> n;
		string type; cin >> type;
        vector<int> par(n);
        for(int i = 1; i < n; i++){
            cin >> par[i];
            par[i]--;
        }
		cout << get(n, par, type) << endl;
	}
}

Subtask 2:

The sum of N over all test cases doesn’t exceed 2000.

So during DFS if we are at some node say s (such that its subtree is already explored) which doesn’t have a token. Then we need to find the deepest node in its subtree which has a token so that we can shift that token to this node.

We can find this deepest node by doing DFS again on this subtree since the value of N is small enough and it does allow us to do DFS again. Once we found that node we shift that node token to node s adding the number of moves that were needed to our answer.

Subtask 3:

The sum of N over all test cases doesn’t exceed 10^5.

The idea is the same i.e during DFS if there is a node that doesn’t have a token then we will simply find the deepest node that has a token on its subtree and will shift that token to this node.

But finding the deepest node in the subtree by traversing each node again will result in TLE as the value of N is large enough this time.

To optimize it, we can maintain the depths of nodes with the help of multiset + offset, use small to large merging.

This results in a O(N*log^2 N) solution which will be good enough to pass this subtask.

Solution
#include <bits/stdc++.h>
 
#define ll long long
#define sz(x) ((int) (x).size())
#define all(x) (x).begin(), (x).end()
#define vi vector<int>
#define pii pair<int, int>
#define rep(i, a, b) for(int i = (a); i < (b); i++)
using namespace std;
template<typename T>
using minpq = priority_queue<T, vector<T>, greater<T>>;
 
// O(n log^2 n) solution
// dfs on the tree
// if a node is empty and has a token in the subtree, jump the deepest token up
// maintain the depths with a multiset + offset, use small to large merging
 
struct tokens {
    multiset<int> depths;
    int offset = 0;
};
 
void solve() {
    int n;
    string s;
    cin >> n >> s;
    vector<vi> g(n + 1);
    rep(i, 2, n + 1) {
        int p;
        cin >> p;
        g[p].push_back(i);
    }
    ll ans = 0;
    vector<tokens> ma(n + 1);
    function<void(int)> dfs = [&](int x) {
        for(int y : g[x]) {
            dfs(y);
            ma[y].offset++;
            if(sz(ma[x].depths) < sz(ma[y].depths)) {
                for(int d : ma[x].depths) {
                    ma[y].depths.insert(d + ma[x].offset - ma[y].offset);
                }
                ma[x].depths.swap(ma[y].depths);
                swap(ma[x].offset, ma[y].offset);
            }else {
                for(int d : ma[y].depths) {
                    ma[x].depths.insert(d + ma[y].offset - ma[x].offset);
                }
            }
        }
        ma[x].depths.insert(-ma[x].offset);
        if(s[x - 1] == '0') {
            int d = *prev(ma[x].depths.end()) + ma[x].offset;
            ma[x].depths.erase(prev(ma[x].depths.end()));
            ans += d;
        }
    };
    dfs(1);
    cout << ans << '\n';
}
 
int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    int te;
    cin >> te;
    while(te--) solve();
}

Subtask 4:

Original Constraints

As the value of N is large enough to get the TLE verdict for our O(N*log^2N) solution. We can optimize it further to the O(N*logN) solution by using Euler Tour and Range Minimum Query.

We can simply do Euler Tour in the given tree and build the RMQ structure of this Euler tour. So when we are at node s which doesn’t have a token on it then we can make a query on the subtree and can find the deepest token using RMQ.

Hence we are able to optimize our solution to O(N*log(N)).

TIME COMPLEXITY:

O(N*log(N)) per test case

SOLUTIONS:

Setter
 #include <bits/stdc++.h>
using namespace std;

#define ll long long
#define pii pair<int, int>
#define all(c) ((c).begin()), ((c).end())
#define sz(x) ((int)(x).size())

#ifdef LOCAL
#include <print.h>
#else
#define trace(...)
#endif
 
long long readInt(long long l,long long r,char endd){
	long long x=0;
	int cnt=0;
	int fi=-1;
	bool is_neg=false;
	while(true){
		char g=getchar();
		if(g=='-'){
			assert(fi==-1);
			is_neg=true;
			continue;
		}
		if('0'<=g && g<='9'){
			x*=10;
			x+=g-'0';
			if(cnt==0){
				fi=g-'0';
			}
			cnt++;
			assert(fi!=0 || cnt==1);
			assert(fi!=0 || is_neg==false);
 
			assert(!(cnt>19 || ( cnt==19 && fi>1) ));
		} else if(g==endd){
			if(is_neg){
				x= -x;
			}
			if(!(l<=x && x<=r))cerr<<l<<"<="<<x<<"<="<<r<<endl;
			assert(l<=x && x<=r);
			return x;
		} else {
			assert(false);
		}
	}
}
string readString(int l,int r,char endd, char minc = 'a', char maxc = 'z'){
	string ret="";
	int cnt=0;
	while(true){
		char g=getchar();
		assert(g!=-1);
		if(g==endd){
			break;
		}
		assert(g >= minc && g <= maxc);
		cnt++;
		ret+=g;
	}
	assert(l<=cnt && cnt<=r);
	return ret;
}
long long readIntSp(long long l,long long r){
	return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
	return readInt(l,r,'\n');
}
string readStringLn(int l,int r, char minc = 'a', char maxc = 'z'){
	return readString(l,r,'\n', minc, maxc);
}
string readStringSp(int l,int r, char minc = 'a', char maxc = 'z'){
	return readString(l,r,' ', minc, maxc);
}
template<class T>
vector<T> readVector(int n, long long l, long long r){
    vector<T> ret(n);
    for(int i = 0; i < n; i++){
        ret[i] = i == n - 1 ? readIntLn(l, r) : readIntSp(l, r);
    }
    return ret;
}

template<class T>
struct segtree{
	int n;
	vector<T> t, A;
	T def;
	inline T combine(T a, T b){
		if(a == -1) return b;
		if(b == -1) return a;
		return A[a] > A[b] ? a : b;
	}
	segtree(vector<T> inp) : n(sz(inp)), A(inp), def(-1){
		t.resize(2 * n, def);
		for(int i = 0; i < n; i++) t[n + i] = i;
		for(int i = n - 1; i > 0; --i) t[i] = combine(t[i<<1], t[i<<1|1]);
	}

	void modify(int p, T value) { // modify A[p] = value
		// value = combine(value, t[p + n]); // if a[p] = combine(a[p], value)
		A[p] = value;
		for (p += n; p >>= 1; ) t[p] = combine(t[p<<1], t[p<<1|1]);
	}

	T query(int l, int r) {  // compute on interval [l, r]
    	r++;
		T resl = def, resr = def;
		for (l += n, r += n; l < r; l >>= 1, r >>= 1) {
			if (l&1) resl = combine(resl, t[l++]);
			if (r&1) resr = combine(t[--r], resr);
		}
		return combine(resl, resr);
	}
};

long long get(int n, vector<int> par, string type){
	vector<vector<int>> children(n);
	for(int i = 1; i < n; i++){
		assert(par[i] >= 0 && par[i] < i);
		children[par[i]].push_back(i);
	}
	vector<int> depth(n, 1);
	long long sum = (type[0] - '0');
	for(int i = 1; i < n; i++){
		depth[i] = depth[par[i]] + 1;
		sum += (type[i] - '0') * depth[i];
	}
	vector<int> st(n), en(n);
	stack<int> stk;
	stk.push(0);
	int timer = 0;
	while(!stk.empty()){
		int s = stk.top(); stk.pop();
		st[s] = timer++;
		for(int v : children[s]) stk.push(v);
	}
	for(int s = n - 1; s >= 0; s--){
		en[s] = st[s];
		reverse(all(children[s]));
		for(int v : children[s]) en[s] = en[v];
	}
	segtree<int> stree(vector<int>(n, 0));
	for(int s = n - 1; s >= 0; s--){
		stree.modify(st[s], depth[s]);
		sum -= depth[s];
		if(type[s] == '0'){
			int u = stree.query(st[s], en[s]);
			sum += stree.A[u];
			stree.modify(u, 0);
		}
	}
	return sum;
}


const int SN = 1000000;
int main(){
	int t = readIntLn(1, SN);
	int sn = 0;
	while(t--){
		int n = readIntLn(1, SN);
		sn += n;
		assert(sn <= SN);
		string type = readStringLn(n, n, '0', '1');
        vector<int> par = readVector<int>(n - 1, 0, n);
		reverse(all(par)); par.push_back(0); reverse(all(par));
        for(int i = 1; i < n; i++){
            par[i]--;
        }
		cout << get(n, par, type) << endl;
	}
}
Tester
#pragma GCC optimize ("Ofast")
 
#include <bits/stdc++.h>
 
#define ll long long
#define sz(x) ((int) (x).size())
#define all(x) (x).begin(), (x).end()
#define vi vector<int>
#define pii pair<int, int>
#define rep(i, a, b) for(int i = (a); i < (b); i++)
using namespace std;
template<typename T>
using minpq = priority_queue<T, vector<T>, greater<T>>;
 
// O(n log n), optimized from O(n log^2 n) solution
// to query deepest token in subtree, use euler tour tree and RMQ
 
void solve() {
    int n;
    string s;
    cin >> n >> s;
    vector<vi> g(n + 1);
    rep(i, 2, n + 1) {
        int p;
        cin >> p;
        g[p].push_back(i);
    }
    ll ans = 0;
    vi tree(4 * n, n);
    vi a(n + 1, -1);
    function<int(int, int, int, int, int)> query = [&](int i, int l, int r, int L, int R) {
        if(r < L || R < l) return n;
        if(L <= l && r <= R) return tree[i];
        int m = (l + r) / 2;
        int j = query(2 * i + 1, l, m, L, R);
        int k = query(2 * i + 2, m + 1, r, L, R);
        return a[j] > a[k] ? j : k;
    };
    function<void(int, int, int, int, int)> upd = [&](int i, int l, int r, int k, int x) {
        if(l == r) {
            a[k] = x;
            tree[i] = k;
            return;
        }
        int m = (l + r) / 2;
        if(k <= m) upd(2 * i + 1, l, m, k, x);
        else upd(2 * i + 2, m + 1, r, k, x);
        tree[i] = (a[tree[2 * i + 1]] > a[tree[2 * i + 2]] ? tree[2 * i + 1] : tree[2 * i + 2]);
    };
 
    vi tin(n + 1), tout(n + 1), dep(n + 1);
    int ti = 0;
    function<void(int)> dfs = [&](int x) {
        tin[x] = ti++;
        for(int y : g[x]) {
            dep[y] = 1 + dep[x];
            dfs(y);
        }
        tout[x] = ti;
        upd(0, 0, n - 1, tin[x], dep[x]);
        if(s[x - 1] == '0') {
            int j = query(0, 0, n - 1, tin[x], tout[x] - 1);
            ans += a[j] - dep[x];
            upd(0, 0, n - 1, j, -1);
        }
    };
    dfs(1);
    cout << ans << '\n';
}
 
int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    int te;
    cin >> te;
    while(te--) solve();
}
Editorialist
#include<bits/stdc++.h>
using namespace std;
 
#define int long long
 
const int mxN=1e6+5;
vector <int> adj[mxN];
int n;
string s;
bool visited[mxN];
int level[mxN];
int tin[mxN];
int sz[mxN];
int timer;
int ans;
vector <int> a;
pair<int,int> t[4*mxN];
 
 
void build(int arr[],int v,int tl,int tr)
{
  if(tl==tr)
  {
    t[v].first=arr[tl];
    t[v].second=tl;
  }
  else
  {
    int tm=(tl+tr)/2;
    build(arr,v*2,tl,tm);
    build(arr,v*2+1,tm+1,tr);
 
    t[v].first=max(t[v*2].first,t[v*2+1].first);
 
    if(t[v].first==t[v*2].first)
      t[v].second=t[v*2].second;
    else
      t[v].second=t[v*2+1].second;
  }
}
 
pair<int,int> find_max(int v,int tl,int tr,int l,int r)
{
  if(l>r)
    return {-1,-1};
  if(l==tl && r==tr)
    return t[v];
 
  int tm=(tl+tr)/2;
 
  pair<int,int> fst=find_max(v*2,tl,tm,l,min(r,tm));
  pair<int,int> snd=find_max(v*2+1,tm+1,tr,max(l,tm+1),r);
 
  if(fst.first>snd.first)
    return fst;
  else
    return snd;
}
 
void update(int v,int tl,int tr,int pos,int new_val)
{
  if(tl==tr)
    t[v].first=new_val;
  else
  {
    int tm=(tl+tr)/2;
    if(pos<=tm)
      update(v*2,tl,tm,pos,new_val);
    else
      update(v*2+1,tm+1,tr,pos,new_val);
 
    t[v].first=max(t[v*2].first,t[v*2+1].first);
    if(t[v].first==t[v*2].first)
      t[v].second=t[v*2].second;
    else
      t[v].second=t[v*2+1].second;
  }
}
 
void dfs(int v)
{
  visited[v]=false;
 
  for(auto x: adj[v])
  {
    if(visited[x])
      dfs(x);
  }
 
  if(s[v]=='0')
  {
    int l=tin[v];
    int r=l+sz[v]-1;
    // cout<<v<<" "<<l<<" "<<r<<endl;
    pair <int,int> val=find_max(1,0,n-1,l,r);
    // cout<<val.first<<" "<<val.second<<endl;
    if(val.first!=0)
    {
      ans+=(val.first-level[v]);
      update(1,0,n-1,val.second,0);
      update(1,0,n-1,tin[v],level[v]);
    }
  }
}
 
int euler_tour(int v,int he)
{
  tin[v]=timer;
  level[v]=he;
  visited[v]=true;
  timer++;
  a.push_back(v);
 
  for(auto x: adj[v])
  {
    if(!visited[x])
      sz[v]+=euler_tour(x,he+1);
  }
 
  sz[v]++;
 
  return sz[v];
}
 
void solve()
{
    cin>>n;
    cin>>s;
 
    ans=0;
    timer=0;
    a.clear();
 
    for(int i=0;i<n;i++)
    {
        adj[i].clear();
        visited[i]=false;
        level[i]=0;
        tin[i]=-1;
        sz[i]=0;
    }
 
    for(int i=1;i<n;i++)
    {
        int x;
        cin>>x;
 
        adj[x-1].push_back(i);
    }
 
    int waste=euler_tour(0,1);
 
 
    int arr[n];
 
    for(int i=0;i<n;i++)
    {
      if(s[a[i]]=='1')
        arr[i]=level[a[i]];
      else
        arr[i]=0;
    }
 
    build(arr,1,0,n-1);
 
    dfs(0);
    cout<<ans<<"\n";
}
 
int32_t main()
{
    ios_base::sync_with_stdio(0);
    cin.tie(0);
 
    int tc;
    cin>>tc;
 
    while(tc--)
        solve();
 
return 0;
}
 
10 Likes

what is meant by offset ?

6 Likes

The editorial mentions “move token to closest free ancestor” But didn’t the problem say we can move token to only a free parent? Or do you mean , moving all those tokens which lie on the path between the free ancestor and deepest node with token ?

Ofc moving all those tokens which lie on the path between the free ancestor and deepest node with token.

1 Like

Got it thanks . Just wanted to clear it as the statement said, “find the deepest token in its subtree and shift it to node s” which kind of seemed as moving the deepest token to that ancestor directly :stuck_out_tongue:

Can someone give me a case where this approach fails.

So first of all, I have a dfs which returns the longest chain of tokens from that node to below. So at some node with no token, it can have a lot of child subtrees with different longest chains. I claim that it would be optimal to bring the longest chain to one step up to fill the empty node. and while doing this we get ans += longest chain because each node in chain moves one up.

I do this recursively and I’m not able to understand where this wouldn’t be optimal. Any help would be appreciated!
My submission: CodeChef: Practical coding for everyone

4 Likes

The solution for subtask 3 can be made O(N log N) using linked lists or vectors, since the nodes in the subtree occupy a contiguous range of depths

1 Like

\mathcal O(N \log^2 N) passes if you use priority queue instead of multiset thanks to lower constant factor. Submission.

Also, I believe \mathcal O(N) can be achieved if instead of using a logarithmic data structure, we use a deque with depths as indices, and count of nodes at each depth as values, because with merging depths we can prove a better bound of \mathcal O(N) with small to large merging. See last paragraph of this blog.

4 Likes

I’m not sure of this but I think that here you aren’t checking for what happens after one of the max chains of a node s has moved up.
Suppose node s had 3 chains of sizes (8,8,7). You checked for the first chain and brought it up and assigned the new max as chain[1]+1. What if after some moves you could bring up the other chains too that is the remaining chains with sizes (7,7).

2 Likes

I did the same thing but later realized that it was wrong.
Suppose there is a long chain (say 10 nodes) having no tokens, and at the end of it there are two diverging chains ( let their Lengths be 2 each, both having 2 tokens), according to your logic, one of the chains is being shifted along the longer chain, but once one of smaller Chain has reached high enough, we can start shifting the other one (your logic doesn’t include this part).

2 Likes

Edit: Yes, I got my mistake. Thanks!

1 Like

If you want to look on a code having exact same idea but bringing up the chain brute forcely then you can look at my code

1 Like

Can someone tell me why this approach does not work. In my thinking it should have worked. In the map i am storing the lev,node pair and whenever we reach a node with coin we can move that to the uppermost free level. It is gauranteed that in the path the coins will only be on top not in the middle since they are removed when reached. Also the sum of jumps should not matter whether coins from deepest are moved first. Fail cases and critics are much appreciated.
Thanks
My approach: CodeChef: Practical coding for everyone

Yup this is what I meant with my comment. Thanks for clarifying

Oh right. Thanks a lot !

You are wrong in this fact , actually moving a branch up doesnt mean moving a subtree up , you need to care about the subtree again , counter-reasoning , as if it would not had been the case you would had solved it even for real constarint using this approach.

1 Like

Nice problem. Just solved it. If you do not know how Euler Tour is being used check out the ‘Tree Queries’ chapter of cses handbook.

1 Like

Yeah, it’s the correct O(N) solution. Ref impl - CodeChef: Practical coding for everyone
Simply put in “Small to large merging with each subtree size bounded by the depth of the deepest node is just O(N)”.

2 Likes

My logic in this question was different than this… I use set in c++ and depth of each node only to compute the answer.

I am briefly explaining my implementation details here-

Run dfs on the tree. Store the depth of each node while traversing. Also for each node, in a separate array, say arr ,store the node which is its closest free ancestor. This can be computed while during dfs.

After this insert all the nodes with tokens in a set in the form of pair - {depth[node], node}.
Run a while loop till size of this set is non-zero. Pop an item from the last.
The depth to which this node should move to is given by depth[arr[node]] but it may be possible that it is already filled by some more deeper node before. So applying compression technique, find the closest free ancestor of this node and update the answer as depth[node]-depth[free_ancestor].

closest free ancestor what does this mean