NTA-Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4

Setter: Devendra Singh
Tester: Manan Grover, Lavish Gupta
Editorialist: Devendra Singh

DIFFICULTY:

2586

PREREQUISITES:

Trie ,Small to large technique, C++ multiset, Basic Graph algorithms.

PROBLEM:

You are given a tree with N nodes rooted at node 1.
The i^{th} node of the tree has a value A_i assigned to it.

For the i^{th} node (1\le i \le N), output the minimum value of bitwise XOR of the values of any two distinct nodes in the subtree of node i.

Formally, for each node i, find the value of min(A_x \oplus A_y), where x \ne y and x, y \in subtree of node i.
Note that since leaf nodes have only 1 node in their subtree, output -1 for all the leaf nodes.

EXPLANATION:

Solution using C++ multiset
Claim: Given an array of integers A, the pair of integers in the array which have minimum XOR value is the one having minimal value of A[i] XOR A[i+1] among all i from 1 to N-1 in the sorted array.

Proof

The first step is to sort the array.
Let’s suppose that the answer is not A[i] XOR A[i+1], but X XOR Y and there exists Z in the array such as X <= Z <= Y.

Next is the proof that either X XOR Z or Z XOR Y are smaller than X XOR Y.

Let X[i] = 0/1 be the i-th bit in the binary representation of X
Let Y[i] = 0/1 be the i-th bit in the binary representation of Y
Let Z[i] = 0/1 be the i-th bit in the binary representation of Z

This is with the assumption that all of X, Y and Z are padded with 0 on the left until they all have the same length
Let i be the leftmost (biggest) index such that X[i] differs from Y[i]. There are 2 cases now:

  1. Z[i] = X[i] = 0,
    then (X XOR Z)[i] = 0 and (X XOR Y)[i] = 1
    This implies (X XOR Z) < (X XOR Y)
  2. Z[i] = Y[i] = 1,
    then (Y XOR Z)[i] = 0 and (X XOR Y)[i] = 1
    This implies (Y XOR Z) < (X XOR Y)

Therefore for each node we need to iterate on all node values in its subtree and take the bitwise XOR of this value with just greater value than this value and just smaller value than this value in the subtree. The brute force approach to maintain the values and iterate on each value is not fast enough.
The technique used in the solution to maintain the subtree node values for each subtree is called small to large merging or dsu on trees. This technique works in O(Nlog N) complexity assuming O(1) insertion time complexity.However since each insertion in multiset also takes O(log(N)) time the overall complexity becomes O(Nlog^2(N)).
The technique:
First, we calculate the subtree size for each node u from 1 to N using a simple dfs.
sz[u]=u+all(sz[child\: of\: u])
Then, we will find a big child of each node u using pre-calculated sz[u]. Note that the big child of vertex u is the child with maximum subtree size. All other nodes are referred to as small child.
Now, we know that
subtree[u]=subtree[big child]+u+all(subtree[small child])
Let subtree[u] be empty in the beginning. Because we only care about current subtree, we can swap subtree[big child] and subtree[u] so that subtree[u] has all nodes from bigchild’s subtree. The complexity of swap operation is O(1). Initialize ans[u] = ans[bigchild]
Then recur for all small child and keep merging all the values to the subtree[u] and while merging each value take the bitwise XOR of this value with just greater value than this and just smaller value than this in the subtree and keep updating the answer for the subtree. Just greater value and just smaller value can be found using lower_bound function of the set.

O(Nlog N) overall. why?

let y=sz[big child], x=sz[small child] (y≥x)

let us consider a vertex u (1\leq u\leq N)

whenever we merge it to the bigger child, the size of the subtree that contains u becomes x+y\geq x+x = 2x. So each time we merge subtrees the total size becomes at least twice of the smaller of the two subtrees and maximum size of a subtree is only N so we cannot add a node more than log(N) time.
we have N vertices u so the complexity becomes O(NlogN)

Solution using Trie
The solution with Trie is almost similar to the above solution except the fact that for each node of the tree, instead of maintaining a multiset we maintain a trie and while merging instead of calculating the just greater or just smaller value for every value, we calculate the answer greedily by iterating on the bits of the value in the order of most significant to the least significant bit, equating them and checking whether such a number exists in the trie or not.

For details of implementation please refer to the solutions attached.

TIME COMPLEXITY:

O(Nlog^2(N))\: or\: O(Nlog(N)\cdot log(max(A_i))) for each test case.

SOLUTION:

Tester-1's solution(C++ set)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
#define asc(i,a,n) for(I i=a;i<n;i++)
#define dsc(i,a,n) for(I i=n-1;i>=a;i--)
#define forw(it,x) for(A it=(x).begin();it!=(x).end();it++)
#define bacw(it,x) for(A it=(x).rbegin();it!=(x).rend();it++)
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define lb(x) lower_bound(x)
#define ub(x) upper_bound(x)
#define fbo(x) find_by_order(x)
#define ook(x) order_of_key(x)
#define all(x) (x).begin(),(x).end()
#define sz(x) (I)((x).size())
#define clr(x) (x).clear()
#define U unsigned
#define I long long int
#define S string
#define C char
#define D long double
#define A auto
#define B bool
#define CM(x) complex<x>
#define V(x) vector<x>
#define P(x,y) pair<x,y>
#define OS(x) set<x>
#define US(x) unordered_set<x>
#define OMS(x) multiset<x>
#define UMS(x) unordered_multiset<x>
#define OM(x,y) map<x,y>
#define UM(x,y) unordered_map<x,y>
#define OMM(x,y) multimap<x,y>
#define UMM(x,y) unordered_multimap<x,y>
#define BS(x) bitset<x>
#define L(x) list<x>
#define Q(x) queue<x>
#define PBS(x) tree<x,null_type,less<I>,rb_tree_tag,tree_order_statistics_node_update>
#define PBM(x,y) tree<x,y,less<I>,rb_tree_tag,tree_order_statistics_node_update>
#define pi (D)acos(-1)
#define md 1000000007
#define rnd randGen(rng)
long long readInt(long long l, long long r, char endd) {
    long long x = 0;
    int cnt = 0;
    int fi = -1;
    bool is_neg = false;
    while (true) {
        char g = getchar();
        if (g == '-') {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if ('0' <= g && g <= '9') {
            x *= 10;
            x += g - '0';
            if (cnt == 0) {
                fi = g - '0';
            }
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);
 
            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if (g == endd) {
            assert(cnt > 0);
            if (is_neg) {
                x = -x;
            }
            assert(l <= x && x <= r);
            return x;
        }
        else {
            assert(false);
        }
    }
}
 
string readString(int l, int r, char endd) {
    string ret = "";
    int cnt = 0;
    while (true) {
        char g = getchar();
        assert(g != -1);
        if (g == endd) {
            break;
        }
        cnt++;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}
void comp(I x,I &y,OS(I) &s){
  A it=s.lb(x);
  if(it!=s.end()){
    y=min(y,(*it)^x);
  }
  if(it!=s.begin()){
    it--;
    y=min(y,(*it)^x);
  }
  s.insert(x);
}
void dfs(I x,I pr,V(I) tr[],OS(I) *s[],I dp[],I a[]){
  I z=-1;
  I siz=0;
  dp[x]=LLONG_MAX;
  asc(i,0,sz(tr[x])){
    I y=tr[x][i];
    if(y!=pr){
      dfs(y,x,tr,s,dp,a);
      if(siz<sz(*s[y])){
        siz=sz(*s[y]);
        z=y;
      }
      dp[x]=min(dp[x],dp[y]);
    }
  }
  if(z!=-1){
    comp(a[x-1],dp[x],*s[z]);
    asc(i,0,sz(tr[x])){
      I y=tr[x][i];
      if(y!=pr && y!=z){
        forw(it,*s[y]){
          comp((*it),dp[x],*s[z]);
        }
      }
    }
    s[x]=s[z];
  }else{
    s[x]=new OS(I);
    (*s[x]).insert(a[x-1]);
  }
}
int main(){
  mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
  uniform_int_distribution<I> randGen;
  ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
  #ifndef ONLINE_JUDGE
  freopen("input.txt", "r", stdin);
  freopen("output.txt", "w", stdout);
  #endif
  I t;
  t=readInt(1,100000,'\n');
  I ns=0;
  while(t--){
    I n;
    n=readInt(1,100000,'\n');
    ns+=n;
    assert(ns<=100000);
    V(I) tr[n+1];
    asc(i,0,n-1){
      I u,v;
      u=readInt(1,n,' ');
      v=readInt(1,n,'\n');
      tr[u].pb(v);
      tr[v].pb(u);
    }
    OS(I) *s[n+1];
    I a[n];
    OS(I) x;
    asc(i,0,n){
      if(i!=n-1){
        a[i]=readInt(0,1000000000000000,' ');
      }else{
        a[i]=readInt(0,1000000000000000,'\n');
      }
      x.insert(a[i]);
    }
    I dp[n+1];
    memset(dp,-1,sizeof(dp));
    dfs(1,0,tr,s,dp,a);
    asc(i,1,n+1){
      assert(dp[i]!=-1);
      if(dp[i]==LLONG_MAX){
        dp[i]=-1;
      }
      cout<<dp[i]<<" ";
    }
    cout<<"\n";
  }
  return 0;
}
Tester-2's solution(C++ set)
#include <bits/stdc++.h>
using namespace std;
#define ll long long
ll max(ll l, ll r){ if(l > r) return l ; return r;}
ll min(ll l, ll r){ if(l < r) return l ; return r;}

 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/
 
const int MAX_T = 100000;
const int MAX_N = 100000;
const int MAX_M = 1000;
const ll MAX_val = 1e15;
const int MAX_SUM_N = 100000;
 
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define pll pair<ll , ll>
 
int sum_n = 0;
int max_n = 0 ;
ll z = 1000000007;
ll inf = 1e18 ;

ll find(ll f[], ll u)
{
    if(f[u] == u)
        return u ;
    return f[u] = find(f , f[u]) ;
}

void merge(ll f[], ll u, ll v)
{
    ll a = find(f, u) ;
    ll b = find(f, v) ;
    f[a] = b ;
    return ;
}

ll c = 1 ;

void get_ans(multiset<ll> &v1, multiset<ll> &v2, ll &curr_ans)
{
    
    for(auto itr = v2.begin() ; itr != v2.end() ; itr++)
    {
        ll val = *itr ;
        
        auto itr2 = v1.lower_bound(val) ;
        
        if(itr2 != v1.end())
        {
            ll val2 = *itr2 ;
            curr_ans = min(curr_ans , val ^ val2) ;
        }

        if(itr2 != v1.begin())
        {
            itr2 -- ;
            ll val2 = *itr2 ;
            curr_ans = min(curr_ans , val ^ val2) ;
        }
    }
    

    for(auto itr = v2.begin() ; itr != v2.end() ; itr++)
    {
        v1.insert(*itr) ;
    }

    return ;
}


void dfs(vector<ll> adj[], ll arr[], vector<ll>&ans, multiset<ll> values[], ll u, ll p)
{
    // cout << "u = " << u << endl ;

    ll max_size = 0 ;
    ll cnt = 0 ;
    for(int i = 0 ; i < adj[u].size() ; i++)
    {
        ll v = adj[u][i] ;
        if(v == p)
            continue ;

        dfs(adj , arr , ans, values, v, u) ;
        ans[u] = min(ans[u], ans[v]) ;
        max_size = max(max_size , values[v].size()) ;
        cnt++ ;
    }

    if(ans[u] == 0)
        return ;

    for(int i = 0 ; i < adj[u].size() ; i++)
    {
        ll v = adj[u][i] ;
        if(v == p)
            continue ;

        if(values[v].size() == max_size)
        {
            values[v].swap(values[u]) ;
            values[v].insert(arr[u]) ;
            break ;
        }
    }
    for(int i = 0 ; i < adj[u].size() ; i++)
    {
        ll v1 = adj[u][i] ;
        if(v1 == p)
            continue ;

        get_ans(values[u], values[v1], ans[u]) ;
    }

    if(cnt == 0)
        values[u].insert(arr[u]) ;

    return ;
}



void solve()
{   
    int n = readIntLn(1, MAX_N) ;
    sum_n += n ;
    max_n = max(n , max_n) ;
    assert(sum_n <= MAX_SUM_N) ;

    vector<ll> adj[n] ;
    ll f[n] ;
    for(int i = 0 ; i < n ; i++)
        f[i] = i ;

    for(int i = 0 ; i < n-1 ; i++)
    {
        ll u, v ;
        u = readIntSp(1 , n) ;
        v = readIntLn(1 , n) ;
        u-- ; v-- ;

        assert(find(f , u) != find(f , v)) ;
        merge(f , u , v) ;
        adj[u].push_back(v) ;
        adj[v].push_back(u) ;
    }

    ll arr[n] ;
    for(int i = 0 ; i < n-1 ; i++)
    {
        arr[i] = readIntSp(0, MAX_val) ;
    }
    arr[n-1] = readIntLn(0 , MAX_val) ;

    /*************** Input verified ***************/

    vector<ll> ans(n , inf) ;
    multiset<ll> values[n] ;

    // cout << "starting dfs" << endl ;

    dfs(adj, arr, ans, values, 0, -1) ;
    for(int i = 0 ; i < n ; i++)
    {
        if(ans[i] == inf)
            ans[i] = -1 ;
        cout << ans[i] << ' ';
    }
    cout << endl ;
    return ;

}
 
signed main()
{
    //fast;
    #ifndef ONLINE_JUDGE
    freopen("inputf.txt" , "r" , stdin) ;
    freopen("outputf.txt" , "w" , stdout) ;
    freopen("error.txt" , "w" , stderr) ;
    #endif
    
    int t = 1;
    
    t = readIntLn(1,MAX_T);

    for(int i=1;i<=t;i++)
    {    
        solve() ;
    }
    
    assert(getchar() == -1);
 
    cerr<<"SUCCESS\n";
    cerr<<"Tests : " << t << '\n';
    cerr<<"Sum of lengths : " << sum_n << '\n';
    cerr<<"Maximum length : " << max_n << '\n';
    // cerr << "Sum o f product : " << sum_nk << '\n' ;
    // cerr<<"Total operations : " << total_ops << '\n';
    // cerr<<"Answered yes : " << yess << '\n';
    // cerr<<"Answered no : " << nos << '\n';
}

Editorialist's Solution(Trie)
#include "bits/stdc++.h"
using namespace std;
#define ll long long
#define pb push_back
#define all(_obj) _obj.begin(), _obj.end()
#define F first
#define S second
#define pll pair<ll, ll>
#define vll vector<ll>
ll INF = 1e16;
const int N = 1e5 + 11, mod = 1e9 + 7;
ll max(ll a, ll b) { return ((a > b) ? a : b); }
ll min(ll a, ll b) { return ((a > b) ? b : a); }
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
ll sz[N], ans[N], value[N];
vll adj[N], V[N];
class Node
{
public:
    Node *one;
    Node *zero;
};
class trie
{
    Node *root;

public:
    trie() { root = new Node(); }
    void insert(ll n)
    {
        Node *temp = root;
        for (int i = 50; i >= 0; i--)
        {
            int bit = (((ll)n)>> i) & 1;
            if (bit == 0)
            {
                if (temp->zero == NULL)
                {
                    temp->zero = new Node();
                }
                temp = temp->zero;
            }
            else
            {
                if (temp->one == NULL)
                {
                    temp->one = new Node();
                }
                temp = temp->one;
            }
        }
    }
    ll min_xor_helper(ll value)
    {
        Node *temp = root;
        ll current_ans = 0;
        for (int i = 50; i >= 0; i--)
        {
            int bit = (((ll)value) >> i) & 1;
            if (bit == 1)
            {
                if (temp->one)
                {
                    temp = temp->one;
                }
                else
                {
                    temp = temp->zero;
                    current_ans += (1LL << i);
                }
            }
            else
            {
                if (temp->zero)
                {
                    temp = temp->zero;
                }
                else
                {
                    temp = temp->one;
                    current_ans += (1LL << i);
                }
            }
        }
        return current_ans;
    }
};
vector<trie> vec;
void dfs_size(int v, int p)
{
    sz[v] = 1;

    for (auto u : adj[v])
    {
        if (u != p)
        {
            dfs_size(u, v);
            sz[v] += sz[u];
        }
    }
}
void dfs(int v, int p)
{
    int Max = -1, bigchild = -1;
    for (auto u : adj[v])
    {
        if (u != p && Max < sz[u])
        {
            Max = sz[u];
            bigchild = u;
        }
    }
    for (auto u : adj[v])
    {
        if (u != p && u != bigchild)
        {
            dfs(u, v);
        }
    }
    if (bigchild != -1)
    {
        dfs(bigchild, v);
        swap(vec[v], vec[bigchild]);
        swap(V[v], V[bigchild]);
        ans[v] = ans[bigchild];
    }
    V[v].pb(v);
    if (bigchild != -1)
        ans[v] = min(ans[v], vec[v].min_xor_helper(value[v]));
    vec[v].insert(value[v]);
    for (auto u : adj[v])
    {
        if (u != p && u != bigchild)
        {
            for (auto x : V[u])
            {
                ans[v] = min(ans[v], vec[v].min_xor_helper(value[x]));
                vec[v].insert(value[x]);
                V[v].pb(x);
            }
        }
    }
}
void sol(void)
{
    int n;
    cin >> n;
    for (int i = 1; i <= n; i++)
        ans[i] = 1e18, adj[i].clear(), V[i].clear();
    vec.clear();
    vec.resize(n + 1);
    for (int i = 1; i <= n - 1; i++)
    {
        int a, b;
        cin >> a >> b;
        adj[a].pb(b);
        adj[b].pb(a);
    }
    for (int i = 1; i <= n; i++)
        cin >> value[i];
    dfs_size(1, -1);
    dfs(1, -1);
    for (int i = 1; i <= n; i++)
        cout << (ans[i] <= (INF) ? ans[i] : -1) << ' ';
    cout << '\n';
    return;
}
int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);
    int test = 1;
    cin >> test;
    while (test--)
        sol();
}


1 Like

When i click on problem link it saying that contest is lock and open from contest page and when i am opening contest page there also no problem are visible.

1 Like

If we don’t keep the track of the size of subtrees of childs, and just merge all the childs as they are coming in the dfs, how will it affect the Time Complexity?

I was getting TLE by doing this. Code: CodeChef: Practical coding for everyone

1 Like

I solved this problem with Mos Algorithm and Euler’s Tour technique.
Basically, I do an Euler’s tour and the answer of node i is the minimum xor pair in the range intime[i] and outtime[i].

So we have n queries and the I solved the n-queries using Mos Algorithm.
Insertion and Deletion were similar and was done using multiset and map.
I think my complexity is O(Nroot(N)*log(N))

It was failing in tc 18 and 19 (TLE) initially even after several optimisations.
Finally I optimised the way of sorting the queries for Mos Algorithm using
Hilbert Order Technique and It got accepted. Link to Code. It’s an ideone link since Codechef seems to be down now.
You can view my attempts to this problem by filtering the problems submission by user name ‘sumit_kp’.

3 Likes

If you don’t keep track of sizes, then the complexity blows up to O(n^2). Try reading about Small to Large Merging. You can also try this problem CSES: Distinct Colours.

2 Likes

Oh…ok…
Thanks mate

How to find just smaller value using lower_bound of set?

Use the lower_bound function to find the just greater or equal value to X, then just decrease the iterator by one to find the just smaller value if it exists.

1 Like

Thanks @devendra7700