TREEREQ1 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: sushil2006
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

Easy

PREREQUISITES:

DFS

PROBLEM:

You’re given a tree on N vertices, vertex i has value A_i.
There are M constraints of the form (u, r, k).

Find the minimum possible sum of values of set S of vertices such that, for every constraint (u, r, k):

  • The number of elements of S in the subtree of u, when rooted at r, is exactly k.

EXPLANATION:

Let’s first try to solve an even simpler version of the problem: one where r = 1 for every constraint (so only a single root needs to be considered).
This means, rooting the tree at 1, we have constraints of the form “pick exactly k vertices from the subtree of u”.

Simple version solution

Call a vertex u important if the subtree of u has a constraint.
If u is important, let r_u denote the number of vertices we need to select from the subtree of u.

Perform a DFS starting at vertex 1.
When at u,

  • First, process all the children of u.
  • Then, if u is not important, nothing needs to be done.
  • If u is important, we also know exactly how many vertices in the subtree of u have been chosen already.
    Let this number be y - then we need to choose another r_u - y vertices from the subtree of u to satisfy it.

However, we can’t just choose any vertices in the subtree of u.
Since all the important descendants of u have already been satisfied, any vertex we now choose to satisfy u should not affect any other important vertex - that is, if we are to choose vertex v, then there shouldn’t be any other important descendant of u that contains v in its subtree.

Another way to think of this, is that a vertex v can only be picked to satisfy the constraint of its closest important ancestor.

Hence, we can precompute for each vertex its closest important ancestor (easy with a DFS).
Then, when we’re at u, we have a list of vertices we can choose from: since we want r_u - y of them, clearly it’s best to choose the smallest r_u - y values among them.
(Note that if r_u - y \lt 0 or exceeds the size of the list, no solution exists).

This solves the problem in \mathcal{O}(N\log N) time (or \mathcal{O}(N) after the initial sort).


Now, let’s look at the general version.
For now, let’s root the tree at 1, and let S_v denote the subtree of v when rooted at 1.

Looking at some constraint (u, r, k), we find that there are three possibilities:

  1. If u = r, the constraint is the for entire tree, which is also S_1.
  2. The subtree of u, when rooted at r, remains S_u itself; or
  3. The subtree of u, when rooted at r, is everything other than S_v for some vertex v.
How?

The first case is obvious.

Now, suppose u \neq r.
Then,

  1. Suppose u is an ancestor of r.
    Let v be the first vertex other than u on the u\to r path.
    Then, the subtree of u when rooted at r, is exactly all vertices other than those in S_v.
  2. If u is not an ancestor of r, the subtree of u when rooted at r is just S_u.

Finding the vertex v when u is an ancestor of r can be done in several ways.
In this task doing it in \mathcal{O}(N) by starting at r and repeatedly moving to the parent is fine.

Now, suppose we fix the total number of vertices we’re choosing, say to X.
The above observation tells us that every constraint gives us one of two conditions:

  1. Pick exactly k vertices within S_u; or
  2. Pick exactly k vertices outside S_v, which in turn means we must pick exactly X-k vertices inside S_v.

So, if X is fixed, we end up with several constraints all telling us to pick a certain number of vertices from within certain subtrees, when rooted at 1.
This is exactly the simplified version we first solved!

For a fixed X, the optimal solution can be found in \mathcal{O}(N) time.
In this version of the problem, N is small enough that we can simply try every value of X from 1 to N, find the answer for each separately, and take the best answers among them all, for \mathcal{O}(N^2) runtime.

TIME COMPLEXITY:

\mathcal{O}(N^2) per testcase.

CODE:

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;

void solve(istringstream cin) {
    int n, m;
    cin >> n >> m;
    vector<long long> a(n);
    for (int i = 0; i < n; i++) {
        cin >> a[i];
    }
    vector<vector<int>> g(n);
    for (int i = 0; i < n - 1; i++) {
        int x, y;
        cin >> x >> y;
        x--;
        y--;
        g[x].emplace_back(y);
        g[y].emplace_back(x);
    }

    int s_size = -1;
    map<pair<int, int>, int> cut;
    for (int i = 0; i < m; i++) {
        int u, r, k;
        cin >> u >> r >> k;
        u--;
        r--;
        if (u == r) {
            s_size = k;
            continue;
        }
        function<void(int, int)> add_cut = [&](int v, int p) {
            if (v == u) {
                cut[make_pair(u, p)] = k;
                if (cut.count(make_pair(p, u))) {
                    s_size = cut[make_pair(p, u)] + k;
                }
            }
            for (int to : g[v]) {
                if (to == p) {
                    continue;
                }
                add_cut(to, v);
            }
        };
        add_cut(r, -1);
    }

    if (cut.empty()) {
        sort(a.begin(), a.end());
        cout << accumulate(a.begin(), a.begin() + s_size, 0LL) << '\n';
        return;
    }

    set<int> st;
    for (int i = 0; i <= n; i++) {
        st.emplace(i);
    }

    int low = 0, high = n;
    vector<long long> c(n + 1);
    long long d = 0;
    vector<bool> checked(n);
    for (int i = 0; i < n; i++) {
        if (checked[i]) {
            continue;
        }
        // s_size * x + y
        int x = 1, y = 0;
        vector<long long> b;
        function<void(int, int)> dfs = [&](int v, int p) {
            b.emplace_back(a[v]);
            checked[v] = true;
            for (int to : g[v]) {
                if (to == p) {
                    continue;
                }
                if (cut.count(make_pair(to, v))) {
                    y -= cut[make_pair(to, v)];
                    continue;
                }
                if (cut.count(make_pair(v, to))) {
                    x -= 1;
                    y += cut[make_pair(v, to)];
                    continue;
                }
                dfs(to, v);
            }
        };
        dfs(i, -1);
        sort(b.begin(), b.end());

        if (x == 0) {
            d += accumulate(b.begin(), b.begin() + y, 0LL);
            continue;
        }

        set<int> t;
        long long s = 0;
        for (int j = 0; j <= (int) b.size(); j++) {
            // s_size * x + y == j
            if ((j - y) % x == 0) {
                int k = (j - y) / x;
                if (k >= 0) {
                    t.emplace(k);
                    c[k] += s;
                }
            }
            if (j < (int) b.size()) {
                s += b[j];
            }
        }
        assert(t.size());
        low = max(low, *t.begin());
        high = min(high, *t.rbegin());
    }

    if (s_size != -1) {
        low = max(low, s_size);
        high = min(high, s_size);
    }
    assert(low <= high);

    long long ans = 1e18;
    for (int i = low; i <= high; i++) {
        ans = min(ans, c[i]);
    }
    ans += d;
    cout << ans << '\n';
}

////////////////////////////////////////

#define IGNORE_CR

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
#ifdef IGNORE_CR
            if (c == '\r') {
                continue;
            }
#endif
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            assert(!isspace(buffer[pos]));
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

struct dsu {
    int n;
    vector<int> p;
    vector<int> sz;

    dsu(int _n) : n(_n) {
        p = vector<int>(n);
        iota(p.begin(), p.end(), 0);
        sz = vector<int>(n, 1);
    }

    inline int get(int x) {
        if (p[x] == x) {
            return x;
        } else {
            return p[x] = get(p[x]);
        }
    }

    inline bool unite(int x, int y) {
        x = get(x);
        y = get(y);
        if (x == y) {
            return false;
        }
        p[x] = y;
        sz[y] += sz[x];
        return true;
    }

    inline bool same(int x, int y) {
        return (get(x) == get(y));
    }

    inline int size(int x) {
        return sz[get(x)];
    }

    inline bool root(int x) {
        return (x == get(x));
    }
};

int main() {
    input_checker in;
    int tt = in.readInt(1, 1000);
    in.readEoln();
    int sn = 0, sm = 0;
    while (tt--) {
        int n = in.readInt(3, 2000);
        in.readSpace();
        int m = in.readInt(1, 2000);
        in.readEoln();
        sn += n;
        sm += m;
        auto a = in.readInts(n, -1e9, 1e9);
        in.readEoln();
        vector<int> u1(n - 1), v1(n - 1);
        for (int i = 0; i < n - 1; i++) {
            u1[i] = in.readInt(1, n);
            in.readSpace();
            v1[i] = in.readInt(1, n);
            in.readEoln();
        }
        dsu uf(n);
        for (int i = 0; i < n - 1; i++) {
            assert(uf.unite(u1[i] - 1, v1[i] - 1));
        }
        vector<int> u2(m), r2(m), k2(m);
        for (int i = 0; i < m; i++) {
            u2[i] = in.readInt(1, n);
            in.readSpace();
            r2[i] = in.readInt(1, n);
            in.readSpace();
            k2[i] = in.readInt(1, n);
            in.readEoln();
        }
        ostringstream sout;
        sout << n << " " << m << '\n';
        for (int i = 0; i < n; i++) {
            sout << a[i] << " \n"[i == n - 1];
        }
        for (int i = 0; i < n - 1; i++) {
            sout << u1[i] << " " << v1[i] << '\n';
        }
        for (int i = 0; i < m; i++) {
            sout << u2[i] << " " << r2[i] << " " << k2[i] << '\n';
        }
        solve(istringstream(sout.str()));
    }
    cerr << sn << " " << sm << endl;
    assert(sn <= 2000);
    assert(sm <= 2000);
    in.readEof();
    return 0;
}
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n, m; cin >> n >> m;
        vector<int> a(n);
        for (int &x : a) cin >> x;

        vector adj(n, vector<int>());
        for (int i = 0; i < n-1; ++i) {
            int u, v; cin >> u >> v;
            adj[--u].push_back(--v);
            adj[v].push_back(u);
        }

        vector<int> par(n);
        vector<int> ord;
        auto dfs = [&] (const auto &self, int u, int p) -> void {
            par[u] = p;
            ord.push_back(u);
            for (int v : adj[u]) if (v != p)
                self(self, v, u);
        };
        dfs(dfs, 0, 0);
        reverse(begin(ord), end(ord));

        vector<int> inside(n, -1), outside(n, -1);
        for (int i = 0; i < m; ++i) {
            int root, u, k; cin >> u >> root >> k;
            --root, --u;

            if (u == root) {
                inside[0] = k;
            }
            else {
                int y = par[root], py = root;
                while (y) {
                    if (y == u) break;
                    py = y;
                    y = par[y];
                }

                if (y == u) {
                    outside[py] = k;
                }
                else {
                    inside[u] = k;
                }
            }
        }

        vector val(n, vector<int>());
        auto populate = [&] (const auto &self, int u, int p, int who) -> void {
            if (inside[u] != -1 or outside[u] != -1) who = u;
            val[who].push_back(a[u]);
            for (int v : adj[u]) if (v != p)
                self(self, v, u, who);
            sort(begin(val[u]), end(val[u]));
        };
        populate(populate, 0, 0, 0);
        
        ll ans = 1e18, cur;
        vector<int> req(n), used(n);
        auto solve = [&] (const auto &self, int u, int p) -> void {
            for (int v : adj[u]) if (v != p) {
                self(self, v, u);
                used[u] += used[v];
            }

            if (req[u] != -1) {
                int take = req[u] - used[u];
                if (take < 0 or take > val[u].size()) cur = 1e18;
                else {
                    for (int i = 0; i < take; ++i) cur += val[u][i];
                }
                used[u] += take;
            }
        };
        for (int k = 1; k <= n; ++k) {
            if (inside[0] != -1 and k != inside[0]) continue;
            bool good = true;
            req.assign(n, -1);
            for (int u : ord) {
                if (inside[u] != -1) req[u] = inside[u];
                if (outside[u] != -1) {
                    req[u] = k - outside[u];
                    good &= req[u] >= 0;
                    if (inside[u] != -1) good &= inside[u] == req[u];
                }
            }
            req[0] = k;

            if (!good) continue;

            cur = 0;
            used.assign(n, 0);
            solve(solve, 0, 0);
            ans = min(ans, cur);
        }
        cout << ans << '\n';
    }
}

what a solution, it seems so simple, yet i couldn’t think for it for 1 hour straight.