COCR108 - Editorial

PROBLEM LINK:

Contest

Author: Abhishek Jugdar
Editorialist: Abhishek Jugdar

DIFFICULTY:

MEDIUM-HARD

PREREQUISITES:

Meet-in-the-middle, Trie

PROBLEM:

We have a 3D grid of dimensions N x M x K and a magic number P. We are required to find the number of paths starting from (1, 1, 1) and ending at (N, M, K) such that xor of numbers on the path is not greater than P.

EXPLANATION:

Some initial observations

N, M, K are all less than 10. So there can’t be too many paths in the grid, right? The number of paths in grid of 9 x 9 x 9 are 9,465,511,770. How to calculate the number of paths? You are initially at (1,1,1) and want to reach (N,M,K). So the number of moves to reach it are = (N - 1 + M -1 + K - 1) = (N + M + K - 3). So the number of paths are \frac{(N + M + K - 3)!}{(N - 1)! (M - 1)! (K - 1)!}. It is quite clear that we cant iterate all these paths and check when xor \leq P, since no. of paths in worst case exceeds 109. Even though a brute force won’t fit in the time limit, we will discuss it since it will pave the path towards the actual solution.

Brute force

Let’s go to the good old brute force method. Let us consider X = (N + M + K - 3). In each move we can either go to the bottom, right or down. So since there are 3 options per move, so there are 3X combinations possible. Note that all these combinations won’t be valid. Only those combinations are valid which have exactly (N - 1) bottom moves (x+1,y,z) , (M - 1) right moves (x,y+1,z) , and (K - 1) down moves (x,y,z+1).

How to implement this?

Write down representations of all numbers from 0 to 3X - 1, in base 3 representation and assign 0,1,2 to the corresponding moves - for eg. 0-bottom, 1-right, 2-down. Now just go over these paths, check if it is a valid one according to the above mentioned criteria, and check if the xor of numbers on path is \leq P.

Time Complexity of brute force

It is quite straightforward - O(3X.X)

Optimization - Meet-in-the-middle technique
What is meet-in-the-middle?

Meet-in-the-middle is just an optimization technique which is used to optimize a brute force solution. If you want to read more about it, you can read it here.

Intuition behind using meet-in-the-middle

As defined above X = (N + M + K - 3), which for the worst case is X = 9 + 9 + 9 - 3 = 24. Now though 3X.X isn’t viable, 3X/2.X/2 is certainly viable in our case.

How to apply this technique - on a simpler version of the statement first

Let us make one small change to the problem statement, just to make it more simpler initially. Find no. of magical paths having xor exactly P.
So, we can observe that all the paths from start point to end point can be broken into 2 parts : path from (1,1,1) to (x,y,z) and path from (x,y,z) to (N,M,K). Each of these paths has length \leq X/2.

Now use the brute force method on X/2 instead of X starting from (1,1,1). As mentioned before, it is possible that all 3X/2 paths aren’t valid, you may go out of grid, so you need to handle that.
Consider a cell (x,y,z) which is an end point when we started from (1,1,1) along a path of length X/2. Maintain a map for all cells, to store count of xor values for all possible paths from (1,1,1) to (x,y,z).
Now do the same thing, but from (N,M,K) and when for any path you reach (x,y,z), suppose the xor for the path is curr, find frequency of (curr \oplus P) in map[x][y][z] and add it to the answer.
Do this for all possible cells, and we will have all paths having xor exactly equal to P.

Full Solution

Whew!! We have done quite a lot uptil now, yet there’s one more thing remaining. We have seen how to calculate no. of paths having xor equal to P, now we need to figure out a way to count paths having xor \leq P.

So what's so much different in this case now?

The only difference is that instead of maintaining a map for each cell, we now have to maintain a trie!

Why Trie??

If you haven’t heard about trie or aren’t too familiar with it, I will recommend learning about it first from here or any other resource which you prefer.

To obtain the answer, for each value of xor obtained for a path from (N,M,K) to (x,y,z), we need to count the numbers (which were xor of some path from (1,1,1) to (x,y,z)) which after performing xor operation with currentpath_xor are \leq P.
To do so, we need to go bit-by-bit and choose optimal values while traversing the trie and add them to the answer.

Now, we build a trie in a normal way, but there is one addition, each node of the trie stores a cnt variable. The role of this variable will be explained a bit later.
So, the structure of our trie is :

Structure
struct Trie {
int cnt;
Trie *left, *right;

};

Trie Implementation
Insertion

cnt variable is just used to store the count of numbers which have traversed through that particular node. Its importance will be explained in the querying part.

void insert(Trie* root, int curr, int v)
{
	if(curr < 0) return;
	if(v & (1 << curr)) {
		if(root->right == NULL) root->right = newnode();
		root->right->cnt++;
		insert(root->right, curr - 1, v);
	}
	else {
		if(root->left == NULL) root->left = newnode();
		root->left->cnt++;
		insert(root->left, curr - 1, v);
	}
}
Querying the trie

The below commented code snippet describes in detail how querying takes place, why it is correct and importance of cnt variable.

int query(Trie* root, int curr, int val, int p)  // val is the currentpath_xor
{
	if(root == NULL) return 0;   // No such node present, so return 0;
	if(root->left == NULL && root->right == NULL) return root->cnt; // reached end of trie, so just return cnt of numbers 
	
	if(p & (1 << curr)) {
		if(val & (1 << curr)) {
			// Case : curr bit is set in both, p and currentpath_xor.
			
			// If we proceed to right, we are guaranteed to have XOR < p
			//(since our current bit becomes 0, but it is set in p), so instead of traversing just add count.
			
			// Also we query along the left, where uptil now, we will be at same value as p. 
			
			
			if(root->right != NULL) return root->right->cnt + query(root->left, curr - 1, val, p);
			return query(root-> left, curr - 1, val, p);
		} 
		else {
			// Case : curr bit is set in p, but unset in currentpath_xor.
			
			// If we proceed to left, we are guaranteed to have XOR < p
			//(since our current bit becomes 0, but it is set in p), so instead of traversing just add count.
			
			// Also we query along the right, where uptil now, we will be at same value as p. 
			
			if(root->left != NULL) return root->left->cnt + query(root->right, curr - 1, val, p);
			return query(root-> right, curr - 1, val, p);
		}
	}
	else {
		// Since curr bit is unset in p, we cannot have it set in our number.
		
		// Remember : If we have reached a node in the trie, 
		// it is guaranteed that value uptil now will be equal to p, if it was less, 
		// we wouldn't have traversed down that path and just added cnt of that node.
		
		if(val & (1 << curr)) return query(root->right, curr - 1, val, p);
		return query(root->left, curr - 1, val, p);
	}
	return 0;
}
Time Complexity

Using X = (N + M + K - 3)
Time complexity - O(3X/2.X/2 + 3X/2.log2(max(Ax,y,z))

Things to lookout for
  • The answer can exceed the int range, so don’t forget to use long long.
  • Not all of the paths you traverse will be valid, be sure to handle the cases when you go out of the grid (check out the Brute Force section for reference).

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
#define ll long long int
#define ld long double
#define f first
#define s second
#define pb push_back
#define eb emplace_back
#define mk make_pair
#define mt make_tuple
#define MOD 1000000007
#define fo(i,a,b) for(i=a;i<b;i++)
#define foe(i,a,b) for(i=a;i<=b;i++)
#define all(x) x.begin(), x.end()
#define vi vector<int>
#define vl vector <long long int>
#define pii pair <int,int>
#define pll pair <long long int, long long int>
#define vpii vector< pair<int,int> >
#define vpll vector < pair <long long int,long long int> >
#define boost ios::sync_with_stdio(false); cin.tie(0)
using namespace std;
const int inf = 1e9 + 5;
const ll inf64 = 1e18 + 5;
 
const int LN = 29;
const int MAXPOW = 13;
int pw[MAXPOW];
struct Trie {
	int cnt;
	Trie *left, *right;
};
Trie* newnode()
{
	Trie* node = new Trie();
	node->cnt = 0;
	node->left = node->right = NULL;
	return node;
}
void insert(Trie* root, int curr, int v)
{
	if(curr < 0) return;
	if(v & (1 << curr)) {
		if(root->right == NULL) root->right = newnode();
		root->right->cnt++;
		insert(root->right, curr - 1, v);
	}
	else {
		if(root->left == NULL) root->left = newnode();
		root->left->cnt++;
		insert(root->left, curr - 1, v);
	}
}
int query(Trie* root, int curr, int val, int k)
{
	if(root == NULL) return 0;
	if(root->left == NULL && root->right == NULL) return root->cnt;
	
	if(k & (1 << curr)) {
		if(val & (1 << curr)) {
			if(root->right != NULL) return root->right->cnt + query(root->left, curr - 1, val, k);
			return query(root-> left, curr - 1, val, k);
		} 
		else {
			if(root->left != NULL) return root->left->cnt + query(root->right, curr - 1, val, k);
			return query(root-> right, curr - 1, val, k);
		}
	}
	else {
		if(val & (1 << curr)) return query(root->right, curr - 1, val, k);
		return query(root->left, curr - 1, val, k);
	}
	return 0;
}
vector <int> base_3(int x, int sz)
{
	vector <int> v;
	while(x) {
		v.pb(x % 3);
		x /= 3;
	}
	while(v.size() < sz) v.pb(0);
	reverse(all(v));
	return v;
}
int main()
{
	boost;
	int n, m, l, p;
	cin >> n >> m >> l >> p;
	int arr[n][m][l];
	for(int i = 0; i < n; i++) {
		for(int j = 0; j < m; j++) {
			for(int k = 0; k < l; k++)
			cin >> arr[i][j][k];
		}
	}
	
	Trie* root[n][m][l];
	for(int i = 0; i < n; i++) {
		for(int j = 0; j < m; j++) {
			for(int k = 0; k < l; k++)
			root[i][j][k] = NULL;
		}
	}
	
	pw[0] = 1;
	for(int i = 1; i < MAXPOW; i++)
	pw[i] = pw[i - 1] * 3;
	
	int mid = (n + m + l - 3) / 2, rmid = (n + m + l - 3) - mid;
	for(int i = 0; i < pw[mid]; i++) {
		vector <int> v = base_3(i, mid);
		int x = 0, y = 0, z = 0;
		int curr = arr[x][y][z];
		for(int val : v) {
			if(val == 0) x++;
			else if(val == 1) y++;
			else z++; 
			
			if(x > n - 1 || y > m - 1 || z > l - 1) break;
			curr ^= arr[x][y][z];
		}
		if(x > n - 1 || y > m - 1 || z > l - 1) continue;
		if(root[x][y][z] == NULL) root[x][y][z] = newnode();
		insert(root[x][y][z], LN, curr);
	} 
	
	ll ans = 0;
	for(int i = 0; i < pw[rmid]; i++) {
		vector <int> v = base_3(i, rmid);
		int x = n - 1, y = m - 1, z = l - 1;
		int curr = arr[x][y][z];
		for(int val : v) {
			if(val == 0) x--;
			else if(val == 1) y--;
			else z--;
			
			if(x < 0 || y < 0 || z < 0) break;
			curr ^= arr[x][y][z];
		}
		if(x < 0 || y < 0 || z < 0) continue;
		curr ^= arr[x][y][z];
		ans += query(root[x][y][z], LN, curr, p);
	}
	cout << ans;
}

If you have any doubts regarding the explanation, feel free to ask them below. If you found any other ways to solve the problem, do mention.

3 Likes