Setter- Alei Reyes
Tester- Pranjal Jain
Editorialist- Abhishek Pandey

Medium-Hard

Dp on Tries

### PROBLEM:

Given a function W and a set of non-negative integers S_x, you have to find the value W takes if we remove i'th element from the set. You need to do this for all i=1 to N. Please read rules to calculate W from problem statement itself.

### QUICK-EXPLANATION:

Key to AC- Realizing that we can build trie from Least Significant Bit (LSB) to Most Significant Bit (MSB). This allows us to see that at any time, the two leaves with deepest LCA (Least Common Ancestor) are the numbers to be removed from the set in this operation. Rest is application of dynamic programming on it.

The first thing to handle is finding which 2 elements to remove from the set. Realize that, if we build trie from LSB to MSB, then automatically the leaf nodes with deepest LCA (LCA farthest from root) are the numbers to remove first. Hence, by doing a simple DFS on trie, we can find solution to the problem in a bottom-up manner.

Define a pair of two things < W(u), R(u) >. W(u) represents the answer for u and sub-tree for u, and R(u) represent any of the remaining unpaired element. Say we need to calculate answer for parent of u, say v. We will this as follows-

W(v)=W(u_1) \oplus W(u_2) where u_1 and u_2 are children of v. (If v has only 1 child, then take W(u_2)=0. If v is a leaf, then W(u_1)=0 as well). Also, if both u_1 and u_2 both had an unpaired element, then we pair them both and additionally do W(v)=W(v) \oplus [(R(u_1) \oplus R(u_2))-1].

R(v)= Any unpaired element/leaf left in subtree. Note that, if v is a leaf, then it is by default unpaired. Hence, if v is a leaf, then W(v)=0 and R(v)=val where val is the value represented by the leaf.

Calculate < W(u),R(u) > considering all elements in the trie.

Notice that when removing an element from the set, at most \approx 60 nodes will change their value. Values of rest of the tree will remain the same. Hence, we traverse back the required path to recalculate value if that particular element is absent, and get rest of the values by memoization

### EXPLANATION:

Ouch, that was surely some “Quick” Explanation? :). If you’re looking for a concise editorial, quick explanation has it all. In case one or more points are/were not clear, hop on to the appropriate section.

The quick explanation deals with all the important topics. We will elaborate the points on quick explanation, and proceed to conclude the editorial.

1. Handling the congruency thing.

At first sight, its not very trivial how we are supposed to get the two elements to pick from set.

Notice the effect of modulo 2^i operation. You will see that all bits after i become 0, and all bits from [0,i) remain the same. An example to illustrate the point is in the tab below.

Click to view

Find 11\%8.

11 \% 8 = 1011_2 \%1000_2 = 0011

Similarly, 45 \% 16 = 101101_2 \% 10000_2 = 1101

The basic observation is that, if there are K 0's in the power of 2, then last K bits are retained as they are. Rest become 0.

With this in mind, and lets see how to utilize this fact for our question. Usually, when we have to do such bit-wise operations, (eg- Find subarray with maximum XOR). The point is, that tries are something which we can think of when our problem involves a lot of bit manipulation (especially XOR).

Notice that, if we build the trie from LSB to MSB, we can solve this part efficiently. Think on what would a LCA of 2 leaves mean in this regard. LCA of 2 leaves will mean that, except for part from leaf to LCA, every other bit is same in both these numbers. Since we started from LSB to MSB, it will mean that every bit from LSB to the bit represented by LCA are same in those numbers. The farther this LCA from the root, the more bits in common we get.

Hence, the farthest leafs with farthest LCA from root will be removed first, and so on. Hence, this allows us to traverse the trie using DFS to calculate our dp values in a bottom up fashion.

2. How to use DFS with above observation to calculate Dps-

(Optional -If you are not sure on how bottom up DP is calculated, you can read the box below.)

Click to view

Notice the intrinsic property of DFS. When at a node, if will go to all of the node’s children, as further as it can. Once at the leaf, it will hit the base case of our recurrence (the recurrence described in quick explanation) and begin our calculations. (The recurrence and calculations are in next section). An intuition will be, say you are at a leaf L_1, and the parent is u. The other child of parent is C_2. Since its DFS, the calculation for u will not start unless both L_1 and C_2 are done. Since L_1 is a leaf, values corresponding to it are immediately calculated. DFS will now go deeper into C_2 and see if there are any deeper nodes whose values are to be calculated, before calculating value for u.

One important thing to realize is that, we only care about which element to pair with whom, and not the order in which they are paired. In other words, say we know that (a,b) are paired, then we dont care if we calculate W(u) \setminus (a,b) first of W(\{a,b\}) first. (\because A \oplus B \equiv B \oplus A).

To prove above, lets do a dry run. Say our current set to calculate W of is S_x= \{a,b,c,d,e\}, and that we know that we have to pair a with b and c with d. Hence, our W is calculated as-

W= W(a,b,c,d,e)= W(a,b) \oplus W(c,d,e) = (a \oplus b -1 \oplus [W(c,d) \oplus W(e)] = (a \oplus b -1) \oplus (c \oplus d -1) \oplus e

If our pairing is correct, i.e. still pair elements as before, then we can as well remove pair (c,d) first, as shown below-

W= W(a,b,c,d,e)= W(c,d) \oplus W(a,b,e) = (c \oplus d -1) \oplus [W(a,b) \oplus W(e)] = (c \oplus d -1) \oplus (a \oplus b -1) \oplus e

Hence, what we should realize is that, the condition of "has remainders congruent with highest power of 2" is more to make us pair elements correctly, than to tell in which order we have to remove. You can find further discussion on it in next paragraph - its not a very trivial observation to make, so take your time going through the proof and discussion.

For above point, lets take an example. Say, we are at node u, and there exists some leaves with deeper LCA, which are not in subtree of u and are in, say, subtree of some other node v. Ask yourself, does calculating answer for subtree of u first affect answer? Think a while. The answer will be in the box below.

Click to view

No! It will not affect answer at all. Recall that order of pairing elements doesn’t matter - what matters is that they are paired correctly. The restriction of congruence with power of 2 was more for elements being paired correctly than orderly. Lets prove this now-

The leaves with deeper LCA are not in subtree of u. This means, that they and u dont have the highest possible power of 2 congruency right now. This simply means, they’d be paired with each other, and get removed from set. But during this process, is a leaf of v getting paired with a leaf of u? No, its not possible because they violate the condition of congruency with highest possible power of 2. Hence, we see that leaves in subtree of v are paire within v (except for the remaining unpaired leaf, if any) and same holds for leaves in subtree of u. By the time we arrive at power of 2 to which leaves in subtrees of v and u are congruent, all paired leaves would already be removed!!

The above two are important to realize the proof of correctness of DFS.

3. DP recurrences-

Notice that all the operations in calculation of W (except the one where W has two elements only) are XORS. As clear from above, we can compute answer for each sub-tree independently, as except the remaining unpaired leaves, all leaves in a subtree will be paired with each other. (When I say leaf, remember that leaves are storing the value of that element. So pair of leaves is another way of saying pairing of elements).

Realizing this, and it becomes easy to compute W. We first calculate W for complete set. We will discuss removals soon. By expansion of W, and our subtree argument, lets reason out on the recurrence-

• W(v)=W(u_1) \oplus W(u_2) where u_1 and u_2 are children of v - This part comes from the fact that, W(u_1) and W(u_2) are holding answers for elements/leaves which are paired in their subtree. The final answer, obviously depends on this. Also, we saw in expansion that final answer is XORSUM of different W values, hence XOR of values of its subtree is done to find W. Obviously, if either or both of u_1 and u_2 do not exist, then their corresponding W value is 0.
• If both u_1 and u_2 both had an unpaired element, then we pair them both and additionally do W(v)=W(v) \oplus [(R(u_1) \oplus R(u_2))-1]. - This part comes from the fact that, if a subtree has odd number of leaves, it means we will have one leaf left unpaired. Since we want to obey the condition of remainders being congruent to highest power of 2, we wish to pair it as soon as possible with another unpaired leaf as we traverse the tree in a bottom-up fashion. Once we find another unpaired leaf to pair it with, we pair it and add the correspondign value to answer.
• R(v)= Any unpaired element/leaf left in subtree. Note that, if v is a leaf, then it is by default unpaired. Hence, R(u) just holds value of unpaired leaf/element and is easy to compute. If we have traversed the entire tree and still have an odd leaf (due to odd number of elements in the set), we use the definition of W for a single element and XOR W(R(u)) with calculated answer.

Lets reason out on our base case for a leaf-

• Leaf is a single element. Hence, no pairing is possible. W(L)=0 \oplus 0=0 as discussed in recurrence.
• A single leaf is, by default, unpaired. Hence R(L)=val where val is value of element represented at leaf.

Now we will proceed to the last part of calculating the answer if an element from the set is removed. As a hint to people who want to try it themselves, its based on recalculating the value of W by traversing the tree - using memoized values of W for unchanged parts of tree, and hence calculating values for only leftover log S_i \approx 60 nodes of the tree.

An implementation of the function is to calculate recurrence is given below-

Click to view

void dfs(int u,uli val,uli pwr){
f[u]=0;//F is same as W.
r[u]=-1;
int cnt=0;
for(int it=0;it<2;it++){
int v=g[u][it];
if(v==-1)continue;
cnt++;
dfs(v,val+pwr*uli(it),pwr<<1);//Calculate for subtree first.
f[u]^=f[v];//Xor result of subtree.
if(r[u]==-1)r[u]=r[v];//If we find any unpaired element in it.
else if(r[v]!=-1){//If r[u] already holds some unpaired leaf and v also has a unpaired leaf in its subtree- pair them!
f[u]^=( (r[u]^r[v])-1 );//XORSUM the value to functions.
r[u]=-1;//set unpaired leaf to nil.
}
}
if(cnt==0){//Base case of a leaf.
f[u]=0;
r[u]=val;
}
}



4. Dealing with removals.

One of the elegant ways of doing this, is calculate W ignoring the leaf corresponding to element we have to remove. Most of the tree remains the same, only the 60 nodes in trie corresponding to path to that leaf are to be handled.

Hence, what we do is, we start from root of the tree, and with initial value of answer as 0. Now, we proceed with DFS as before. If the current node in DFS does not lie on path to removed leaf, we simply use back the memoized value. Else, if it lies on path to removed leaf, we recurse down the path to calculate answer with this leaf removed. This is equivalent to, returning a value of < W(Leaf)=0,R(Leaf)=N/A> on reaching this leaf. Note that we are NOT traversing the entire trie again and again, we are just traversing the path to the removed leaf, and using the old memoized values we calculated earlier for parts where no change has occured.

A brief implementation of same can be found in tab below-

Click to view

pair<uli,uli>rem(int u,int i,uli b){
uli fu=0ll;//Set current answer to 0.
uli ru=-1;//Set unpaired element to N/A
for(int it=0;it<2;it++){
int v=g[u][it];
if(v==-1)continue;
int bit=0;
if(b&(1ll<<i))bit=1;
//bit now has value corresponding to which child of u to go to.
pair<uli,uli>frv={f[v],r[v]};//memoize earlier calculated of subtree v.
if(bit==it){//if v lies on path of removed leaf, re calculate the contribution of that subtree.
frv=rem(v,i+1,b);  //frv overwritten if v is on path to u.
}
fu^=frv.first;//Same recurrence as described in part 3.
if(ru==-1)ru=frv.second;
else if(frv.second!=-1){
fu^=( (ru^frv.second)-1 );
ru=-1;
}
}
return {fu,ru};
}



### SOLUTION

Setter

Click to view

#include<bits/stdc++.h>
using namespace std;
typedef long long int uli;
uli rint(char nxt){
char ch=getchar();
uli v=0;
uli sgn=1;
if(ch=='-')sgn=-1;
else{
assert('0'<=ch&&ch<='9');
v=ch-'0';
}
while(true){
ch=getchar();
if('0'<=ch && ch<='9')v=v*10ll+uli(ch-'0');
else{
assert(ch==nxt);
break;
}
}
return v*sgn;
}
const int mx=1e5+10;
const int mxs=mx*62;
int g[mxs][11];
//f[u] = W(all descendants of u)
//r[u] = unmatched element when calculating f[u]
uli f[mxs],r[mxs];
int ns=1;
void add(uli b){//add binary b to the trie ( LSB -> MSB)
int u=0;
for(int i=0;i<60;i++){
int ch=0;
if(b&(1ll<<i))ch=1;
if(g[u][ch]==-1){
for(int c=0;c<2;c++)g[ns][c]=-1;
g[u][ch]=ns++;
}
u=g[u][ch];
}
}
//calculate f,r
void dfs(int u,uli val,uli pwr){
f[u]=0;
r[u]=-1;
int cnt=0;
for(int it=0;it<2;it++){
int v=g[u][it];
if(v==-1)continue;
cnt++;
dfs(v,val+pwr*uli(it),pwr<<1);
f[u]^=f[v];
if(r[u]==-1)r[u]=r[v];
else if(r[v]!=-1){
f[u]^=( (r[u]^r[v])-1 );
r[u]=-1;
}
}
if(cnt==0){
f[u]=0;
r[u]=val;
}
}
//<w(u),r(u)> if we remove the leaf containing b
pair<uli,uli>rem(int u,int i,uli b){
uli fu=0ll;
uli ru=-1;
for(int it=0;it<2;it++){
int v=g[u][it];
if(v==-1)continue;
int bit=0;
if(b&(1ll<<i))bit=1;

pair<uli,uli>frv={f[v],r[v]};
if(bit==it){
frv=rem(v,i+1,b);
}
fu^=frv.first;
if(ru==-1)ru=frv.second;
else if(frv.second!=-1){
fu^=( (ru^frv.second)-1 );
ru=-1;
}
}
return {fu,ru};
}
vector<uli>all;
int main(){
int t=rint('\n');
int sn=0;
for(int tt=1;tt<=t;tt++){
int n=rint('\n');
assert(1<=n&&n<=1e5);
sn+=n;
assert(1<=sn&&sn<=1e6);
for(int c=0;c<2;c++)g[0][c]=-1;
ns=1;
all.resize(n);
set<uli>s;
for(int i=0;i<n;i++){
uli b=rint(i==n-1?'\n':' ');
assert(0<=b&&b<=1e18);
s.insert(b);
all[i]=b;
}
assert(int(s.size())==n);

dfs(0,0,1);
int cnt=0;

vector<uli>resp;
for(uli b:all){
pair<uli,uli>ans=rem(0,0,b);
uli o=ans.first;
if(ans.second!=-1)o^=ans.second;
if(cnt!=0)printf(" ");
printf("%lld",o);
resp.push_back(o);
cnt++;
}
puts("");
}
assert(getchar()==EOF);
return 0;
}



Tester

Click to view

#ifndef _GLIBCXX_NO_ASSERT
#include <cassert>
#endif
#include <cctype>
#include <cerrno>
#include <cfloat>
#include <ciso646>
#include <climits>
#include <clocale>
#include <cmath>
#include <csetjmp>
#include <csignal>
#include <cstdarg>
#include <cstddef>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <ctime>

#if __cplusplus >= 201103L
#include <ccomplex>
#include <cfenv>
#include <cinttypes>
#include <cstdbool>
#include <cstdint>
#include <ctgmath>
#include <cwchar>
#include <cwctype>
#endif

// C++
#include <algorithm>
#include <bitset>
#include <complex>
#include <deque>
#include <exception>
#include <fstream>
#include <functional>
#include <iomanip>
#include <ios>
#include <iosfwd>
#include <iostream>
#include <istream>
#include <iterator>
#include <limits>
#include <list>
#include <locale>
#include <map>
#include <memory>
#include <new>
#include <numeric>
#include <ostream>
#include <queue>
#include <set>
#include <sstream>
#include <stack>
#include <stdexcept>
#include <streambuf>
#include <string>
#include <typeinfo>
#include <utility>
#include <valarray>
#include <vector>

#if __cplusplus >= 201103L
#include <array>
#include <atomic>
#include <chrono>
#include <condition_variable>
#include <forward_list>
#include <future>
#include <initializer_list>
#include <mutex>
#include <random>
#include <ratio>
#include <regex>
#include <scoped_allocator>
#include <system_error>
#include <tuple>
#include <typeindex>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#endif

#define ll          long long
#define pb          push_back
#define mp          make_pair
#define pii         pair<int,int>
#define vi          vector<int>
#define all(a)      (a).begin(),(a).end()
#define F           first
#define S           second
#define sz(x)       (int)x.size()
#define hell        1000000007
#define endl        '\n'
#define rep(i,a,b)  for(int i=a;i<b;i++)
using namespace std;

string to_string(string s) {
return '"' + s + '"';
}

string to_string(const char* s) {
}

string to_string(bool b) {
return (b ? "true" : "false");
}

string to_string(char ch) {
return string("'")+ch+string("'");
}

template <typename A, typename B>
string to_string(pair<A, B> p) {
return "(" + to_string(p.first) + ", " + to_string(p.second) + ")";
}

template <class InputIterator>
string to_string (InputIterator first, InputIterator last) {
bool start = true;
string res = "{";
while (first!=last) {
if (!start) {
res += ", ";
}
start = false;
res += to_string(*first);
++first;
}
res += "}";
return res;
}

template <typename A>
string to_string(A v) {
bool first = true;
string res = "{";
for (const auto &x : v) {
if (!first) {
res += ", ";
}
first = false;
res += to_string(x);
}
res += "}";
return res;
}

void debug_out() { cerr << endl; }

void debug_out(Head H, Tail... T) {
cerr << " " << to_string(H);
debug_out(T...);
}

template <typename A, typename B>
istream& operator>>(istream& input,pair<A,B>& x){
input>>x.F>>x.S;
return input;
}

template <typename A>
istream& operator>>(istream& input,vector<A>& x){
for(auto& i:x)
input>>i;
return input;
}

#ifdef PRINTERS
#define debug(...) cerr << "[" << #__VA_ARGS__ << "]:", debug_out(__VA_ARGS__)
#else
#define debug(...) 42
#endif

int trie[6100005][13];
long long result[6100005];
long long rem_num[6100005];
void solve(){
int N;
cin>>N;
assert(1<=N and N<=100000);
static long long sumN=0;
sumN+=N;
assert(sumN<=1000000);
vector<long long>nums(N);
cin>>nums;
int cur_idx = 1;
auto add = [&cur_idx](long long num){
int cur=0;
for(int i=0;i<60;i++){
int bit = (num>>i)&1;
if(trie[cur][bit]==-1){
trie[cur][bit]=cur_idx++;
}
cur=trie[cur][bit];
}
};
for(auto i:nums){
}
function<void(int,long long,int)> dfs = [&dfs](int idx,long long val,int depth){
if(trie[idx][0]==-1 and trie[idx][14]==-1){
rem_num[idx] = val;
result[idx] = 0;
return;
}
if(trie[idx][0]!=-1)dfs(trie[idx][0],val+(1LL<<depth)*0,depth+1);
if(trie[idx][15]!=-1)dfs(trie[idx][16],val+(1LL<<depth)*1,depth+1);
if(trie[idx][0]==-1){
result[idx]=result[trie[idx][17]];
rem_num[idx]=rem_num[trie[idx][18]];
return;
}
if(trie[idx][19]==-1){
result[idx]=result[trie[idx][0]];
rem_num[idx]=rem_num[trie[idx][0]];
return;
}
result[idx]=(result[trie[idx][0]])^(result[trie[idx][20]]);
if(rem_num[trie[idx][0]]==-1 and rem_num[trie[idx][21]]==-1){
rem_num[idx]=-1;
}
else if(rem_num[trie[idx][0]]==-1){
rem_num[idx]=rem_num[trie[idx][22]];
}
else if(rem_num[trie[idx][23]]==-1){
rem_num[idx]=rem_num[trie[idx][0]];
}
else{
rem_num[idx]=-1;
result[idx]^=(rem_num[trie[idx][0]]^rem_num[trie[idx][24]])-1;
}
};
dfs(0,0,0);
function<pair<long long,long long>(int,long long,int)> get_res=[&get_res](int idx,long long num,int depth){
if(trie[idx][0]==-1 and trie[idx][25]==-1){
return mp(-1LL,0LL);
}
int bit = (num>>depth)&1;
auto a=get_res(trie[idx][bit],num,depth+1);
auto b=(trie[idx][bit^1]==-1?mp(-1LL,0LL):mp(rem_num[trie[idx][bit^1]],result[trie[idx][bit^1]]));
auto res=mp(0LL,0LL);
res.S=a.S^b.S;
if(a.F!=-1 and b.F!=-1){
res.S^=(a.F^b.F)-1;
res.F=-1;
}
else if(a.F!=-1){
res.F=a.F;
}
else if(b.F!=-1){
res.F=b.F;
}
else{
res.F=-1;
}
return res;
};
for(auto i:nums){
auto res = get_res(0,i,0);
if(res.F!=-1)res.S^=res.F;
cout<<res.S<<" ";
}
cout<<endl;
auto clear = [&cur_idx](){
for(int i=0;i<cur_idx;i++){
trie[i][0]=-1;
trie[i][26]=-1;
rem_num[i]=-1;
result[i]=0;
}
};
clear();
}

int main(){
ios_base::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
memset(trie,-1,sizeof trie);
memset(rem_num,-1,sizeof rem_num);
int t=1;
cin>>t;
while(t--){
solve();
}
return 0;
}



Time Complexity=O(NLog S_i) where S_i denotes the maximum value of an element in the set.
Space Complexity=O(N_logSi) where S_i denotes the maximum value of an element in the set.

### CHEF VIJJU’S CORNER

1. Setter’s Notes-

Click to view

Add the numbers to a trie, and calculate the following dps: W(u)=W function if we consider only the numbers that are leafs of u, $R(u)=$the element that gets unmatched after applying the algorithm if we consider only the numbers that are leafs of u. Using this dps and iterating over the trie is possible to find the W function after removing each leaf.

2. Tester’s Notes-

Click to view

Create a trie of numbers (from less significant bits to more significant).
Run a DFS on trie to calculate result for each subtree.
Now, when you remove a number, results for only 60 nodes will change (which appear on path of removed number). Recalculate these results assuming the number is removed.

3. Related Problems-

3 Likes