PROBLEM LINK:
Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4
Practice
Setter:
Testers: Tejas Pandey and Abhinav sharma
Editorialist: Taranpreet Singh
DIFFICULTY
Easy-Medium
PREREQUISITES
PROBLEM
Consider a tree with N nodes, rooted at node 1.
The value of the i^{th} node (1 \leq i \leq N) is denoted by A_i.
Chef defines the function MEX(u, v) as follows:
Let B denote the set of all the values of nodes that lie in the shortest path from u to v (including both u and v). Then, MEX(u, v) denotes the MEX of the set B.
Find the maximum value of MEX(1, i), where 1 \leq i \leq N.
QUICK EXPLANATION
- Maintain a set of integers not yet found on the current path from the root to u for some node u.
- If A_u is present in the set, we shall remove it at the start of DFS, process its subtree, and then add back A_u to set.
- The MEX on a path from the root to u is the minimum element present in this set which the ordered set can answer fast enough.
EXPLANATION
We need to compute the MEX of values on the path from node 1 to node u for all u and take its maximum. Let’s denote S_u as the set of values on the path from 1 to u. Then we have B_u = \{A_u\} \bigcup B_p and B_1 = \{A_1\}…
So, if we process the nodes of trees in DFS order, we only need to add and remove each value exactly once. The value would be added when DFS enters node u, and the value A_u would be removed when DFS exits node u.
Hence, we have to support adding an element, removing an element, and computing MEX of the current set.
There are multiple possible solutions for this problem, I’d discuss two for now.
Clever solution, simpler implementation
Instead of maintaining the set of values found on the path from 1 to u, let’s maintain the set of values not found on the path from 1 to u. This helps us because now, the MEX is simply the smallest value present in this set.
Now, when DFS enters a node u, we check if this special set contains A_u. If it contains, we would remove A_u from the set at the start.
Now, MEX for the current set is the minimum element in this special set. Also, we can make recursive calls to solve for all nodes in subtree of u.
Lastly, If A_u was removed at the start of DFS, we add back A_u to this set as we exit node u.
The implementation of this is added in Editorialist solution 1.
Generic overkill solution using segment tree
We maintain the frequency of each element in the segment tree. Specifically, leaf x denotes the number of occurrences of x. Our segment tree should support
- Increase/decrease frequency of an element x by 1.
- Find MEX of elements currently present.
The MEX of elements would be the leftmost leaf with 0 frequency, which can be found using a technique called tree descent, described here.
The implementation of this is added in Editorialist solution 2.
TIME COMPLEXITY
The time complexity of both solutions is O(N*log_2(N)) per test case.
SOLUTIONS
Editorialist's Solution 1
import java.util.*;
import java.io.*;
class Main{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
int N = ni();
int[] A = new int[N];
for(int i = 0; i< N; i++)A[i] = ni();
int[] from = new int[N-1], to = new int[N-1];
for(int i = 0; i< N-1; i++){
from[i] = ni()-1;
to[i] = ni()-1;
}
int[][] tree = make(N, from, to);
TreeSet<Integer> set = new TreeSet<>();
for(int i = 0; i<= N; i++)set.add(i); //Building set of values not found
pn(dfs(tree, set, A, 0, -1));
}
int dfs(int[][] tree, TreeSet<Integer> set, int[] A, int u, int p){
boolean rem = false;
if(set.contains(A[u])){
set.remove(A[u]);
rem = true;
}
int ans = set.first();
for(int v:tree[u])if(v != p)ans = Math.max(ans, dfs(tree, set, A, v, u));
if(rem)set.add(A[u]);
return ans;
}
int[][] make(int N, int[] f, int[] t){
int[] cnt = new int[N];
for(int x:f)cnt[x]++;
for(int x:t)cnt[x]++;
int[][] g = new int[N][];
for(int i = 0; i< N; i++){
g[i] = new int[cnt[i]];
cnt[i] = 0;
}
for(int i = 0; i< N-1; i++){
g[f[i]][cnt[f[i]]++] = t[i];
g[t[i]][cnt[t[i]]++] = f[i];
}
return g;
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new Main().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Editorialist's Solution 2
import java.util.*;
import java.io.*;
class Main{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
int N = ni();
int[] A = new int[N];
for(int i = 0; i< N; i++)A[i] = ni();
int[] from = new int[N-1], to = new int[N-1];
for(int i = 0; i< N-1; i++){
from[i] = ni()-1;
to[i] = ni()-1;
}
int[][] tree = make(N, from, to);
int S = 1;
while(S <= N)S<<=1;
int[] segtree = new int[S<<1];
pn(dfs(tree, segtree, A, 0, -1, S));
}
int dfs(int[][] tree, int[] segtree, int[] A, int u, int p, int S){
update(segtree, 0, S-1, 1, A[u], 1);
int ans = mex(segtree, 0, S-1, 1);
for(int v:tree[u])if(v != p)ans = Math.max(ans, dfs(tree, segtree, A, v, u, S));
update(segtree, 0, S-1, 1, A[u], -1);
return ans;
}
void update(int[] segtree, int ll, int rr, int i, int p, int x){
if(ll == rr)segtree[i] += x;
else{
int mid = (ll+rr)/2;
if(p <= mid)update(segtree, ll, mid, i<<1, p, x);
else update(segtree, mid+1, rr, i<<1|1, p, x);
segtree[i] = Math.min(segtree[i<<1], segtree[i<<1|1]);
}
}
int mex(int[] segtree, int ll, int rr, int i){
if(ll == rr)return ll;
int mid = (ll+rr)/2;
if(segtree[i<<1] > 0)return mex(segtree, mid+1, rr, i<<1|1);
return mex(segtree, ll, mid, i<<1);
}
int[][] make(int N, int[] f, int[] t){
int[] cnt = new int[N];
for(int x:f)cnt[x]++;
for(int x:t)cnt[x]++;
int[][] g = new int[N][];
for(int i = 0; i< N; i++){
g[i] = new int[cnt[i]];
cnt[i] = 0;
}
for(int i = 0; i< N-1; i++){
g[f[i]][cnt[f[i]]++] = t[i];
g[t[i]][cnt[t[i]]++] = f[i];
}
return g;
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new Main().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Setter's Solution
#define ll long long
#define dd long double
#define pub push_back
#define pob pop_back
#define puf push_front
#define pof pop_front
#define mp make_pair
#define mt make_tuple
#define fo(i , n) for(ll i = 0 ; i < n ; i++)
#define tll tuple<ll ,ll , ll>
#define pll pair<ll ,ll>
#include<bits/stdc++.h>
/*#include<iomanip>
#include<cmath>
#include<cstdio>
#include<utility>
#include<iostream>
#include<vector>
#include<string>
#include<algorithm>
#include<map>
#include<queue>
#include<set>
#include<stack>
#include<bitset>*/
dd pi = acos(-1) ;
ll z = 1000000007 ;
ll inf = 10000000000000 ;
ll p1 = 37 ;
ll p2 = 53 ;
ll mod1 = 202976689 ;
ll mod2 = 203034253 ;
ll fact[200] ;
ll gdp(ll a , ll b){return (a - (a%b)) ;}
ll ld(ll a , ll b){if(a < 0) return -1*gdp(abs(a) , b) ; if(a%b == 0) return a ; return (a + (b - a%b)) ;} // least number >=a divisible by b
ll gd(ll a , ll b){if(a < 0) return(-1 * ld(abs(a) , b)) ; return (a - (a%b)) ;} // greatest number <= a divisible by b
ll gcd(ll a , ll b){ if(b > a) return gcd(b , a) ; if(b == 0) return a ; return gcd(b , a%b) ;}
ll e_gcd(ll a , ll b , ll &x , ll &y){ if(b > a) return e_gcd(b , a , y , x) ; if(b == 0){x = 1 ; y = 0 ; return a ;}
ll x1 , y1 , g; g = e_gcd(b , a%b , x1 , y1) ; x = y1 ; y = (x1 - ((a/b) * y1)) ; return g ;}
ll power(ll a ,ll b , ll p){if(b == 0) return 1 ; ll c = power(a , b/2 , p) ; if(b%2 == 0) return ((c*c)%p) ; else return ((((c*c)%p)*a)%p) ;}
ll inverse(ll a ,ll n){return power(a , n-2 , n) ;}
ll max(ll a , ll b){if(a > b) return a ; return b ;}
ll min(ll a , ll b){if(a < b) return a ; return b ;}
ll left(ll i){return ((2*i)+1) ;}
ll right(ll i){return ((2*i) + 2) ;}
ll ncr(ll n , ll r){if(n < r|| (n < 0) || (r < 0)) return 0 ; return ((((fact[n] * inverse(fact[r] , z)) % z) * inverse(fact[n-r] , z)) % z);}
void swap(ll&a , ll&b){ll c = a ; a = b ; b = c ; return ;}
//ios_base::sync_with_stdio(0);
//cin.tie(0); cout.tie(0);
using namespace std ;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
#define ordered_set tree<int, null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>
// ordered_set s ; s.order_of_key(val) no. of elements strictly less than val
// s.find_by_order(i) itertor to ith element (0 indexed)
//__builtin_popcount(n) -> returns number of set bits in n
ll seed;
mt19937 rnd(seed=chrono::steady_clock::now().time_since_epoch().count()); // include bits
void dfs(vector<ll> adj[], vector<ll> &val, vector<ll> &cnt, ll &ans, ll u, ll p , set<ll> &s)
{
ll curr_val = val[u] ;
cnt[curr_val]++ ;
if(cnt[curr_val] == 1)
s.erase(s.find(curr_val)) ;
ans = max(ans , (*s.begin())) ;
for(int i = 0 ; i < adj[u].size() ; i++)
{
ll v = adj[u][i] ;
if(v == p)
continue ;
dfs(adj , val , cnt , ans , v , u , s) ;
}
cnt[curr_val]-- ;
if(cnt[curr_val] == 0)
s.insert(curr_val) ;
return ;
}
void solve()
{
ll n ;
cin >> n ;
vector<ll> val(n) , cnt(n+1) ;
set<ll> s ;
for(int i = 0 ; i < n ; i++)
{
cin >> val[i] ;
s.insert(i) ;
}
s.insert(n) ;
vector<ll> adj[n] ;
for(int i = 0 ; i < n-1 ; i++)
{
ll u , v ;
cin >> u >> v ;
u-- ; v-- ;
adj[u].pub(v) ;
adj[v].pub(u) ;
}
ll ans = 0 ;
dfs(adj , val , cnt , ans , 0 , -1 , s) ;
cout << ans << endl ;
return ;
}
int main()
{
ios_base::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
#ifndef ONLINE_JUDGE
freopen("inputf.txt" , "r" , stdin) ;
freopen("outputf.txt" , "w" , stdout) ;
freopen("error.txt" , "w" , stderr) ;
#endif
ll t;
cin >> t ;
while(t--)
{
solve() ;
}
cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
return 0;
}
Tester's Solution 1
#include <bits/stdc++.h>
using namespace std;
/*
------------------------Input Checker----------------------------------
*/
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
if(!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(1 == 0);
}
return x;
} else {
assert(false);
}
}
}
string readString(int l,int r,char endd){
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
return readString(l,r,'\n');
}
string readStringSp(int l,int r){
return readString(l,r,' ');
}
/*
------------------------Main code starts here----------------------------------
*/
const int MAX_T = 10000;
const int MAX_N = 100000;
const int MAX_SUM_N = 100000;
const int lim = 1000007;
#define ll long long int
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
long long int sum_len=0;
int bit[lim], cnt[lim], val[lim], dsu[lim];
int cmex = 0, mmex;
vector<int> ed[lim];
int fnd(int a) {
return (dsu[a] == a?a:dsu[a] = fnd(dsu[a]));
}
bool unite(int a, int b) {
a = fnd(a);
b = fnd(b);
if(a == b) return false;
dsu[a] = b;
return true;
}
void updateMEX(){
int cur = 0; ll csum = 0;
for(int i = 18;i > -1;i--){
if(cur + (1LL<<i) < lim && bit[cur + (1LL<<i)] + csum == cur + (1LL<<i)) cur += (1LL<<i),csum += bit[cur];
}
cmex = cur;
mmex = max(cmex, mmex);
}
void upd(int pos,int val){
while(pos < lim){
bit[pos] += val;
pos += (pos&(-pos));
}
updateMEX();
}
void dfs(int node, int par) {
cnt[val[node]]++;
if(cnt[val[node]] == 1)
upd(val[node], 1);
for(int i = 0; i < ed[node].size(); i++) {
if(ed[node][i] == par) continue;
dfs(ed[node][i], node);
}
cnt[val[node]]--;
if(cnt[val[node]] == 0)
upd(val[node], -1);
}
void solve()
{
int n = readIntLn(1, MAX_N);
for(int i = 1; i < n; i++) val[i] = readIntSp(0, n), val[i]++, ed[i].clear(), dsu[i] = i;
val[n] = readIntLn(0, n), val[n]++, ed[n].clear(), dsu[n] = n;
ed[n].clear();
for(int i = 1; i < n; i++) {
int a = readIntSp(1, n);
int b = readIntLn(1, n);
assert(unite(a, b));
ed[a].push_back(b);
ed[b].push_back(a);
}
mmex = 0;
dfs(1, 0);
cout << mmex << "\n";
}
signed main()
{
//fast;
#ifndef ONLINE_JUDGE
//freopen("input.txt", "r", stdin);
//freopen("output.txt", "w", stdout);
#endif
int t = readIntLn(1, MAX_T);
for(int i=1;i<=t;i++)
{
solve();
}
assert(getchar() == -1);
}
Tester's Solution 2
#include <bits/stdc++.h>
using namespace std;
/*
------------------------Input Checker----------------------------------
*/
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
if(!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(1 == 0);
}
return x;
} else {
assert(false);
}
}
}
string readString(int l,int r,char endd){
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
return readString(l,r,'\n');
}
string readStringSp(int l,int r){
return readString(l,r,' ');
}
/*
------------------------Main code starts here----------------------------------
*/
const int MAX_T = 1e5;
const int MAX_N = 1e5;
const int MAX_SUM_LEN = 1e5;
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define ff first
#define ss second
#define mp make_pair
#define ll long long
#define rep(i,n) for(int i=0;i<n;i++)
#define rev(i,n) for(int i=n;i>=0;i--)
#define rep_a(i,a,n) for(int i=a;i<n;i++)
#define pb push_back
int sum_n = 0, sum_m = 0;
int max_n = 0, max_m = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;
ll mod = 1000000007;
ll po(ll x, ll n){
ll ans=1;
while(n>0){ if(n&1) ans=(ans*x)%mod; x=(x*x)%mod; n/=2;}
return ans;
}
int ans;
void dfs(int c, int p, vector<vector<int> >&g, vector<int>&a, vector<int>&cnt, set<int>&s){
if(cnt[a[c]]==0){
s.erase(a[c]);
}
cnt[a[c]]++;
ans = max(ans, *s.begin());
for(auto h:g[c]){
if(h!=p) dfs(h,c,g,a,cnt,s);
}
cnt[a[c]]--;
if(cnt[a[c]]==0){
s.insert(a[c]);
}
}
void solve()
{
int n = readIntLn(1, 1e5);
vector<int> a(n);
rep(i,n){
if(i<n-1) a[i] = readIntSp(0,n);
else a[i] = readIntLn(0,n);
}
vector<vector<int> > g(n);
int x,y;
rep(i,n-1){
x = readIntSp(1,n);
y = readIntLn(1,n);
assert(x!=y);
x--;
y--;
g[x].pb(y);
g[y].pb(x);
}
set<int> s;
rep(i,n+1) s.insert(i);
vector<int> cnt(n+1, 0);
ans = 0;
dfs(0, -1, g, a, cnt, s);
cout<<ans<<'\n';
}
signed main()
{
#ifndef ONLINE_JUDGE
freopen("input.txt", "r" , stdin);
freopen("output.txt", "w" , stdout);
#endif
fast;
int t = 1;
t = readIntLn(1,10000);
for(int i=1;i<=t;i++)
{
solve();
}
assert(getchar() == -1);
assert(sum_n<=1e5);
cerr<<"SUCCESS\n";
cerr<<"Tests : " << t << '\n';
cerr<<"Sum of lengths : " << sum_n <<'\n';
cerr<<"Maximum length : " << max_n <<'\n';
// cerr<<"Total operations : " << total_ops << '\n';
//cerr<<"Answered yes : " << yess << '\n';
//cerr<<"Answered no : " << nos << '\n';
}
Feel free to share your approach. Suggestions are welcomed as always.