PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Author: Manan Grover
Tester: Danny Mittal
Editorialist: Nishank Suresh
DIFFICULTY:
Easy-medium
PREREQUISITES:
Observation, dynamic programming
PROBLEM:
Given a tree T and a query type Q, you are to color each of its vertices either black or white such that:
If w_i is the number of white vertices on the path from 1 to i, and b_i is the number of black vertices, \sum_{i = 1}^n |b_i - w_i| is as small as possible.
Let B be the number of black vertices and W be the number of white vertices in the tree.
Over all colorings satisfying the condition above,
- If Q = 1, maximize the value of |B - W|
- If Q = 2, minimize the value of |B - W|
QUICK EXPLANATION:
Let the distance of a vertex u from 1 be the number of vertices on the path from 1 to u.
Every vertex at an odd distance from 1 can be colored independently of the others, and that fixes the colors of its children.
With this information:
- The maximum difference between B and W is simply \displaystyle \sum_{v} |c(v) - 1| over all vertices v at odd distance from 1, where c(v) is the number of children of v.
- The minimum difference can be solved by converting it into a subset-sum dynamic programming problem.
EXPLANATION:
Let’s first try to see what a minimum-cost coloring of the tree looks like.
Define the level of a vertex v be the number of vertices on the path from 1 to v.
We will call a vertex odd if its level is odd, and even otherwise.
Note that any odd vertex v contributes at least 1 to the cost of the tree, because b_v and w_v differ by an odd number, whose absolute value cannot be 0.
Thus, the cost is at minimum the number of odd vertices in the tree.
It is then not too hard to see that the minimum cost will be exactly the number of odd vertices.
Proof
Color each odd vertex white and each even vertex black.
Under this coloring, each even vertex contributes 0 to the cost and each odd vertex contributes 1, proving our claim.
In particular, this tells us that in any minimum cost coloring, every even vertex must contribute 0 to the cost.
Let’s analyze this condition a bit more.
Suppose we arbitrarily choose a color for 1 - say, white.
Then, all the children of 1 have no choice but to be colored black.
However, there is no restriction on vertices at level 3 - each of them can independently colored black or white, but that color would force the color of their children.
Then, each level 5 vertexcan be independently, and their children are forced.
This process continues till each vertex is colored.
Notice that this essentially decomposes the tree into a bunch of pieces, each consisting of an odd vertex and its children.
Each piece can be independently colored, with a vertex v with k children contributing either 1 white and k black vertices, or 1 black and k white vertices.
Now we move on to solving the problem.
For convenience, let the odd vertices be v_1, v_2, \dots, v_k, and let c(v_i) denote the number of children of v_i.
Maximizing the value (Q = 1)
We want to maximize |B - W|. However, note that B, W\geq 0 and B + W = N.
W.l.o.g let B\geq W. Then, B - W is maximized when B is maximized, because that also minimizes W.
So let’s try to color as many vertices black as possible.
Looking at this in the context of our above decomposition, it’s easy to see that:
- If an odd vertex v_i has no children, color it black.
- If it does have children, color it white and its children black.
It’s easy to see that this maximizes the difference between B and W, because we quite simply cannot do any better.
The value of this coloring is seen to be
which is easily computed with a single dfs.
Minimizing the value (Q = 2)
Minimizing |B - W| means we would like to make them as close to each other as possible.
In terms of the decomposition of the tree, we see that we have k pairs \{(c(v_1), 1), (c(v_2), 1), \dots, (c(v_k), 1)\}.
A cost-minimizing coloring chooses exactly one element of each pair to color white, while the other element is colored black.
Thinking of it slightly differently, each pair contributes either c(v_i) - 1 or 1 - c(v_i) to the difference B - W.
So, we have k elements \{c(v_1)-1, c(v_2)-1, \dots, c(v_k)-1\}, and to each of these we assign a multiplier of either +1 or -1, and then sum them.
Since the multipliers are only +1 and -1, we might as well assume that every element in the set is non-negative, i.e, we work with C = \{|c(v_1)-1|, |c(v_2)-1|, \dots, |c(v_k)-1|\}.
Let the subset of elements to which we assign +1 be S.
What is the final difference B - W?
- Each element x to which we assign +1 contributes x.
The contribution of this part is then \displaystyle\sum_{x\in S} x - Each element to which we assign -1 contributes -x.
The contribution of this part is \displaystyle\sum_{x\in C\setminus S} -x
Putting them together, we get
Note that \displaystyle\sum_{x\in C} x is a constant, independent of our choice of S. Let it be M.
Our goal is to minimize the absolute value of the above expression, which is equivalent to making \displaystyle 2\sum_{x\in C\setminus S} x as close to M as possible.
S can be chosen arbitrarily by us, so we really just want to find a subset of C whose (doubled) sum is as close to M as possible.
This can be done in \mathcal{O}(k*M) by the classical dynamic programming approach to the subset-sum problem.
However, in our case both k and M can be \mathcal{O}(N), so this is \mathcal{O}(N^2) in the worst case - too slow.
Speeding up the dp
Note that this version of the subset-sum problem is special - we have non-negative values whose sum is not too large, and is bounded above by N for example.
This means that there are only \mathcal{O}(\sqrt{N}) distinct values to consider.
Why?
We can use this to speed up our solution to \mathcal{O}(N\sqrt{N}) in various ways.
Method 1 (Setter and Tester)
Suppose we have pairs (x_i, y_i), denoting that x_i appears y_i times in the set. Let there be K pairs in total. As noted above, K \leq \sqrt{2N}.
Define dp(i, s) to be the minimum number of copies of x_i needed to obtain a sum of s using the first i pairs (and -1 if it isn’t possible to achieve this sum at all).
Then, we have the following transitions:
- If dp((i-1), s) \neq -1, dp(i, s) = 0
- Else, if dp(i, s-x_i) \neq -1 and dp(i, s-x_i) < y_i, dp(i, s) = dp(i, s-x_i) + 1
- Else, dp(i, s) = -1
Once this is computed, iterate through all values s such that dp(K, s) \neq -1 and find the minimum value of |M - 2s|
Method 2 (Editorialist)
Once again, suppose we have K pairs (x_i, y_i). We will convert this to a usual subset-sum problem with not too many elements and then solve that.
Let A be a new array, initially empty.
For each pair (x_i, y_i),
- For each 2^k such that 2^k - 1 \leq y_i, add x_i\cdot 2^{k-1} to A.
- Let k be the smallest integer such that 2^k - 1 > y_i. Add x_i\cdot (y_i - (2^{k-1}-1)) to A.
The intuition here is that adding these values to A allows us to simulate picking x_i anywhere between 0 and y_i times via its binary representation and a little extra.
It’s easy to see that A has \mathcal{O}(\sqrt{N}\log(N)) elements.
Then we run the usual subset-sum dynamic programming on A and we are done.
Method 3?
Run the subset-sum dynamic programming using bitsets to obtain a constant factor speedup to the quadratic solution, which ends up being extremely fast in practice .
TIME COMPLEXITY:
\mathcal{O}(N\sqrt{N}) or \mathcal{O}(N\sqrt{N}\log{N}), depending on implementation
CODE:
Setter (C++)
#include <bits/stdc++.h>
using namespace std;
void dfs(int x, int pr, vector<int> tr[], int cur, vector<int> &v){
int cnt = 0;
for(int i = 0; i < tr[x].size(); i++){
int y = tr[x][i];
if(y != pr){
cnt++;
dfs(y, x, tr, cur + 1, v);
}
}
if(cnt == 1){
return;
}
if(cur % 2){
v.push_back(abs(cnt - 1));
}
}
int main(){
ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
int t;
cin>>t;
while(t--){
int n, k;
cin>>n>>k;
vector<int> tr[n + 1];
for(int i = 0; i < n - 1; i++){
int u, v;
cin>>u>>v;
if(u==v){
break;
}
tr[u].push_back(v);
tr[v].push_back(u);
}
vector<int> v;
dfs(1, 0, tr, 1, v);
int ans = 0;
if(k == 1){
for(int i = 0; i < v.size(); i++){
ans += v[i];
}
}else{
vector<int> a, b;
int sum = 0;
map<int, int> mpp;
for(int i = 0; i < v.size(); i++){
mpp[v[i]]++;
sum += v[i];
}
for(auto it : mpp){
a.push_back(it.first);
b.push_back(it.second);
}
int m = a.size();
int dp[m][sum + 1];
memset(dp, -1, sizeof(dp));
for(int i = 0; i < m; i++){
dp[i][0] = 0;
}
for(int i = 0; i < sum + 1; i++){
if(i % a[0] == 0 && i / a[0] <= b[0]){
dp[0][i] = i / a[0];
}
}
for(int i = 1; i < m; i++){
for(int j = 0; j < sum + 1; j++){
if(dp[i - 1][j] != -1){
dp[i][j] = 0;
continue;
}
if(j >= a[i]){
if(dp[i][j - a[i]] != -1 && dp[i][j - a[i]] + 1 <= b[i]){
dp[i][j] = dp[i][j - a[i]] + 1;
continue;
}
}
}
}
ans = sum;
for(int i = 0; i < sum + 1; i++){
if(dp[m - 1][i] != -1){
ans = min(ans, abs(sum - 2 * i));
}
}
}
cout<<ans<<"\n";
}
return 0;
}
Tester (Kotlin)
import java.io.BufferedInputStream
import java.util.*
import kotlin.math.abs
fun main(omkar: Array<String>) {
val jin = FastScanner()
var nSum = 0
repeat(jin.nextInt(1000)) {
val n = jin.nextInt(100000, false)
val q = jin.nextInt(1, 2)
nSum += n
if (nSum > 100000) {
throw InvalidInputException("constraint on sum n exceeded")
}
val adj = Array(n + 1) { mutableListOf<Int>() }
repeat(n - 1) {
val a = jin.nextInt(n, false)
val b = jin.nextInt(n)
if (a == b) {
throw InvalidInputException("edge from $a to itself")
}
adj[a].add(b)
adj[b].add(a)
}
val sign = IntArray(n + 1)
val stack = Stack<Int>()
sign[1] = 1
stack.push(1)
val freq = IntArray(n + 1)
while (stack.isNotEmpty()) {
val a = stack.pop()
var d = -1
for (b in adj[a]) {
if (sign[b] == 0) {
d++
sign[b] = -sign[a]
stack.push(b)
}
}
if (sign[a] == 1) {
d = abs(d)
freq[d]++
}
}
if ((1..n).any { sign[it] == 0 }) {
throw InvalidInputException("input does not form a tree, ${(1..n).find { sign[it] == 0 }!!} not reachable from 1")
}
val maxValue = (0..n).sumBy { d -> freq[d] * d }
if (q == 1) {
println(maxValue)
} else {
var dp = BooleanArray(n + 1)
dp[0] = true
for (d in 1..n) {
if (freq[d] != 0) {
val f = freq[d]
val newDP = BooleanArray(n + 1)
for (r in 0 until d) {
var amt = 0
for (x in r..n step d) {
if (dp[x]) {
amt++
}
if (x >= (f + 1) * d && dp[x - ((f + 1) * d)]) {
amt--
}
newDP[x] = amt > 0
}
}
dp = newDP
}
}
val minValue = (0..n).filter { dp[it] }.map { abs((2 * it) - maxValue) }.min()!!
println(minValue)
}
}
jin.endOfInput()
}
class InvalidInputException(message: String): Exception(message)
class FastScanner {
private val BS = 1 shl 16
private val NC = 0.toChar()
private val buf = ByteArray(BS)
private var bId = 0
private var size = 0
private var c = NC
private var `in`: BufferedInputStream? = null
private val validation: Boolean
constructor(validation: Boolean) {
this.validation = validation
`in` = BufferedInputStream(System.`in`, BS)
}
constructor() : this(true)
private val char: Char
private get() {
while (bId == size) {
size = try {
`in`!!.read(buf)
} catch (e: Exception) {
return NC
}
if (size == -1) return NC
bId = 0
}
return buf[bId++].toChar()
}
fun validationFail(message: String) {
if (validation) {
throw InvalidInputException(message)
}
}
fun endOfInput() {
if (char != NC) {
validationFail("excessive input")
}
if (validation) {
System.err.println("input validated")
}
}
fun nextInt(from: Int, to: Int, endsLine: Boolean = true) = nextLong(from.toLong(), to.toLong(), endsLine).toInt()
fun nextInt(to: Int, endsLine: Boolean = true) = nextInt(1, to, endsLine)
fun nextLong(endsLine: Boolean): Long {
var neg = false
c = char
if (c !in '0'..'9' && c != '-' && c != ' ' && c != '\n') {
validationFail("found character other than digit, negative sign, space, and newline, character code = ${c.toInt()}")
}
if (c == '-') {
neg = true
c = char
}
var res = 0L
while (c in '0'..'9') {
res = (res shl 3) + (res shl 1) + (c - '0').toLong()
c = char
}
if (endsLine) {
if (c != '\n') {
validationFail("found character other than newline, character code = ${c.toInt()}")
}
} else {
if (c != ' ') {
validationFail("found character other than space, character code = ${c.toInt()}")
}
}
return if (neg) -res else res
}
fun nextLong(from: Long, to: Long, endsLine: Boolean = true): Long {
val res = nextLong(endsLine)
if (res !in from..to) {
validationFail("$res not in range $from..$to")
}
return res
}
fun nextLong(to: Long, endsLine: Boolean = true) = nextLong(1L, to, endsLine)
}
Editorialist (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,mmx,avx,avx2")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());
int main()
{
ios::sync_with_stdio(0); cin.tie(0);
bitset<100005> B;
int t; cin >> t;
while (t--) {
int n, q; cin >> n >> q;
vector<vector<int>> adj(n);
for (int i = 0; i < n-1; ++i) {
int u, v; cin >> u >> v;
adj[--u].push_back(--v);
adj[v].push_back(u);
}
map<int, int> freq;
auto dfs = [&] (const auto &self, int u, int par, int level) -> void {
int childct = 0;
for (int v : adj[u]) {
if (v == par) continue;
++childct;
self(self, v, u, level^1);
}
if (level == 1) ++freq[childct];
};
dfs(dfs, 0, 0, 1);
if (q == 1) {
int ans = 0;
for (auto &[x, y] : freq)
ans += y*abs(x-1);
cout << ans << '\n';
continue;
}
B.reset();
B[0] = 1;
vector<int> v;
int M = 0;
for (auto &[x, y] : freq) {
int cur = 1, pw = 1;
M += y*abs(x-1);
while (cur <= y) {
v.push_back(abs(x-1)*pw);
pw *= 2;
cur += pw;
}
cur -= pw;
v.push_back((y - cur)*abs(x-1));
}
for (int x : v) {
B |= B << x;
}
int ans = n+1;
for (int i = 0; i <= M; ++i) {
if (B[i])
ans = min(ans, abs(M-2*i));
}
cout << ans << '\n';
}
}