 # MEXSUBTR - Editorial

Author: Ashish Gangwar
Preparer: Yahor Dubovik
Tester: Harris Leung
Editorialist: Nishank Suresh

3121

# PREREQUISITES:

Depth-first search, small-to-large merging

# PROBLEM:

You have a tree with N vertices rooted at vertex 1, and an array B. Assign a value A_i to vertex i such that:

• For each vertex u, the mex of the values assigned to vertices in its subtree equals B_u
• The sum A_1 + A_2 + \ldots + A_N is minimized.

# EXPLANATION:

First, we have the following two observations:

• Consider a vertex u and its parent p. Then, for a valid assignment to exist, it must hold that B_u \leq B_p.
• Consider a vertex u. Let m_u be the largest integer such that every one of 0, 1, 2, \ldots m_u-1 is present as the B_v value of some ancestor v of u (u is also considered to be an ancestor of itself). That is, m_u is the mex of values on the path from the root to u.
Then, no matter how we assign values, it must hold that A_u \geq m_u.

Both observations are fairly trivial to prove, and the second one also gives us a lower bound on the answer.

Using these observations, let us try to assign the values of A in a bottom-up manner.

Suppose we are considering a vertex u, and we have processed all its children already. We now want to assign the value A_u. Let M = \max(B_c) across all children c of u (and M = 0 if u is a leaf).

• From the first observation, we know that for a valid assignment to exist at all, it must hold that B_u \geq M. This gives us two cases.
• Suppose B_u = M. Then, the mex of assigned values in the subtree of u is already M, so A_u can be (almost) freely assigned. There are only two restrictions: A_u \geq m_u and A_u \neq M (since assigning A_u = M would make the mex \gt M). Let us hold off on assigning A_u for now.
• Suppose B_u \gt M. We now have to assign the values M, M+1, M+2, \ldots, B_u-1 to some free vertices in the subtree of u. However, there are a couple of restrictions:
• First, whenever we assign a value of x to a vertex v, we must ensure that x \geq m_v.
• Second, we cannot assign the value of M to a vertex that lies inside one of the children c of u with B_c = M, because this would mess up the mex of this subtree.

This gives us an algorithm for the B_u \gt M case:

• First, pick some free vertex v such that m_v \leq M and v is not in the subtree of a child with mex M, and assign A_v := M
• Then, for x = M+1, M+2, \ldots, B_u - 1, pick a free vertex v such that m_v \leq x and assign A_v := x.
• Note that, for our purposes, free vertices with lower values of m_v are better since they can be assigned a larger range of values. So, whenever we pick a free vertex to assign a value of x, pick the one with the highest possible value of m_v that is \leq x
• If at any point we are unable to pick a valid free vertex, no assignment exists and the answer is -1.

All that remains is to implement the different working parts of this to run fast enough.

• First, we need to be able to compute the values of m_u for every vertex u. This can be done with a DFS that maintains the current set of active and inactive values on the path from the root to u.
• Initially, the inactive set contains every integer from 0 to N while the active set is empty.
• When we enter u, insert one instance of B_u to the active set and remove B_u from the inactive set.
• m_u is the smallest element of the inactive set.
• When exiting u, remove one instance of B_u from the active set. If there are no instances of B_u in the active set, insert B_u to the inactive set.
• This can be implemented with a std::set for \mathcal{O}(N\log N) overall
• Next, we need to do the assigning values part. For this, when we are at a vertex u, we need to know all free vertices in the subtree of u.
• When picking a vertex to assign a value of x, we choose the one with the largest value of m_v that is \leq x. If the available free vertices are sorted by m_v, this can be found in \mathcal{O}(\log N) with a single binary search
• Keeping the sorted list of free vertices at each vertex can be done in \mathcal{O}(N\log^2 N) with the help of small-to-large merging.

Finally, remember to assign a value of m_v to any vertex v that hasn’t been assigned a value yet.

# TIME COMPLEXITY

\mathcal{O}(N \log^2 N) per test case.

# CODE:

Preparer's code (C++)
#ifdef DEBUG
#define _GLIBCXX_DEBUG
#endif
//#pragma GCC optimize("O3")
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef long double ld;
const int maxN = 1e6 + 10;
vector<int> g[maxN];
int n;
int sb[maxN];
bool OK = true;
int val[maxN];
set<int> not_here;
int cnt[maxN];
void dfs(int v) {
if (!cnt[sb[v]]) {
not_here.erase(sb[v]);
}
cnt[sb[v]]++;
val[v] = *not_here.begin();
for (int to : g[v]) {
if (sb[to] > sb[v]) {
OK = false;
}
dfs(to);
}
cnt[sb[v]]--;
if (!cnt[sb[v]]) {
not_here.insert(sb[v]);
}
}
multiset<int> all_vals[maxN];
ll tot_s = 0;
void dfs2(int v) {
int mx = 0;
for (int to : g[v]) {
dfs2(to);
mx = max(mx, sb[to]);
}
if (mx == sb[v]) {
all_vals[v].insert(val[v]);
for (int to : g[v]) {
if ((int)all_vals[to].size() > (int)all_vals[v].size()) {
swap(all_vals[to], all_vals[v]);
}
for (int x : all_vals[to]) {
all_vals[v].insert(x);
}
all_vals[to].clear();
}
}
else {
good_sub.insert(val[v]);
for (int to : g[v]) {
if (mx == sb[to]) {
if ((int)bad_sub.size() < (int)all_vals[to].size()) {
}
for (int x : all_vals[to]) {
}
}
else {
if ((int)good_sub.size() < (int)all_vals[to].size()) {
swap(all_vals[to], good_sub);
}
for (int x : all_vals[to]) {
good_sub.insert(x);
}
}
all_vals[to].clear();
}
if (good_sub.empty()) {
OK = false;
return;
}
else {
tot_s += mx;
good_sub.erase(good_sub.find(*(--good_sub.end())));
}
if ((int)good_sub.size() < (int)bad_sub.size()) {
}
for (int r : bad_sub) {
good_sub.insert(r);
}
for (int t = mx + 1; t < sb[v]; t++) {
if (good_sub.empty()) {
OK = false;
return;
}
tot_s += t;
good_sub.erase(good_sub.find(*(--good_sub.end())));
}
swap(all_vals[v], good_sub);
}
}
void solve() {
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> sb[i];
g[i].clear();
all_vals[i].clear();
}
not_here.clear();
for (int i = 0; i <= n; i++) {
cnt[i] = 0;
not_here.insert(i);
}
for (int i = 2; i <= n; i++) {
int x;
cin >> x;
g[x].emplace_back(i);
}
OK = true;
tot_s = 0;
dfs(1);
if (!OK) {
cout << -1 << '\n';
return;
}
dfs2(1);
if (!OK) {
cout << -1 << '\n';
return;
}
for (int x : all_vals) {
tot_s += x;
}
cout << tot_s << '\n';
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
#ifdef DEBUG
freopen("input.txt", "r", stdin);
#endif
int tst;
cin >> tst;
while (tst--) {
solve();
}
return 0;
}

Tester's code (C++)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define fi first
#define se second
const ll mod=998244353;
const int N=2e5+1;
int n;
bool ok;
ll ans=0;

int b[N];
int sz[N],bc[N],bs[N];

map<int,int>mp[N];
vector<int>ch[N];
void dfs(int id){
sz[id]=1;
bs[id]=id;bc[id]=0;
for(auto c:ch[id]){
dfs(c);
sz[id]+=sz[c];
if(sz[c]>sz[bc[id]]){
bc[id]=c;
bs[id]=bs[bc[id]];
}
}
//cout << "stats " << id << ' ' << bs[id] << ' ' << bc[id] << endl;
}
void clean(int id){
while(true){
auto it=mp[bs[id]].end();
if(it==mp[bs[id]].begin()) return;
--it;
if(it->se==0) mp[bs[id]].erase(it);
else return;
}
}
void collapse(int id){
for(auto c:ch[id]){
if(c==bc[id]) continue;
for(auto d:mp[bs[c]]){
mp[bs[id]][d.fi]+=d.se;
}
mp[bs[c]].clear();
}
}
void solve(int id){
if(!ok) return;
int mx=0;
for(auto c:ch[id]){
solve(c);
mx=max(mx,b[c]);
}
//cout << "do " << id << ' ' << mp[bs] << endl;
if(b[id]<mx){
ok=false;return;
}
if(b[id]==mx){
collapse(id);
mp[bs[id]]++;
mp[bs[id]][b[id]+1]+=mp[bs[id]][b[id]];
mp[bs[id]][b[id]]=0;
return;
}
int gd=0;
for(auto c:ch[id]){
if(b[c]==mx) continue;
clean(c);
if(!mp[bs[c]].empty()){
auto it=mp[bs[c]].rbegin();
gd=max(gd,it->fi);
}
}
collapse(id);
mp[bs[id]]++;
mp[bs[id]][gd]--;
ans+=mx;
//cout << "!! " << id << ' ' << gd << endl;
for(int i=mx+1; i<b[id] ;i++){
clean(id);
auto it=mp[bs[id]].end();
if(it==mp[bs[id]].begin()){
ok=false;return;
}
--it;
ans+=i;
it->se--;

}
mp[bs[id]][b[id]+1]+=mp[bs[id]][b[id]];
mp[bs[id]][b[id]]=0;
}
void solve(){
cin >> n;
ok=true;ans=0;
for(int i=1; i<=n ;i++){
cin >> b[i];
ch[i].clear();
mp[i].clear();
}
for(int i=2; i<=n ;i++){
int p;cin >> p;ch[p].push_back(i);
}
dfs(1);
solve(1);
//cout << "! " << ans << endl;
for(auto c:mp[bs]){
ans+=1LL*c.fi*c.se;
//cout << "sin " << c.fi << ' ' << c.se << endl;
}
if(!ok) cout << "-1\n";
else cout << ans << '\n';
}
int main(){
ios::sync_with_stdio(false);cin.tie(0);
int t;cin >> t;while(t--) solve();
}

Editorialist's code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
ios::sync_with_stdio(false); cin.tie(0);

int t; cin >> t;
while (t--) {
int n; cin >> n;
vector<int> mex(n);
for (int &x : mex) cin >> x;
for (int i = 1; i < n; ++i) {
int p; cin >> p;
}

vector<int> active(n+1);
set<int> inactive;
for (int i = 0; i <= n; ++i) inactive.insert(i);
vector<int> val(n);
vector<set<array<int, 2>>> available(n);

auto dfs = [&] (const auto &self, int u) -> bool {
auto ret = true;
int mx = 0;
++active[mex[u]];
if (active[mex[u]] == 1) inactive.erase(mex[u]);
int low = *inactive.begin();
for (int v : adj[u]) {
ret &= self(self, v);
mx = max(mx, mex[v]);
}
--active[mex[u]];
if (active[mex[u]] == 0) inactive.insert(mex[u]);
if (mx > mex[u]) return false;

if (mx == mex[u]) {
available[u] = {{low, u}};
for (int v : adj[u]) {
if (available[u].size() < available[v].size()) swap(available[u], available[v]);
for (auto it : available[v]) available[u].insert(it);
}
return ret;
}

auto assign = [&] (int k) {
auto it = available[u].lower_bound({k+1, 0});
if (it == begin(available[u])) return false;
--it;
auto [_, v] = *it;
available[u].erase(it);
val[v] = k;
return true;
};

available[u] = {{low, u}};
for (int v : adj[u]) {
if (mex[v] == mx) continue;
if (available[u].size() < available[v].size()) swap(available[u], available[v]);
for (auto it : available[v]) available[u].insert(it);
}
ret &= assign(mx);
for (int v : adj[u]) {
if (mex[v] != mx) continue;
if (available[u].size() < available[v].size()) swap(available[u], available[v]);
for (auto it : available[v]) available[u].insert(it);
}
for (int i = mx+1; ret and i < mex[u]; ++i) ret &= assign(i);
return ret;
};
auto res = dfs(dfs, 0);
for (auto [x, y] : available) val[y] = x;
if (res) cout << accumulate(begin(val), end(val), 0LL) << '\n';
else cout << -1 << '\n';
}
}


Solved this problem just 1 minute before end and reached 7-stars 2 Likes

Hope you liked the problem.