MEXSUBTR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Ashish Gangwar
Preparer: Yahor Dubovik
Tester: Harris Leung
Editorialist: Nishank Suresh

DIFFICULTY:

3121

PREREQUISITES:

Depth-first search, small-to-large merging

PROBLEM:

You have a tree with N vertices rooted at vertex 1, and an array B. Assign a value A_i to vertex i such that:

  • For each vertex u, the mex of the values assigned to vertices in its subtree equals B_u
  • The sum A_1 + A_2 + \ldots + A_N is minimized.

EXPLANATION:

First, we have the following two observations:

  • Consider a vertex u and its parent p. Then, for a valid assignment to exist, it must hold that B_u \leq B_p.
  • Consider a vertex u. Let m_u be the largest integer such that every one of 0, 1, 2, \ldots m_u-1 is present as the B_v value of some ancestor v of u (u is also considered to be an ancestor of itself). That is, m_u is the mex of values on the path from the root to u.
    Then, no matter how we assign values, it must hold that A_u \geq m_u.

Both observations are fairly trivial to prove, and the second one also gives us a lower bound on the answer.

Using these observations, let us try to assign the values of A in a bottom-up manner.

Suppose we are considering a vertex u, and we have processed all its children already. We now want to assign the value A_u. Let M = \max(B_c) across all children c of u (and M = 0 if u is a leaf).

  • From the first observation, we know that for a valid assignment to exist at all, it must hold that B_u \geq M. This gives us two cases.
    • Suppose B_u = M. Then, the mex of assigned values in the subtree of u is already M, so A_u can be (almost) freely assigned. There are only two restrictions: A_u \geq m_u and A_u \neq M (since assigning A_u = M would make the mex \gt M). Let us hold off on assigning A_u for now.
    • Suppose B_u \gt M. We now have to assign the values M, M+1, M+2, \ldots, B_u-1 to some free vertices in the subtree of u. However, there are a couple of restrictions:
      • First, whenever we assign a value of x to a vertex v, we must ensure that x \geq m_v.
      • Second, we cannot assign the value of M to a vertex that lies inside one of the children c of u with B_c = M, because this would mess up the mex of this subtree.

This gives us an algorithm for the B_u \gt M case:

  • First, pick some free vertex v such that m_v \leq M and v is not in the subtree of a child with mex M, and assign A_v := M
  • Then, for x = M+1, M+2, \ldots, B_u - 1, pick a free vertex v such that m_v \leq x and assign A_v := x.
  • Note that, for our purposes, free vertices with lower values of m_v are better since they can be assigned a larger range of values. So, whenever we pick a free vertex to assign a value of x, pick the one with the highest possible value of m_v that is \leq x
  • If at any point we are unable to pick a valid free vertex, no assignment exists and the answer is -1.

All that remains is to implement the different working parts of this to run fast enough.

  • First, we need to be able to compute the values of m_u for every vertex u. This can be done with a DFS that maintains the current set of active and inactive values on the path from the root to u.
    • Initially, the inactive set contains every integer from 0 to N while the active set is empty.
    • When we enter u, insert one instance of B_u to the active set and remove B_u from the inactive set.
    • m_u is the smallest element of the inactive set.
    • When exiting u, remove one instance of B_u from the active set. If there are no instances of B_u in the active set, insert B_u to the inactive set.
    • This can be implemented with a std::set for \mathcal{O}(N\log N) overall
  • Next, we need to do the assigning values part. For this, when we are at a vertex u, we need to know all free vertices in the subtree of u.
    • When picking a vertex to assign a value of x, we choose the one with the largest value of m_v that is \leq x. If the available free vertices are sorted by m_v, this can be found in \mathcal{O}(\log N) with a single binary search
    • Keeping the sorted list of free vertices at each vertex can be done in \mathcal{O}(N\log^2 N) with the help of small-to-large merging.

Finally, remember to assign a value of m_v to any vertex v that hasn’t been assigned a value yet.

TIME COMPLEXITY

\mathcal{O}(N \log^2 N) per test case.

CODE:

Preparer's code (C++)
#ifdef DEBUG
#define _GLIBCXX_DEBUG
#endif
//#pragma GCC optimize("O3")
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef long double ld;
const int maxN = 1e6 + 10;
vector<int> g[maxN];
int n;
int sb[maxN];
bool OK = true;
int val[maxN];
set<int> not_here;
int cnt[maxN];
void dfs(int v) {
    if (!cnt[sb[v]]) {
        not_here.erase(sb[v]);
    }
    cnt[sb[v]]++;
    val[v] = *not_here.begin();
    for (int to : g[v]) {
        if (sb[to] > sb[v]) {
            OK = false;
        }
        dfs(to);
    }
    cnt[sb[v]]--;
    if (!cnt[sb[v]]) {
        not_here.insert(sb[v]);
    }
}
multiset<int> all_vals[maxN];
ll tot_s = 0;
void dfs2(int v) {
    int mx = 0;
    for (int to : g[v]) {
        dfs2(to);
        mx = max(mx, sb[to]);
    }
    if (mx == sb[v]) {
        all_vals[v].insert(val[v]);
        for (int to : g[v]) {
            if ((int)all_vals[to].size() > (int)all_vals[v].size()) {
                swap(all_vals[to], all_vals[v]);
            }
            for (int x : all_vals[to]) {
                all_vals[v].insert(x);
            }
            all_vals[to].clear();
        }
    }
    else {
        multiset<int> bad_sub, good_sub;
        good_sub.insert(val[v]);
        for (int to : g[v]) {
            if (mx == sb[to]) {
                if ((int)bad_sub.size() < (int)all_vals[to].size()) {
                    swap(all_vals[to], bad_sub);
                }
                for (int x : all_vals[to]) {
                    bad_sub.insert(x);
                }
            }
            else {
                if ((int)good_sub.size() < (int)all_vals[to].size()) {
                    swap(all_vals[to], good_sub);
                }
                for (int x : all_vals[to]) {
                    good_sub.insert(x);
                }
            }
            all_vals[to].clear();
        }
        if (good_sub.empty()) {
            OK = false;
            return;
        }
        else {
            tot_s += mx;
            good_sub.erase(good_sub.find(*(--good_sub.end())));
        }
        if ((int)good_sub.size() < (int)bad_sub.size()) {
            swap(good_sub, bad_sub);
        }
        for (int r : bad_sub) {
            good_sub.insert(r);
        }
        bad_sub.clear();
        for (int t = mx + 1; t < sb[v]; t++) {
            if (good_sub.empty()) {
                OK = false;
                return;
            }
            tot_s += t;
            good_sub.erase(good_sub.find(*(--good_sub.end())));
        }
        swap(all_vals[v], good_sub);
    }
}
void solve() {
    cin >> n;
    for (int i = 1; i <= n; i++) {
        cin >> sb[i];
        g[i].clear();
        all_vals[i].clear();
    }
    not_here.clear();
    for (int i = 0; i <= n; i++) {
        cnt[i] = 0;
        not_here.insert(i);
    }
    for (int i = 2; i <= n; i++) {
        int x;
        cin >> x;
        g[x].emplace_back(i);
    }
    OK = true;
    tot_s = 0;
    dfs(1);
    if (!OK) {
        cout << -1 << '\n';
        return;
    }
    dfs2(1);
    if (!OK) {
        cout << -1 << '\n';
        return;
    }
    for (int x : all_vals[1]) {
        tot_s += x;
    }
    cout << tot_s << '\n';
}
int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
#ifdef DEBUG
    freopen("input.txt", "r", stdin);
#endif
    int tst;
    cin >> tst;
    while (tst--) {
        solve();
    }
    return 0;
}
Tester's code (C++)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define fi first
#define se second
const ll mod=998244353;
const int N=2e5+1;
int n;
bool ok;
ll ans=0;

int b[N];
int sz[N],bc[N],bs[N];

map<int,int>mp[N];
vector<int>ch[N];
void dfs(int id){
	sz[id]=1;
	bs[id]=id;bc[id]=0;
	for(auto c:ch[id]){
		dfs(c);
		sz[id]+=sz[c];
		if(sz[c]>sz[bc[id]]){
			bc[id]=c;
			bs[id]=bs[bc[id]];
		}
	}
	//cout << "stats " << id << ' ' << bs[id] << ' ' << bc[id] << endl;
}
void clean(int id){
	while(true){
		auto it=mp[bs[id]].end();
		if(it==mp[bs[id]].begin()) return;
		--it;
		if(it->se==0) mp[bs[id]].erase(it);
		else return;
	}
}
void collapse(int id){
	for(auto c:ch[id]){
		if(c==bc[id]) continue;
		for(auto d:mp[bs[c]]){
			mp[bs[id]][d.fi]+=d.se;
		}
		mp[bs[c]].clear();
	}
}
void solve(int id){
	if(!ok) return;
	int mx=0;
	for(auto c:ch[id]){
		solve(c);
		mx=max(mx,b[c]);
	}
	//cout << "do " << id << ' ' << mp[bs[1]][1] << endl;
	if(b[id]<mx){
		ok=false;return;
	}
	if(b[id]==mx){
		collapse(id);
		mp[bs[id]][0]++;
		mp[bs[id]][b[id]+1]+=mp[bs[id]][b[id]];
		mp[bs[id]][b[id]]=0;
		return;
	}
	int gd=0;
	for(auto c:ch[id]){
		if(b[c]==mx) continue;
		clean(c);
		if(!mp[bs[c]].empty()){
			auto it=mp[bs[c]].rbegin();
			gd=max(gd,it->fi);
		}
	}
	collapse(id);
	mp[bs[id]][0]++;
	mp[bs[id]][gd]--;
	ans+=mx;
	//cout << "!! " << id << ' ' << gd << endl;
	for(int i=mx+1; i<b[id] ;i++){
		clean(id);
		auto it=mp[bs[id]].end();
		if(it==mp[bs[id]].begin()){
			ok=false;return;
		}
		--it;
		ans+=i;
		it->se--;
		
	}
	mp[bs[id]][b[id]+1]+=mp[bs[id]][b[id]];
	mp[bs[id]][b[id]]=0;
}
void solve(){
	cin >> n;
	ok=true;ans=0;
	for(int i=1; i<=n ;i++){
		cin >> b[i];
		ch[i].clear();
		mp[i].clear();
	}
	for(int i=2; i<=n ;i++){
		int p;cin >> p;ch[p].push_back(i);
	}
	dfs(1);
	solve(1);
	//cout << "! " << ans << endl;
	for(auto c:mp[bs[1]]){
		ans+=1LL*c.fi*c.se;
		//cout << "sin " << c.fi << ' ' << c.se << endl;
	}
	if(!ok) cout << "-1\n";
	else cout << ans << '\n';
}
int main(){
	ios::sync_with_stdio(false);cin.tie(0);
	int t;cin >> t;while(t--) solve();
}
Editorialist's code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	int t; cin >> t;
	while (t--) {
		int n; cin >> n;
		vector<int> mex(n);
		for (int &x : mex) cin >> x;
		vector<vector<int>> adj(n);
		for (int i = 1; i < n; ++i) {
			int p; cin >> p;
			adj[--p].push_back(i);
		}

		vector<int> active(n+1);
		set<int> inactive;
		for (int i = 0; i <= n; ++i) inactive.insert(i);
		vector<int> val(n);
		vector<set<array<int, 2>>> available(n);

		auto dfs = [&] (const auto &self, int u) -> bool {
			auto ret = true;
			int mx = 0;
			++active[mex[u]];
			if (active[mex[u]] == 1) inactive.erase(mex[u]);
			int low = *inactive.begin();
			for (int v : adj[u]) {
				ret &= self(self, v);
				mx = max(mx, mex[v]);
			}
			--active[mex[u]];
			if (active[mex[u]] == 0) inactive.insert(mex[u]);
			if (mx > mex[u]) return false;

			if (mx == mex[u]) {
				available[u] = {{low, u}};
				for (int v : adj[u]) {
					if (available[u].size() < available[v].size()) swap(available[u], available[v]);
					for (auto it : available[v]) available[u].insert(it);
				}
				return ret;
			}

			auto assign = [&] (int k) {
				auto it = available[u].lower_bound({k+1, 0});
				if (it == begin(available[u])) return false;
				--it;
				auto [_, v] = *it;
				available[u].erase(it);
				val[v] = k;
				return true;
			};

			available[u] = {{low, u}};
			for (int v : adj[u]) {
				if (mex[v] == mx) continue;
				if (available[u].size() < available[v].size()) swap(available[u], available[v]);
				for (auto it : available[v]) available[u].insert(it);
			}
			ret &= assign(mx);
			for (int v : adj[u]) {
				if (mex[v] != mx) continue;
				if (available[u].size() < available[v].size()) swap(available[u], available[v]);
				for (auto it : available[v]) available[u].insert(it);
			}
			for (int i = mx+1; ret and i < mex[u]; ++i) ret &= assign(i);
			return ret;
		};
		auto res = dfs(dfs, 0);
		for (auto [x, y] : available[0]) val[y] = x;
		if (res) cout << accumulate(begin(val), end(val), 0LL) << '\n';
		else cout << -1 << '\n';
	}
}

Solved this problem just 1 minute before end and reached 7-stars :smile:

2 Likes

Hope you liked the problem.