CLBIT Editorial

Problem Link

Practice
Contest

Author and Editorialist: Sarthak Manna
Tester: Jatin Yadav, Jatin Nagpal, Avijit Agarwal

Difficulty

Medium-Hard

Prerequisites

Tree flattening using DFS, Segment Tree, Trie, Implementation

Problem

Given a tree rooted at node 1 whose every node has a value associated with it, implement a technique to support the following types of operations:

  • Type 1: Given two nodes a and b and a value x, apply bitwise OR operation on the values of the nodes in the simple path from a to b.
  • Type 2: Given a node u and a value y, find the maximum value which can be obtained by applying bitwise XOR operation on any of the values of a node which lies in the subtree of u.

Quick Explanation

Flatten the tree using a dfs traversal. After this, the subtree of any node can be represented by a contiguous range on the flattened tree.

It is important to note that the overall number of bit flips due to the update operation can never exceed O(N * 17) [17 = Bit length of the values] because bitwise OR operation can only set a bit if it wasnā€™t set already. This observation is important because for every update, we can now iterate over only those values which will get affected.

Finding the maximum possible XOR value is a rather standard problem and can be solved by maintaining tries. You can study and practice the technique at link1 and link2

Explanation

Flatten the tree using dfs traversal as explained above. After flattening, every subtree can be represented by a single contiguous range. Then, build a segment tree on the flattened tree.

Going to the easy part first. Letā€™s handle the query operations. Every query operation requires us to find the answer in a contiguous range. To solve this, maintain a trie at each node of the segment tree. The trie should contain all the values present in the covered range (simply think of a merge sort tree. Instead of the linear arrays at each node, maintain tries to store the values). On average, queries on segment trees involves O(lg N) nodes and each node (trie) will take O(17) time to compute the maximum XOR value. Therefore, the time complexity of each query operation is O(lg N * 17).

Next, comes the update part. As proved earlier, we only need to find the indexes (or tree nodes) which will be affected by the particular update operation. We can then naively iterate over all those nodes, and update them in the required tries. To find the affected values, we can maintain 17 separate DSU (Disjoint set union) type data structures. The i-th index of the j-th DSU will point to the nearest ancestor of node i whose j-th bit is 0. As the values of the nodes get updated, update the DSU structure accordingly. It can be updated efficiently using path compression technique. Refer to the solution(s) given below for implementation details. Coming to the time complexity, for each update, there are O(lg N) tries to modify (point update complexity in Segment Tree), each trie takes O(17) time to update (Height of each trie). There are O(N * 17) sucessful update operations overall. Therefore, the overall time complexity comes to (N * 17 * lg N * 17).

Time Complexity

The overall time complexity of the solution is intended to be O(Q * lg N * 17 + N * 17 * lg N * 17).

Solutions

C++ Solution :

#include <bits/stdc++.h>
using namespace std;

#define all(c) ((c).begin()), ((c).end())

const int N = 1 << 15, logN = 15;
int in[N], rin[N], out[N], timer;
int par[logN][N], depth[N];
vector<int> g[N];

struct dsu{
	int n;
	vector<int> par;
	dsu(){}
	dsu(int n) : n(n), par(n + 1){
		iota(all(par), 0);
	}
	int root(int x){
		return x == par[x] ? x : (par[x] = root(par[x]));
	}
	bool merge(int x, int y){
		x = root(x); y = root(y);
		if(x == y) return false;
		par[x] = y;
		return true;
	}
};

void dfs_sz(int v = 1, int p = 0){
	depth[v] = depth[p] + 1;
	par[0][v] = p;
	in[v] = ++timer;
	rin[in[v]] = v;
	for(auto &u: g[v]){
	    if(u == p) continue;    
	    dfs_sz(u, v);
	}
	out[v] = timer;
}

int lca(int a, int b){
	if(depth[a]<depth[b])
	    swap(a,b);
	int l = depth[a]-depth[b];
	for(int i = 0;i<logN;i++) if(l&(1<<i)) a = par[i][a];    
	if(a==b) return a;
	assert(depth[a] == depth[b]);
	for(int i = logN-1;i>=0;i--)
	    if(par[i][a]!=par[i][b])
	        a = par[i][a],b = par[i][b];
	return par[0][a];
}

const int LN = 17;
const int SN = 3e7;
int cur;
int lft[SN];
int rgt[SN];
int val[SN];

void clr(int n){
	for(int i = 0; i <= n; i++){
		in[i] = rin[i] = out[i] = timer = 0;
		g[i].clear();
	}
}

struct trie{
	int root;
	trie(){root = ++cur;}
	void insert(int num){
	    int node = root;
		val[node]++;
	    for(int i = LN - 1 ; i >= 0 ; --i){
	        if(num & (1 << i)){
	            if(!rgt[node]){
	                rgt[node] = ++cur;
	            }
	            node = rgt[node];
	        }
	        else{
	            if(!lft[node]){
	                lft[node] = ++cur;
	            }
	            node = lft[node];
	        }
			val[node]++;
	    }
	}
	void remove(int num){
		int node = root;
		val[node]--;
		for(int i = LN - 1 ; i >= 0 ; --i){
	        if(num & (1 << i)){
				assert(rgt[node]);
	            node = rgt[node];
	        }
	        else{
				assert(lft[node]);
	            node = lft[node];
	        }
			val[node]--;
	    }
	}
	int query(int num){ // maximum xor
	    int node = root;
		if(!val[node]) return 0;
	    int res = 0;
	    for(int i = LN - 1 ; i >= 0 ; --i){
	        if(num & (1 << i)){
	            if(val[lft[node]]){
	                res += 1 << i;
	                node = lft[node];
	            }
	            else{
	                node = rgt[node];
	            }
	        }
	        else{
	            if(val[rgt[node]]){
	                res += 1 << i;
	                node = rgt[node];
	            }
	            else{
	                node = lft[node];
	            }
	        }
	    }
	    return res;
	}
};

// 0-indexed
struct segtree{
	int n;
	vector<trie> t;
	vector<int> curr;
	segtree(int n) : n(n), curr(n + 1), t(4 * n + 10){
	}
	void update(int i, int v, int orig, int s, int e, int ind){
		if(i > e || i < s) return;
		if(!orig){
			t[ind].remove(curr[i]);
		}
		t[ind].insert(v);
		if(s == e){
			return;
		}
		int mid = (s + e) >> 1;
		update(i, v, orig, s, mid, ind << 1);
		update(i, v, orig, mid + 1, e, ind << 1 | 1);
	}
	void update(int i, int v, int orig){
		assert(i >= 1 && i <= n);
		update(i, v, orig, 1, n, 1);
		curr[i] = v;
	}
	int get(int l, int r, int x, int s, int e, int ind){
		if(l > e || s > r) return 0;
		if(s >= l && e <= r) return t[ind].query(x);
		int mid = (s + e) >> 1;
		return max(get(l, r, x, s, mid, ind << 1), get(l, r, x, mid + 1, e, ind << 1 | 1));
	}
	int get(int l, int r, int x){
		return get(l, r, x, 1, n, 1);
	}
};

dsu D[LN];
int main(){
	cin.tie(0); ios_base::sync_with_stdio(0);
	int t; cin >> t;
	while(t--){
		int n, q; cin >> n >> q;
		vector<int> a(n + 1);
		for(int i = 0; i < LN; i++) D[i] = dsu(n);
		for(int i = 1; i <= n; i++) cin >> a[i];
		for(int i = 1; i < n; i++){
			int u, v; cin >> u >> v;
			g[u].push_back(v);
			g[v].push_back(u);
		}
		dfs_sz();
		for(int i = 1; i < logN; i++) for(int j = 1; j <= n; j++) par[i][j] = par[i - 1][par[i - 1][j]];
		segtree st(n);
		for(int i = 1; i <= n; i++){
			st.update(i, a[rin[i]], 1);
			for(int j = 0; j < LN; j++) if(a[rin[i]] >> j & 1){
				if(i != 1) D[j].merge(rin[i], par[0][rin[i]]);
			}
		}
		function<void(int, int, int)> update = [&](int u, int l, int x){
			for(int i = 0; i < LN; i++) if(x >> i & 1){
				int node = u;
				while(depth[node] >= depth[l]){
					node = D[i].root(node);
					if(depth[node] < depth[l]) break;
					if(a[node] >> i & 1){
						break;
					}
					st.update(in[node], a[node] |= (1 << i), 0);
					if(node != 1) D[i].merge(node, par[0][node]);
				}
			}
		};

		while(q--){
			int type; cin >> type;
			int u, v, k, x;
			if(type == 1){
				cin >> u >> v >> x;
				int l = lca(u, v);
				update(u, l, x);
				update(v, l, x);
			} else{
				cin >> k >> x;
				cout << st.get(in[k], out[k], x) << '\n';
			}
		}
		clr(n);
	}
}

Java Solution :

import java.io.*;
import java.util.*;

public class Main {
	public static void main(String[] args) throws Exception {
	    new Solver().solve();
	}
}

class Solver {
	final FastIO hp = new FastIO();

	void solve() throws Exception {
	    int tc = TESTCASES ? hp.nextInt() : 1;
	    for (int tce = 1; tce <= tc; ++tce) solve(tce);
	    hp.flush();
	}

	boolean TESTCASES = true;
	final static int BITLEN = 17;

	void solve(int tc) throws Exception {
	    int i, j, k;

	    int N = hp.nextInt(), Q = hp.nextInt();
	    A = hp.getIntArray(N);

	    ArrayList<Integer>[] graph = new ArrayList[N];
	    for (i = 0; i < N; ++i) graph[i] = new ArrayList<>();
	    for (i = 1; i < N; ++i) {
	        int a = hp.nextInt() - 1, b = hp.nextInt() - 1;
	        graph[a].add(b); graph[b].add(a);
	    }

	    TreeUtil util = new TreeUtil(graph, N, 0, A);
	    depth = util.depth;
	    parent = new int[BITLEN][];
	    for (i = 0; i < BITLEN; ++i) parent[i] = util.parent.clone();
	    bits = new boolean[BITLEN][N];

	    for (i = 0; i < N; ++i) setBits(i);

	    HashSet<Integer> set = new HashSet<>();
	    for (i = 0; i < Q; ++i) {
	        int choice = hp.nextInt();

	        if (choice == 1) {
	            int u = hp.nextInt() - 1, v = hp.nextInt() - 1, o = hp.nextInt();
	            int lcaDep = depth[util.getLCA(u, v)];

	            for (j = 0; j < BITLEN; ++j) if (((o >> j) & 1) > 0) {
	                int x = u, y = v;

	                x = getParent(x, j);
	                while (x >= 0 && depth[x] >= lcaDep) {
	                    set.add(x);
	                    x = getParent(parent[j][x], j);
	                }

	                y = getParent(y, j);
	                while (y >= 0 && depth[y] >= lcaDep) {
	                    set.add(y);
	                    y = getParent(parent[j][y], j);
	                }
	            }

	            for (int node : set) {
	                A[node] |= o;
	                setBits(node);
	                util.pointUpdate(node, o);
	            }

	            set.clear();
	        } else if (choice == 2) {
	            int node = hp.nextInt() - 1, x = hp.nextInt();
	            hp.println(util.subtreeQuery(node, x));
	        }
	    }
	}

	int[] A, depth;
	int[][] parent;
	boolean[][] bits;

	int getParent(int node, final int bitPos) {
	    if (node < 0 || !bits[bitPos][node]) return node;
	    else return parent[bitPos][node] = getParent(parent[bitPos][node], bitPos);
	}

	void setBits(int node) {
	    for (int i = 0; i < BITLEN; ++i) if (((A[node] >> i) & 1) > 0) {
	        bits[i][node] = true;
	    }
	}
}

class TreeUtil {
	ArrayList<Integer>[] graph;
	int[] depth, parent, chCount, queue;
	int N, root;
	int[] weight;

	SegmentTree st;
	int[] treePos, linearTree, segRoot;

	TreeUtil(ArrayList<Integer>[] g, int n, int r, int[] wt) {
	    graph = g;
	    N = n;
	    root = r;
	    weight = wt;
	    iterativeDFS();

	    precompute();
	}

	private void iterativeDFS() {
	    parent = new int[N];
	    depth = new int[N];
	    chCount = new int[N];
	    queue = new int[N];
	    Arrays.fill(chCount, 1);

	    int i, st = 0, end = 0;
	    parent[root] = -1;
	    depth[root] = 1;
	    queue[end++] = root;

	    while (st < end) {
	        int node = queue[st++], h = depth[node] + 1;
	        Iterator<Integer> itr = graph[node].iterator();
	        while (itr.hasNext()) {
	            int ch = itr.next();
	            if (depth[ch] > 0) continue;
	            depth[ch] = h;
	            parent[ch] = node;
	            queue[end++] = ch;
	        }
	    }
	    for (i = N - 1; i >= 0; --i)
	        if (queue[i] != root)
	            chCount[parent[queue[i]]] += chCount[queue[i]];
	}

	private void precompute() {
	    int i, j, treeRoot = -7;

	    treePos = new int[N];
	    linearTree = new int[N];
	    segRoot = new int[N];

	    Stack<Integer> stack = new Stack<>();
	    stack.ensureCapacity(N << 1);
	    stack.push(root);
	    for (i = 0; !stack.isEmpty(); ++i) {
	        int node = stack.pop();
	        if (i == 0 || linearTree[i - 1] != parent[node])
	            treeRoot = node;
	        linearTree[i] = node;
	        treePos[node] = i;
	        segRoot[node] = treeRoot;

	        int bigChild = -7, bigChildPos = -7, lastPos = graph[node].size() - 1;
	        for (j = 0; j < graph[node].size(); ++j) {
	            int tempNode = graph[node].get(j);
	            if (tempNode == parent[node]) continue;
	            if (bigChild < 0 || chCount[bigChild] < chCount[tempNode]) {
	                bigChild = tempNode;
	                bigChildPos = j;
	            }
	        }
	        if (bigChildPos >= 0) {
	            int temp = graph[node].get(lastPos);
	            graph[node].set(lastPos, bigChild);
	            graph[node].set(bigChildPos, temp);
	        }

	        for (int itr : graph[node])
	            if (parent[node] != itr)
	                stack.push(itr);
	    }

	    int[] respectiveWeights = new int[N];
	    for (i = 0; i < N; ++i)
	        respectiveWeights[i] = weight[linearTree[i]];
	    st = new SegmentTree(respectiveWeights);
	}

	void pointUpdate(int node, int value) {
	    st.pointUpdate(treePos[node], value);
	}

	int subtreeQuery(int node, int key) {
	    int pos = treePos[node];
	    return st.rangeQuery(pos, pos + chCount[node] - 1, key);
	}

	int getLCA(int node1, int node2) {
	    while (segRoot[node1] != segRoot[node2]) {
	        if (depth[segRoot[node1]] > depth[segRoot[node2]]) {
	            node1 ^= node2;
	            node2 ^= node1;
	            node1 ^= node2;
	        }
	        node2 = parent[segRoot[node2]];
	    }
	    return (depth[node1] < depth[node2]) ? node1 : node2;
	}
}

class SegmentTree {
	private final int[] A;
	private int N;
	private Trie[] tree;

	public SegmentTree(int[] ar) {
	    A = ar;

	    N = 1; while (N < ar.length) N <<= 1;
	    tree = new Trie[N << 1];
	    for (int i = 1; i < tree.length; ++i) tree[i] = new Trie();
	    for (int i = 0; i < ar.length; ++i) addValueAtIndex(i, ar[i]);
	}

	public void addValueAtIndex(int idx, int val) {
	    idx += N;
	    while (idx > 0) {
	        tree[idx].addValue(val);
	        idx >>= 1;
	    }
	}

	public void removeValueAtIndex(int idx, int val) {
	    idx += N;
	    while (idx > 0) {
	        tree[idx].removeValue(val);
	        idx >>= 1;
	    }
	}

	void pointUpdate(int idx, int orVal) {
	    removeValueAtIndex(idx, A[idx]);
	    A[idx] |= orVal;
	    addValueAtIndex(idx, A[idx]);
	}

	public int rangeQuery(int l, int r, int xorWith) {
	    return rangeQuery(1, 0, N - 1, l, r, xorWith);
	}

	private int rangeQuery(int idx, int l, int r, int ql, int qr, int xorWith) {
	    if (l > qr || r < ql) {
	        return -7;
	    } else if (l >= ql && r <= qr) {
	        return tree[idx].findMaxXOR(xorWith);
	    } else {
	        int c1 = idx << 1, c2 = c1 | 1, mid = l + r >> 1;
	        return Math.max(rangeQuery(c1, l, mid, ql, qr, xorWith),
	                rangeQuery(c2, mid + 1, r, ql, qr, xorWith));
	    }
	}
}

class TrieNode {
	int count;
	TrieNode left, right;
}

class Trie {
	private final static int BITLEN = Solver.BITLEN;
	private TrieNode root = new TrieNode();

	void addValue(int val) {
	    TrieNode curr = root;

	    for (int i = BITLEN - 1; i >= 0; --i) {
	        if (((val >> i) & 1) == 0) {
	            if (curr.left == null) curr.left = new TrieNode();
	            curr = curr.left;
	        } else {
	            if (curr.right == null) curr.right = new TrieNode();
	            curr = curr.right;
	        }
	        ++curr.count;
	    }
	}

	void removeValue(int val) {
	    TrieNode curr = root;

	    for (int i = BITLEN - 1; i >= 0 && curr != null; --i) {
	        if (((val >> i) & 1) == 0) {
	            if (--curr.left.count <= 0) curr.left = null;
	            curr = curr.left;
	        } else {
	            if (--curr.right.count <= 0) curr.right = null;
	            curr = curr.right;
	        }
	    }
	}

	int findMaxXOR(int val) {
	    int maxXor = 0;
	    TrieNode curr = root;

	    for (int i = BITLEN - 1; i >= 0; --i) {
	        if (((val >> i) & 1) == 0) {
	            if (curr.right != null) {
	                curr = curr.right;
	                maxXor |= 1 << i;
	            } else {
	                curr = curr.left;
	            }
	        } else {
	            if (curr.left != null) {
	                curr = curr.left;
	                maxXor |= 1 << i;
	            } else {
	                curr = curr.right;
	            }
	        }
	    }

	    return maxXor;
	}
}

class FastIO {
	BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
	StringTokenizer st = new StringTokenizer("");
	StringBuilder sb = new StringBuilder();

	public String next() throws Exception {
	    while (!st.hasMoreTokens()) st = new StringTokenizer(br.readLine());
	    return st.nextToken();
	}

	public int nextInt() throws Exception {
	    return Integer.parseInt(next());
	}

	public void print(Object o) {
	    sb.append(o);
	}

	public void println() {
	    print("\n");
	}

	public void println(Object o) {
	    print(o);
	    println();
	}

	public void flush() {
	    System.out.print(sb);
	    sb = new StringBuilder();
	}

	int[] getIntArray(int size) throws Exception {
	    int[] ret = new int[size];
	    for (int i = 0; i < size; ++i) ret[i] = nextInt();
	    return ret;
	}
}
7 Likes

Damn why did I use HLD which was useless

Actually we used HLD approach initially. Later changed it into this approach.

Invitation to Coders' Legacy 2020 (Rated for all) - Codeforces seems this O(nq) sol got passed. Maybe next time you should strengthen the test cases.

Solving CLBIT the hard way here, check this out

Iā€™m unable to understand how trie will help in calculating max XOR. Can anyone please explain it?

EDIT: nevermind, I saw the code now :rofl:

Woah, I feel dumb now. My solution is much more complicated. Basically we use HLD, and on each heavy segment we build

  1. A segment tree for updates
  2. A set-ish structure where we store values of nodes from ā€œchildā€ heavy segments (You also have to make sure that when you use the nodes from child segment from a node of depth d it doesnā€™t impact the nodes below etc. so thats why I said -ish. I ended up using a kind of hashing and other stuff to save space.) [its basically like a trie but more complex, assume]
    Update is easy: walk up the heavy segments and update eachā€™s segment tree. Also, on reaching the LCA, walk ā€˜upwardsā€™ and update the structures of parent heavy segments. This takes \mathcal O (\log ^2 n) per update. (The set-ish struct gives the extra log). For query you have to query both the segment tree of the heavy segment containing the query node as well as its set-ish structure. This is overall \mathcal O (\log n). So the overall complexity is \mathcal O (n \log ^2 n) (assume Q \sim n). [Note that each nodeā€™s record is stored in at most \log n heavy segments so precomputation is not too long.]
    And after doing ALL this for an hour I now see that there is an easier solution. :face_with_thermometer:

Why my code is giving runtime error itā€™s working for custom input. I first flatten the tree using dfs. Then used LCA (binary lifting)for path querry. Itā€™s not giving tle but wa. Can some plz see my code.

code_link

@sarthakmanna Can you please elaborate on time complexity of path compression in this case. I think it can be linear sometimes. So how to prove that overall the total number of operations are less?

This is kinda intuitive, Iā€™d say. You can refer to Union By Rank and Path Compression in Union-Find Algorithm - GeeksforGeeks.

Suppose a ā† b ā† c [a is the parent of b which is the parent of c]. Now, suppose, you call findParent(c). This operation will indeed take O(3) time. But, after this call, parent[c] and parent[b] will become a. So, the next time, when you call findParent(b) or findParent(c), itā€™ll return in O(1). Thatā€™s why the overall amortized complexity becomes O(N. lg N + Q) where Q is the number of findParent() calls.

Thank You