KFOREST - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: yash5507
Testers: iceknight1093, rivalq
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Depth-first search, binary search (in the binary-lifting style)

PROBLEM:

You’re given a tree on N vertices and an integer K.
The i-th vertex has value A_i written on it.

Split the tree into a forest of K+1 trees such that the bitwise XOR of the bitwise ANDs of the trees is maximized.

EXPLANATION:

As in many problems dealing with bitwise operations, we can build the answer bit-by-bit.

To maximize the answer, it’s better to have a higher bit in the answer, even at the cost of lower bits.
So, lets iterate over bits from 29 down to 0, each time checking whether it can be added to the answer.

That is, we have the following algorithm:

  • Initialize ans = 0.
  • For each bit b from 29 to 0:
    • If it’s possible for the answer to be at least (ans + 2^b), set ans \gets ans + 2^b.
    • Otherwise, do nothing.

The final value of ans is what we want.

Notice that this requires us to check whether the answer can be at least ans + 2^b.
Let’s focus on doing that now.
For convenience, let x = ans + 2^b.

Notice that our process builds the answer bit-by-bit, so in our check, we can simply ignore all bits \lt b (which are all unset in x anyway).
Also, every bit larger than b has already been fixed, so we can’t change those.
In particular, if a bit larger than b is 0 in x, then we know there’s no way it can be in the answer.
So, we can ignore this bit when performing computations.

Note that the above discussion tells us that we’re ignoring all bits that are 0 in x.
In particular, that means we’re actually considering the values of A_i \ \& \ x for each i, instead of just A_i.

This means that we simply need to check whether we can divide the given tree into K trees such that:

  • Let B_j be the bitwise XOR of the (A_i \ \& \ x) values of the j-th tree.
  • Then, we want B_1 \ \& \ B_2 \ \&\ \ldots \ \& \ B_K = x

Now, note that for x to be the bitwise AND of the B_j values, each of them must be a supermask of x.
However, each B_j is the bitwise XOR of some values that are submasks of x; thus making it also a submask of x.
Being both a submask and a supermask means that every B_j must be equal to x.

So, our objective is to figure out whether the tree can be partitioned into K trees, each with bitwise XOR x.

To do that, let’s find out the maximum number of subtrees with XOR x that the given tree can be partitioned into.
That can be done using DFS, as follows:

  • Start a DFS from some node, say 1.
  • When at a node u,
    • Let S_u denote the subtree XOR of u.
    • If S_u \neq x, do nothing.
    • If S_u = x, we’ve found one subtree with XOR x. We can ‘cut’ it out, and then pretend the edge between u and its parent doesn’t exist anymore (so u and its descendants won’t contribute to the subtree XOR of the parent of u).

The number of subtrees we cut out is the maximum number we’re looking for.
Let it be y.
Now,

  • If y \lt K, then forming K trees is obviously impossible.
  • Otherwise, y \geq K.
  • Note that when y \geq 3, we can form y-2 trees with XOR x by merging three adjacent trees. - In particular, repeating this allows us to form y, y-2, y-4, \ldots trees, i.e, any number with the same parity as y.
  • It’s not hard to see that these are the only possible numbers achievable.
  • So, obtaining K trees is possible if and only if y \geq K and (y-K) is even.

Thus we have an \mathcal{O}(N) time algorithm to check whether the answer is at least x.
Doing this for each bit gives us a solution in \mathcal{O}(30N).

It can be noted that the method we use to build the answer bit-by-bit is technically the exact same thing as binary searching on the range [0, 2^{30}).
However, the function we’re searching on isn’t monotonic; we only get the correct answer by utilizing the fact that we can work bit-by-bit.

In particular, this means that if you binary search on [0, M) where M is not a power of 2, you most likely won’t get AC.

TIME COMPLEXITY

\mathcal{O}(N\log{\max(A)}) per test case.

CODE:

Setter's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long 
#define pb(e) push_back(e)
#define sv(a) sort(a.begin(),a.end())
#define sa(a,n) sort(a,a+n)
#define mp(a,b) make_pair(a,b)
#define vf first
#define vs second
#define ar array
#define all(x) x.begin(),x.end()
const int inf = 0x3f3f3f3f;
const int mod = 1000000007; 
const double PI=3.14159265358979323846264338327950288419716939937510582097494459230;
 
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
 
bool remender(ll a , ll b){return a%b;}
 
//freopen("problemname.in", "r", stdin);
//freopen("problemname.out", "w", stdout);
 
const int N = 200003;
 
vector<int> adj[N];
int arr[N];
int cmp;
 
int dfs(int node , int par , int desire){
	int cur = arr[node];
	for(int i : adj[node]){
		if(i == par)continue;
		cur ^= dfs(i , node , desire);
	}
	if((cur&desire)==desire){
		cmp++;
		cur = 0;
	}
	return cur;
}
 
int solve(int k , int n){
	int ans = 0;
	for(int i = 30; i >= 0; i--){
		ans += (1 << i);
		cmp = 0;
		int x = dfs(1,1,ans);
		if((x&ans) != 0)cmp = 0;
		if(cmp < k || (cmp - k) % 2 == 1){
			ans -= (1<<i);
		}
	}
	return ans;
}
 
int main(){
ios_base::sync_with_stdio(false);
cin.tie(NULL);
	int t;cin >> t;while(t--){
	int n;
	cin >> n;
	int k;
	cin >> k;
	for(int i = 1; i <= n; i++)cin >> arr[i];
	for(int i = 0; i < n - 1; i++){
		int u , v;
		cin >> u >> v;
		adj[u].pb(v);
		adj[v].pb(u);
	}
	cout << solve(k , n) << '\n';
	for(int i = 0; i <= n; i++)adj[i].clear();
	}
	return 0;
}
Tester's code (C++)
// Jai Shree Ram  
  
#include<bits/stdc++.h>
using namespace std;

#define rep(i,a,n)     for(int i=a;i<n;i++)
#define ll             long long
#define int            long long
#define pb             push_back
#define all(v)         v.begin(),v.end()
#define endl           "\n"
#define x              first
#define y              second
#define gcd(a,b)       __gcd(a,b)
#define mem1(a)        memset(a,-1,sizeof(a))
#define mem0(a)        memset(a,0,sizeof(a))
#define sz(a)          (int)a.size()
#define pii            pair<int,int>
#define hell           1000000007
#define elasped_time   1.0 * clock() / CLOCKS_PER_SEC



template<typename T1,typename T2>istream& operator>>(istream& in,pair<T1,T2> &a){in>>a.x>>a.y;return in;}
template<typename T1,typename T2>ostream& operator<<(ostream& out,pair<T1,T2> a){out<<a.x<<" "<<a.y;return out;}
template<typename T,typename T1>T maxs(T &a,T1 b){if(b>a)a=b;return a;}
template<typename T,typename T1>T mins(T &a,T1 b){if(b<a)a=b;return a;}

// -------------------- Input Checker Start --------------------
 
long long readInt(long long l, long long r, char endd)
{
    long long x = 0;
    int cnt = 0, fi = -1;
    bool is_neg = false;
    while(true)
    {
        char g = getchar();
        if(g == '-')
        {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if('0' <= g && g <= '9')
        {
            x *= 10;
            x += g - '0';
            if(cnt == 0)
                fi = g - '0';
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);
            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if(g == endd)
        {
            if(is_neg)
                x = -x;
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(false);
            }
            return x;
        }
        else
        {
            assert(false);
        }
    }
}
 
string readString(int l, int r, char endd)
{
    string ret = "";
    int cnt = 0;
    while(true)
    {
        char g = getchar();
        assert(g != -1);
        if(g == endd)
            break;
        cnt++;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}
 
long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
void readEOF() { assert(getchar() == EOF); }
 
vector<int> readVectorInt(int n, long long l, long long r)
{
    vector<int> a(n);
    for(int i = 0; i < n - 1; i++)
        a[i] = readIntSp(l, r);
    a[n - 1] = readIntLn(l, r);
    return a;
}
 
// -------------------- Input Checker End --------------------

const int maxn=2000005;
int p[maxn];
int sz[maxn];
void clear(int n=maxn){
    rep(i,0,n)p[i]=i,sz[i]=1;
}
int root(int x){
   while(x!=p[x]){
       p[x]=p[p[x]];
       x=p[x];
   }
   return x;  
}
void merge(int x,int y){
    int p1=root(x);
    int p2=root(y);
    if(p1==p2)return;
    if(sz[p1]>=sz[p2]){
        p[p2]=p1;
        sz[p1]+=sz[p2];
    }
    else{
        p[p1]=p2;
        sz[p2]+=sz[p1];
    }
}

int solve(){

		int n = readIntSp(1, 2e5);
		int k = readIntLn(1, n);
		static int sum_n = 0;
		sum_n += n;
		assert(sum_n <= 2e5);
		vector<int> a = readVectorInt(n, 0, 1e9);
		vector<vector<int>> g(n);
		clear(n);
		for(int i = 2; i <= n; i++){
			int u = readIntSp(1, n); u--;
			int v = readIntLn(1, n); v--;
			g[u].push_back(v);
			g[v].push_back(u);
			assert(root(u) != root(v));
			merge(u, v);
		}

		auto check = [&] (int mask) {
			int cnt = 0;

			function<int(int,int)> dfs = [&](int u, int p){
				int val = a[u];
				for(auto i: g[u]){
					if(i != p){
						val ^= dfs(i, u);
					}
				}
				if((val & mask) == mask) {
					cnt++;
					val = 0;
				}
				return val;
			};
			if((dfs(0,0) & mask) != 0) return -1LL;
			
			return cnt;
		};
		int ans = 0;
		for(int j = 30; j >= 0; j--){
			int c = check(ans + (1LL << j));
			if(c >= k and k % 2 == c % 2) ans += (1LL << j);
		}
		cout << ans << endl;


 return 0;
}
signed main(){
    ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    //freopen("input.txt", "r", stdin);
    //freopen("output.txt", "w", stdout);
    #ifdef SIEVE
    sieve();
    #endif
    #ifdef NCR
    init();
    #endif
    int t = readIntLn(1, 5e4);
    while(t--){
        solve();
    }
    return 0;
}
Editorialist's code (Python)
import sys
input = sys.stdin.readline
for _ in range(int(input())):
	n, k = map(int, input().split())
	a = list(map(int, input().split()))
	graph = [[] for _ in range(n)]
	for i in range(n-1):
		u, v = map(int, input().split())
		graph[u-1].append(v-1)
		graph[v-1].append(u-1)

	subxor = [0] * n
	visited = [0] * n
	def dfs(mask):
		ct = 0
		for i in range(n): visited[i] = 0

		stack = [0]
		while stack:
			start = stack[-1]
			if not visited[start]:
				visited[start] = 1
				for child in graph[start]:
					if not visited[child]:
						stack.append(child)
			else:
				stack.pop()
				subxor[start] = a[start] & mask
				for child in graph[start]:
					if visited[child] == 2:
						subxor[start] ^= subxor[child]

				if subxor[start] == mask:
					subxor[start] = 0
					ct += 1
				visited[start] = 2
		if subxor[0] != 0: return 0
		return ct

	ans = 0
	for bit in reversed(range(30)):
		ct = dfs(ans + 2**bit)
		if ct%2 == k%2 and ct >= k: ans += 2**bit
	print(ans)