BLOCKTREE - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: mannshah1211
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Easy

PREREQUISITES:

DFS

PROBLEM:

You’re given a tree with N vertices. You must give exactly K of its vertices the value 1, and the rest the value 0.

Define f(u, v) to be the number of blocks of equal values when looking at the unique simple path from u to v.
Find a way to color the tree such that the maximum of f(u, v) across all (u, v) is minimized.

EXPLANATION:

First, let’s rewrite the definition of f(u, v) slightly.
The number of blocks of equal elements in a binary string can also be seen as the number of pairs of unequal adjacent elements in the string (i.e. substrings that are either 10 or 01), plus one.
For example, if S = 11001010, there are 6 blocks, and 5 pairs of unequal adjacent elements.

This means we can look at the cost of a path as the number of times the value changes when walking along it.

Let A_u denote the value assigned to vertex u.
We can now see that:

  1. If A_u \neq A_v, then f(u, v) \geq 2 because the value must change at least once when moving from u to v.
  2. If A_u = A_v, then there are two possibilities:
    • If all the values along the path from u to v are equal (to A_u), then f(u, v) = 1 since there are no changes.
    • If not all the values are equal, then f(u, v) \geq 3 since there must be at least two changes (once when changing away from A_u, and then once more to return to it by the end).

Given that our aim is to minimize the maximum f(u, v), the last part there is seemingly the worst, since we obtain a value of at least 3.
If we are to avoid this case, we must ensure that every (u, v) pair falls into one of the other two cases. In particular, observe that this means all the vertices given the same value must form a connected subset.

This means we want to essentially split the tree into two connected components, such that one of them has K vertices.
If we can do this, everything in the component of size K can be given the value 1 and everything in the other component can be given the value 0, and we’ll have f(u, v) \leq 2 for all vertices (which is clearly the best we can do when both colors are present).

To check for this, we use the fact that we’re working with a tree.
Let’s root the tree at vertex 1. Then, if the tree is split into two components, one of these components will be a subtree of this rooted tree.

So, for each vertex u, let’s look at the subtree of vertex u (when the tree is rooted at 1).

  • If this subtree has size K, we’ve found a component of size K as needed.
  • If this subtree has size N-K, then everything outside it will form a component of size K instead.

To quickly check for these conditions, all we need to do is precompute all subtree sizes (which is a standard usage of DFS).


What about when we fail to find a valid split above?
Then we certainly know that the maximum f(u, v) value will be \geq 3.

Taking the same idea as the previous case, let’s anyway try to make all the vertices with value 1 form a single component.
It’s easy to see that doing this will result in f(u, v) \leq 3 for all vertices, because:

  • If A_u \neq A_v then f(u, v) = 2 always, since the black vertices are connected.
  • If A_u = A_v = 1 then f(u, v) = 1.
  • If A_u = A_v = 0 then there are two further cases:
    • If u and v lie in the same component, f(u, v) = 1.
    • Otherwise, f(u, v) = 3, because we need to start at 0, move to 1, then move back to 0.

Since 3 was a lower bound, and it’s been achieved, it’s also optimal.
Thus, in this case, all we need to do is give the value 1 to some connected set of K vertices.
This is easy to do: for example, start a BFS or DFS from an arbitrary vertex, and halt it once K vertices have been visited.


There is one final edge case.
All the discussion above relied on moving between vertices of different values. However, if K = 0 or K = N then everything must be given the same color and we have no choice (and in these two cases, the maximum f(u, v) will be 1).

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

void Solve() 
{
    int n, k; cin >> n >> k;
    
    vector<vector<int>> adj(n + 1);
    for (int i = 1; i < n; i++){
        int u, v; cin >> u >> v;
        
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    
    if (k == 0 || k == n){
        cout << 1 << "\n";
        for (int i = 1; i <= n; i++){
            cout << (k == n) << " \n"[i == n];
        }
        return;
    }
    
    vector <int> sub(n + 1);
    vector <int> tin(n + 1), tout(n + 1), p(n + 1), deg(n + 1);
    int timer = 0;
    
    auto dfs = [&](auto self, int u, int par) -> void{
        sub[u] = 1;
        tin[u] = ++timer;
        p[u] = par;
        
        for (int v : adj[u]) if (v != par){
            self(self, v, u);
            deg[u]++;
            sub[u] += sub[v];
        }
        tout[u] = timer;
    };
    
    dfs(dfs, 1, 0);
    
    for (int i = 1; i <= n; i++){
        if (sub[i] == k || sub[i] == n - k){
            int xo = (sub[i] == (n - k));
            cout << 2 << "\n";
            for (int j = 1; j <= n; j++){
                if (tin[j] >= tin[i] && tout[j] <= tout[i]){
                    cout << (1 ^ xo) << " \n"[j == n];
                } else {
                    cout << (xo) << " \n"[j == n];
                }
            }
            return;
        }
    }
    
    queue <int> q;
    for (int i = 1; i <= n; i++){
        if (deg[i] == 0){
            q.push(i);
        }
    }
    
    vector <int> ans(n + 1, 0);
    
    while (!q.empty()){
        int u = q.front();
        q.pop();
        k--;
        ans[u] = 1;
        if (k == 0){
            break;
        }
        
        deg[p[u]]--;
        if (deg[p[u]] == 0){
            q.push(p[u]);
        }
    }
    
    cout << 3 << "\n";
    for (int i = 1; i <= n; i++){
        cout << ans[i] << " \n"[i == n];
    }
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;

template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;

#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl

#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)

template<typename T>
void amin(T &a, T b) {
    a = min(a,b);
}

template<typename T>
void amax(T &a, T b) {
    a = max(a,b);
}

#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif

/*



*/

const int MOD = 1e9 + 7;
const int N = 2e3 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;

vector<ll> adj[N];

void solve(int test_case){
    ll n,k; cin >> n >> k;
    rep1(i,n){
        adj[i].clear();
    }
    rep1(i,n-1){
        ll u,v; cin >> u >> v;
        adj[u].pb(v), adj[v].pb(u);
    }

    if(k == 0 or k == n){
        cout << 1 << endl;
        if(k == 0){
            rep1(i,n) cout << 0 << " ";
        }
        else{
            rep1(i,n) cout << 1 << " ";
        }
        cout << endl;
        return;
    }

    vector<ll> subsiz(n+5);
    vector<ll> ord;
    vector<ll> ans(n+5);
    vector<ll> par(n+5);

    auto dfs1 = [&](ll u, ll p, auto &&dfs1) -> void{
        subsiz[u] = 1;
        par[u] = p;
        ord.pb(u);
        trav(v,adj[u]){
            if(v == p) conts;
            dfs1(v,u,dfs1);
            subsiz[u] += subsiz[v];
        }
    };

    auto col_sub = [&](ll u, ll c, auto &&col_sub) -> void{
        ans[u] = c;
        trav(v,adj[u]){
            if(v == par[u]) conts;
            col_sub(v,c,col_sub);
        }
    };

    dfs1(1,-1,dfs1);
    bool found = false;

    rep1(u,n){
        if(subsiz[u] == k){
            col_sub(u,1,col_sub);
            found = true;
            break;
        }
        else if(subsiz[u] == n-k){
            col_sub(u,1,col_sub);
            rep1(i,n) ans[i] ^= 1;
            found = true;
            break;
        }
    }

    if(found){
        cout << 2 << endl;
        rep1(i,n) cout << ans[i] << " ";
        cout << endl;
    }
    else{
        cout << 3 << endl;
        rep(i,k) ans[ord[i]] = 1;
        rep1(i,n) cout << ans[i] << " ";
        cout << endl;
    }
}

int main()
{
    fastio;

    int t = 1;
    cin >> t;

    rep1(i, t) {
        solve(i);
    }

    return 0;
}
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n, k; cin >> n >> k;
        vector adj(n, vector<int>());
        for (int i = 0; i < n-1; ++i) {
            int u, v; cin >> u >> v;
            adj[--u].push_back(--v);
            adj[v].push_back(u);
        }

        vector<int> subsz(n), order;
        auto dfs = [&] (const auto &self, int u, int p) -> void {
            subsz[u] = 1;
            order.push_back(u);
            for (int v : adj[u]) if (v != p) {
                self(self, v, u);
                subsz[u] += subsz[v];
            }
        };
        dfs(dfs, 0, 0);

        vector<int> ans(n);
        auto go = [&] () {
            if (k == 0) return 1;
            if (k == n) {
                ranges::fill(ans, 1);
                return 1;
            }
            for (int i = 0; i < n; ++i) {
                int u = order[i];
                if (subsz[u] == k or subsz[u] == n-k) {
                    ranges::fill(ans, subsz[u] != k);
                    for (int j = 0; j < subsz[u]; ++j)
                        ans[order[i+j]] = subsz[u] == k;
                    return 2;
                }
            }
            for (int i = 0; i < k; ++i) ans[order[i]] = 1;
            return 3;
        };
        
        cout << go() << '\n';
        for (int x : ans) cout << x << ' ';
        cout << '\n';
    }
}
2 Likes

Code LinkCan anyone check why i am getting Wrong Answer I did exactly what mentioned in the Editorial

Line 23 par_ instead of par

https://www.codechef.com/viewsolution/1141729418

1 Like

Ohhh that was a silly one. THANKS for help