DUDIST - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iceknight1093
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

Dynamic programming

PROBLEM:

You’re given a tree.
Count the number of ways of coloring the vertices of the tree either red or blue such that:

  • K vertices are colored red.
  • Each red vertex is at most D distance away from a blue vertex.
  • Each blue vertex is at most D distance away from a red vertex.

EXPLANATION:

The setup of the problem (and the constraints) scream dynamic programming, so naturally that’s what we shall try.

Suppose we run a DFS, and when we’re at a vertex u we want to finish coloring everything in the subtree of u.
Let us figure out what information we need from this coloring.

First, we certainly need the number of vertices colored red in this subtree; since we have a constraint on their overall count.

Next, we need to worry about the distance criterion.
Let’s call a vertex x satisfied if there exists another vertex y within the subtree having opposite color, such that \text{dist}(x, y) \le D.
A vertex that’s not satisfied is called unsatisfied.

Now, after we’ve colored the subtree of u, there might be some unsatisfied vertices, both red and blue.
Clearly, only the deepest unsatisfied vertex of each color matters - for example if the deepest unsatisfied red vertex becomes satisfied by some blue vertex outside the subtree, then every unsatisfied red vertex can become satisfied with that same blue vertex too, since they’ll have lower distances to it.

We also need to consider the case of unsatisfied vertices from outside the subtree becoming satisfied by vertices from within the subtree.
For this, note that by similar reasoning it’s enough to only store the minimum distance of a red/blue vertex from u in its subtree.


Now, everything we did above is indeed enough information to write a dynamic programming solution; where for each vertex we keep its count of red vertices, deepest unsatisfied red/blue vertex, and nearest red/blue vertex.

However, this will be too slow: we have five pieces of information for each vertex leading to \mathcal{O}(N^5) states; multiplied by the number of vertices that becomes \mathcal{O}(N^6) already - and that’s even without looking into the transitions which will add further complexity!

Our goal now is hence to reduce the amount of information we need to store.

To begin, one simple observation is that storing both the nearest red and blue vertices is a bit pointless - because one of them is guaranteed to be u itself, whichever color it’s given.
So, we can instead just store the color of u, and then the nearest opposite color vertex.
This reduces a factor of N.

Next, let’s look at the unsatisfied vertices.
Suppose u is colored red.
Then, observe that if there are any blue unsatisfied vertices in the subtree of u, the entire coloring must be invalid!
This is because if a blue vertex remains unsatisfied, to become satisfied it must end up being paired with an outside red vertex - but then the path to this outside vertex will pass through u, so the distance to u will be strictly shorter; contradicting the fact that the blue vertex was unsatisfied in the subtree.

Thus, if u is colored red, the only possible unsatisfied vertices in its subtree must also be red.
So, rather than deepest unsatisfied vertex for both colors, we only need to store the deepest unsatisfied vertex that’s of the same color as u.
(Note that we’re already storing the color of u.)
This reduces another factor of N.

We’re now down to storing \mathcal{O}(N^3) information per vertex.


However, it turns out we can do better!
The key here is to observe that we never need to care about both the deepest unsatisfied vertex and the nearest opposite-color vertex simultaneously - we only need to care about one of them.
More precisely: if there exists an unsatisfied vertex, we need to care about the deepest one; if there doesn’t exist one, we care about the nearest opposite-color vertex instead.

Proof

Suppose we’re looking at the subtree of u.
Without loss of generality, let u be colored red.

Let v be the deepest unsatisfied red vertex in the subtree of u, and define d_1 = \text{dist}(u, v).
Let w be the nearest blue vertex in the subtree of u, and define d_2 = \text{dist}(u, w).

First, note that since v is unsatisfied, d_1+d_2 \gt D must hold.
Otherwise, v could just be satisfied by w.

Now, v must be satisfied by some blue vertex outside u.
Let this outside blue vertex be x, and let d_3 = \text{dist}(u, x).
Then, d_1 + d_3 \le D must hold.

Now, consider some arbitrary red vertex y outside the subtree of u.
Suppose y is satisfied by w.
Then, the path from y to w passes through u, so we can instead replace it by the path from y to u and then the path from u to x (more precisely, this is a walk, but the actual path from y to x cannot be longer than this.)
Note that the walk y \to u \to x is strictly shorter than y \to u \to w, because u\to x is itself strictly shorter than u\to x by choice of x.

So, if w could satisfy y, so can x.
This means we don’t really need to care about the vertex w at all, as claimed.

However, taking this result as gospel for now, we’re down to \mathcal{O}(N^2) information per vertex, for \mathcal{O}(N^3) overall - which is a good place to be in, given the constraints.

Let’s now work on transitions.
We define dp(u, c, x, y) to be the number of ways to assign colors to the subtree of u such that:

  1. u is given the color c
  2. There are exactly x red vertices in the subtree of u.
  3. y denotes the following:
    • if y \ge 0, then there exists an unsatisfied vertex in the subtree; and the deepest such vertex is at a distance of y from u.
      We limit y to being no more than D-1 in this case; any higher would be invalid.
    • if y \lt 0, then all vertices are satisfied; and instead the nearest opposite-color vertex is at a distance of (-y) from u.
      We limit y to being at least -(D-1) in this case; any lower and it would be useless to satisfy other vertices.

For the transitions, let’s look at some vertex u, let v be a child of u, and suppose we’re merging v into u (to form a larger tree rooted at u).

Then, suppose we’re considering state (c_1, x_1, y_1) of u and (c_2, x_2, y_2) of v.

  • if c_1\ne c_2, note that every vertex from both sides will definitely be satisfied (given that we’re limiting unsatisfied distance to D-1.)
    Further, the nearest opposite-color of vertex will be a child of u, hence at distance 1.
    So, the new state will just be (c_1, x_1+x_2, 1).
  • if c_1 = c_2, a little more work needs to be done for merging.
    However, the casework isn’t too bad:
    • If both sides are unsatisfied, keep the larger one (don’t forget to add 1 to the distance from v).
    • If one is unsatisfied and the other is satisfied, try to satisfy and update the state appropriately depending on whether it succeeds or not.
    • If both are satisfied, keep the smaller one (again, after adding 1 to the distance from v.)

The exact details can be seen in the code below.


We do a total of \mathcal{O}(N^4) work in this merge, since we iterate through \mathcal{O}(N^2) states for each of u and v.
More precisely, it’s \mathcal{O}((N\cdot D)^2) since there are \mathcal{O}(N\cdot D) states; but D can be as large as N so we treat it as N for now.
Paired with doing this N-1 times, the overall complexity is \mathcal{O}(N^5).

Except, it’s not!

Observe that we can bound the count of “number of red vertices in the subtree” by the size of the subtree itself, and not just plain N.

So, the work done during merging is more along the lines of \mathcal{O}(D^2 \cdot \text{sub}[u]\cdot\text{sub}[v]) where \text{sub}[u] denotes the (current) size of the subtree of u.

It can be shown that the sum of \text{sub}[u]\cdot \text{sub}[v] across all edges (u, v) when implemented this way is in fact \mathcal{O}(N^2) - see point 7 of this blog for more detail.

So, the true complexity is really \mathcal{O}(N^4), which is fast enough already!


It’s in fact possible to further optimize this algorithm to run in \mathcal{O}(N^3) time with the help of prefix sums.
The key idea is that when merging states (c_1, x_1, y_1) and (c_2, x_2, y_2), they will be merged into (c_1, x_1+x_2, f(y_1, y_2)) for some function f.
Nicely enough, f(y_1, y_2) will only take the values y_1 or y_2\pm 1.
This allows us to fix all the parameters (c_1, c_2, x_1, x_2, y_1) and then compute the contribution of all y_2 in constant time by using some prefix sums.
This was not needed to get AC, though.

TIME COMPLEXITY:

\mathcal{O}(N^4) per testcase.

CODE:

Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

/**
 * Integers modulo p, where p is a prime
 * Source: Aeren (modified from tourist?)
 *         Modmul for 64-bit mod from kactl:ModMulLL
 * Works with p < 7.2e18 with x87 80-bit long double, and p < 2^52 ~ 4.5e12 with 64-bit
 */
template<typename T>
struct Z_p{
	using Type = typename decay<decltype(T::value)>::type;
	static vector<Type> MOD_INV;
	constexpr Z_p(): value(){ }
	template<typename U> Z_p(const U &x){ value = normalize(x); }
	template<typename U> static Type normalize(const U &x){
		Type v;
		if(-mod() <= x && x < mod()) v = static_cast<Type>(x);
		else v = static_cast<Type>(x % mod());
		if(v < 0) v += mod();
		return v;
	}
	const Type& operator()() const{ return value; }
	template<typename U> explicit operator U() const{ return static_cast<U>(value); }
	constexpr static Type mod(){ return T::value; }
	Z_p &operator+=(const Z_p &otr){ if((value += otr.value) >= mod()) value -= mod(); return *this; }
	Z_p &operator-=(const Z_p &otr){ if((value -= otr.value) < 0) value += mod(); return *this; }
	template<typename U> Z_p &operator+=(const U &otr){ return *this += Z_p(otr); }
	template<typename U> Z_p &operator-=(const U &otr){ return *this -= Z_p(otr); }
	Z_p &operator++(){ return *this += 1; }
	Z_p &operator--(){ return *this -= 1; }
	Z_p operator++(int){ Z_p result(*this); *this += 1; return result; }
	Z_p operator--(int){ Z_p result(*this); *this -= 1; return result; }
	Z_p operator-() const{ return Z_p(-value); }
	template<typename U = T>
	typename enable_if<is_same<typename Z_p<U>::Type, int>::value, Z_p>::type &operator*=(const Z_p& rhs){
		#ifdef _WIN32
		uint64_t x = static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value);
		uint32_t xh = static_cast<uint32_t>(x >> 32), xl = static_cast<uint32_t>(x), d, m;
		asm(
			"divl %4; \n\t"
			: "=a" (d), "=d" (m)
			: "d" (xh), "a" (xl), "r" (mod())
		);
		value = m;
		#else
		value = normalize(static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value));
		#endif
		return *this;
	}
	template<typename U = T>
	typename enable_if<is_same<typename Z_p<U>::Type, int64_t>::value, Z_p>::type &operator*=(const Z_p &rhs){
		uint64_t ret = static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value) - static_cast<uint64_t>(mod()) * static_cast<uint64_t>(1.L / static_cast<uint64_t>(mod()) * static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value));
		value = normalize(static_cast<int64_t>(ret + static_cast<uint64_t>(mod()) * (ret < 0) - static_cast<uint64_t>(mod()) * (ret >= static_cast<uint64_t>(mod()))));
		return *this;
	}
	template<typename U = T>
	typename enable_if<!is_integral<typename Z_p<U>::Type>::value, Z_p>::type &operator*=(const Z_p &rhs){
		value = normalize(value * rhs.value);
		return *this;
	}
	template<typename U>
	Z_p &operator^=(U e){
		if(e < 0) *this = 1 / *this, e = -e;
		Z_p res = 1;
		for(; e; *this *= *this, e >>= 1) if(e & 1) res *= *this;
		return *this = res;
	}
	template<typename U>
	Z_p operator^(U e) const{
		return Z_p(*this) ^= e;
	}
	Z_p &operator/=(const Z_p &otr){
		Type a = otr.value, m = mod(), u = 0, v = 1;
		if(a < (int)MOD_INV.size()) return *this *= MOD_INV[a];
		while(a){
			Type t = m / a;
			m -= t * a; swap(a, m);
			u -= t * v; swap(u, v);
		}
		assert(m == 1);
		return *this *= u;
	}
	template<typename U> friend const Z_p<U> &abs(const Z_p<U> &v){ return v; }
	Type value;
};
template<typename T> bool operator==(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value == rhs.value; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(const Z_p<T>& lhs, U rhs){ return lhs == Z_p<T>(rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) == rhs; }
template<typename T> bool operator!=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(const Z_p<T> &lhs, U rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(U lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T> bool operator<(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value < rhs.value; }
template<typename T> bool operator>(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value > rhs.value; }
template<typename T> bool operator<=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value <= rhs.value; }
template<typename T> bool operator>=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value >= rhs.value; }
template<typename T> Z_p<T> operator+(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(const Z_p<T> &lhs, U rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T> Z_p<T> operator-(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T> Z_p<T> operator*(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T> Z_p<T> operator/(const Z_p<T> &lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(const Z_p<T>& lhs, U rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(U lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T> istream &operator>>(istream &in, Z_p<T> &number){
	typename common_type<typename Z_p<T>::Type, int64_t>::type x;
	in >> x;
	number.value = Z_p<T>::normalize(x);
	return in;
}
template<typename T> ostream &operator<<(ostream &out, const Z_p<T> &number){ return out << number(); }

/*
using ModType = int;
struct VarMod{ static ModType value; };
ModType VarMod::value;
ModType &mod = VarMod::value;
using Zp = Z_p<VarMod>;
*/

// constexpr int mod = 1e9 + 7; // 1000000007
constexpr int mod = (119 << 23) + 1; // 998244353
// constexpr int mod = 1e9 + 9; // 1000000009
using Zp = Z_p<integral_constant<decay<decltype(mod)>::type, mod>>;

template<typename T> vector<typename Z_p<T>::Type> Z_p<T>::MOD_INV;
template<typename T = integral_constant<decay<decltype(mod)>::type, mod>>
void precalc_inverse(int SZ){
	auto &inv = Z_p<T>::MOD_INV;
	if(inv.empty()) inv.assign(2, 1);
	for(; inv.size() <= SZ; ) inv.push_back((mod - 1LL * mod / (int)inv.size() * inv[mod % (int)inv.size()]) % mod);
}

template<typename T>
vector<T> precalc_power(T base, int SZ){
	vector<T> res(SZ + 1, 1);
	for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * base;
	return res;
}

template<typename T>
vector<T> precalc_factorial(int SZ){
	vector<T> res(SZ + 1, 1); res[0] = 1;
	for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * i;
	return res;
}


int main()
{
    ios::sync_with_stdio(false); cin.tie(0);
	
    int t; cin >> t;
    while (t--) {
        int n, k, d; cin >> n >> k >> d;
        vector adj(n+1, vector<int>());
        for (int i = 0; i < n-1; ++i) {
            int u, v; cin >> u >> v;
            adj[u].push_back(v);
            adj[v].push_back(u);
        }

        vector dpr(n+1, vector(n+1, vector(2*d + 2, Zp(0))));
        vector dpb(n+1, vector(n+1, vector(2*d + 2, Zp(0))));
        // [0, d-1] -> furthest unmatched
        // [d, 2d] -> nearest opposite
        // nearest opposite >= d becomes d
        auto get_state = [&] (int p, int q) {
            if (p < d and q < d) {
                if (max(p, q+1) == d) return -1;
                return max(p, q+1);
            }
            if (p < d) {
                if (p+(q-d)+1 <= d) return min(2*d, q+1);
                else return p;
            }
            if (q < d) {
                if ((p-d)+q+1 <= d) return min(2*d, p);
                else if (q == d-1) return -1;
                else return q+1;
            }
            return min({p, q+1, 2*d});
        };

        vector subsz(n+1, 0);
        auto dfs = [&] (const auto &self, int u, int p) -> void {
            dpr[u][1][0] = 1;
            dpb[u][0][0] = 1;
            subsz[u] = 1;

            for (int v : adj[u]) if (v != p) {
                self(self, v, u);

                vector ndpr(subsz[u]+subsz[v]+1, vector(2*d + 2, Zp(0)));
                vector ndpb(subsz[u]+subsz[v]+1, vector(2*d + 2, Zp(0)));
                for (int x = 0; x <= subsz[u]; ++x) for (int p = 0; p <= 2*d; ++p) if (dpr[u][x][p] != 0 or dpb[u][x][p] != 0) {
                    for (int y = 0; y <= subsz[v]; ++y) for (int q = 0; q <= 2*d+1; ++q) {
                        // red blue
                        // no issues, nearest opposite becomes 1
                        ndpr[x+y][d+1] += dpr[u][x][p] * dpb[v][y][q];

                        // blue red is similar
                        ndpb[x+y][d+1] += dpb[u][x][p] * dpr[v][y][q];

                        // red red or blue blue
                        int st = get_state(p, q);
                        if (st != -1) {
                            ndpr[x+y][st] += dpr[u][x][p] * dpr[v][y][q];
                            ndpb[x+y][st] += dpb[u][x][p] * dpb[v][y][q];
                        }
                    }
                }

                swap(dpr[u], ndpr);
                swap(dpb[u], ndpb);
                subsz[u] += subsz[v];
            }
        };
        dfs(dfs, 1, 0);

        Zp ans = 0;
        for (int i = d; i <= 2*d; ++i) ans += dpr[1][k][i] + dpb[1][k][i];
        cout << ans << '\n';
    }
}