PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: gunpoint_88
Tester: apoorv_me
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
DP on trees, basic combinatorics
PROBLEM:
You’re given a tree on N vertices, rooted at 1. Each vertex should be colored either red or blue.
How many colorings exist such that in every subtree, the difference between the number of red and blue vertices is at most 1?
EXPLANATION:
We have a counting problem on a tree, with conditions that every subtree needs to satisfy — which screams at us to try dynamic programming.
Let dp(u, k) denote the number of ways of coloring the subtree of u such that the number of red vertices minus the number of blue vertices is exactly k.
Note that k is allowed to be negative here.
Since we only care about the differences being \leq 1, it’s enough to consider -1 \leq k \leq 1.
Further, observe that dp(u, 1) = dp(u, -1) for every u. This arises from the fact that given a coloring with one extra red vertex, we can flip the colors of all vertices to obtain a coloring with one extra blue vertex; and vice versa.
So, it’s enough to consider the states k = 0 and k = 1.
We now have 2N states, which is small enough. Let’s attempt to actually compute their values.
Let \text{sub}[u] be either 0 or 1, denoting the parity of the number of vertices in the subtree of u.
This is easily computed with a DFS.
Observe that if \text{sub}[u] = 0, the only possibility is to have a difference of 0, i.e, dp(u, 1) = 0 will hold. This is because the difference between the number of red and blue vertices will itself be even, and being \leq 1 means it can only be 0.
Conversely, if \text{sub}[u] = 1, dp(u, 0) = 0 will be true.
(This also means it’s enough to use N states rather than 2N).
Consider a vertex u. Let its children be v_1, v_2, \ldots v_x and w_1, w_2, \ldots, w_y; where \text{sub}[v_i] = 0 for every i and \text{sub}[w_i] = 1 for every i.
Observe that all the subtrees corresponding to v_i must contain an equal number of red and blue vertices; so they don’t affect the difference at all.
So, any valid arrangement within those subtrees can be chosen, with the total number of ways being the product of all dp(v_i, 0).
Next, let’s look at the w_i, the vertices with odd subtree size. There are two cases.
- Suppose y = 2m+1, i.e, there’s an odd number of odd children.
Then, \text{sub}[u] = 0 (since u itself counts too), so we compute dp(u, 0).
How do we get a difference of 0?
Answer
The color of u itself contributes either 1 or -1 to the difference, so the odd children together should contribute -1 or +1, respectively.
That is, of the 2m+1 odd children, either m of them contribute +1 and the rest -1, or vice versa (note that making this choice uniquely fixes the color of u as well).
As noted earlier, dp(w_i, 1) = dp(w_i, -1) for everything, so after choosing which vertices are +1/-1, the contribution is the product of all dp(w_i, 1).
This is independent of the actual choice of m vertices, so the overall contribution is
2\cdot \binom{2m+1}{m} \cdot \prod dp(w_i, 1)
- Suppose y = 2m, i.e, there’s an even number of odd children.
Then, \text{sub}[u] = 1, so we compute dp(u, 1).
How do we get a difference of 1?
Answer
This can be done just as the previous case was.
The odd children must give a total difference of either 0 or +2, which then uniquely determines the color of u.
A difference of 0 can be obtained by choosing an equal number of +1's and -1's, so there are \binom{2m}{m} ways.
A difference of 2 requires choosing 2 more +1's then -1's, which can be done in \binom{2m}{m+1} ways.
So, the answer is
(\binom{2m}{m} + \binom{2m}{m+1}) \cdot \prod dp(w_i, 1)
Binomial coefficients can be computed quickly by precomputing factorials, after which our solution runs in \mathcal{O}(N).
TIME COMPLEXITY:
\mathcal{O}(N) per testcase.
CODE:
Author's code (C++)
#include<bits/stdc++.h>
using namespace std;
using ll=long long;
const ll mod=1e9+7;
#ifdef ANI
#include "D:/DUSTBIN/local_inc.h"
#else
#define dbg(...) 0
#endif
ll fastexp(ll base,ll exp) {
ll res=1;
base%=mod;
while(exp) {
if(exp&1) res=(res*base)%mod;
base=(base*base)%mod;
exp>>=1;
}
return res;
}
ll modinv(ll num) {
return fastexp(num, mod-2);
}
vector<ll> factorial,inv_factorial;
void compute_factorial(ll n=1e6) {
factorial=inv_factorial=vector<ll>(n+1,1);
for(ll i=2;i<n+1;i++) {
factorial[i]=factorial[i-1]*i%mod;
}
inv_factorial[n]=modinv(factorial[n]);
for(ll i=n-1;i>0;i--) {
inv_factorial[i]=inv_factorial[i+1]*(i+1)%mod;
}
}
ll ncr(ll n,ll r) {
if(r>n or r<0 or n<0) return 0;
return factorial[n]*((inv_factorial[r]*inv_factorial[n-r])%mod)%mod;
}
class Testcase{
public:
ll n,ans=-1;
vector<array<ll,2>> edges;
vector<vector<ll>> adj;
Testcase(vector<array<ll,2>> edges) {
// 0 indexed
this->n=edges.size()+1;
this->edges=edges;
this->adj=vector<vector<ll>>(n);
for(ll i=0;i<n-1;i++) {
adj[edges[i][0]].push_back(edges[i][1]);
adj[edges[i][1]].push_back(edges[i][0]);
}
}
Testcase(vector<vector<ll>> adj) {
this->adj=adj;
this->n=adj.size();
for(ll i=0;i<n;i++) {
for(ll j:adj[i]) {
if(i<j) {
(this->edges).push_back({i,j});
}
}
}
}
void write(ofstream &inp, ofstream &out) {
inp<<n<<"\n";
for(ll i=0;i<n-1;i++) {
inp<<edges[i][0]+1<<" "<<edges[i][1]+1<<"\n";
}
assert(this->ans != -1);
out<<(this->ans)<<"\n";
}
};
ll solution(Testcase &tc) {
auto e=tc.adj;
ll n=tc.n;
vector<ll> dp(n,0);
function<ll(ll,ll)> dfs=[&](ll cur,ll par)->ll{
ll p=1,c=1;
for(ll node:e[cur]) {
if(node==par) continue;
c+=dfs(node,cur);
p=(p*dp[node])%mod;
}
dp[cur]=(ncr(c,c/2)*p)%mod;
return c&1;
};
return tc.ans=(1+dfs(0,-1))*dp[0]%mod;
}
void main_() {
compute_factorial();
ll t; cin>>t;
assert(t<=1e4);
ll nsum=0;
while(t--) {
ll n; cin>>n; nsum+=n;
vector<array<ll,2>> tree;
for(ll i=0;i<n-1;i++) {
ll u,v; cin>>u>>v;
tree.push_back({u-1,v-1});
}
// dbg(tree);
auto tc_=Testcase(tree);
solution(tc_);
cout<<tc_.ans<<"\n";
}
assert(nsum<=3e5);
}
int main() {
main_();
}
Editorialist's code (Python)
mod = 10**9 + 7
def dfs(graph, start=0):
n = len(graph)
fac = [1]*(n+1)
for i in range(2, n+1): fac[i] = i*fac[i-1] % mod
inv = fac[:]
inv[-1] = pow(fac[-1], mod-2, mod)
for i in reversed(range(n)): inv[i] = inv[i+1] * (i+1) % mod
def C(n, r):
if n < r or r < 0: return 0
return fac[n] * inv[r] % mod * inv[n-r] % mod
dp = [[0, 0] for _ in range(n)]
subsz = [0]*n
visited, finished = [False] * n, [False] * n
stack = [start]
while stack:
start = stack[-1]
# push unvisited children into stack
if not visited[start]:
visited[start] = True
for child in graph[start]:
if not visited[child]:
stack.append(child)
else:
stack.pop()
# base case
subsz[start] = 1
prod, odds = 1, 0
# update with finished children
for child in graph[start]:
if finished[child]:
subsz[start] += subsz[child]
prod = prod*dp[child][subsz[child] % 2] % mod
odds += subsz[child]%2
finished[start] = True
if odds%2 == 0:
dp[start][1] = prod * (C(odds, odds//2) + C(odds, odds//2 + 1)) % mod
else:
dp[start][0] = 2 * prod * C(odds, odds//2) % mod
# print(dp)
return (dp[0][0] + 2*dp[0][1])%mod
for _ in range(int(input())):
n = int(input())
tree = [ [] for _ in range(n)]
for i in range(n-1):
u, v = map(int, input().split())
tree[u-1].append(v-1)
tree[v-1].append(u-1)
print(dfs(tree))