# Solving CLBIT the hard way

## Note

This is not the intended solution. The actual solution doesn’t require Heavy Light Decomposition. This is an alternate solution which was acceptable during the contest.

Author and Editorialist: Sarthak Manna

Medium-Hard

### Prerequisites

Heavy-light decomposition, Segment Tree, Trie, Implementation

### Problem

Given a tree rooted at node 1 whose every node has a value associated with it, implement a technique to support the following types of operations:

• Type 1: Given two nodes a and b and a value x, apply bitwise OR operation on the values of the nodes in the simple path from a to b.
• Type 2: Given a node u and a value y, find the maximum value which can be obtained by applying bitwise XOR operation on any of the values of a node which lies in the subtree of u.

### Quick Explanation

Flatten the tree using Heavy Light Decomposition. After this, all the updates and queries has to be done on contiguous ranges.

It is important to note that the overall number of bit flips due to the update operation can never exceed O(N * 17) [17 = Bit length of the values] because bitwise OR operation can only set a bit if it wasn’t set already. This observation is important because for every update, we can now iterate over only those values which will get affected.

Finding the maximum possible XOR value is a rather standard problem and can be solved by maintaining tries. You can study and practice the technique at link1 and link2

### Explanation

Flatten the tree using Heavy Light Decomposition (henceforth, referred as HLD). After flattening, every subtree can be represented as a single contiguous range (link) and every path can be represented by atmost O(lg N) contiguous ranges. Build a segment tree on the flattened tree.

Going to the easy part first. Let’s handle the query operations. Every query operation requires us to find the answer in a contiguous range. To solve this, maintain a trie at each node of the segment tree. The trie should contain all the values present in the covered range (simply think of a merge sort tree. Instead of the linear arrays at each node, maintain tries to store the values). On average, queries on segment trees involves O(lg N) nodes and each node (trie) will take O(17) time to compute the maximum XOR value. Therefore, the time complexity of each query operation is O(lg N * 17).

Next, comes the update part. As proved earlier, we only need to find the indexes (or tree nodes) which will be affected by the particular update operation. We can then naively iterate over all those indexes, and update them in the required tries. To find the affected values, we can maintain 17 separate TreeSets (set in C++). The i-th TreeSet will store the indexes whose i-th bit is not set. Eg. if the value at index 7 is 1101 (in binary), then the TreeSets at indexes (2, 5, 6, 7, 8, … 17) will store the index 7. Now suppose, you want to apply bitwise OR operation in the range [l, r] with the value v. To do that, iterate over all the set bit positions of v and lookup for the indices in the range [l, r] at the corresponding TreeSet. There will only be O(N * 17) updates overall (as proved earlier). Every update requires us to change the value in O(lg N) tries (point update on the segment tree) and each update on trie takes O(17) time. Therefore, the overall complexity will be O(N * 17 * lg N * 17).

### Deeper Analysis

Consider a case where every value is initially 0. Let’s analyze the time complexity for each update operation with the value v = 2^k - 1 (k = Arbitary value, changes with every query).

• There are O(lg N) contiguous segments [l_i, r_i] since this is a path operation on HLD.
• Iterate over the set bits of v. This step takes O(k) time.
• Query the TreeSet at the set bit index to find a value in the range [l, r]. This step takes at least O(lg N) even though there’s no element in the range (lookup time for TreeSet).

Therefore, in this case, the total time complexity turns out to be O(Q * lg N * k * lg N) which is not enough fast to pass the TL.

To prevent this from happening we can iterate over only those set bits of v which guarantees at least 1 change, ie, at least one index lies in the range [l, r] of the corresponding TreeSet. This optimisation will reduce O(lg N) time (TreeSet lookup time) from the aforementioned complexity. The setter and tester solutions have maintained another Segment Tree to point out the redundant indexes. There are other possible methods too. Kindly, refer to the solution(s) for implementation details.

### Solutions

C++ Solution :

``````#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define pii pair<int, int>
#define F first
#define S second
#define all(c) ((c).begin()), ((c).end())
#define sz(x) ((int)(x).size())
#define ld long double

template<class T,class U>
ostream& operator<<(ostream& os,const pair<T,U>& p){
os<<"("<<p.first<<", "<<p.second<<")";
return os;
}

template<class T>
ostream& operator <<(ostream& os,const vector<T>& v){
os<<"{";
for(int i = 0;i < (int)v.size(); i++){
if(i)os<<", ";
os<<v[i];
}
os<<"}";
return os;
}

#ifdef LOCAL
#define cerr cout
#else
#endif

#define TRACE

#ifdef TRACE
#define trace(...) __f(#__VA_ARGS__, __VA_ARGS__)
template <typename Arg1>
void __f(const char* name, Arg1&& arg1){
cerr << name << " : " << arg1 << std::endl;
}
template <typename Arg1, typename... Args>
void __f(const char* names, Arg1&& arg1, Args&&... args){
const char* comma = strchr(names + 1, ',');cerr.write(names, comma - names) << " : " << arg1<<" | ";__f(comma+1, args...);
}
#else
#define trace(...)
#endif

long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);

assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
if(!(l<=x && x<=r))cerr<<l<<"<="<<x<<"<="<<r<<endl;
assert(l<=x && x<=r);
return x;
} else {
assert(false);
}
}
}

string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
}
long long readIntLn(long long l,long long r){
}
}
}

const int N = 1 << 15, logN = 15;
int size[N], nxt[N], in[N], rin[N], out[N], timer;
int par[logN][N], depth[N];
vector<int> g[N];

void dfs_sz(int v = 1, int p = 0){
depth[v] = depth[p] + 1;
par[0][v] = p;
size[v] = 1;
nxt[v] = v;
for(auto &u: g[v]){
if(u == p) continue;
dfs_sz(u, v);
size[v] += size[u];
if(size[u] >= size[g[v][0]])
swap(u, g[v][0]);
}
}

void dfs_hld(int v = 1, int p = 0){
in[v] = ++timer;
rin[in[v]] = v;
for(auto u: g[v]){
if(u == p) continue;
nxt[u] = (u == g[v][0] ? nxt[v] : u);
dfs_hld(u, v);
}
out[v] = timer;
}

int lca(int a, int b){
if(depth[a]<depth[b])
swap(a,b);
int l = depth[a]-depth[b];
for(int i = 0;i<logN;i++) if(l&(1<<i)) a = par[i][a];
if(a==b) return a;
assert(depth[a] == depth[b]);
for(int i = logN-1;i>=0;i--)
if(par[i][a]!=par[i][b])
a = par[i][a],b = par[i][b];
return par[0][a];
}

const int LN = 17;
const int SN = 3e7;
int cur;
int lft[SN];
int rgt[SN];
int val[SN];

void clr(int n){
for(int i = 0; i <= n; i++){
size[i] = nxt[i] = in[i] = rin[i] = out[i] = timer = 0;
g[i].clear();
}
}

struct trie{
int root;
trie(){root = ++cur;}
void insert(int num){
int node = root;
val[node]++;
for(int i = LN - 1 ; i >= 0 ; --i){
if(num & (1 << i)){
if(!rgt[node]){
rgt[node] = ++cur;
}
node = rgt[node];
}
else{
if(!lft[node]){
lft[node] = ++cur;
}
node = lft[node];
}
val[node]++;
}
}
void remove(int num){
int node = root;
val[node]--;
for(int i = LN - 1 ; i >= 0 ; --i){
if(num & (1 << i)){
assert(rgt[node]);
node = rgt[node];
}
else{
assert(lft[node]);
node = lft[node];
}
val[node]--;
}
}
int query(int num){ // maximum xor
int node = root;
if(!val[node]) return 0;
int res = 0;
for(int i = LN - 1 ; i >= 0 ; --i){
if(num & (1 << i)){
if(val[lft[node]]){
res += 1 << i;
node = lft[node];
}
else{
node = rgt[node];
}
}
else{
if(val[rgt[node]]){
res += 1 << i;
node = rgt[node];
}
else{
node = lft[node];
}
}
}
return res;
}
};

// 0-indexed
struct segtree{
int n;
vector<trie> t;
vector<int> t2;
vector<int> curr;
segtree(int n) : n(n), curr(n + 1), t(4 * n + 10), t2(4 * n + 10, (1 << LN) - 1){
}
void update(int i, int v, int orig, int s, int e, int ind){
if(i > e || i < s) return;
if(!orig){
t[ind].remove(curr[i]);
}
t[ind].insert(v);
if(s == e){
t2[ind] = ((1 << LN) - 1) ^ v;
return;
}
int mid = (s + e) >> 1;
update(i, v, orig, s, mid, ind << 1);
update(i, v, orig, mid + 1, e, ind << 1 | 1);
t2[ind] = t2[ind << 1] | t2[ind << 1 ^ 1];
}
void update(int i, int v, int orig){
assert(i >= 1 && i <= n);
update(i, v, orig, 1, n, 1);
curr[i] = v;
}
int get(int l, int r, int x, int s, int e, int ind){
if(l > e || s > r) return 0;
if(s >= l && e <= r) return t[ind].query(x);
int mid = (s + e) >> 1;
return max(get(l, r, x, s, mid, ind << 1), get(l, r, x, mid + 1, e, ind << 1 | 1));
}
int get_missing_bits(int l, int r, int s, int e, int ind){
if(l > e || s > r) return 0;
if(s >= l && e <= r) return t2[ind];
int mid = (s + e) >> 1;
return get_missing_bits(l, r, s, mid, ind << 1) | get_missing_bits(l, r, mid + 1, e, ind << 1 | 1);
}

int get_missing_bits(int l, int r){
return get_missing_bits(l, r, 1, n, 1);
}

int get(int l, int r, int x){
return get(l, r, x, 1, n, 1);
}
};

int main(){
int sn = 0, sq = 0;
while(t--){
sn += n; sq += q;
vector<int> a(n + 1);
for(int i = 1; i <= n; i++) a[i] = i == n ? readIntLn(0, (1 << LN) - 1) : readIntSp(0, (1 << LN) - 1);
for(int i = 1; i < n; i++){
assert(u != v);
g[u].push_back(v);
g[v].push_back(u);
}
dfs_sz();
dfs_hld();
for(int i = 1; i < logN; i++) for(int j = 1; j <= n; j++) par[i][j] = par[i - 1][par[i - 1][j]];
vector<int> b(n + 1);
segtree st(n);
vector<set<int>> remaining(LN);
for(int i = 1; i <= n; i++){
b[i] = a[rin[i]];
st.update(i, b[i], 1);
for(int j = 0; j < LN; j++) if(!(b[i] >> j & 1)) remaining[j].insert(i);
}
function<void(int, int, int)> update = [&](int l, int r, int x){
x &= st.get_missing_bits(l, r);
for(int i = 0; i < LN; i++) if(x >> i & 1){
auto it = remaining[i].lower_bound(l);
while(it != remaining[i].end() && *it <= r){
int pos = *it;
updates[pos] = (b[pos] = b[pos] | (1 << i));
remaining[i].erase(it++);
}
}
for(auto it : updates) st.update(it.F, it.S, 0);
};

function<void(int, int, int)> add = [&](int a, int l, int x){
while(depth[a] >= depth[l]){
int na = nxt[a];
int Q = in[a];
int P = depth[na] >= depth[l] + 1 ? in[na] : in[l];
update(P, Q, x);
a = par[0][na];
}
};

while(q--){
if(type == 1){
int x = readIntLn(0, (1 << LN) - 1);
int l = lca(u, v);
} else{
int x = readIntLn(0, (1 << LN) - 1);
printf("%d\n", st.get(in[k], out[k], x));
}
}
clr(n);
}
assert(sn <= 30000 && sq <= 500000);
}
``````

Java Solution :

``````import java.io.*;
import java.util.*;

public class Main implements Runnable {
@Override
public void run() {
try {
new Solver().solve();
System.exit(0);
} catch (Exception | Error e) {
e.printStackTrace();
System.exit(1);
}
}

public static void main(String[] args) throws Exception {
//new Thread(null, new Main(), "Solver", 1l << 25).start();
new Main().run();
}
}

class Solver {
final FastIO hp;

Solver() {
hp = new FastIO();
hp.initIO(System.in, System.out);
}

void solve() throws Exception {
int tc = TESTCASES ? hp.nextInt() : 1;
for (int tce = 1; tce <= tc; ++tce) solve(tce);
hp.flush();
}

boolean TESTCASES = true;

static final int BITLEN = 17;
int N, Q;
int[] A;
ArrayList<Integer>[] graph;

void solve(int tc) throws Exception {
int i, j, k;

N = hp.nextInt(); Q = hp.nextInt();
A = hp.getIntArray(N);

graph = new ArrayList[N];
for (i = 0; i < N; ++i) graph[i] = new ArrayList<>();
for (i = 1; i < N; ++i) {
int a = hp.nextInt() - 1, b = hp.nextInt() - 1;
}

HLD_LCA hld = new HLD_LCA(graph, N, 0, A);

for (i = 0; i < Q; ++i) {
int choice = hp.nextInt();

if (choice == 1) {
int u = hp.nextInt() - 1, v = hp.nextInt() - 1, o = hp.nextInt();
hld.pathUpdate(u, v, o);
} else if (choice == 2) {
int node = hp.nextInt() - 1, x = hp.nextInt();
hp.println(hld.subtreeQuery(node, x));
}
}
}
}

class FastIO {
static final int BUFSIZE = 1 << 20;
static byte[] buf;
static int index, total;
static InputStream in;
static BufferedWriter bw;

public void initIO(InputStream is, OutputStream os) {
try {
in = is;
bw = new BufferedWriter(new OutputStreamWriter(os));
buf = new byte[BUFSIZE];
} catch (Exception e) {
}
}

public void initIO(String inputFile, String outputFile) {
try {
in = new FileInputStream(inputFile);
bw = new BufferedWriter(new OutputStreamWriter(
new FileOutputStream(outputFile)));
buf = new byte[BUFSIZE];
} catch (Exception e) {
}
}

private int scan() throws Exception {
if (index >= total) {
index = 0;
if (total <= 0)
return -1;
}
return buf[index++];
}

public String next() throws Exception {
int c;
for (c = scan(); c <= 32; c = scan()) ;
StringBuilder sb = new StringBuilder();
for (; c > 32; c = scan())
sb.append((char) c);
return sb.toString();
}

public int nextInt() throws Exception {
int c, val = 0;
for (c = scan(); c <= 32; c = scan()) ;
boolean neg = c == '-';
if (c == '-' || c == '+')
c = scan();
for (; c >= '0' && c <= '9'; c = scan())
val = (val << 3) + (val << 1) + (c & 15);
return neg ? -val : val;
}

public long nextLong() throws Exception {
int c;
long val = 0;
for (c = scan(); c <= 32; c = scan()) ;
boolean neg = c == '-';
if (c == '-' || c == '+')
c = scan();
for (; c >= '0' && c <= '9'; c = scan())
val = (val << 3) + (val << 1) + (c & 15);
return neg ? -val : val;
}

public long[] getLongArray(int size) throws Exception {
long[] ar = new long[size];
for (int i = 0; i < size; ++i) ar[i] = nextLong();
return ar;
}

public int[] getIntArray(int size) throws Exception {
int[] ar = new int[size];
for (int i = 0; i < size; ++i) ar[i] = nextInt();
return ar;
}

public String[] getStringArray(int size) throws Exception {
String[] ar = new String[size];
for (int i = 0; i < size; ++i) ar[i] = next();
return ar;
}

public void print(Object a) throws Exception {
bw.write(a.toString());
}

public void printsp(Object a) throws Exception {
print(a);
print(" ");
}

public void println() throws Exception {
bw.write("\n");
}

public void println(Object a) throws Exception {
print(a);
println();
}

public void flush() throws Exception {
bw.flush();
}
}

class HLD_LCA {
static final int BITLEN = Solver.BITLEN;

ArrayList<Integer>[] graph;
int[] depth, parent, chCount, queue;
int N, root;
int[] weight;

SegmentTree st;
int[] treePos, linearTree, segRoot;
static SegmentTreeOR segTOR;

HLD_LCA(ArrayList<Integer>[] g, int n, int r, int[] wt) {
graph = g;
N = n;
root = r;
weight = wt;
iterativeDFS();

HLDify();
}

private void iterativeDFS() {
parent = new int[N];
depth = new int[N];
chCount = new int[N];
queue = new int[N];
Arrays.fill(chCount, 1);

int i, st = 0, end = 0;
parent[root] = -1;
depth[root] = 1;
queue[end++] = root;

while (st < end) {
int node = queue[st++], h = depth[node] + 1;
Iterator<Integer> itr = graph[node].iterator();
while (itr.hasNext()) {
int ch = itr.next();
if (depth[ch] > 0) continue;
depth[ch] = h;
parent[ch] = node;
queue[end++] = ch;
}
}
for (i = N - 1; i >= 0; --i)
if (queue[i] != root)
chCount[parent[queue[i]]] += chCount[queue[i]];
}

private void HLDify() {
int i, j, treeRoot = -7;

treePos = new int[N];
linearTree = new int[N];
segRoot = new int[N];

Stack<Integer> stack = new Stack<>();
stack.ensureCapacity(N << 1);
stack.push(root);
for (i = 0; !stack.isEmpty(); ++i) {
int node = stack.pop();
if (i == 0 || linearTree[i - 1] != parent[node])
treeRoot = node;
linearTree[i] = node;
treePos[node] = i;
segRoot[node] = treeRoot;

int bigChild = -7, bigChildPos = -7, lastPos = graph[node].size() - 1;
for (j = 0; j < graph[node].size(); ++j) {
int tempNode = graph[node].get(j);
if (tempNode == parent[node]) continue;
if (bigChild < 0 || chCount[bigChild] < chCount[tempNode]) {
bigChild = tempNode;
bigChildPos = j;
}
}
if (bigChildPos >= 0) {
int temp = graph[node].get(lastPos);
graph[node].set(lastPos, bigChild);
graph[node].set(bigChildPos, temp);
}

for (int itr : graph[node])
if (parent[node] != itr)
stack.push(itr);
}

int[] respectiveWeights = new int[N], complements = new int[N];
for (i = 0; i < N; ++i) {
respectiveWeights[i] = weight[linearTree[i]];
complements[i] = (1 << BITLEN) - 1 ^ weight[linearTree[i]];
}

segTOR = new SegmentTreeOR(complements);
st = new SegmentTree(respectiveWeights);
}

void pathUpdate(int node1, int node2, int value) {
while (segRoot[node1] != segRoot[node2]) {
if (depth[segRoot[node1]] > depth[segRoot[node2]]) {
node1 ^= node2;
node2 ^= node1;
node1 ^= node2;
}
st.rangeUpdate(treePos[segRoot[node2]], treePos[node2], value);
node2 = parent[segRoot[node2]];
}
if (treePos[node1] > treePos[node2]) {
node1 ^= node2;
node2 ^= node1;
node1 ^= node2;
}
st.rangeUpdate(treePos[node1], treePos[node2], value);     // ...treePos[node1] + 1... for Edge Update
}

int subtreeQuery(int node, int key) {
int pos = treePos[node];
return st.rangeQuery(pos, pos + chCount[node] - 1, key); // ...(pos + 1,... for Edge Query
}
}

class Trie {
private final static int BITLEN = Solver.BITLEN;
private TrieNode root = new TrieNode();

TrieNode curr = root;

for (int i = BITLEN - 1; i >= 0; --i) {
if (((val >> i) & 1) == 0) {
if (curr.left == null) curr.left = new TrieNode();
curr = curr.left;
} else {
if (curr.right == null) curr.right = new TrieNode();
curr = curr.right;
}
++curr.count;
}
}

void removeValue(int val) {
TrieNode curr = root;

for (int i = BITLEN - 1; i >= 0 && curr != null; --i) {
if (((val >> i) & 1) == 0) {
if (--curr.left.count <= 0) curr.left = null;
curr = curr.left;
} else {
if (--curr.right.count <= 0) curr.right = null;
curr = curr.right;
}
}
}

int findMaxXOR(int val) {
int maxXor = 0;
TrieNode curr = root;

for (int i = BITLEN - 1; i >= 0; --i) {
if (((val >> i) & 1) == 0) {
if (curr.right != null) {
curr = curr.right;
maxXor |= 1 << i;
} else {
curr = curr.left;
}
} else {
if (curr.left != null) {
curr = curr.left;
maxXor |= 1 << i;
} else {
curr = curr.right;
}
}
}

return maxXor;
}
}

class TrieNode {
int count;
TrieNode left, right;
}

class SegmentTree {
private final static int BITLEN = Solver.BITLEN;
private SegmentTreeOR segTOR = HLD_LCA.segTOR;
private final int[] A;
private int N;
private Trie[] tree;
private BitSet[] indexes;

public SegmentTree(int[] ar) {
A = ar;

indexes = new BitSet[BITLEN];
for (int i = 0; i < BITLEN; ++i) {
indexes[i] = new BitSet();
}

N = 1; while (N < ar.length) N <<= 1;
tree = new Trie[N << 1];
for (int i = 1; i < tree.length; ++i) tree[i] = new Trie();
for (int i = 0; i < ar.length; ++i) addValueAtIndex(i, ar[i]);
}

public void addValueAtIndex(int idx, int val) {
for (int i = 0; i < BITLEN; ++i) if (((val >> i) & 1) == 0) {
indexes[i].set(idx);
}

idx += N;
while (idx > 0) {
idx >>= 1;
}
}

public void removeValueAtIndex(int idx, int val) {
for (int i = 0; i < BITLEN; ++i) if (((val >> i) & 1) == 0) {
indexes[i].clear(idx);
}

idx += N;
while (idx > 0) {
tree[idx].removeValue(val);
idx >>= 1;
}
}

public void rangeUpdate(int l, int r, int orVal) {
int rangeOr = segTOR.rangeQuery(l, r);
int reqBits = rangeOr & orVal;
for (int i = 0; i < BITLEN; ++i) if (((reqBits >> i) & 1) > 0) {
while (true) {
int next = indexes[i].nextSetBit(l);
if (next < 0 || next > r) break;

removeValueAtIndex(next, A[next]);
A[next] |= orVal;

segTOR.pointUpdate(next, (1 << BITLEN) - 1 ^ A[next]);
}
}
}

public int rangeQuery(int l, int r, int xorWith) {
return rangeQuery(1, 0, N - 1, l, r, xorWith);
}

private int rangeQuery(int idx, int l, int r, int ql, int qr, int xorWith) {
if (l > qr || r < ql) {
return -7;
} else if (l >= ql && r <= qr) {
return tree[idx].findMaxXOR(xorWith);
} else {
int c1 = idx << 1, c2 = c1 | 1, mid = l + r >> 1;
return Math.max(rangeQuery(c1, l, mid, ql, qr, xorWith),
rangeQuery(c2, mid + 1, r, ql, qr, xorWith));
}
}
}
class SegmentTreeOR {
private int N;
private int[] tree;

public SegmentTreeOR(int[] ar) {
N = 1;
while (N < ar.length) N <<= 1;
tree = new int[N * 2 - 1];

for (int i = 0; i < ar.length; ++i) tree[i + N - 1] = ar[i];
for (int i = N - 2; i >= 0; --i) tree[i] = tree[i * 2 + 1] | tree[i * 2 + 2];
}

public void pointUpdate(int i, int val) {
i += N - 1;
tree[i] = val;
i = i - 1 >> 1;

while (i >= 0) {
tree[i] = tree[i * 2 + 1] | tree[i * 2 + 2];
i = i - 1 >> 1;
}
}

public int rangeQuery(int l, int r) {
return query(0, 0, N - 1, l, r);
}

private int query(int i, int l, int r, int ql, int qr) {
int mid = l + r >> 1, i2 = i << 1;
if (l > qr || r < ql) return 0;
else if (l >= ql && r <= qr) return tree[i];
else {
return query(i2 + 1, l, mid, ql, qr) | query(i2 + 2, mid + 1, r, ql, qr);
}
}
}``````
1 Like