PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: Sachin Deb
Testers: Harris Leung
Editorialist: Nishank Suresh
DIFFICULTY:
1898
PREREQUISITES:
Depth-first search
PROBLEM:
You are given a tree T and an integer C. Count the number of ways to color the vertices of T with colors 1, 2, \ldots, C such that no two vertices with a distance of \leq 2 have the same color.
EXPLANATION:
First, root the tree at some node, say 1. Let us try to color the tree in top-down fashion, i.e, from the root down to the leaves.
Look at some vertex u. What restriction do we have on its possible choice of colors, in relation to vertices that have been colored already?
Answer
Let p be the parent of u, and let g be the parent of p. For now, assume p and g both exist.
p and g have been colored already, and so u cannot have the same color as either p or g. p and g must also have distinct colors, so we are left with C-2 choices.
Further, u also cannot have the same color as some other vertex v that is a child of p and has already been colored. Note that each such vertex will also have a distinct color.
So, suppose s children of p have been colored already. Then, there are C-s-2 choices for the color of u.
Note that when either p or g (or both) don’t exist, the number of choices becomes C-s-1 or C-s respectively: make sure to not forget those cases.
Once we know the number of choices for each vertex u, the final answer is simply their product.
Implementing this is relatively simple, and can be done with a single DFS. Checking whether p and g exist is straightforward, and computing s is also trivial since we know how many children of p we have processed already.
TIME COMPLEXITY
\mathcal{O}(N) per test case.
CODE:
Setter's code (C++)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair <int, int> pii;
#define ALL(a) a.begin(), a.end()
#define FastIO ios::sync_with_stdio(false); cin.tie(0);cout.tie(0)
#define IN freopen("input.txt","r+",stdin)
#define OUT freopen("output.txt","w+",stdout)
#define DBG(a) cerr<< "line "<<__LINE__ <<" : "<< #a <<" --> "<<(a)<<endl
#define NL cerr<<endl
template < class T1,class T2>
ostream &operator <<(ostream &os,const pair < T1,T2 > &p)
{
os<<"{"<<p.first<<","<<p.second<<"}";
return os;
}
long long bigmod ( long long a, long long p, long long m )
{
long long res = 1;
long long x = a;
while ( p )
{
if ( p & 1 ) //p is odd
{
res = ( res * x ) % m;
}
x = ( x * x ) % m;
p = p >> 1;
}
return res;
}
const int N=1e6+1;
const ll oo=1e9+7;
vector<int> g[N];
ll fact[N];
ll inv_fact[N];
void init()
{
fact[0]=1;
for(ll i=1;i<N;i++)
fact[i]=(fact[i-1]*i)%oo;
inv_fact[N-1] = bigmod(fact[N-1],oo-2,oo);
for(ll i=N-2;i>=0;i--)
inv_fact[i]=(inv_fact[i+1]*(i+1))%oo;
}
ll ncr(int n,int r)
{
if(r>n) return 0;
return fact[n]*inv_fact[r]%oo*inv_fact[n-r]%oo;
}
int n,c;
ll dfs(int u,int p)
{
int x = 0;
ll ret=1;
for(int v: g[u])
{
if(v==p) continue;
x++;
ret=(ret*dfs(v,u))%oo;
}
ret = (ret*ncr(c-2,x))%oo * fact[x] % oo;
return ret;
}
int32_t main()
{
FastIO;
cin>>n>>c;
for(int i=1;i<n;i++)
{
int u,v;
cin>>u>>v;
g[u].push_back(v);
g[v].push_back(u);
}
init();
ll ans = c;
for(int v: g[1])
{
ans = ans * dfs(v,1)%oo;
}
ans = ans * ncr(c-1,g[1].size()) % oo * fact[g[1].size()] % oo;
cout<<ans<<"\n";
}
Tester's code (C++)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define fi first
#define se second
const ll mod=1e9+7;
const int N=2e6+1;
ll n,k;
vector<int>adj[N];
ll ans=1;
void dfs(int id,int p,int hp){
ans=ans*(k-hp)%mod;
int shp=1+(p!=0);
for(auto c:adj[id]){
if(c==p) continue;
dfs(c,id,shp);
shp++;
}
}
void solve(){
cin >> n >> k;
for(int i=1; i<n ;i++){
int u,v;cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
dfs(1,0,0);
cout << ans << '\n';
}
int main(){
ios::sync_with_stdio(false);cin.tie(0);
solve();
}
Editorialist's code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());
int main()
{
ios::sync_with_stdio(false); cin.tie(0);
int n, c; cin >> n >> c;
vector<vector<int>> adj(n);
for (int i = 0; i < n-1; ++i) {
int u, v; cin >> u >> v;
adj[--u].push_back(--v);
adj[v].push_back(u);
}
const int mod = 1e9 + 7;
int ans = c;
auto dfs = [&] (const auto &self, int u, int p) -> void {
int poss = c-2 + (u == 0);
for (int v : adj[u]) {
if (v == p) continue;
ans = (1LL * ans * poss)%mod;
--poss;
self(self, v, u);
}
};
dfs(dfs, 0, 0);
cout << ans << '\n';
}