PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Setter: Yogesh Sharma
Tester: Rahul Dugar
Editorialist: Taranpreet Singh
DIFFICULTY
Medium
PREREQUISITES
Bitmasking, Dynamic programming, Disjoint Set Union
PROBLEM
Given a tree with N nodes where each node is colored from 0 to K, find the number of ways to choose a tuple of K nodes where i-th node in tuple has color i and for every pair of nodes in the tuple, there exists at least one node with color 0 on the path between nodes in pair.
QUICK EXPLANATION
- Dividing the whole tree into separate components by 0-colored nodes, we can select at most one node from any component. Also, we should select exactly one node of each color.
- We can use a bitmask to represent colors of already selected nodes after considering the first i components. For each component, we have only (1+K) choices, either select no node from this component or select exactly one, which can be simulated by dynamic programming.
EXPLANATION
The first subtask is just brute force, so ignoring it.
The second subtask has N = 2000 and K = 2. meaning we need to count the number of valid pairs, such that thereās at least one node with color 0 on the path.
To check this efficiently, we can form components, divided by node colored 0. This way, all pairs of nodes within the same component do not have 0 on the path between them, while all pairs having nodes in different components have at least one node colored 0 on the path connecting them. Hence, we just merge non-zero nodes directly connected by edges using DSU and while checking pairs, assert that they belong to different pairs.
Towards complete Solution
The idea of splitting the tree by nodes colored 0 is highly useful, as it allows us to quickly check whether thereās a node colored 0 on the path between two nodes. But It can help even more. Since two nodes in the same component will not have node colored 0 on the path between them, we cannot select more than one node from each component.
The significance of the above statement is that for each component, we can compute the frequency of nodes with each color, and treat this problem as a subset selection problem.
Letās denote f_{i, c} denote the number of nodes with color c in i-th component. Hence, if we want to select c-colored node from i-th component, there are f_{c, i} ways to do so.
Now, since weād like to maintain information on which colored nodes are already selected and also, K is quite small, so it suggests using bitmask with K bits representing which colored nodes are already selected.
Hence, we can now maintain the number of ways to select nodes represented by mask from the first x components by ways_{x, mask}. For each component, we may select no node, or exactly one. Working out all cases, we get the following recurrence.
\displaystyle \text{ways}_{x, mask} = \text{ways}_{x-1, mask} + \sum_{c \in S} \text{ways}_{x-1, mask \oplus 2^c} * \text{f}_{x, c} where set S denotes the indices of bits set in mask
Since the number of components is of the order N, there are 2^K masks, leading to N*2^K states and O(K) time needed to compute each state, the time complexity of this approach becomes O(N*K*2^K) which is sufficient for all except the last subtask.
One last trick is to notice that the sum of nodes across all components is N, so if we iterate over mask if and only if f_{i, c} is non-zero, there cannot be more than min(K, size) iterations over all masks where size is the size of that component. Summing over all component, this gives us N*2^K operations, just by skipping mask updates when f_{x, c} is zero.
TIME COMPLEXITY
The time complexity is O(N*2^K) per test case.
The memory complexity is O(N*2^K) per test case.
SOLUTIONS
Setter's Solution
#pragma GCC optimize("Ofast")
#include <bits/stdc++.h>
using namespace std;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/rope>
using namespace __gnu_pbds;
using namespace __gnu_cxx;
#ifndef rd
#define trace(...)
//#define endl '\n'
#endif
#define pb push_back
#define fi first
#define se second
#define int long long
typedef long long ll;
typedef long double f80;
#define double long double
#define pii pair<int,int>
#define pll pair<ll,ll>
#define sz(x) ((long long)x.size())
#define fr(a,b,c) for(int a=b; a<=(c); a++)
#define rep(a,b,c) for(int a=b; a<(c); a++)
#define trav(a,x) for(auto &a:x)
#define all(con) con.begin(),con.end()
const ll infl=0x3f3f3f3f3f3f3f3fLL;
const int infi=0x3f3f3f3f;
const int mod=998244353;
//const int mod=1000000007;
typedef vector<int> vi;
typedef vector<ll> vl;
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> oset;
auto clk=clock();
mt19937_64 rang(chrono::high_resolution_clock::now().time_since_epoch().count());
int rng(int lim) {
uniform_int_distribution<int> uid(0,lim-1);
return uid(rang);
}
int powm(int a, int b) {
int res=1;
while(b) {
if(b&1)
res=(res*a)%mod;
a=(a*a)%mod;
b>>=1;
}
return res;
}
long long readInt(long long l, long long r, char endd) {
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true) {
char g=getchar();
if(g=='-') {
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g&&g<='9') {
x*=10;
x+=g-'0';
if(cnt==0) {
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd) {
if(is_neg) {
x=-x;
}
assert(l<=x&&x<=r);
return x;
} else {
cout<<ll(g)<<" "<<g<<endl;
assert(false);
}
}
}
string readString(int l, int r, char endd) {
string ret="";
int cnt=0;
while(true) {
char g=getchar();
assert(g!=-1);
if(g==endd) {
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt&&cnt<=r);
return ret;
}
long long readIntSp(long long l, long long r) {
return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
return readString(l,r,' ');
}
int sum_n=0;
int subtask_n=100000,subtask_k=11;
vi gra[100005];
bool vis[100005];
int c[100005];
void dfs(vi &cnt, int fr, int at) {
vis[at]=1;
cnt[c[at]-1]++;
for(int i:gra[at])
if(i!=fr&&c[i])
dfs(cnt, at,i);
}
int cntr=0;
void tree_check(int fr, int at) {
cntr++;
for(int i:gra[at])
if(i!=fr)
tree_check(at,i);
}
int dp[1<<11],dp2[1<<11];
void solve() {
int n=readIntSp(1,subtask_n),k=readIntLn(2,subtask_k);
// int n,k;
// cin>>n>>k;
fr(i,1,n) {
gra[i].clear();
vis[i]=0;
}
memset(dp,0,sizeof(int)*(1<<k));
memset(dp2,0,sizeof(int)*(1<<k));
sum_n+=n;
assert(sum_n<=100000);
fr(i,1,n) {
// cin>>c[i];
if(i!=n)
c[i]=readIntSp(0,k);
else
c[i]=readIntLn(0,k);
}
rep(i,1,n) {
// int u,v;
// cin>>u>>v;
int u=readIntSp(1,n),v=readIntLn(1,n);
assert(u!=v);
gra[u].pb(v);
gra[v].pb(u);
}
cntr=0;
tree_check(1,1);
assert(cntr==n);
vector<vi> cnts;
fr(i,1,n)
if(vis[i]==0&&c[i]) {
cnts.pb(vi(k,0LL));
dfs(cnts.back(),i,i);
}
dp[0]=dp2[0]=1;
for(auto &i: cnts) {
rep(j,0,k)
if(i[j])
rep(l,0,1<<k)
if((l>>j)&1)
dp2[l]=(dp2[l]+dp[l^(1<<j)]*i[j])%mod;
memcpy(dp,dp2,sizeof(int)*(1<<k));
}
cout<<dp[(1<<k)-1]<<endl;
}
signed main() {
ios_base::sync_with_stdio(0),cin.tie(0);
srand(chrono::high_resolution_clock::now().time_since_epoch().count());
cout<<fixed<<setprecision(7);
cerr<<100<<endl;
int t=readIntLn(1,10);
cerr<<t<<endl;
// int t;
// cin>>t;
fr(i,1,t)
solve();
assert(getchar()==EOF);
#ifdef rd
cerr<<endl<<endl<<endl<<"Time Elapsed: "<<((double)(clock()-clk))/CLOCKS_PER_SEC<<endl;
#endif
}
Tester's Solution
#pragma GCC optimize("Ofast")
#include <bits/stdc++.h>
using namespace std;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/rope>
using namespace __gnu_pbds;
using namespace __gnu_cxx;
#ifndef rd
#define trace(...)
#define endl '\n'
#endif
#define pb push_back
#define fi first
#define se second
#define int long long
typedef long long ll;
typedef long double f80;
#define double long double
#define pii pair<int,int>
#define pll pair<ll,ll>
#define sz(x) ((long long)x.size())
#define fr(a,b,c) for(int a=b; a<=(c); a++)
#define rep(a,b,c) for(int a=b; a<(c); a++)
#define trav(a,x) for(auto &a:x)
#define all(con) con.begin(),con.end()
const ll infl=0x3f3f3f3f3f3f3f3fLL;
const int infi=0x3f3f3f3f;
const int mod=998244353;
//const int mod=1000000007;
typedef vector<int> vi;
typedef vector<ll> vl;
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> oset;
auto clk=clock();
mt19937_64 rang(chrono::high_resolution_clock::now().time_since_epoch().count());
int rng(int lim) {
uniform_int_distribution<int> uid(0,lim-1);
return uid(rang);
}
int powm(int a, int b) {
int res=1;
while(b) {
if(b&1)
res=(res*a)%mod;
a=(a*a)%mod;
b>>=1;
}
return res;
}
long long readInt(long long l, long long r, char endd) {
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true) {
char g=getchar();
if(g=='-') {
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g&&g<='9') {
x*=10;
x+=g-'0';
if(cnt==0) {
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd) {
if(is_neg) {
x=-x;
}
assert(l<=x&&x<=r);
return x;
} else {
assert(false);
}
}
}
string readString(int l, int r, char endd) {
string ret="";
int cnt=0;
while(true) {
char g=getchar();
assert(g!=-1);
if(g==endd) {
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt&&cnt<=r);
return ret;
}
long long readIntSp(long long l, long long r) {
return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
return readString(l,r,' ');
}
int sum_n=0;
int subtask_n=100000,subtask_k=11;
vi gra[100005];
bool vis[100005];
int c[100005];
void dfs(vi &cnt, int fr, int at) {
vis[at]=1;
cnt[c[at]-1]++;
for(int i:gra[at])
if(i!=fr&&c[i])
dfs(cnt, at,i);
}
int cntr=0;
void tree_check(int fr, int at) {
cntr++;
for(int i:gra[at])
if(i!=fr)
tree_check(at,i);
}
int dp[1<<11],dp2[1<<11];
void solve() {
int n=readIntSp(1,subtask_n),k=readIntLn(2,subtask_k);
fr(i,1,n) {
gra[i].clear();
vis[i]=0;
}
memset(dp,0,sizeof(int)*(1<<k));
memset(dp2,0,sizeof(int)*(1<<k));
sum_n+=n;
assert(sum_n<=100000);
fr(i,1,n) {
if(i!=n)
c[i]=readIntSp(0,k);
else
c[i]=readIntLn(0,k);
}
rep(i,1,n) {
int u=readIntSp(1,n),v=readIntLn(1,n);
assert(u!=v);
gra[u].pb(v);
gra[v].pb(u);
}
cntr=0;
tree_check(1,1);
assert(cntr==n);
vector<vi> cnts;
fr(i,1,n)
if(vis[i]==0&&c[i]) {
cnts.pb(vi(k,0LL));
dfs(cnts.back(),i,i);
}
dp[0]=dp2[0]=1;
for(auto &i: cnts) {
rep(j,0,k)
if(i[j])
rep(l,0,1<<k)
if((l>>j)&1)
dp2[l]=(dp2[l]+dp[l^(1<<j)]*i[j])%mod;
memcpy(dp,dp2,sizeof(int)*(1<<k));
}
cout<<dp[(1<<k)-1]<<endl;
}
signed main() {
ios_base::sync_with_stdio(0),cin.tie(0);
srand(chrono::high_resolution_clock::now().time_since_epoch().count());
cout<<fixed<<setprecision(7);
int t=readIntLn(1,10);
// int t;
// cin>>t;
fr(i,1,t)
solve();
assert(getchar()==EOF);
#ifdef rd
cerr<<endl<<endl<<endl<<"Time Elapsed: "<<((double)(clock()-clk))/CLOCKS_PER_SEC<<endl;
#endif
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class KTTREE{
//SOLUTION BEGIN
long MOD = 998244353;
void pre() throws Exception{}
void solve(int TC) throws Exception{
int N = ni(), K = ni();
int[] col = new int[N];
for(int i = 0; i< N; i++)col[i] = ni();
int[] from = new int[N-1], to = new int[N-1];
int[] set = java.util.stream.IntStream.range(0, N).toArray();
for(int i = 0; i< N-1; i++){
from[i] = ni()-1;
to[i] = ni()-1;
if((col[from[i]] == 0) == (col[to[i]] == 0)){
//merging
set[find(set, from[i])] = find(set, to[i]);
}
}
//relabeling
int cnt = 0;
int[] map = new int[N];
for(int i = 0; i< N; i++)if(find(set, i) == i)map[i] = cnt++;
int[][] count = new int[cnt][1+K];
for(int i = 0; i< N; i++)count[map[find(set, i)]][col[i]]++;
long[][] ways = new long[1+cnt][1<<K];
ways[0][0] = 1;
for(int i = 0; i< cnt; i++){
for(int mask = 0; mask < 1<<K; mask++)ways[i+1][mask] = ways[i][mask];
for(int color = 1; color <= K; color++){
if(count[i][color] > 0){
int cur = 1<<(color-1);
long way = count[i][color];
for(int mask = 0; mask < 1<<K; mask++){
if((mask&cur) == 0){
ways[i+1][mask|cur] += ways[i][mask]*way%MOD;
if(ways[i+1][mask|cur] >= MOD)ways[i+1][mask|cur] -= MOD;
}
}
}
}
}
pn(ways[cnt][(1<<K)-1]);
}
int find(int[] set, int u){return set[u] = set[u] == u?u:find(set, set[u]);}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new KTTREE().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Feel free to share your approach. Suggestions are welcomed as always.