KTTREE - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2

Setter: Yogesh Sharma
Tester: Rahul Dugar
Editorialist: Taranpreet Singh

DIFFICULTY

Medium

PREREQUISITES

Bitmasking, Dynamic programming, Disjoint Set Union

PROBLEM

Given a tree with N nodes where each node is colored from 0 to K, find the number of ways to choose a tuple of K nodes where i-th node in tuple has color i and for every pair of nodes in the tuple, there exists at least one node with color 0 on the path between nodes in pair.

QUICK EXPLANATION

  • Dividing the whole tree into separate components by 0-colored nodes, we can select at most one node from any component. Also, we should select exactly one node of each color.
  • We can use a bitmask to represent colors of already selected nodes after considering the first i components. For each component, we have only (1+K) choices, either select no node from this component or select exactly one, which can be simulated by dynamic programming.

EXPLANATION

The first subtask is just brute force, so ignoring it.

The second subtask has N = 2000 and K = 2. meaning we need to count the number of valid pairs, such that there’s at least one node with color 0 on the path.

To check this efficiently, we can form components, divided by node colored 0. This way, all pairs of nodes within the same component do not have 0 on the path between them, while all pairs having nodes in different components have at least one node colored 0 on the path connecting them. Hence, we just merge non-zero nodes directly connected by edges using DSU and while checking pairs, assert that they belong to different pairs.

Towards complete Solution

The idea of splitting the tree by nodes colored 0 is highly useful, as it allows us to quickly check whether there’s a node colored 0 on the path between two nodes. But It can help even more. Since two nodes in the same component will not have node colored 0 on the path between them, we cannot select more than one node from each component.

The significance of the above statement is that for each component, we can compute the frequency of nodes with each color, and treat this problem as a subset selection problem.

Let’s denote f_{i, c} denote the number of nodes with color c in i-th component. Hence, if we want to select c-colored node from i-th component, there are f_{c, i} ways to do so.

Now, since we’d like to maintain information on which colored nodes are already selected and also, K is quite small, so it suggests using bitmask with K bits representing which colored nodes are already selected.

Hence, we can now maintain the number of ways to select nodes represented by mask from the first x components by ways_{x, mask}. For each component, we may select no node, or exactly one. Working out all cases, we get the following recurrence.

\displaystyle \text{ways}_{x, mask} = \text{ways}_{x-1, mask} + \sum_{c \in S} \text{ways}_{x-1, mask \oplus 2^c} * \text{f}_{x, c} where set S denotes the indices of bits set in mask

Since the number of components is of the order N, there are 2^K masks, leading to N*2^K states and O(K) time needed to compute each state, the time complexity of this approach becomes O(N*K*2^K) which is sufficient for all except the last subtask.

One last trick is to notice that the sum of nodes across all components is N, so if we iterate over mask if and only if f_{i, c} is non-zero, there cannot be more than min(K, size) iterations over all masks where size is the size of that component. Summing over all component, this gives us N*2^K operations, just by skipping mask updates when f_{x, c} is zero.

TIME COMPLEXITY

The time complexity is O(N*2^K) per test case.
The memory complexity is O(N*2^K) per test case.

SOLUTIONS

Setter's Solution
#pragma GCC optimize("Ofast")
#include <bits/stdc++.h>
using namespace std;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/rope>
using namespace __gnu_pbds;
using namespace __gnu_cxx;
#ifndef rd
#define trace(...)
//#define endl '\n'
#endif
#define pb push_back
#define fi first
#define se second
#define int long long
typedef long long ll;
typedef long double f80;
#define double long double
#define pii pair<int,int>
#define pll pair<ll,ll>
#define sz(x) ((long long)x.size())
#define fr(a,b,c) for(int a=b; a<=(c); a++)
#define rep(a,b,c) for(int a=b; a<(c); a++)
#define trav(a,x) for(auto &a:x)
#define all(con) con.begin(),con.end()
const ll infl=0x3f3f3f3f3f3f3f3fLL;
const int infi=0x3f3f3f3f;
const int mod=998244353;
//const int mod=1000000007;
typedef vector<int> vi;
typedef vector<ll> vl;
 
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> oset;
auto clk=clock();
mt19937_64 rang(chrono::high_resolution_clock::now().time_since_epoch().count());
int rng(int lim) {
	uniform_int_distribution<int> uid(0,lim-1);
	return uid(rang);
}
 
int powm(int a, int b) {
	int res=1;
	while(b) {
		if(b&1)
			res=(res*a)%mod;
		a=(a*a)%mod;
		b>>=1;
	}
	return res;
}
 
long long readInt(long long l, long long r, char endd) {
	long long x=0;
	int cnt=0;
	int fi=-1;
	bool is_neg=false;
	while(true) {
		char g=getchar();
		if(g=='-') {
			assert(fi==-1);
			is_neg=true;
			continue;
		}
		if('0'<=g&&g<='9') {
			x*=10;
			x+=g-'0';
			if(cnt==0) {
				fi=g-'0';
			}
			cnt++;
			assert(fi!=0 || cnt==1);
			assert(fi!=0 || is_neg==false);
 
			assert(!(cnt>19 || ( cnt==19 && fi>1) ));
		} else if(g==endd) {
			if(is_neg) {
				x=-x;
			}
			assert(l<=x&&x<=r);
			return x;
		} else {
			cout<<ll(g)<<" "<<g<<endl;
			assert(false);
		}
	}
}
string readString(int l, int r, char endd) {
	string ret="";
	int cnt=0;
	while(true) {
		char g=getchar();
		assert(g!=-1);
		if(g==endd) {
			break;
		}
		cnt++;
		ret+=g;
	}
	assert(l<=cnt&&cnt<=r);
	return ret;
}
long long readIntSp(long long l, long long r) {
	return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
	return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
	return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
	return readString(l,r,' ');
}
 
int sum_n=0;
int subtask_n=100000,subtask_k=11;
vi gra[100005];
bool vis[100005];
int c[100005];
void dfs(vi &cnt, int fr, int at) {
	vis[at]=1;
	cnt[c[at]-1]++;
	for(int i:gra[at])
		if(i!=fr&&c[i])
			dfs(cnt, at,i);
}
int cntr=0;
void tree_check(int fr, int at) {
	cntr++;
	for(int i:gra[at])
		if(i!=fr)
			tree_check(at,i);
}
int dp[1<<11],dp2[1<<11];
void solve() {
	int n=readIntSp(1,subtask_n),k=readIntLn(2,subtask_k);
//	int n,k;
//	cin>>n>>k;
	fr(i,1,n) {
		gra[i].clear();
		vis[i]=0;
	}
	memset(dp,0,sizeof(int)*(1<<k));
	memset(dp2,0,sizeof(int)*(1<<k));
	sum_n+=n;
	assert(sum_n<=100000);
	fr(i,1,n) {
//		cin>>c[i];
		if(i!=n)
			c[i]=readIntSp(0,k);
		else
			c[i]=readIntLn(0,k);
	}
	rep(i,1,n) {
//		int u,v;
//		cin>>u>>v;
		int u=readIntSp(1,n),v=readIntLn(1,n);
		assert(u!=v);
		gra[u].pb(v);
		gra[v].pb(u);
	}
	cntr=0;
	tree_check(1,1);
	assert(cntr==n);
	vector<vi> cnts;
	fr(i,1,n)
		if(vis[i]==0&&c[i]) {
			cnts.pb(vi(k,0LL));
			dfs(cnts.back(),i,i);
		}
	dp[0]=dp2[0]=1;
	for(auto &i: cnts) {
		rep(j,0,k)
			if(i[j])
				rep(l,0,1<<k)
					if((l>>j)&1)
						dp2[l]=(dp2[l]+dp[l^(1<<j)]*i[j])%mod;
		memcpy(dp,dp2,sizeof(int)*(1<<k));
	}
	cout<<dp[(1<<k)-1]<<endl;
}
 
 
signed main() {
	ios_base::sync_with_stdio(0),cin.tie(0);
	srand(chrono::high_resolution_clock::now().time_since_epoch().count());
	cout<<fixed<<setprecision(7);
	cerr<<100<<endl;
	int t=readIntLn(1,10);
	cerr<<t<<endl;
//	int t;
//	cin>>t;
	fr(i,1,t)
		solve();
	assert(getchar()==EOF);
#ifdef rd
	cerr<<endl<<endl<<endl<<"Time Elapsed: "<<((double)(clock()-clk))/CLOCKS_PER_SEC<<endl;
#endif
}
Tester's Solution
#pragma GCC optimize("Ofast")
#include <bits/stdc++.h>
using namespace std;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/rope>
using namespace __gnu_pbds;
using namespace __gnu_cxx;
#ifndef rd
#define trace(...)
#define endl '\n'
#endif
#define pb push_back
#define fi first
#define se second
#define int long long
typedef long long ll;
typedef long double f80;
#define double long double
#define pii pair<int,int>
#define pll pair<ll,ll>
#define sz(x) ((long long)x.size())
#define fr(a,b,c) for(int a=b; a<=(c); a++)
#define rep(a,b,c) for(int a=b; a<(c); a++)
#define trav(a,x) for(auto &a:x)
#define all(con) con.begin(),con.end()
const ll infl=0x3f3f3f3f3f3f3f3fLL;
const int infi=0x3f3f3f3f;
const int mod=998244353;
//const int mod=1000000007;
typedef vector<int> vi;
typedef vector<ll> vl;

typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> oset;
auto clk=clock();
mt19937_64 rang(chrono::high_resolution_clock::now().time_since_epoch().count());
int rng(int lim) {
	uniform_int_distribution<int> uid(0,lim-1);
	return uid(rang);
}

int powm(int a, int b) {
	int res=1;
	while(b) {
		if(b&1)
			res=(res*a)%mod;
		a=(a*a)%mod;
		b>>=1;
	}
	return res;
}

long long readInt(long long l, long long r, char endd) {
	long long x=0;
	int cnt=0;
	int fi=-1;
	bool is_neg=false;
	while(true) {
		char g=getchar();
		if(g=='-') {
			assert(fi==-1);
			is_neg=true;
			continue;
		}
		if('0'<=g&&g<='9') {
			x*=10;
			x+=g-'0';
			if(cnt==0) {
				fi=g-'0';
			}
			cnt++;
			assert(fi!=0 || cnt==1);
			assert(fi!=0 || is_neg==false);

			assert(!(cnt>19 || ( cnt==19 && fi>1) ));
		} else if(g==endd) {
			if(is_neg) {
				x=-x;
			}
			assert(l<=x&&x<=r);
			return x;
		} else {
			assert(false);
		}
	}
}
string readString(int l, int r, char endd) {
	string ret="";
	int cnt=0;
	while(true) {
		char g=getchar();
		assert(g!=-1);
		if(g==endd) {
			break;
		}
		cnt++;
		ret+=g;
	}
	assert(l<=cnt&&cnt<=r);
	return ret;
}
long long readIntSp(long long l, long long r) {
	return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
	return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
	return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
	return readString(l,r,' ');
}

int sum_n=0;
int subtask_n=100000,subtask_k=11;
vi gra[100005];
bool vis[100005];
int c[100005];
void dfs(vi &cnt, int fr, int at) {
	vis[at]=1;
	cnt[c[at]-1]++;
	for(int i:gra[at])
		if(i!=fr&&c[i])
			dfs(cnt, at,i);
}
int cntr=0;
void tree_check(int fr, int at) {
	cntr++;
	for(int i:gra[at])
		if(i!=fr)
			tree_check(at,i);
}
int dp[1<<11],dp2[1<<11];
void solve() {
	int n=readIntSp(1,subtask_n),k=readIntLn(2,subtask_k);
	fr(i,1,n) {
		gra[i].clear();
		vis[i]=0;
	}
	memset(dp,0,sizeof(int)*(1<<k));
	memset(dp2,0,sizeof(int)*(1<<k));
	sum_n+=n;
	assert(sum_n<=100000);
	fr(i,1,n) {
		if(i!=n)
			c[i]=readIntSp(0,k);
		else
			c[i]=readIntLn(0,k);
	}
	rep(i,1,n) {
		int u=readIntSp(1,n),v=readIntLn(1,n);
		assert(u!=v);
		gra[u].pb(v);
		gra[v].pb(u);
	}
	cntr=0;
	tree_check(1,1);
	assert(cntr==n);
	vector<vi> cnts;
	fr(i,1,n)
		if(vis[i]==0&&c[i]) {
			cnts.pb(vi(k,0LL));
			dfs(cnts.back(),i,i);
		}
	dp[0]=dp2[0]=1;
	for(auto &i: cnts) {
		rep(j,0,k)
			if(i[j])
				rep(l,0,1<<k)
					if((l>>j)&1)
						dp2[l]=(dp2[l]+dp[l^(1<<j)]*i[j])%mod;
		memcpy(dp,dp2,sizeof(int)*(1<<k));
	}
	cout<<dp[(1<<k)-1]<<endl;
}


signed main() {
	ios_base::sync_with_stdio(0),cin.tie(0);
	srand(chrono::high_resolution_clock::now().time_since_epoch().count());
	cout<<fixed<<setprecision(7);
	int t=readIntLn(1,10);
//	int t;
//	cin>>t;
	fr(i,1,t)
		solve();
	assert(getchar()==EOF);
#ifdef rd
	cerr<<endl<<endl<<endl<<"Time Elapsed: "<<((double)(clock()-clk))/CLOCKS_PER_SEC<<endl;
#endif
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class KTTREE{
	//SOLUTION BEGIN
	long MOD = 998244353;
	void pre() throws Exception{}
	void solve(int TC) throws Exception{
	    int N = ni(), K = ni();
	    int[] col = new int[N];
	    for(int i = 0; i< N; i++)col[i] = ni();
	    int[] from = new int[N-1], to = new int[N-1];
	    int[] set = java.util.stream.IntStream.range(0, N).toArray();
	    for(int i = 0; i< N-1; i++){
	        from[i] = ni()-1;
	        to[i] = ni()-1;
	        if((col[from[i]] == 0) == (col[to[i]] == 0)){
	            //merging
	            set[find(set, from[i])] = find(set, to[i]);
	        }
	    }

	    //relabeling
	    int cnt = 0;
	    int[] map = new int[N];
	    for(int i = 0; i< N; i++)if(find(set, i) == i)map[i] = cnt++;

	    int[][] count = new int[cnt][1+K];
	    for(int i = 0; i< N; i++)count[map[find(set, i)]][col[i]]++;

	    long[][] ways = new long[1+cnt][1<<K];
	    ways[0][0] = 1;
	    for(int i = 0; i< cnt; i++){
	        for(int mask = 0; mask < 1<<K; mask++)ways[i+1][mask] = ways[i][mask];
	        for(int color = 1; color <= K; color++){
	            if(count[i][color] > 0){
	                int cur = 1<<(color-1);
	                long way = count[i][color];
	                for(int mask = 0; mask < 1<<K; mask++){
	                    if((mask&cur) == 0){
	                        ways[i+1][mask|cur] += ways[i][mask]*way%MOD;
	                        if(ways[i+1][mask|cur] >= MOD)ways[i+1][mask|cur] -= MOD;
	                    }
	                }
	            }
	        }
	    }
	    pn(ways[cnt][(1<<K)-1]);
	}
	int find(int[] set, int u){return set[u] = set[u] == u?u:find(set, set[u]);}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new KTTREE().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

12 Likes

Woahhhh… this is such a cool approach. Never seen splitting a tree into components before.

6 Likes

Very nice problem… :+1:
Can anyone please explain the last trick part of the editorial a bit more…

2 Likes

Can someone please help why is this solution giving RE for some subtasks.
Submission Link

Can somebody clear out the last optimization for me… I could not understand how the complexity dropped down from N * K * 2^ K to N * 2 ^ K.

3 Likes

The N * K essentially means ‘for each of the C (C <= N) components, iterate over each of the K colors, include this color (update the mask), and check in the next component’. However, if you do this cleverly, you observe that the total houses of all colors in all components (the C * K factor) cannot be greater than N. Hence, you only need to skip colors with 0 frequency in each component to ensure that you don’t iterate more than N times.

A much better simplification would be to disregard maintaining the frequency, but instead only keep track of which node lies in which component, iterate over these nodes (instead of colors) for every component, update the mask and add to the answer accordingly. It probably works a bit slower compared to the editorial’s solution, but passes with ease. My Solution

@taran_1407 Could you explain the “last trick” in more detail? I guess others also do not understand it very well.

Sorry for delay, saw this now.

The last trick is that when we iterate over K components, and for each component, we iterate over colors within that component, we can see that if for a component x, some color c is not present in component x at all, then f_{x, c} it’s transitions do not affect any state at all. So, for each component, iterate only over colors which have atleast one node.

This way, summing over all components, there would be atmost N iterations.