PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: awoholo, able_joy_75
Tester: raysh07
Editorialist: iceknight1093
DIFFICULTY:
Medium
PREREQUISITES:
Dynamic programming, combinatorics
PROBLEM:
You’re given a tree on N vertices, and X colors.
Count the number of ways of coloring the tree such that there exists at least one pair of vertices having the same color and having the maximum distance across all pairs of vertices.
EXPLANATION:
A diameter of a tree is a longest path present in it.
A tree can have many diameters.
Our goal is to count the number of ways of coloring the tree such that at least one diameter has both its vertices given the same color.
First, let’s find any one diameter of the tree.
This can be done in linear time via the following well-known algorithm:
- Root the tree at some vertex v_1 and run a DFS to find the distances of all vertices from v_1.
Let v_2 be one vertex with maximum distance from v_1. - Next, root the tree at v_2 and run a DFS to find the distances of all vertices from v_2.
Let v_3 be one vertex with maximum distance from v_2. - v_2 and v_3 form the endpoints of one diameter.
(The proof or correctness is in the link above.)
After finding one diameter, we consider two cases depending on the parity of its length.
Let’s consider first the case of the diameter having an odd number of vertices, say 2d+1 vertices (and so 2d edges.)
In this case, there will be some midpoint vertex m of the diameter.
Now, it can be shown that any diameter of the tree will have m as its midpoint: because if two diameters do not have the same midpoint, it becomes possible to stitch together a longer path than these two, contradicting them being diameters.
Now, let’s root the tree at vertex m.
Let the children of m be c_1, \ldots, c_k.
Since any diameter must pass through m (in fact, have m as its midpoint), we can in fact characterize all possible diameters.
Specifically, the path between vertices u and v is a diameter if and only if:
- u and v are leaves,
- u must be in the subtree of some child c_i and v must be in the subtree of c_j, where i \ne j (so that the u-v path passes through m), and
- There are d edges on the m \to u and m \to v paths.
So, for each child c_i of m, let s_i denote the number of leaves present in its subtree that are at distance d from m.
These are the only relevant vertices we need to care about: the color of everything else is arbitrary.
We now have k “groups” of leaves.
Any diameter is formed by taking two leaves from two different groups.
So, a coloring is good if and only if there exist a pair of leaves from different groups that have the same color.
This is a bit hard to count, so we’ll instead count the opposite: the number of ways to assign colors such that no two leaves from different groups have the same color.
This can be done using dynamic programming.
First, we define an auxiliary array ways, where ways[x][p] denotes the number of ways of assigning exactly p distinct colors to a single group of x leaves such that every color is used.
Here, we assume that the colors are \{1, 2, \dots, p\}, and color i is used only after colors 1, 2, \ldots, i-1 have all been used.
This can be computed via the following recursion:
because we have two options: give the last vertex its own distinct color (in which case it must receive color p, and the remaining x-1 vertices need p-1 colors), or it doesn’t have its own color (in which case the other vertices already have p colors, and then there are p choices for the last one).
The base cases are ways[x][1] = 1 and ways[x][0] = 0 for all x \ge 1, along with ways[0][0] = 1 and ways[0][p] = 0 for p \ge 1.
This can be precomputed in quadratic time.
In fact, these are really just the Stirling numbers of the second kind.
Now, let’s define function dp(i, p) to be the number of ways of assigning colors to the leaves in the first i groups, such that they use exactly p colors in total amongst themselves and no two leaves from different groups share a color.
To compute this, let’s look at the leaves in the i-th group.
There are s_i of them.
Suppose we decide that they must receive y distinct colors.
Then,
- The previous groups must use p-y colors.
There are, by definition, dp(i-1, p-y) ways to do this. - For the current group, there are X - (p-y) colors remaining. We choose y among them, which can be done in \binom{X - p + y}{y} ways.
- There are ways[s_i][y] possible colorings for the leaves in this group.
- However, recall that when computing ways, we fixed an order for the colors to be used, as in 1, 2, 3, \ldots
Different orders will result in different colorings though, and there are y! possible ways to choose an order of the chosen y colors.
Putting everything together, we obtain
There are \mathcal{O}(N^2) states and linear work done in each, so this is \mathcal{O}(N^3) overall…except (after a small optimization) it is not!
Note that we only care about p \le s_1 + s_2 + \ldots + s_i since that’s the largest number of colors we can use.
Restricting to this, the amount of work we do is \mathcal{O}(\sum_{i=1}^k s_i \cdot (\sum_{j \le i} s_j))
This is bounded by \mathcal{O}((s_1 + \ldots + s_k)^2) and hence \mathcal{O}(N^2) since s_1 + \ldots + s_k \le N.
Thus, as long as we bound p correctly, our true work is only quadratic!
Once all this has been computed, the number of “bad” colorings is obtained by just summing up dp(k, p) across all choices of p, and multiplying it by X^{N-S} (where S = s_1 + \ldots + s_k) to account for the free vertices.
Finally, subtract this quantity from X^N to obtain the number of good colorings.
That handled the case where the diameter had an odd number of vertices.
Luckily, the even case is not too different!
For the even vertices case, rather than a single vertex, the center of the diameter will be two vertices connected by an edge.
The same argument holds here: every diameter will pass through this edge, so once again you can root here and do the same type of thing as above.
In fact, by inserting a new vertex into the middle of this central edge, the problem just reduces to the odd version entirely so you don’t even need to do all the work again!
(Just make sure to not count the newly inserted vertex as a free vertex.)
TIME COMPLEXITY:
\mathcal{O}(N^2) per testcase.
CODE:
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
const int mod = 998244353;
struct mint{
int x;
mint (){ x = 0;}
mint (int32_t xx){ x = xx % mod; if (x < 0) x += mod;}
mint (long long xx){ x = xx % mod; if (x < 0) x += mod;}
int val(){
return x;
}
mint &operator++(){
x++;
if (x == mod) x = 0;
return *this;
}
mint &operator--(){
if (x == 0) x = mod;
x--;
return *this;
}
mint operator++(int32_t){
mint result = *this;
++*this;
return result;
}
mint operator--(int32_t){
mint result = *this;
--*this;
return result;
}
mint& operator+=(const mint &b){
x += b.x;
if (x >= mod) x -= mod;
return *this;
}
mint& operator-=(const mint &b){
x -= b.x;
if (x < 0) x += mod;
return *this;
}
mint& operator*=(const mint &b){
long long z = x;
z *= b.x;
z %= mod;
x = (int)z;
return *this;
}
mint operator+() const {
return *this;
}
mint operator-() const {
return mint() - *this;
}
mint operator/=(const mint &b){
return *this = *this * b.inv();
}
mint power(long long n) const {
mint ok = *this, r = 1;
while (n){
if (n & 1){
r *= ok;
}
ok *= ok;
n >>= 1;
}
return r;
}
mint inv() const {
return power(mod - 2);
}
friend mint operator+(const mint& a, const mint& b){ return mint(a) += b;}
friend mint operator-(const mint& a, const mint& b){ return mint(a) -= b;}
friend mint operator*(const mint& a, const mint& b){ return mint(a) *= b;}
friend mint operator/(const mint& a, const mint& b){ return mint(a) /= b;}
friend bool operator==(const mint& a, const mint& b){ return a.x == b.x;}
friend bool operator!=(const mint& a, const mint& b){ return a.x != b.x;}
mint power(mint a, long long n){
return a.power(n);
}
friend ostream &operator<<(ostream &os, const mint &m) {
os << m.x;
return os;
}
explicit operator bool() const {
return x != 0;
}
};
// Remember to check MOD
struct factorials{
int n;
vector <mint> ff, iff;
factorials(int nn){
n = nn;
ff.resize(n + 1);
iff.resize(n + 1);
ff[0] = 1;
for (int i = 1; i <= n; i++){
ff[i] = ff[i - 1] * i;
}
iff[n] = ff[n].inv();
for (int i = n - 1; i >= 0; i--){
iff[i] = iff[i + 1] * (i + 1);
}
}
mint C(int n, int r){
if (n == r) return mint(1);
if (n < 0 || r < 0 || r > n) return mint(0);
return ff[n] * iff[r] * iff[n - r];
}
mint P(int n, int r){
if (n < 0 || r < 0 || r > n) return mint(0);
return ff[n] * iff[n - r];
}
mint solutions(int n, int r){
// Solutions to x1 + x2 + ... + xn = r, xi >= 0
return C(n + r - 1, n - 1);
}
mint catalan(int n){
return ff[2 * n] * iff[n] * iff[n + 1];
}
};
const int PRECOMP = 3e6 + 69;
factorials F(PRECOMP);
// REMEMBER To check MOD and PRECOMP
void Solve()
{
int n, m; cin >> n >> m;
vector<vector<int>> adj(n);
for (int i = 1; i < n; i++){
int u, v; cin >> u >> v;
u--; v--;
adj[u].push_back(v);
adj[v].push_back(u);
}
auto bfs = [&](int s){
vector <int> d(n, INF);
queue <int> q;
q.push(s);
d[s] = 0;
while (!q.empty()){
int u = q.front(); q.pop();
for (int v : adj[u]) if (d[v] == INF){
d[v] = d[u] + 1;
q.push(v);
}
}
return d;
};
auto d0 = bfs(0);
int v1 = max_element(d0.begin(), d0.end()) - d0.begin();
auto d1 = bfs(v1);
int v2 = max_element(d1.begin(), d1.end()) - d1.begin();
auto d2 = bfs(v2);
vector <int> d(n);
for (int i = 0; i < n; i++) d[i] = max(d1[i], d2[i]);
vector <int> sizes;
if (d1[v2] % 2 == 1){
// two centres
int m1 = -1, m2 = -1;
int g1 = d1[v2] / 2;
int g2 = g1 + 1;
for (int i = 0; i < n; i++){
if (d1[i] + d2[i] == d1[v2]){
if (d1[i] == g1) m1 = i;
if (d1[i] == g2) m2 = i;
}
}
int t1 = 0, t2 = 0;
auto d3 = bfs(m1);
auto d4 = bfs(m2);
for (int i = 0; i < n; i++){
if (d3[i] == d4[i] - 1 && d3[i] == g1){
t1++;
}
if (d4[i] == d3[i] - 1 && d4[i] == g1){
t2++;
}
}
sizes.push_back(t1);
sizes.push_back(t2);
} else {
int m = -1, g = d1[v2] / 2;
for (int i = 0; i < n; i++){
if (d1[i] + d2[i] == d1[v2]){
if (d1[i] == g) m = i;
}
}
auto d3 = bfs(m);
int c = 0;
auto dfs = [&](auto self, int u, int par) -> void{
c += d3[u] == g;
for (int v : adj[u]) if (v != par){
self(self, v, u);
}
};
for (int v : adj[m]){
c = 0;
dfs(dfs, v, m);
sizes.push_back(c);
}
}
vector <mint> dp(n + 1);
dp[0] = 1;
auto gen = [&](int s){
vector <mint> ans(s + 1);
if (s == 0){
ans[0] = 1;
return ans;
}
for (int i = 1; i <= s; i++){
// colour with exactly i colours
ans[i] = mint(i).power(s);
for (int j = 1; j < i; j++){
ans[i] -= ans[j] * F.C(i, j);
}
}
return ans;
};
int tot = 0;
for (int s : sizes){
// color these s with some distinct colours
auto arr = gen(s);
vector <mint> ndp(n + 1);
for (int i = 0; i <= tot; i++){
for (int j = 0; j <= s; j++){
ndp[i + j] += F.C(i + j, j) * dp[i] * arr[j];
}
}
tot += s;
dp = ndp;
}
auto choose = [&](int n, int r){
// r is small
mint ans = 1;
for (int i = 1; i <= r; i++){
ans *= n + 1 - i;
// ans /= i;
ans *= F.ff[i - 1];
ans *= F.iff[i];
}
return ans;
};
mint bad = 0;
for (int i = 1; i <= n; i++){
bad += choose(m, i) * dp[i];
}
bad *= mint(m).power(n - tot);
mint ans = mint(m).power(n) - bad;
cout << ans << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}