# MAXMEXPATH - Editorial

Setter:
Testers: Tejas Pandey and Abhinav sharma
Editorialist: Taranpreet Singh

Easy-Medium

Segment Tree

# PROBLEM

Consider a tree with N nodes, rooted at node 1.
The value of the i^{th} node (1 \leq i \leq N) is denoted by A_i.

Chef defines the function MEX(u, v) as follows:
Let B denote the set of all the values of nodes that lie in the shortest path from u to v (including both u and v). Then, MEX(u, v) denotes the MEX of the set B.

Find the maximum value of MEX(1, i), where 1 \leq i \leq N.

# QUICK EXPLANATION

• Maintain a set of integers not yet found on the current path from the root to u for some node u.
• If A_u is present in the set, we shall remove it at the start of DFS, process its subtree, and then add back A_u to set.
• The MEX on a path from the root to u is the minimum element present in this set which the ordered set can answer fast enough.

# EXPLANATION

We need to compute the MEX of values on the path from node 1 to node u for all u and take its maximum. Let’s denote S_u as the set of values on the path from 1 to u. Then we have B_u = \{A_u\} \bigcup B_p and B_1 = \{A_1\}

So, if we process the nodes of trees in DFS order, we only need to add and remove each value exactly once. The value would be added when DFS enters node u, and the value A_u would be removed when DFS exits node u.

Hence, we have to support adding an element, removing an element, and computing MEX of the current set.

There are multiple possible solutions for this problem, I’d discuss two for now.

Clever solution, simpler implementation

Instead of maintaining the set of values found on the path from 1 to u, let’s maintain the set of values not found on the path from 1 to u. This helps us because now, the MEX is simply the smallest value present in this set.

Now, when DFS enters a node u, we check if this special set contains A_u. If it contains, we would remove A_u from the set at the start.

Now, MEX for the current set is the minimum element in this special set. Also, we can make recursive calls to solve for all nodes in subtree of u.

Lastly, If A_u was removed at the start of DFS, we add back A_u to this set as we exit node u.

The implementation of this is added in Editorialist solution 1.

Generic overkill solution using segment tree

We maintain the frequency of each element in the segment tree. Specifically, leaf x denotes the number of occurrences of x. Our segment tree should support

• Increase/decrease frequency of an element x by 1.
• Find MEX of elements currently present.

The MEX of elements would be the leftmost leaf with 0 frequency, which can be found using a technique called tree descent, described here.

The implementation of this is added in Editorialist solution 2.

# TIME COMPLEXITY

The time complexity of both solutions is O(N*log_2(N)) per test case.

# SOLUTIONS

Editorialist's Solution 1
import java.util.*;
import java.io.*;
class Main{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
int N = ni();
int[] A = new int[N];
for(int i = 0; i< N; i++)A[i] = ni();

int[] from = new int[N-1], to = new int[N-1];
for(int i = 0; i< N-1; i++){
from[i] = ni()-1;
to[i] = ni()-1;
}
int[][] tree = make(N, from, to);
TreeSet<Integer> set = new TreeSet<>();

pn(dfs(tree, set, A, 0, -1));
}
int dfs(int[][] tree, TreeSet<Integer> set, int[] A, int u, int p){
boolean rem = false;
if(set.contains(A[u])){
set.remove(A[u]);
rem = true;
}
int ans = set.first();
for(int v:tree[u])if(v != p)ans = Math.max(ans, dfs(tree, set, A, v, u));
return ans;
}
int[][] make(int N, int[] f, int[] t){
int[] cnt = new int[N];
for(int x:f)cnt[x]++;
for(int x:t)cnt[x]++;
int[][] g = new int[N][];
for(int i = 0; i< N; i++){
g[i] = new int[cnt[i]];
cnt[i] = 0;
}
for(int i = 0; i< N-1; i++){
g[f[i]][cnt[f[i]]++] = t[i];
g[t[i]][cnt[t[i]]++] = f[i];
}
return g;
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = true;
void run() throws Exception{
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new Main().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}

StringTokenizer st;
}

}

String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
}catch (IOException  e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}

String nextLine() throws Exception{
String str = "";
try{
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}

Editorialist's Solution 2
import java.util.*;
import java.io.*;
class Main{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
int N = ni();
int[] A = new int[N];
for(int i = 0; i< N; i++)A[i] = ni();
int[] from = new int[N-1], to = new int[N-1];
for(int i = 0; i< N-1; i++){
from[i] = ni()-1;
to[i] = ni()-1;
}
int[][] tree = make(N, from, to);
int S = 1;
while(S <= N)S<<=1;
int[] segtree = new int[S<<1];
pn(dfs(tree, segtree, A, 0, -1, S));
}
int dfs(int[][] tree, int[] segtree, int[] A, int u, int p, int S){
update(segtree, 0, S-1, 1, A[u], 1);
int ans = mex(segtree, 0, S-1, 1);
for(int v:tree[u])if(v != p)ans = Math.max(ans, dfs(tree, segtree, A, v, u, S));
update(segtree, 0, S-1, 1, A[u], -1);
return ans;
}
void update(int[] segtree, int ll, int rr, int i, int p, int x){
if(ll == rr)segtree[i] += x;
else{
int mid = (ll+rr)/2;
if(p <= mid)update(segtree, ll, mid, i<<1, p, x);
else update(segtree, mid+1, rr, i<<1|1, p, x);
segtree[i] = Math.min(segtree[i<<1], segtree[i<<1|1]);
}
}
int mex(int[] segtree, int ll, int rr, int i){
if(ll == rr)return ll;
int mid = (ll+rr)/2;
if(segtree[i<<1] > 0)return mex(segtree, mid+1, rr, i<<1|1);
return mex(segtree, ll, mid, i<<1);
}
int[][] make(int N, int[] f, int[] t){
int[] cnt = new int[N];
for(int x:f)cnt[x]++;
for(int x:t)cnt[x]++;
int[][] g = new int[N][];
for(int i = 0; i< N; i++){
g[i] = new int[cnt[i]];
cnt[i] = 0;
}
for(int i = 0; i< N-1; i++){
g[f[i]][cnt[f[i]]++] = t[i];
g[t[i]][cnt[t[i]]++] = f[i];
}
return g;
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = true;
void run() throws Exception{
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new Main().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}

StringTokenizer st;
}

}

String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
}catch (IOException  e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}

String nextLine() throws Exception{
String str = "";
try{
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}

Setter's Solution
#define ll long long
#define dd long double
#define pub push_back
#define pob pop_back
#define puf push_front
#define pof pop_front
#define mp make_pair
#define mt make_tuple
#define fo(i , n) for(ll i = 0 ; i < n ; i++)
#define tll tuple<ll ,ll , ll>
#define pll pair<ll ,ll>
#include<bits/stdc++.h>
/*#include<iomanip>
#include<cmath>
#include<cstdio>
#include<utility>
#include<iostream>
#include<vector>
#include<string>
#include<algorithm>
#include<map>
#include<queue>
#include<set>
#include<stack>
#include<bitset>*/
dd pi = acos(-1) ;
ll z =  1000000007 ;
ll inf = 10000000000000 ;
ll p1 = 37 ;
ll p2 = 53 ;
ll mod1 =  202976689 ;
ll mod2 =  203034253 ;
ll fact[200] ;
ll gdp(ll a , ll b){return (a - (a%b)) ;}
ll ld(ll a , ll b){if(a < 0) return -1*gdp(abs(a) , b) ; if(a%b == 0) return a ; return (a + (b - a%b)) ;} // least number >=a divisible by b
ll gd(ll a , ll b){if(a < 0) return(-1 * ld(abs(a) , b)) ;    return (a - (a%b)) ;} // greatest number <= a divisible by b
ll gcd(ll a , ll b){ if(b > a) return gcd(b , a) ; if(b == 0) return a ; return gcd(b , a%b) ;}
ll e_gcd(ll a , ll b , ll &x , ll &y){ if(b > a) return e_gcd(b , a , y , x) ; if(b == 0){x = 1 ; y = 0 ; return a ;}
ll x1 , y1 , g; g = e_gcd(b , a%b , x1 , y1) ; x = y1 ; y = (x1 - ((a/b) * y1)) ; return g ;}
ll power(ll a ,ll b , ll p){if(b == 0) return 1 ; ll c = power(a , b/2 , p) ; if(b%2 == 0) return ((c*c)%p) ; else return ((((c*c)%p)*a)%p) ;}
ll inverse(ll a ,ll n){return power(a , n-2 , n) ;}
ll max(ll a , ll b){if(a > b) return a ; return b ;}
ll min(ll a , ll b){if(a < b) return a ; return b ;}
ll left(ll i){return ((2*i)+1) ;}
ll right(ll i){return ((2*i) + 2) ;}
ll ncr(ll n , ll r){if(n < r|| (n < 0) || (r < 0)) return 0 ; return ((((fact[n] * inverse(fact[r] , z)) % z) * inverse(fact[n-r] , z)) % z);}
void swap(ll&a , ll&b){ll c = a ; a = b ; b = c ; return ;}
//ios_base::sync_with_stdio(0);
//cin.tie(0); cout.tie(0);
using namespace std ;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
#define ordered_set tree<int, null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>
// ordered_set s ; s.order_of_key(val)  no. of elements strictly less than val
// s.find_by_order(i)  itertor to ith element (0 indexed)
//__builtin_popcount(n) -> returns number of set bits in n
ll seed;

void dfs(vector<ll> adj[], vector<ll> &val, vector<ll> &cnt, ll &ans, ll u, ll p , set<ll> &s)
{
ll curr_val = val[u] ;
cnt[curr_val]++ ;
if(cnt[curr_val] == 1)
s.erase(s.find(curr_val)) ;

ans = max(ans , (*s.begin())) ;

for(int i = 0 ; i < adj[u].size() ; i++)
{
if(v == p)
continue ;
dfs(adj , val , cnt , ans , v , u , s) ;
}

cnt[curr_val]-- ;
if(cnt[curr_val] == 0)
s.insert(curr_val) ;
return ;
}

void solve()
{

ll n ;
cin >> n ;
vector<ll> val(n) , cnt(n+1) ;
set<ll> s ;
for(int i = 0 ; i < n ; i++)
{
cin >> val[i] ;
s.insert(i) ;
}
s.insert(n) ;

for(int i = 0 ; i < n-1 ; i++)
{
ll u , v ;
cin >> u >> v ;
u-- ; v-- ;
}

ll ans = 0 ;

dfs(adj , val , cnt , ans , 0 , -1 , s) ;
cout << ans << endl ;
return ;
}

int main()
{
ios_base::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
#ifndef ONLINE_JUDGE
freopen("inputf.txt" , "r" , stdin) ;
freopen("outputf.txt" , "w" , stdout) ;
freopen("error.txt" , "w" , stderr) ;
#endif

ll t;
cin >> t ;
while(t--)
{
solve() ;
}

cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";

return 0;
}

Tester's Solution 1
#include <bits/stdc++.h>
using namespace std;

/*
------------------------Input Checker----------------------------------
*/

long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);

assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}

if(!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(1 == 0);
}

return x;
} else {
assert(false);
}
}
}
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
}
long long readIntLn(long long l,long long r){
}
}
}

/*
------------------------Main code starts here----------------------------------
*/

const int MAX_T = 10000;
const int MAX_N = 100000;
const int MAX_SUM_N = 100000;
const int lim = 1000007;

#define ll long long int
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)

long long int sum_len=0;

int bit[lim], cnt[lim], val[lim], dsu[lim];
int cmex = 0, mmex;
vector<int> ed[lim];

int fnd(int a) {
return (dsu[a] == a?a:dsu[a] = fnd(dsu[a]));
}

bool unite(int a, int b) {
a = fnd(a);
b = fnd(b);
if(a == b) return false;
dsu[a] = b;
return true;
}

void updateMEX(){
int cur = 0; ll csum = 0;
for(int i = 18;i > -1;i--){
if(cur + (1LL<<i) < lim && bit[cur + (1LL<<i)] + csum == cur + (1LL<<i)) cur += (1LL<<i),csum += bit[cur];
}
cmex = cur;
mmex = max(cmex, mmex);
}

void upd(int pos,int val){
while(pos < lim){
bit[pos] += val;
pos += (pos&(-pos));
}
updateMEX();
}

void dfs(int node, int par) {
cnt[val[node]]++;
if(cnt[val[node]] == 1)
upd(val[node], 1);
for(int i = 0; i < ed[node].size(); i++) {
if(ed[node][i] == par) continue;
dfs(ed[node][i], node);
}
cnt[val[node]]--;
if(cnt[val[node]] == 0)
upd(val[node], -1);
}

void solve()
{
for(int i = 1; i < n; i++) val[i] = readIntSp(0, n), val[i]++, ed[i].clear(), dsu[i] = i;
val[n] = readIntLn(0, n), val[n]++, ed[n].clear(), dsu[n] = n;
ed[n].clear();
for(int i = 1; i < n; i++) {
assert(unite(a, b));
ed[a].push_back(b);
ed[b].push_back(a);
}
mmex = 0;
dfs(1, 0);
cout << mmex << "\n";
}

signed main()
{
//fast;
#ifndef ONLINE_JUDGE
//freopen("input.txt", "r", stdin);
//freopen("output.txt", "w", stdout);
#endif

for(int i=1;i<=t;i++)
{
solve();
}

assert(getchar() == -1);
}

Tester's Solution 2
#include <bits/stdc++.h>
using namespace std;

/*
------------------------Input Checker----------------------------------
*/

long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);

assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}

if(!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(1 == 0);
}

return x;
} else {
assert(false);
}
}
}
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
}
long long readIntLn(long long l,long long r){
}
}
}

/*
------------------------Main code starts here----------------------------------
*/

const int MAX_T = 1e5;
const int MAX_N = 1e5;
const int MAX_SUM_LEN = 1e5;

#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define ff first
#define ss second
#define mp make_pair
#define ll long long
#define rep(i,n) for(int i=0;i<n;i++)
#define rev(i,n) for(int i=n;i>=0;i--)
#define rep_a(i,a,n) for(int i=a;i<n;i++)
#define pb push_back

int sum_n = 0, sum_m = 0;
int max_n = 0, max_m = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;
ll mod = 1000000007;

ll po(ll x, ll n){
ll ans=1;
while(n>0){ if(n&1) ans=(ans*x)%mod; x=(x*x)%mod; n/=2;}
return ans;
}

int ans;

void dfs(int c, int p, vector<vector<int> >&g, vector<int>&a, vector<int>&cnt, set<int>&s){
if(cnt[a[c]]==0){
s.erase(a[c]);
}

cnt[a[c]]++;
ans = max(ans, *s.begin());

for(auto h:g[c]){
if(h!=p) dfs(h,c,g,a,cnt,s);
}

cnt[a[c]]--;
if(cnt[a[c]]==0){
s.insert(a[c]);
}
}

void solve()
{

vector<int> a(n);
rep(i,n){
}

vector<vector<int> > g(n);
int x,y;

rep(i,n-1){

assert(x!=y);

x--;
y--;
g[x].pb(y);
g[y].pb(x);
}

set<int> s;
rep(i,n+1) s.insert(i);
vector<int> cnt(n+1, 0);

ans = 0;

dfs(0, -1, g, a, cnt, s);

cout<<ans<<'\n';
}

signed main()
{

#ifndef ONLINE_JUDGE
freopen("input.txt", "r" , stdin);
freopen("output.txt", "w" , stdout);
#endif
fast;

int t = 1;

for(int i=1;i<=t;i++)
{
solve();
}

assert(getchar() == -1);
assert(sum_n<=1e5);

cerr<<"SUCCESS\n";
cerr<<"Tests : " << t << '\n';
cerr<<"Sum of lengths : " << sum_n <<'\n';
cerr<<"Maximum length : " << max_n <<'\n';
// cerr<<"Total operations : " << total_ops << '\n';
//cerr<<"Answered yes : " << yess << '\n';
//cerr<<"Answered no : " << nos << '\n';
}


Feel free to share your approach. Suggestions are welcomed as always.

4 Likes

Didn’t get what MEX is: What is MEX(1,1) if N is 1 ?

MEX(1,1) will be 0
MEX is the minimum number that is not present from 0 to N

I was amazed to see 26 test files for this problem. Unfortunately, the problem setter forgot to cover this kind of tree.

    Input Tree Structure (Nodes):               Values of Nodes:

1                                   0
/ \                                 / \
2   3                               1   1
/ \                                 / \
4   5                               2   2
/ \                                 / \
6   7                               3   3
/ \                                 / \
8   9                               4   4

Input: T = 3, N = 1e5.


Here’s the Generator

Generator in C++
/**
* @author: Sai Suman Chitturi [suman_18733097]
*/

#include <bits/stdc++.h>
using namespace std;

int randint(int a, int b) {
return a + rand() % (b + 1 - a);
}

void generate() {

// Maximum value of n
int n = 1e5;
printf("%d\n", n);

// Values of nodes, as shown in the second diagram
for(int i = 0; i < n - 1; i++) {
printf("%d ", (i + 1) / 2);
}
printf("%d\n", n / 2);

// Edges as shown in the first diagram
for(int i = n; i > 1; i--) {
if(i % 2 == 0) {
if(i == 2) {
printf("2 1\n");
}
else {
printf("%d %d\n", i, i - 2);
}
}
else {
if(i == 3) {
printf("3 1\n");
}
else {
printf("%d %d\n", i, i - 3);
}
}
}
}

int main(int argc, char* argv[]) {
srand(atoi(argv[1]));
int test_cases = 3;
printf("%d\n", test_cases);
for(int test = 0; test < test_cases; test++) {
generate();
}
return 0;
}


My solution fails to run within the given time limit for this input.

Edit: I mean, there could be many such submissions.

2 Likes

My solution Why this is failing in last test case

Guys I Know my code will give TLE.
i submitted it and i got wrong answers on most of the test cases.
Can you please tell me where did i go wrong?

#include <bits/stdc++.h>
//#include<ext/pb_ds/assoc_container.hpp>
//#include<ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace chrono;
//using namespace __gnu_pbds;

//template <class T> using Tree = tree<T, null_type, less<T>, rb_tree_tag,tree_order_statistics_node_update>;

using ll = long long;
using ld = long double;

#pragma GCC optimize("Ofast")
#pragma GCC optimize("no-stack-protector")
#pragma GCC optimize("unroll-loops")
#pragma GCC target("sse,sse2,sse3,ssse3,popcnt,abm,mmx,tune=native")
#pragma GCC optimize("fast-math")

#define sz(x) int(x.size())
#define all(x) x.begin(), x.end()
#define ALL(x, n) x, x + n
#define SZ(x) sizeof(x)
#define endl "\n"
#define f first
#define s second
#define pb push_back
#define pf push_front
#define sorted(v) is_sorted(all(v))
#define pqueue priority_queue < pair < int, int > , vector < pair < int, int >> , greater < pair < int, int >>> pq;
#define pop pop_back
#define ar array
#define MAX(x) * max_element(all(x))
#define MIN(x) * min_element(all(x))

#define debug(x) cout << #x << " " << x << endl
#define debug_(v, n) cout << "[ ";
for (int i = 0; i < n; ++i) cout << v[i] << " ";
cout << "]";
#define debug__(v) cout << "[ " << v.f << " " << v.s << " " << "]" << endl;

#define gcd(x, y) __gcd(x, y)
#define lcm(x, y) x / (gcd(x, y)) * y

typedef pair < int, int > pi;
typedef vector < int > vi;
typedef set < int > si;
typedef vector < string > vs;
typedef vector < ll > vl;
typedef unordered_map < char, int > unchar;
typedef unordered_map < int, int > unint;

int Rand(int l, int r) {
uniform_int_distribution < int > uid(l, r);
return uid(rng);
}

int dx[8] = {
1,
1,
1,
0,
0,
-1,
-1,
-1
}; //dx();
int dy[8] = {
1,
-1,
0,
1,
-1,
1,
-1,
0
}; //dy();

const int mZX = 1e5 + 1;
vector < vector < ll >> v(mZX);
vector < ll > val(mZX), MEXS;
unordered_set < ll > s;

ll MEX_(unordered_set < ll > s) {
ll ans = (1e9 + 1) / 2, TP = MAX(s);
for (ll i = 0; i <= TP + 1; ++i) {
if (!s.count(i)) {
ans = i;
break;
}
}
return ans;
}

void dfs(ll i) {
s.insert(val[i - 1]);
ll VL = MEX_(s);
MEXS.pb(VL);
for (ll j: v[i]) {
dfs(j);
}
s.erase(val[i - 1]);
}

//O((n^2) * m);

int main() {
cin.tie(NULL) -> sync_with_stdio(0);
cin.exceptions(cin.failbit);

ll t;
cin >> t;

while (t--) {
ll n;
cin >> n;
for (ll i = 0; i < n; ++i) cin >> val[i];
for (ll i = 0; i < n - 1; ++i) {
ll a, b;
cin >> a >> b;
v[a].pb(b);
//v[b].pb(a);
}
dfs(1);
cout << MAX(MEXS) << endl;
for (ll i = 0; i <= n + 1; ++i) v[i].clear();
val.clear(), MEXS.clear(), s.clear();
}
}


With an extra logN factor, we can use a fenwick tree. While maintaining frequency of every element in set, if the frequency is 1, update fenwick tree, upd(a[u], +1) and when the frequency becomes 0, upd(a[u], -1). At every point in dfs, do a binary search over fenwick tree to find max i such the prefixSum[i] = i . Then i + 1 is your current MEX. Here upd function adds values to index, that is why +1, -1 to track if number is present on not.
Code :- Solution: 60357488 | CodeChef

Someone please explain me that in dfs function while reinserting the node
if(cnt[value[node]] == 0)
st.insert(value[node]);

works Here while
if(st.find(value[node]) == st.end())
st.insert(value[node]);

We can optimise it further by finding the \text{MEX} only when we encounter a leaf node. Because the \text{MEX} of values of all nodes in the path from 1 to K (where K is a leaf node) will be always greater than or equal to that of all internal nodes encountered in the path from 1 to K.

1 Like

My Solution

Can someone help me with this. I’ve tried to traverse the tree with node indexes.

Sample test case:
5
1 2 3 4 0
1 2
2 3
2 4
1 5

While reading the input, I try to build a paths array which works as follows
paths[y] = x

So,
paths[1] = 0 (zero for root node)
paths[2] = 1
paths[3] = 2
paths[4] = 2
paths[5] = 1
and so on

set[0] = A[0] (value of root)
Now, I can make my sets by iterating from (1,N) as follows:
set[i] = A[i] U set[paths[i]]

now iterating over this set list, I can find mex and max(mex) accordingly.

I missed to read it as tree instead tried to continue applying Dijktra algorithm for the shortest path. Only 2 cases passed and rest all failed.

What could have been the reason for dijkstra not working? I am still struggling to figure out.

This is my first post here. Just not sure if code long code snippet can be attached. I am pasting here.
Any help is much appreciated.

from collections import defaultdict

def find_mex(s):
return list(set(range(len(s) + 1)) - s)[0]

def get_min_position(dist, visited):
min_ = float('inf')
idx = len(dist)
for i, j in enumerate(dist):
if i not in visited:
if j < min_:
min_ = j
idx = i
return idx

def testcase():
n = int(input())
nodes = list(map(int, input().split()))
edges = defaultdict(list)
for _ in range(n-1):
u, v= map(int, input().split())
edges[u-1].append(v-1)

dist = [ float('inf') for _ in range(n)]
dist[0] = nodes[0]
visited = set()
max_mex = find_mex({nodes[0]})
nodes_mex = [ [ {nodes[0]}, max_mex ] for _ in range(n)]
i = 0
while i < n:
if nodes_mex[i][1] > max_mex:
max_mex = nodes_mex[i][1]
for node in edges[i]:
if node not in visited:
new_set = nodes_mex[i][0].union({nodes[node]})
new_mex = find_mex(new_set)
if dist[i] + 1 < dist[node]:
dist[node] = dist[i] + 1
nodes_mex[node][0] = new_set
nodes_mex[node][1] = new_mex
elif dist[i] + 1 == dist[node] and new_mex > nodes_mex[node][1]:
nodes_mex[node][0] = new_set
nodes_mex[node][1] = new_mex
i = get_min_position(dist, visited)
return max_mex

for _ in range(int(input())):
print(testcase())


Your code is somehow dependent on the ordering of edges.
For this tc the answer should be 2 but your code gives 0.
1
4
1 2 1 0
4 1
1 2
1 3
But when I switch 4 1 with 1 4 the output is correct.

Should not this be the case when the input says E(u,v), this means there is an edge originating at u and ending at v. E(4,1) implies 4 should be the root in this case and there is no path from 1 to 4 but there is a path from 4 to 1.

As per result of this test case, I have to assume that ordering of u,v in edges does not matter. Whichever is smaller out of u and v, should be the source. Is my understanding right?

Even with the understanding of having smaller node as the source node, most of the test cases fail.

My solution with better Time Complexity
Here

2 Likes

void dfs(int s,int par){
vis[a[s]]++;
if(vis[a[s]]) // what is this check for ??
st.erase(a[s]);
for(auto it:g[s]){
if(it!=par)
dfs(it,s);
}
ans=max(ans,*st.begin());
vis[a[s]]–;
if(vis[a[s]]==0) // Why this check is required ??
st.insert(a[s]);
}

The given tree is undirected. So when there is an edge from 4 to 1 then there is also an edge from 1 to 4.

How do you test this input against your code?

1 Like

/* By Krishna Kumar */

#include<bits/stdc++.h>

#include

// #include<ext/pb_ds/assoc_container.hpp>

//#include<ext/pb_ds/tree_policy.hpp>

//#include<ext/pb_ds/trie_policy.hpp>

//using namespace __gnu_pbds;

using namespace std;

#define ll long long int

#define ld long double

#define mod 10000000

#define inf 1e9

#define endl “\n”

#define pb push_back

#define vi vector

#define vs vector

#define pil pair<ll,ll>

#define ump unordered_map

#define mp make_pair

#define pq_max priority_queue

#define pq_min priority_queue<ll, vi , greater>

#define all(v) v.begin(), v.end()

#define ff first

#define ss second

#define mid(l,r) l(l+(r-1)/2)

#define bitx(x) __builtin_popcount(x)

#define loop(i,a,b) for(int i = (a);i<=(b);i++)

#define looprev(i,a,b) for(int i = (a); i>=b; i–)

#define iter(c,it) for(__typeof(c.begin()) it = c.begin();it!=c.end();it++)

#define log(args…) { string _s = #args; replace(_s.begin(), _s.end(), ‘,’, ’ '); stringstream _ss(_s); istream_iterator _it(_ss); err(_it, args); }

void err(istream_iterator it){}

template<typename T, typename… Args>

void err(istream_iterator it, T a, Args… args){

cout<<*it<< "="<<a<<endl;

err(++it, args...);


}

#define logarr(arr,a,b) for(int i = a;i<=b;i++) cout<<arr[i]<<" "; cout<<endl;

template T gcd(T a, T b){if(a%b) return gcd(b, a%b); return b;}

template T lcm(T a, T b){return ((a*b)/gcd(a,b));}

vs tokenizer(string str, char ch ){std::istringstream var((str)); vs v; string t; while(getline((var), t, (ch))){v.pb(t);} return v;};

//typedef tree<ll, null_type, less, rb_tree_tag, tree_order_statistics_node_update> pbds;

//typedef trie<string, null_type, trie_string_access_traits<>, pat_trie_tag, trie_prefix_search_node_update> pbtrie;

// void file_i_o(){

// ios_base::sync_with_stdio(0);

// cin.tie(0);

// cout.tie(0);

// #ifndef ONLINE_JUDGE

// freopen(“input.txt”, “r”, stdin);

// freopen(“output.txt”, “w”, stdout);

// #endif

// }

vector<vector> graph;

vector a;

vector vis;

set st;

int ans;

void dfs(int s){

// vis[a[s]]++;

// if(vis[a[s]]) st.erase(a[s]);

vis[s] = 1;

st.erase(a[s]);

for(int child: graph[s]){

if(!vis[child]){

dfs(child);

}

// if(child == par) continue;

// dfs(child, s);

}

ans = max(ans, *st.begin());

st.insert(a[s]);

// ans = max(ans, *st.begin());

// vis[a[s]]--;

// if(vis[a[s]]==0) st.insert(a[s]);


}

void solve(){

int n;

cin >> n;

//clearing the variables

graph.clear();

a.clear();

vis.clear();

// resizing the variables

graph.resize(n+1);

a.resize(n+1);

vis.resize(n+1, 0);

loop(i, 1, n) cin >> a[i];

loop(i, 1, n-1){

int u, v;

cin >> u >> v;

graph[u].push_back(v);

graph[v].push_back(u);

}

loop(i, 0, n){

st.insert(i);

}

ans = -1;

// dfs(1, -1);

dfs(1);

cout << ans << endl;


}

int main(int argc, char const *argv[]){

//file_i_o();

clock_t begin = clock();

int T;

cin>>T;

while(T--){

solve();

}

#ifndef ONLINE_JUDGE

clock_t end = clock();

cout<<"\n\n\nExecuted In: "<<double(end-begin)/CLOCKS_PER_SEC*1000<<" ms";

#endif


}
it is passing most of the test case but failing for few testCases . can anyone explain why ?

Suppose we are at some state of dfs traversal and about to backtrack. Let the value at current node (which is leave node by assumption) be v and cnt[v] = 2. Currently there is no v in set st, therefore while backtracking we must not insert v to st , but in your second code it will check if(st.find(value[node]) == st.end()); which is true (as we removed it when we first came across v), hence it will add it to st which not the right thing to do. Hence it can lead to wrong answer.

Not to mention if(cnt[value[node]] == 0) was the exact thing needed.