Problem Link
Author and Editorialist: Sarthak Manna
Tester: Jatin Yadav, Jatin Nagpal, Avijit Agarwal
Difficulty
Medium-Hard
Prerequisites
Tree flattening using DFS, Segment Tree, Trie, Implementation
Problem
Given a tree rooted at node 1 whose every node has a value associated with it, implement a technique to support the following types of operations:
- Type 1: Given two nodes a and b and a value x, apply bitwise OR operation on the values of the nodes in the simple path from a to b.
- Type 2: Given a node u and a value y, find the maximum value which can be obtained by applying bitwise XOR operation on any of the values of a node which lies in the subtree of u.
Quick Explanation
Flatten the tree using a dfs traversal. After this, the subtree of any node can be represented by a contiguous range on the flattened tree.
It is important to note that the overall number of bit flips due to the update operation can never exceed O(N * 17) [17 = Bit length of the values] because bitwise OR operation can only set a bit if it wasnāt set already. This observation is important because for every update, we can now iterate over only those values which will get affected.
Finding the maximum possible XOR value is a rather standard problem and can be solved by maintaining tries. You can study and practice the technique at link1 and link2
Explanation
Flatten the tree using dfs traversal as explained above. After flattening, every subtree can be represented by a single contiguous range. Then, build a segment tree on the flattened tree.
Going to the easy part first. Letās handle the query operations. Every query operation requires us to find the answer in a contiguous range. To solve this, maintain a trie at each node of the segment tree. The trie should contain all the values present in the covered range (simply think of a merge sort tree. Instead of the linear arrays at each node, maintain tries to store the values). On average, queries on segment trees involves O(lg N) nodes and each node (trie) will take O(17) time to compute the maximum XOR value. Therefore, the time complexity of each query operation is O(lg N * 17).
Next, comes the update part. As proved earlier, we only need to find the indexes (or tree nodes) which will be affected by the particular update operation. We can then naively iterate over all those nodes, and update them in the required tries. To find the affected values, we can maintain 17 separate DSU (Disjoint set union) type data structures. The i-th index of the j-th DSU will point to the nearest ancestor of node i whose j-th bit is 0. As the values of the nodes get updated, update the DSU structure accordingly. It can be updated efficiently using path compression technique. Refer to the solution(s) given below for implementation details. Coming to the time complexity, for each update, there are O(lg N) tries to modify (point update complexity in Segment Tree), each trie takes O(17) time to update (Height of each trie). There are O(N * 17) sucessful update operations overall. Therefore, the overall time complexity comes to (N * 17 * lg N * 17).
Time Complexity
The overall time complexity of the solution is intended to be O(Q * lg N * 17 + N * 17 * lg N * 17).
Solutions
C++ Solution :
#include <bits/stdc++.h>
using namespace std;
#define all(c) ((c).begin()), ((c).end())
const int N = 1 << 15, logN = 15;
int in[N], rin[N], out[N], timer;
int par[logN][N], depth[N];
vector<int> g[N];
struct dsu{
int n;
vector<int> par;
dsu(){}
dsu(int n) : n(n), par(n + 1){
iota(all(par), 0);
}
int root(int x){
return x == par[x] ? x : (par[x] = root(par[x]));
}
bool merge(int x, int y){
x = root(x); y = root(y);
if(x == y) return false;
par[x] = y;
return true;
}
};
void dfs_sz(int v = 1, int p = 0){
depth[v] = depth[p] + 1;
par[0][v] = p;
in[v] = ++timer;
rin[in[v]] = v;
for(auto &u: g[v]){
if(u == p) continue;
dfs_sz(u, v);
}
out[v] = timer;
}
int lca(int a, int b){
if(depth[a]<depth[b])
swap(a,b);
int l = depth[a]-depth[b];
for(int i = 0;i<logN;i++) if(l&(1<<i)) a = par[i][a];
if(a==b) return a;
assert(depth[a] == depth[b]);
for(int i = logN-1;i>=0;i--)
if(par[i][a]!=par[i][b])
a = par[i][a],b = par[i][b];
return par[0][a];
}
const int LN = 17;
const int SN = 3e7;
int cur;
int lft[SN];
int rgt[SN];
int val[SN];
void clr(int n){
for(int i = 0; i <= n; i++){
in[i] = rin[i] = out[i] = timer = 0;
g[i].clear();
}
}
struct trie{
int root;
trie(){root = ++cur;}
void insert(int num){
int node = root;
val[node]++;
for(int i = LN - 1 ; i >= 0 ; --i){
if(num & (1 << i)){
if(!rgt[node]){
rgt[node] = ++cur;
}
node = rgt[node];
}
else{
if(!lft[node]){
lft[node] = ++cur;
}
node = lft[node];
}
val[node]++;
}
}
void remove(int num){
int node = root;
val[node]--;
for(int i = LN - 1 ; i >= 0 ; --i){
if(num & (1 << i)){
assert(rgt[node]);
node = rgt[node];
}
else{
assert(lft[node]);
node = lft[node];
}
val[node]--;
}
}
int query(int num){ // maximum xor
int node = root;
if(!val[node]) return 0;
int res = 0;
for(int i = LN - 1 ; i >= 0 ; --i){
if(num & (1 << i)){
if(val[lft[node]]){
res += 1 << i;
node = lft[node];
}
else{
node = rgt[node];
}
}
else{
if(val[rgt[node]]){
res += 1 << i;
node = rgt[node];
}
else{
node = lft[node];
}
}
}
return res;
}
};
// 0-indexed
struct segtree{
int n;
vector<trie> t;
vector<int> curr;
segtree(int n) : n(n), curr(n + 1), t(4 * n + 10){
}
void update(int i, int v, int orig, int s, int e, int ind){
if(i > e || i < s) return;
if(!orig){
t[ind].remove(curr[i]);
}
t[ind].insert(v);
if(s == e){
return;
}
int mid = (s + e) >> 1;
update(i, v, orig, s, mid, ind << 1);
update(i, v, orig, mid + 1, e, ind << 1 | 1);
}
void update(int i, int v, int orig){
assert(i >= 1 && i <= n);
update(i, v, orig, 1, n, 1);
curr[i] = v;
}
int get(int l, int r, int x, int s, int e, int ind){
if(l > e || s > r) return 0;
if(s >= l && e <= r) return t[ind].query(x);
int mid = (s + e) >> 1;
return max(get(l, r, x, s, mid, ind << 1), get(l, r, x, mid + 1, e, ind << 1 | 1));
}
int get(int l, int r, int x){
return get(l, r, x, 1, n, 1);
}
};
dsu D[LN];
int main(){
cin.tie(0); ios_base::sync_with_stdio(0);
int t; cin >> t;
while(t--){
int n, q; cin >> n >> q;
vector<int> a(n + 1);
for(int i = 0; i < LN; i++) D[i] = dsu(n);
for(int i = 1; i <= n; i++) cin >> a[i];
for(int i = 1; i < n; i++){
int u, v; cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs_sz();
for(int i = 1; i < logN; i++) for(int j = 1; j <= n; j++) par[i][j] = par[i - 1][par[i - 1][j]];
segtree st(n);
for(int i = 1; i <= n; i++){
st.update(i, a[rin[i]], 1);
for(int j = 0; j < LN; j++) if(a[rin[i]] >> j & 1){
if(i != 1) D[j].merge(rin[i], par[0][rin[i]]);
}
}
function<void(int, int, int)> update = [&](int u, int l, int x){
for(int i = 0; i < LN; i++) if(x >> i & 1){
int node = u;
while(depth[node] >= depth[l]){
node = D[i].root(node);
if(depth[node] < depth[l]) break;
if(a[node] >> i & 1){
break;
}
st.update(in[node], a[node] |= (1 << i), 0);
if(node != 1) D[i].merge(node, par[0][node]);
}
}
};
while(q--){
int type; cin >> type;
int u, v, k, x;
if(type == 1){
cin >> u >> v >> x;
int l = lca(u, v);
update(u, l, x);
update(v, l, x);
} else{
cin >> k >> x;
cout << st.get(in[k], out[k], x) << '\n';
}
}
clr(n);
}
}
Java Solution :
import java.io.*;
import java.util.*;
public class Main {
public static void main(String[] args) throws Exception {
new Solver().solve();
}
}
class Solver {
final FastIO hp = new FastIO();
void solve() throws Exception {
int tc = TESTCASES ? hp.nextInt() : 1;
for (int tce = 1; tce <= tc; ++tce) solve(tce);
hp.flush();
}
boolean TESTCASES = true;
final static int BITLEN = 17;
void solve(int tc) throws Exception {
int i, j, k;
int N = hp.nextInt(), Q = hp.nextInt();
A = hp.getIntArray(N);
ArrayList<Integer>[] graph = new ArrayList[N];
for (i = 0; i < N; ++i) graph[i] = new ArrayList<>();
for (i = 1; i < N; ++i) {
int a = hp.nextInt() - 1, b = hp.nextInt() - 1;
graph[a].add(b); graph[b].add(a);
}
TreeUtil util = new TreeUtil(graph, N, 0, A);
depth = util.depth;
parent = new int[BITLEN][];
for (i = 0; i < BITLEN; ++i) parent[i] = util.parent.clone();
bits = new boolean[BITLEN][N];
for (i = 0; i < N; ++i) setBits(i);
HashSet<Integer> set = new HashSet<>();
for (i = 0; i < Q; ++i) {
int choice = hp.nextInt();
if (choice == 1) {
int u = hp.nextInt() - 1, v = hp.nextInt() - 1, o = hp.nextInt();
int lcaDep = depth[util.getLCA(u, v)];
for (j = 0; j < BITLEN; ++j) if (((o >> j) & 1) > 0) {
int x = u, y = v;
x = getParent(x, j);
while (x >= 0 && depth[x] >= lcaDep) {
set.add(x);
x = getParent(parent[j][x], j);
}
y = getParent(y, j);
while (y >= 0 && depth[y] >= lcaDep) {
set.add(y);
y = getParent(parent[j][y], j);
}
}
for (int node : set) {
A[node] |= o;
setBits(node);
util.pointUpdate(node, o);
}
set.clear();
} else if (choice == 2) {
int node = hp.nextInt() - 1, x = hp.nextInt();
hp.println(util.subtreeQuery(node, x));
}
}
}
int[] A, depth;
int[][] parent;
boolean[][] bits;
int getParent(int node, final int bitPos) {
if (node < 0 || !bits[bitPos][node]) return node;
else return parent[bitPos][node] = getParent(parent[bitPos][node], bitPos);
}
void setBits(int node) {
for (int i = 0; i < BITLEN; ++i) if (((A[node] >> i) & 1) > 0) {
bits[i][node] = true;
}
}
}
class TreeUtil {
ArrayList<Integer>[] graph;
int[] depth, parent, chCount, queue;
int N, root;
int[] weight;
SegmentTree st;
int[] treePos, linearTree, segRoot;
TreeUtil(ArrayList<Integer>[] g, int n, int r, int[] wt) {
graph = g;
N = n;
root = r;
weight = wt;
iterativeDFS();
precompute();
}
private void iterativeDFS() {
parent = new int[N];
depth = new int[N];
chCount = new int[N];
queue = new int[N];
Arrays.fill(chCount, 1);
int i, st = 0, end = 0;
parent[root] = -1;
depth[root] = 1;
queue[end++] = root;
while (st < end) {
int node = queue[st++], h = depth[node] + 1;
Iterator<Integer> itr = graph[node].iterator();
while (itr.hasNext()) {
int ch = itr.next();
if (depth[ch] > 0) continue;
depth[ch] = h;
parent[ch] = node;
queue[end++] = ch;
}
}
for (i = N - 1; i >= 0; --i)
if (queue[i] != root)
chCount[parent[queue[i]]] += chCount[queue[i]];
}
private void precompute() {
int i, j, treeRoot = -7;
treePos = new int[N];
linearTree = new int[N];
segRoot = new int[N];
Stack<Integer> stack = new Stack<>();
stack.ensureCapacity(N << 1);
stack.push(root);
for (i = 0; !stack.isEmpty(); ++i) {
int node = stack.pop();
if (i == 0 || linearTree[i - 1] != parent[node])
treeRoot = node;
linearTree[i] = node;
treePos[node] = i;
segRoot[node] = treeRoot;
int bigChild = -7, bigChildPos = -7, lastPos = graph[node].size() - 1;
for (j = 0; j < graph[node].size(); ++j) {
int tempNode = graph[node].get(j);
if (tempNode == parent[node]) continue;
if (bigChild < 0 || chCount[bigChild] < chCount[tempNode]) {
bigChild = tempNode;
bigChildPos = j;
}
}
if (bigChildPos >= 0) {
int temp = graph[node].get(lastPos);
graph[node].set(lastPos, bigChild);
graph[node].set(bigChildPos, temp);
}
for (int itr : graph[node])
if (parent[node] != itr)
stack.push(itr);
}
int[] respectiveWeights = new int[N];
for (i = 0; i < N; ++i)
respectiveWeights[i] = weight[linearTree[i]];
st = new SegmentTree(respectiveWeights);
}
void pointUpdate(int node, int value) {
st.pointUpdate(treePos[node], value);
}
int subtreeQuery(int node, int key) {
int pos = treePos[node];
return st.rangeQuery(pos, pos + chCount[node] - 1, key);
}
int getLCA(int node1, int node2) {
while (segRoot[node1] != segRoot[node2]) {
if (depth[segRoot[node1]] > depth[segRoot[node2]]) {
node1 ^= node2;
node2 ^= node1;
node1 ^= node2;
}
node2 = parent[segRoot[node2]];
}
return (depth[node1] < depth[node2]) ? node1 : node2;
}
}
class SegmentTree {
private final int[] A;
private int N;
private Trie[] tree;
public SegmentTree(int[] ar) {
A = ar;
N = 1; while (N < ar.length) N <<= 1;
tree = new Trie[N << 1];
for (int i = 1; i < tree.length; ++i) tree[i] = new Trie();
for (int i = 0; i < ar.length; ++i) addValueAtIndex(i, ar[i]);
}
public void addValueAtIndex(int idx, int val) {
idx += N;
while (idx > 0) {
tree[idx].addValue(val);
idx >>= 1;
}
}
public void removeValueAtIndex(int idx, int val) {
idx += N;
while (idx > 0) {
tree[idx].removeValue(val);
idx >>= 1;
}
}
void pointUpdate(int idx, int orVal) {
removeValueAtIndex(idx, A[idx]);
A[idx] |= orVal;
addValueAtIndex(idx, A[idx]);
}
public int rangeQuery(int l, int r, int xorWith) {
return rangeQuery(1, 0, N - 1, l, r, xorWith);
}
private int rangeQuery(int idx, int l, int r, int ql, int qr, int xorWith) {
if (l > qr || r < ql) {
return -7;
} else if (l >= ql && r <= qr) {
return tree[idx].findMaxXOR(xorWith);
} else {
int c1 = idx << 1, c2 = c1 | 1, mid = l + r >> 1;
return Math.max(rangeQuery(c1, l, mid, ql, qr, xorWith),
rangeQuery(c2, mid + 1, r, ql, qr, xorWith));
}
}
}
class TrieNode {
int count;
TrieNode left, right;
}
class Trie {
private final static int BITLEN = Solver.BITLEN;
private TrieNode root = new TrieNode();
void addValue(int val) {
TrieNode curr = root;
for (int i = BITLEN - 1; i >= 0; --i) {
if (((val >> i) & 1) == 0) {
if (curr.left == null) curr.left = new TrieNode();
curr = curr.left;
} else {
if (curr.right == null) curr.right = new TrieNode();
curr = curr.right;
}
++curr.count;
}
}
void removeValue(int val) {
TrieNode curr = root;
for (int i = BITLEN - 1; i >= 0 && curr != null; --i) {
if (((val >> i) & 1) == 0) {
if (--curr.left.count <= 0) curr.left = null;
curr = curr.left;
} else {
if (--curr.right.count <= 0) curr.right = null;
curr = curr.right;
}
}
}
int findMaxXOR(int val) {
int maxXor = 0;
TrieNode curr = root;
for (int i = BITLEN - 1; i >= 0; --i) {
if (((val >> i) & 1) == 0) {
if (curr.right != null) {
curr = curr.right;
maxXor |= 1 << i;
} else {
curr = curr.left;
}
} else {
if (curr.left != null) {
curr = curr.left;
maxXor |= 1 << i;
} else {
curr = curr.right;
}
}
}
return maxXor;
}
}
class FastIO {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer("");
StringBuilder sb = new StringBuilder();
public String next() throws Exception {
while (!st.hasMoreTokens()) st = new StringTokenizer(br.readLine());
return st.nextToken();
}
public int nextInt() throws Exception {
return Integer.parseInt(next());
}
public void print(Object o) {
sb.append(o);
}
public void println() {
print("\n");
}
public void println(Object o) {
print(o);
println();
}
public void flush() {
System.out.print(sb);
sb = new StringBuilder();
}
int[] getIntArray(int size) throws Exception {
int[] ret = new int[size];
for (int i = 0; i < size; ++i) ret[i] = nextInt();
return ret;
}
}