TREETOSTAR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: everule1
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Bitmasks

PROBLEM:

For a tree T, define f(T) to be the minimum number of following operations needed to convert it to a star tree:

  • Choose two distinct vertices u and v of the tree.
  • Let A be the set of edges (x, y) such that x lies on the u-v path (though x\neq u), and y doesn’t.
  • Choose any subset B of A, and replace each (x, y) in this subset with (u, y).

You’re given a connected graph. Find the minimum possible value of f(T) across all spanning trees T of G.

EXPLANATION:

First, let’s try to find f(T) for a fixed tree T.
Our objective is to turn T into a star tree - equivalently, T should have |T|-1 leaf vertices (where a leaf vertex is a vertex with with degree 1).

How does the given operation change the number of leaves of T?

Answer

It can be seen that:

  • If y was originally not a leaf, it can be turned into a leaf by moving every edge attached to it, to x.
  • If x was originally a leaf, it can be turned into a non-leaf by moving at least one other edge to it.
  • For any other vertex, their ‘leafness’ doesn’t change, because their degrees don’t change.

That is, the number of leaves can change by +1, 0, or -1 depending on our operation.
In particular, if T has L leaves initially, a lower bound on f(T) is |T|-1-L; since to reach |T|-1 leaves we certainly need one operation for each.

This lower bound is in fact tight - that is, it’s always possible to use |T|-1-L operations to turn T into a tree.

How?

If T is already a star, the statement is trivially true.

Otherwise, T must contain at least 2 non-leaves.
Pick two of them, x and y, and perform the operation with them.
Move every edge incident to y, to x.

This turns y into a leaf, and since x was already a non-leaf, the overall number of leaves increases by 1.

Repeat this process while T has \geq 2 non-leaves.


Let’s apply this observation to our original problem.
We have a graph G, and we want to minimize f(T) across all spanning trees T of G.
Since T spans G, |T| = N; meaning we want to minimize N-1-L (L being the number of leaves of T).

This is, of course, equivalent to maximizing L.
So, our objective turns to finding a spanning tree of G with as many leaves as possible.

This can be done by essentially bruteforce.
Fix some subset S of the N vertices that will form the leaves of your spanning tree.
Let S^c denote the other vertices.
S can form the leaves of a spanning tree of G only if:

  • Each vertex of S is adjacent to some vertex of S^c; and
  • The subgraph induced by S^c is itself connected (and hence has a spanning tree of its own).

Note that these conditions only guarantee that some superset of S will form the leaves of the spanning tree; but we’re looking for the maximum-sized subset anyway so that’s ok — after all, no strict superset of the maximum-sized subset can be valid, otherwise it wouldn’t be maximum!

This can be implemented in \mathcal{O}(2^N N^2) or even \mathcal{O}(2^N N), both of which should pass comfortably.

TIME COMPLEXITY:

\mathcal{O}(2^N \cdot N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

struct ufds{
    vector <int> root, sz;
    int n;
 
    void init(int nn){
        n = nn;
        root.resize(n + 1);
        sz.resize(n + 1, 1);
        for (int i = 1; i <= n; i++) root[i] = i;
    }
 
    int find(int x){
        if (root[x] == x) return x;
        return root[x] = find(root[x]);
    }
 
    bool unite(int x, int y){
        x = find(x); y = find(y);
        if (x == y) return false;
 
        if (sz[y] > sz[x]) swap(x, y);
        sz[x] += sz[y];
        root[y] = x;
        return true;
    }
};

void Solve() 
{
    int n; cin >> n;

    vector<string> a(n);
    for (auto &x : a) cin >> x;
    
    ufds uf;
    uf.init(n);
    for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) if (a[i][j] == '1') uf.unite(i, j);
    
    for (int i = 0; i < n; i++){
       // cout << uf.find(i) << " \n"[i + 1 == n];
        assert(uf.find(i) == uf.find(0));
    }

    int ans = n - 1;
    for (int i = 0; i < (1 << n); i++){
        // can this mask be all leaves 
        // rest of the mask must be connected 
        // each membner of mask must have edge to non mask 
        ufds uf;
        uf.init(n);
        for (int j = 0; j < n; j++) for (int k = j + 1; k < n; k++){
            if (!(i >> j & 1) && !(i >> k & 1) && a[j][k] == '1'){
                uf.unite(j, k);
            }
        }

        bool good = true;

        for (int j = 0; j < n; j++) for (int k = j + 1; k < n; k++){
            if (!(i >> j & 1) && !(i >> k & 1)){
                good &= uf.find(j) == uf.find(k);
            }
        }

        for (int j = 0; j < n; j++) if (i >> j & 1){
            bool ok = false;
            for (int k = 0; k < n; k++) if (!(i >> k & 1) && a[j][k] == '1'){
                ok = true;
            }
            good &= ok;
        }

        if (good) ans = min(ans, n - 1 - __builtin_popcount(i));
    }

    cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    assert(t <= (1 << 12));
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++)
#ifndef LOCAL
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,avx2,sse,sse2,sse3,sse4,popcnt,fma")
#endif

#include <bits/stdc++.h>
using namespace std;

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && !isspace(buffer[now])) {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

struct DSU {
    int n, cmps;
    vector<int> p, sub, cnt;
    DSU () {}
    DSU(int n_) : n(n_), cmps(n_) {
        p.resize(n+1);
        sub.resize(n+1, 1);
        cnt.resize(n + 1, 1);
        iota(p.begin(), p.end(), 0);
    }
    int parent(int i) {
        assert(i <= n);
        return p[i] = (p[i] == i ? i : parent(p[i]));
    }
    bool join(int x, int y) {
        assert(x <= n && y <= n);
        x = parent(x), y = parent(y);
        if(x == y) {
            return false;
        }
        --cmps;
        if(sub[x] > sub[y]) {
            swap(x, y);
        }
        sub[y] += sub[x];
        cnt[y] += cnt[x];
        cnt[x] = 0;
        p[x] = y;
        return true;
    }
};

int32_t main() {
    ios_base::sync_with_stdio(0);   cin.tie(0);

    input_checker input;

    int T = input.readInt(1, 1 << 12);  input.readEoln();
    int S = 0;
    while(T-- > 0) {
        int N = input.readInt(3, 18);   input.readEoln();
        S += 1 << N;
        vector<string> G(N);
        for(int i = 0 ; i < N ; ++i) {
            G[i] = input.readString(N, N, "01");    input.readEoln();
        }

        int ans = N;
        for(int mask = 1 ; mask < (1 << N) ; ++mask) {
            DSU ds(N);
            int root = -1;
            for(int i = 0 ; i < N ; ++i) if(mask >> i & 1) {
                root = i;
                for(int j = i + 1 ; j < N ; ++j) if (mask >> j & 1) {
                    if(G[i][j] == '1')    ds.join(i, j);
                }
            }
            bool good = true;
            for(int i = 0 ; i < N ; ++i) if(mask >> i & 1) {
                if(ds.parent(i) != ds.parent(root)) {
                    good = false;
                }
            }
            for(int i = 0 ; i < N ; ++i) if(1 ^ (mask >> i & 1)) {
                bool here = false;
                for(int j = 0 ; j < N ; ++j) if(mask >> j & 1) {
                    if(G[i][j] == '1') {
                        here = true;
                    }
                }
                good = good && here;
            }
            if(good) {
                ans = min(ans, __builtin_popcount(mask));
            }
        }
        cout << ans - 1 << '\n';
    }
    assert(S <= (1 << 18));

    input.readEof();

    return 0;
}

Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector<int> adjmask(n);
        for (int i = 0; i < n; ++i) {
            string s; cin >> s;
            for (int j = 0; j < n; ++j) if (s[j] == '1')
                adjmask[i] |= 1 << j;
        }

        vector<int> connected(1 << n);
        connected[0] = 1;
        int mx = 0;
        for (int mask = 1; mask < 1 << n; ++mask) {
            bool good = true;
            for (int i = 0; i < n; ++i) {
                if (mask >> i & 1) {
                    int rem = mask ^ (1 << i);
                    if (connected[rem] and (rem == 0 or (adjmask[i] & rem))) connected[mask] = 1;
                }
                else {
                    good &= (adjmask[i] & mask) != 0;
                }
            }
            if (good and connected[mask]) mx = max(mx, n - __builtin_popcount(mask));
        }
        cout << n-1-mx << '\n';
    }
}