TREUPS - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3

Author: Manan Grover
Tester: Danny Mittal
Editorialist: Nishank Suresh

DIFFICULTY:

Easy-medium

PREREQUISITES:

Observation, dynamic programming

PROBLEM:

Given a tree T and a query type Q, you are to color each of its vertices either black or white such that:

If w_i is the number of white vertices on the path from 1 to i, and b_i is the number of black vertices, \sum_{i = 1}^n |b_i - w_i| is as small as possible.

Let B be the number of black vertices and W be the number of white vertices in the tree.
Over all colorings satisfying the condition above,

  • If Q = 1, maximize the value of |B - W|
  • If Q = 2, minimize the value of |B - W|

QUICK EXPLANATION:

Let the distance of a vertex u from 1 be the number of vertices on the path from 1 to u.
Every vertex at an odd distance from 1 can be colored independently of the others, and that fixes the colors of its children.
With this information:

  • The maximum difference between B and W is simply \displaystyle \sum_{v} |c(v) - 1| over all vertices v at odd distance from 1, where c(v) is the number of children of v.
  • The minimum difference can be solved by converting it into a subset-sum dynamic programming problem.

EXPLANATION:

Let’s first try to see what a minimum-cost coloring of the tree looks like.

Define the level of a vertex v be the number of vertices on the path from 1 to v.
We will call a vertex odd if its level is odd, and even otherwise.

Note that any odd vertex v contributes at least 1 to the cost of the tree, because b_v and w_v differ by an odd number, whose absolute value cannot be 0.
Thus, the cost is at minimum the number of odd vertices in the tree.
It is then not too hard to see that the minimum cost will be exactly the number of odd vertices.

Proof

Color each odd vertex white and each even vertex black.
Under this coloring, each even vertex contributes 0 to the cost and each odd vertex contributes 1, proving our claim.

In particular, this tells us that in any minimum cost coloring, every even vertex must contribute 0 to the cost.
Let’s analyze this condition a bit more.

Suppose we arbitrarily choose a color for 1 - say, white.
Then, all the children of 1 have no choice but to be colored black.
However, there is no restriction on vertices at level 3 - each of them can independently colored black or white, but that color would force the color of their children.
Then, each level 5 vertexcan be independently, and their children are forced.
This process continues till each vertex is colored.

Notice that this essentially decomposes the tree into a bunch of pieces, each consisting of an odd vertex and its children.
Each piece can be independently colored, with a vertex v with k children contributing either 1 white and k black vertices, or 1 black and k white vertices.
Now we move on to solving the problem.

For convenience, let the odd vertices be v_1, v_2, \dots, v_k, and let c(v_i) denote the number of children of v_i.

Maximizing the value (Q = 1)

We want to maximize |B - W|. However, note that B, W\geq 0 and B + W = N.
W.l.o.g let B\geq W. Then, B - W is maximized when B is maximized, because that also minimizes W.
So let’s try to color as many vertices black as possible.

Looking at this in the context of our above decomposition, it’s easy to see that:

  • If an odd vertex v_i has no children, color it black.
  • If it does have children, color it white and its children black.

It’s easy to see that this maximizes the difference between B and W, because we quite simply cannot do any better.
The value of this coloring is seen to be

\sum_{i = 1}^k |c(v_i) - 1|

which is easily computed with a single dfs.

Minimizing the value (Q = 2)

Minimizing |B - W| means we would like to make them as close to each other as possible.

In terms of the decomposition of the tree, we see that we have k pairs \{(c(v_1), 1), (c(v_2), 1), \dots, (c(v_k), 1)\}.
A cost-minimizing coloring chooses exactly one element of each pair to color white, while the other element is colored black.
Thinking of it slightly differently, each pair contributes either c(v_i) - 1 or 1 - c(v_i) to the difference B - W.

So, we have k elements \{c(v_1)-1, c(v_2)-1, \dots, c(v_k)-1\}, and to each of these we assign a multiplier of either +1 or -1, and then sum them.
Since the multipliers are only +1 and -1, we might as well assume that every element in the set is non-negative, i.e, we work with C = \{|c(v_1)-1|, |c(v_2)-1|, \dots, |c(v_k)-1|\}.
Let the subset of elements to which we assign +1 be S.
What is the final difference B - W?

  • Each element x to which we assign +1 contributes x.
    The contribution of this part is then \displaystyle\sum_{x\in S} x
  • Each element to which we assign -1 contributes -x.
    The contribution of this part is \displaystyle\sum_{x\in C\setminus S} -x

Putting them together, we get

\sum_{x\in S} x + \sum_{x\in C\setminus S} -x = \sum_{x\in S} x - \sum_{x\in C\setminus S} x \\ = \sum_{x\in S} x - \sum_{x\in C\setminus S} x + \sum_{x\in C\setminus S} x - \sum_{x\in C\setminus S} x \\ = \sum_{x\in C} x - 2\sum_{x\in C\setminus S} x

Note that \displaystyle\sum_{x\in C} x is a constant, independent of our choice of S. Let it be M.
Our goal is to minimize the absolute value of the above expression, which is equivalent to making \displaystyle 2\sum_{x\in C\setminus S} x as close to M as possible.
S can be chosen arbitrarily by us, so we really just want to find a subset of C whose (doubled) sum is as close to M as possible.

This can be done in \mathcal{O}(k*M) by the classical dynamic programming approach to the subset-sum problem.
However, in our case both k and M can be \mathcal{O}(N), so this is \mathcal{O}(N^2) in the worst case - too slow.

Speeding up the dp

Note that this version of the subset-sum problem is special - we have non-negative values whose sum is not too large, and is bounded above by N for example.
This means that there are only \mathcal{O}(\sqrt{N}) distinct values to consider.

Why?
0 + 1 + 2 + 3 + \dots + \sqrt{2N} = \frac{\sqrt{2N} \cdot (\sqrt{2N}+1)}{2} \geq N

We can use this to speed up our solution to \mathcal{O}(N\sqrt{N}) in various ways.

Method 1 (Setter and Tester)

Suppose we have pairs (x_i, y_i), denoting that x_i appears y_i times in the set. Let there be K pairs in total. As noted above, K \leq \sqrt{2N}.
Define dp(i, s) to be the minimum number of copies of x_i needed to obtain a sum of s using the first i pairs (and -1 if it isn’t possible to achieve this sum at all).
Then, we have the following transitions:

  • If dp((i-1), s) \neq -1, dp(i, s) = 0
  • Else, if dp(i, s-x_i) \neq -1 and dp(i, s-x_i) < y_i, dp(i, s) = dp(i, s-x_i) + 1
  • Else, dp(i, s) = -1

Once this is computed, iterate through all values s such that dp(K, s) \neq -1 and find the minimum value of |M - 2s|

Method 2 (Editorialist)

Once again, suppose we have K pairs (x_i, y_i). We will convert this to a usual subset-sum problem with not too many elements and then solve that.
Let A be a new array, initially empty.
For each pair (x_i, y_i),

  1. For each 2^k such that 2^k - 1 \leq y_i, add x_i\cdot 2^{k-1} to A.
  2. Let k be the smallest integer such that 2^k - 1 > y_i. Add x_i\cdot (y_i - (2^{k-1}-1)) to A.

The intuition here is that adding these values to A allows us to simulate picking x_i anywhere between 0 and y_i times via its binary representation and a little extra.
It’s easy to see that A has \mathcal{O}(\sqrt{N}\log(N)) elements.
Then we run the usual subset-sum dynamic programming on A and we are done.

Method 3?

Run the subset-sum dynamic programming using bitsets to obtain a constant factor speedup to the quadratic solution, which ends up being extremely fast in practice :slight_smile:.

TIME COMPLEXITY:

\mathcal{O}(N\sqrt{N}) or \mathcal{O}(N\sqrt{N}\log{N}), depending on implementation

CODE:

Setter (C++)
#include <bits/stdc++.h>
using namespace std;
void dfs(int x, int pr, vector<int> tr[], int cur, vector<int> &v){
  int cnt = 0;
  for(int i = 0; i < tr[x].size(); i++){
    int y = tr[x][i];
    if(y != pr){
      cnt++;
      dfs(y, x, tr, cur + 1, v);
    }
  }
  if(cnt == 1){
    return;
  }
  if(cur % 2){
    v.push_back(abs(cnt - 1));
  }
}
int main(){
  ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
  int t;
  cin>>t;
  while(t--){
    int n, k;
    cin>>n>>k;
    vector<int> tr[n + 1];
    for(int i = 0; i < n - 1; i++){
      int u, v;
      cin>>u>>v;
      if(u==v){
        break;
      }
      tr[u].push_back(v);
      tr[v].push_back(u);
    }
    vector<int> v;
    dfs(1, 0, tr, 1, v);
    int ans = 0;
    if(k == 1){
      for(int i = 0; i < v.size(); i++){
        ans += v[i];
      }
    }else{
      vector<int> a, b;
      int sum = 0;
      map<int, int> mpp;
      for(int i = 0; i < v.size(); i++){
        mpp[v[i]]++;
        sum += v[i];
      }
      for(auto it : mpp){
        a.push_back(it.first);
        b.push_back(it.second);
      }
      int m = a.size();
      int dp[m][sum + 1];
      memset(dp, -1, sizeof(dp));
      for(int i = 0; i < m; i++){
        dp[i][0] = 0;
      }
      for(int i = 0; i < sum + 1; i++){
        if(i % a[0] == 0 && i / a[0] <= b[0]){
          dp[0][i] = i / a[0];
        }
      }
      for(int i = 1; i < m; i++){
        for(int j = 0; j < sum + 1; j++){
          if(dp[i - 1][j] != -1){
            dp[i][j] = 0;
            continue;
          }
          if(j >= a[i]){
            if(dp[i][j - a[i]] != -1 && dp[i][j - a[i]] + 1 <= b[i]){
              dp[i][j] = dp[i][j - a[i]] + 1;
              continue;
            }
          }
        }
      }
      ans = sum;
      for(int i = 0; i < sum + 1; i++){
        if(dp[m - 1][i] != -1){
          ans = min(ans, abs(sum - 2 * i));
        }
      }
    }
    cout<<ans<<"\n";
  }
  return 0;
}
Tester (Kotlin)
import java.io.BufferedInputStream
import java.util.*
import kotlin.math.abs

fun main(omkar: Array<String>) {
    val jin = FastScanner()
    var nSum = 0
    repeat(jin.nextInt(1000)) {
        val n = jin.nextInt(100000, false)
        val q = jin.nextInt(1, 2)
        nSum += n
        if (nSum > 100000) {
            throw InvalidInputException("constraint on sum n exceeded")
        }
        val adj = Array(n + 1) { mutableListOf<Int>() }
        repeat(n - 1) {
            val a = jin.nextInt(n, false)
            val b = jin.nextInt(n)
            if (a == b) {
                throw InvalidInputException("edge from $a to itself")
            }
            adj[a].add(b)
            adj[b].add(a)
        }
        val sign = IntArray(n + 1)
        val stack = Stack<Int>()
        sign[1] = 1
        stack.push(1)
        val freq = IntArray(n + 1)
        while (stack.isNotEmpty()) {
            val a = stack.pop()
            var d = -1
            for (b in adj[a]) {
                if (sign[b] == 0) {
                    d++
                    sign[b] = -sign[a]
                    stack.push(b)
                }
            }
            if (sign[a] == 1) {
                d = abs(d)
                freq[d]++
            }
        }
        if ((1..n).any { sign[it] == 0 }) {
            throw InvalidInputException("input does not form a tree, ${(1..n).find { sign[it] == 0 }!!} not reachable from 1")
        }
        val maxValue = (0..n).sumBy { d -> freq[d] * d }
        if (q == 1) {
            println(maxValue)
        } else {
            var dp = BooleanArray(n + 1)
            dp[0] = true
            for (d in 1..n) {
                if (freq[d] != 0) {
                    val f = freq[d]
                    val newDP = BooleanArray(n + 1)
                    for (r in 0 until d) {
                        var amt = 0
                        for (x in r..n step d) {
                            if (dp[x]) {
                                amt++
                            }
                            if (x >= (f + 1) * d && dp[x - ((f + 1) * d)]) {
                                amt--
                            }
                            newDP[x] = amt > 0
                        }
                    }
                    dp = newDP
                }
            }
            val minValue = (0..n).filter { dp[it] }.map { abs((2 * it) - maxValue) }.min()!!
            println(minValue)
        }
    }
    jin.endOfInput()
}

class InvalidInputException(message: String): Exception(message)

class FastScanner {
    private val BS = 1 shl 16
    private val NC = 0.toChar()
    private val buf = ByteArray(BS)
    private var bId = 0
    private var size = 0
    private var c = NC
    private var `in`: BufferedInputStream? = null
    private val validation: Boolean

    constructor(validation: Boolean) {
        this.validation = validation
        `in` = BufferedInputStream(System.`in`, BS)
    }

    constructor() : this(true)

    private val char: Char
        private get() {
            while (bId == size) {
                size = try {
                    `in`!!.read(buf)
                } catch (e: Exception) {
                    return NC
                }
                if (size == -1) return NC
                bId = 0
            }
            return buf[bId++].toChar()
        }

    fun validationFail(message: String) {
        if (validation) {
            throw InvalidInputException(message)
        }
    }

    fun endOfInput() {
        if (char != NC) {
            validationFail("excessive input")
        }
        if (validation) {
            System.err.println("input validated")
        }
    }

    fun nextInt(from: Int, to: Int, endsLine: Boolean = true) = nextLong(from.toLong(), to.toLong(), endsLine).toInt()

    fun nextInt(to: Int, endsLine: Boolean = true) = nextInt(1, to, endsLine)

    fun nextLong(endsLine: Boolean): Long {
        var neg = false
        c = char
        if (c !in '0'..'9' && c != '-' && c != ' ' && c != '\n') {
            validationFail("found character other than digit, negative sign, space, and newline, character code = ${c.toInt()}")
        }
        if (c == '-') {
            neg = true
            c = char
        }
        var res = 0L
        while (c in '0'..'9') {
            res = (res shl 3) + (res shl 1) + (c - '0').toLong()
            c = char
        }
        if (endsLine) {
            if (c != '\n') {
                validationFail("found character other than newline, character code = ${c.toInt()}")
            }
        } else {
            if (c != ' ') {
                validationFail("found character other than space, character code = ${c.toInt()}")
            }
        }
        return if (neg) -res else res
    }

    fun nextLong(from: Long, to: Long, endsLine: Boolean = true): Long {
        val res = nextLong(endsLine)
        if (res !in from..to) {
            validationFail("$res not in range $from..$to")
        }
        return res
    }

    fun nextLong(to: Long, endsLine: Boolean = true) = nextLong(1L, to, endsLine)
}
Editorialist (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,mmx,avx,avx2")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(0); cin.tie(0);

    bitset<100005> B;

    int t; cin >> t;
    while (t--) {
        int n, q; cin >> n >> q;
        vector<vector<int>> adj(n);
        for (int i = 0; i < n-1; ++i) {
            int u, v; cin >> u >> v;
            adj[--u].push_back(--v);
            adj[v].push_back(u);
        }

        map<int, int> freq;

        auto dfs = [&] (const auto &self, int u, int par, int level) -> void {
            int childct = 0;
            for (int v : adj[u]) {
                if (v == par) continue;
                ++childct;
                self(self, v, u, level^1);
            }
            if (level == 1) ++freq[childct];
        };
        dfs(dfs, 0, 0, 1);

        if (q == 1) {
            int ans = 0;
            for (auto &[x, y] : freq)
                ans += y*abs(x-1);
            cout << ans << '\n';
            continue;
        }

        B.reset();
        B[0] = 1;
        vector<int> v;
        int M = 0;
        for (auto &[x, y] : freq) {
            int cur = 1, pw = 1;
            M += y*abs(x-1);
            while (cur <= y) {
                v.push_back(abs(x-1)*pw);
                pw *= 2;
                cur += pw;
            }
            cur -= pw;
            v.push_back((y - cur)*abs(x-1));
        }
        for (int x : v) {
            B |= B << x;
        }
        int ans = n+1;
        for (int i = 0; i <= M; ++i) {
            if (B[i])
                ans = min(ans, abs(M-2*i));
        }
        cout << ans << '\n';
    }
}
11 Likes

There is a typo under proof. I think odd vertices and even vertices were meant to be colored differently. :))

1 Like

I submitted using subset dp , but WITHOUT any speedup techniques and was expecting TLE, but got AC. How is this possible lol? CodeChef: Practical coding for everyone

8 Likes

Thanks for catching that, edited :slight_smile:

do mention why you are not getting tle once you come to know

Perhaps the test cases were not strong and didn’t consist of many repeating elements, so there were about root(N) numbers. The cases where there were many repeating elements, the sum was perhaps much less than 1e5. So somehow O(n*sum) worked, but it was not supposed to. :stuck_out_tongue: Anyways these are just my guesses.

But still it was an amazing problem which taught me a lot! Especially those speedup methods

Perhaps the test cases were not strong and didn’t consist of many repeating elements, so there were about root(N) numbers. The cases where there were many repeating elements, the sum was perhaps much less than 1e5. So somehow O(n*sum) worked, but it was not supposed to. :stuck_out_tongue: Anyways these are just my guesses.

But still it was an amazing problem which taught me a lot! Especially those speedup methods.

I didn’t think of the speedup method, instead used bitsets and got AC (0.12 sec).
PS: I also made it a little bit faster by not considering vi if c(vi)==1.

1 Like
for (auto &[x, y] : freq) {
    int cur = 1, pw = 1;
    M += y*abs(x-1);
    while (cur <= y) {
        v.push_back(abs(x-1)*pw);
        pw *= 2;
        cur += pw;
    }
    cur -= pw;
    v.push_back((y - cur)*abs(x-1));
}

What are we doing here?

can you explain how the below code ensure that we are getting the minimum value.
for(int i=0;i<=100000;i++) { if(dp[i]) ans = min(ans,abs(sum-(2*i))); }

This is the optimization technique described in Method 2 at the end, where the pairs are converted into an equivalent knapsack problem with at most \mathcal{O}(\sqrt{N}\log{N}) elements.

freq is a map maintaining the pairs, where freq[x] = y. That loop iterates over all of them and decomposes y as described.
Also note that the code stores the count of c(v) instead of |c(v) - 1|, which is why you see abs(x-1)*pw instead of x*pw (which the editorial describes).

1 Like

We want two subsets whose difference is as close to 0 as possible.
Note that if you pick some subset S whose sum is x, and the sum of the full set is M, the other subset has sum M - x, and the difference between the sums is |x - (M - x)| = |2x - M|.

The answer is thus the minimum value of |2x - M| across all possible x which are achievable. Knapsack dp finds which x can be achieved, the loop you mentioned computes the minimum value after this is done.

 for (int x : v) {
            B |= B << x;
        }

what is this part doing, can anyone explain?

can Someone please check my solution

I am just getting one RE rest all are AC .

PLEASEEEEE… also I have done this without using DP

https://www.codechef.com/viewsolution/53346553

Just one RE rest all are AC!!!

can Someone please check my solution

PLEASEEEEE… also I have done this without using DP

https://www.codechef.com/viewsolution/53346553

You’re getting RE because you allocate too much memory - there’s no need to create 10^6 lists for every graph.

If you fix that, you now get AC on the first subtask and WA on the second (Submission 53348405), which is expected because the first subtask can be solved greedily but the second cannot.

2 Likes

That’s the same thing as setting

B[i] = B[i] | B[i-x]

simultaneously for every x. Compare that with the standard subset-sum dp and you’ll see where it comes from.

You can also check cppreference to see what that operator does.

Why not use bitset to speed up? CodeChef: Practical coding for everyone