NOTEQUALTREE - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: envyaims
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

Easy-Medium

PREREQUISITES:

DP on trees

PROBLEM:

You’re given a tree on N vertices.
Vertex i has the values A_i and B_i assigned to it.

You can do the following operation at most once.

  1. Choose any subset of vertices, say \{x_1, x_2, \ldots, x_k\}.
  2. Rearrange the A_i values of these vertices as you please.

Find the minimum size of a subset you must choose so that after the operation, A_i \neq B_i for all i.
Also find the number of subsets of minimum size that can be chosen.

EXPLANATION:

To even begin solving this problem, we need to be able to recognize when a valid rearrangement is possible.
That is, if we fix a subset S of vertices, we must understand when their A_i values be arranged to make A and B pointwise different.

Answer

To keep things simple, let’s relabel vertices so that S = \{1, 2, 3, \ldots, k\}.

Of course, the very first condition that must hold, is that A_i \neq B_i for i \gt k since we can’t change the values at those indices.

Now, observe that if some value appears “too many” times, we’re in trouble.
Specifically, if some integer x appears more than k times in total across the 2k elements A_1 to A_k and B_1 to B_k, no matter what we do a valid rearrangement cannot exist.

On the other hand, if every element appears \leq k times in total, a valid rearrangement always exists.
This can be proved via construction:

  • Suppose some element, say x, appears exactly k times.
    Then, place all occurrences of x in A such at whichever indices of B don’t contain x, and distribute the other elements however you please.
    This way, at each of the k indices, an instance of x will be paired with an instance of not-x.
  • If every element appears \lt k times, set A_k to be any element that isn’t equal to B_k (always possible), and then recursively solve for the remaining set of elements.

In summary, when picking a subset of size k, we’re happy if \text{freq}[x] \leq k for every x, and not happy otherwise.

If upon choosing all N vertices the above check still fails, there’s no way to satisfy the condition.
Otherwise, we know that a solution definitely exists.

First off, any vertex i which initially has A_i = B_i must definitely be chosen in the subset S.
(If there are no such vertices, the answer is of course 0.)

Let’s mark all vertices that must be chosen this way.
Since we must pick a connected subset of vertices, we’re also forced to include every vertex that lies on a path between some two of these marked vertices.
Finding all such vertices is fairly easy, and doable in \mathcal{O}(N) time with a single DFS (though the constraints also allow for quadratic approaches.)

Let S denote the subset of vertices we currently have, and let k = |S| denote the size of S.
If S satisfies the necessary condition on being rearrangeable, we’re done: the smallest subset size is k, and there’s one way to choose it.

Otherwise, there exists exactly one element x which occurs strictly more than k times.
Let’s look at the quantity k - \text{freq}[x].
Whenever we expand our connected component by adding a new vertex to it,

  • k, being the size of the set, increases by 1.
  • \text{freq}[x] increases by 0, 1, or 2, depending on the A_i and B_i values of the new vertex.
  • So, as a whole, k - \text{freq}[x] changes by -1, 0, or 1.

k - \text{freq}[x] starts off negative, and our goal is quite simple: make it equal to 0 (as long as it remains negative, rearrangement is impossible - and since we want to choose the minimum number of vertices, and each chosen vertex can increase the value by at most 1, ending at a positive value is never optimal).

Computing the size of the smallest connected subset containing S can now be solved with the help of a fairly classical tree DP.

Let C_i = 1 - [A_i = x] - [B_i = x] denote the change in value if we choose the i-th node.
Note that C_i is defined only for vertices not initially in S.
We compress the entirety of S into a single node, and root the tree at this compressed node.
For convenience, let the new compressed node be (N+1), and let C_{N+1} = \text{freq}[x] - k.

Then, define \text{dp}[u][y] to be the minimum size of a connected subset rooted at u, with sum of C_i values equal to y.

\text{dp}[u] can be computed from the children of u, by adding them in one by one.
Specifically, first set \text{dp}[u][C_u] = 1.
Then, for each child v of u, each value x, and each value y,
update \text{dp}[u][x+y] with \text{dp}[u][x] + \text{dp}[v][y].

This algorithm, if implemented directly, takes \mathcal{O}(N^3) time: each time we merge a child into the parent, we do \mathcal{O}(N^2) work by iterating over all pairs of (x, y); and this happens N-1 times in total since there are N-1 edges.

We can optimize this slightly by noting that we don’t really need to iterate over all pairs of (x, y).
For instance, since y is the sum of some connected component rooted at v, and every C_i takes values between -1 and 1, the maximum possible value of y for which there can possibly exist a subset at all, is the subtree size of v (similarly, the minimum possible value is the negative of the subtree size of v).

So, rather than iterate y from -N to N, we can iterate it only over -\text{subsz}_v to \text{subsz}_v.
Similarly, x only needs to be iterated from -\text{subsz}_u to \text{subsz}_u.
(Here, \text{subsz}_u represents the total size only considering children of u that have been processed already. After processing v, this value is to be increased by \text{subsz}_v.)

This simple-looking heuristic in fact brings the complexity down to \mathcal{O}(N^2)!.
A proof (and some discussion) can be found in point number 7 of this blog.

Note that the above discussion was entirely about computing the minimum size of a subset.
The problem also asks us to compute the number of ways to choose such a smallest subset: that can be computed similarly with a separate DP, the transitions are elementary combinatorics and left as an exercise to the reader.

TIME COMPLEXITY:

\mathcal{O}(N^2) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

void Solve() 
{
    int n; cin >> n;
    
    vector <int> a(n + 1), b(n + 1);
    for (int i = 1; i <= n; i++) cin >> a[i];
    for (int i = 1; i <= n; i++) cin >> b[i];
    
    vector <int> ga(n + 1, 0), gb(n + 1, 0);
    for (int i = 1; i <= n; i++){
        ga[a[i]]++;
        gb[b[i]]++;
    }
    
    vector<vector<int>> adj(n + 2);
    for (int i = 1; i < n; i++){
        int u, v; cin >> u >> v;
        
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    
    for (int i = 1; i <= n; i++) if (ga[i] + gb[i] > n){
        cout << -1 << " " << -1 << "\n";
        return;
    }
    
    vector <bool> vis(n + 1, false);
    vector<vector<int>> g(n + 2);
    vector <int> insub(n + 1);
    
    auto add_edge = [&](int u, int v){
        g[u].push_back(v);
        g[v].push_back(u);
        
     //   cout << u << " " << v << "\n";
    };
    
    vector <bool> need(n + 1, false);
    
    auto dfs = [&](auto self, int u, int par) -> void{
        insub[u] = (a[u] == b[u]);
        if (insub[u]) need[u] = true;
        int have = 0;
        
        for (int v : adj[u]) if (v != par){
            self(self, v, u);
            insub[u] += insub[v];
            if (insub[v]) have++;
        }
        
        if (have >= 2){
            need[u] = true;
        }
    };
    
    dfs(dfs, 1, -1);
    
    int tot = 0;
    for (int i = 1; i <= n; i++){
        tot += a[i] == b[i];
    }
    
    for (int i = 1; i <= n; i++){
        if (insub[i] > 0 && insub[i] < tot){
            need[i] = true;
        }
        
        if (!need[i]){
            vis[i] = true;
        }
    }
    
    // for (int i = 1; i <= n; i++){
    //     cout << need[i] << " \n"[i == n];
    // }
    
    for (int i = 1; i <= n; i++) if (vis[i]){
        bool got = false;
        for (int v : adj[i]){
            if (vis[v] && v < i){
                add_edge(v, i);
            } else if (!vis[v]) {
                got = true;
            }
        }
        
        if (got){
            add_edge(n + 1, i);
        }
    }
  //  return;
    
    vector <int> fa(n + 1, 0), fb(n + 1, 0);
    int cnt = 0;
    for (int i = 1; i <= n; i++) if (!vis[i]){
        fa[a[i]]++;
        fb[b[i]]++;
        cnt++;
    }
    
  //  cout << cnt << "\n";
    
    bool bad = false;
    int col = -1;
    for (int i = 1; i <= n; i++){
        if (fa[i] + fb[i] > cnt){
            bad = true;
            col = i;
        }
    }
    
    if (!bad){
        cout << cnt << " " << 1 << "\n";
        return;
    }
    
    const int mod = 1e9 + 7;
    vector<vector<int>> dp(n + 2, vector<int>(2 * n + 4 , INF));
    vector<vector<int>> cdp(n + 2, vector<int>(2 * n + 4, 0));
    
    vector<int> ndp(2 * n + 4), ncnt(2 * n + 4);
    vector<int> sub(n + 2);
    
    // a[u] = col or b[u] = col is a -1 
    // otherwise +1 
    
    // dp[i][j] -> min size, cnt[i][j] -> ways 
    
    auto rec = [&](auto self, int u, int par) -> void{
        sub[u] = 1; 
        if (u == n + 1){
           dp[u][n] = 0;
           cdp[u][n] = 1;
        } else if (a[u] == col || b[u] == col){
            dp[u][n] = 1;
            cdp[u][n] = 1;
        } else {
            dp[u][n + 1] = 1;
            cdp[u][n + 1] = 1;
        }
        
        for (int v : g[u]) if (v != par){
            self(self, v, u);
            
            for (int i = 0; i <= 2 * n + 3; i++){
                ndp[i] = INF;
                ncnt[i] = 0;
            }
            
            for (int i = 0; i <= sub[u]; i++){
                for (int j = 0; j <= sub[v]; j++){
                    if (dp[u][i + n] + dp[v][j + n] < ndp[i + j + n]){
                        ndp[i + j + n] = dp[u][i + n] + dp[v][j + n];
                        ncnt[i + j + n] = cdp[u][i + n] * cdp[v][j + n] % mod;
                    } else if (dp[u][i + n] + dp[v][j + n] == ndp[i + j + n]){
                        ncnt[i + j + n] += cdp[u][i + n] * cdp[v][j + n];
                        ncnt[i + j + n] %= mod; 
                    }
                }
            }
            
            sub[u] += sub[v];
            
            for (int i = 0; i <= 2 * n + 3; i++){
                dp[u][i] = ndp[i];
                cdp[u][i] = ncnt[i];
            }
        }
        
        dp[u][n] = 0;
        cdp[u][n] = 1;
    };
    
    // for (int i = 1; i <= n + 1; i++){
    //     for (int v : g[i]){
    //         cout << v << " ";
    //     }
    //     cout << "\n";
    // }
    // return;
    
    rec(rec, n + 1, -1);
    
    int diff = fa[col] + fb[col] - cnt;
    
    cout << (dp[n + 1][diff + n] + cnt) << " " << cdp[n + 1][diff + n] << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
      //  cout << "Case #" << i << ": \n";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n; cin >> n;

        vector<int> a(n), b(n);
        for (int i = 0; i < n; ++i) cin >> a[i];
        for (int i = 0; i < n; ++i) cin >> b[i];
        
        vector adj(n, vector<int>());
        for (int i = 0; i < n-1; ++i) {
            int u, v; cin >> u >> v;
            adj[--u].push_back(--v);
            adj[v].push_back(u);
        }

        vector<int> mark(n);
        for (int i = 0; i < n; ++i)
            mark[i] = a[i] == b[i];
        
        int total_marked = accumulate(begin(mark), end(mark), 0);
        if (total_marked == 0) {
            cout << 0 << ' ' << 1 << '\n';
            continue;
        }

        auto markall = [&] (const auto &self, int u, int p) -> int {
            int ch = 0, ret = mark[u];
            for (int v : adj[u]) {
                if (v != p) ret += self(self, v, u);
                ch += mark[v];
            }
            mark[u] |= (ch > 1) or (ret > 0 and ret < total_marked);
            return ret;
        };
        markall(markall, 0, 0);

        total_marked = accumulate(begin(mark), end(mark), 0);
        int bad = -1;
        for (int i = 1; i <= n; ++i) {
            int ct = 0;
            for (int u = 0; u < n; ++u) {
                if (mark[u]) {
                    ct += a[u] == i;
                    ct += b[u] == i;
                }
            }
            if (ct > total_marked) bad = i;
        }

        if (bad == -1) {
            cout << total_marked << ' ' << 1 << '\n';
            continue;
        }

        vector new_adj(n+1, vector<int>());
        vector<int> val(n+1);
        for (int u = 0; u < n; ++u) {
            for (int v : adj[u]) {
                if (mark[u] and mark[v]) continue;
                
                if (mark[u]) new_adj[n].push_back(v);
                else if (!mark[v]) new_adj[u].push_back(v);
            }

            if (!mark[u]) {
                val[u] = 1 - (a[u] == bad) - (b[u] == bad);
            }
        }

        const int mod = 1e9 + 7;
        vector mn(n+1, vector(n+1, -1));
        vector ct(n+1, vector(n+1, -1));
        vector<int> subsz(n+1);
        auto dfs = [&] (const auto &self, int u, int p) -> void {
            subsz[u] = 1;
            if (val[u] >= 0) mn[u][val[u]] = 1, ct[u][val[u]] = 1;
            
            for (int v : new_adj[u]) {
                if (v == p) continue;
                self(self, v, u);

                for (int w = subsz[u] + subsz[v]; w >= 0; --w) {
                    for (int x = min(w, subsz[v]); x > 0; --x) {
                        int y = w - x;
                        if (y > subsz[u]) break;
                        if (y < 0) continue;

                        if (mn[u][y] == -1 or mn[v][x] == -1) continue;

                        if (mn[u][w] == -1 or mn[u][w] > mn[u][y] + mn[v][x]) {
                            mn[u][w] = mn[u][y] + mn[v][x];
                            ct[u][w] = 0;
                        }
                        if (mn[u][w] == mn[u][y] + mn[v][x]) {
                            ct[u][w] = (ct[u][w] + 1ll*ct[u][y]*ct[v][x]) % mod;
                        }
                    }
                }

                if (val[u] == -1) {
                    for (int x = 1; x <= subsz[v]; ++x) {
                        if (mn[v][x] == -1) continue;
                        if (mn[u][x-1] == -1 or mn[u][x-1] > mn[v][x] + 1) {
                            mn[u][x-1] = mn[v][x] + 1;
                            ct[u][x-1] = 0;
                        }
                        if (mn[u][x-1] == 1 + mn[v][x]) {
                            ct[u][x-1] = (ct[u][x-1] + ct[v][x]) % mod;
                        }
                    }
                }
                subsz[u] += subsz[v];
            }
        };
        dfs(dfs, n, n);

        int balance = -total_marked;
        for (int i = 0; i < n; ++i) {
            if (mark[i]) {
                balance += a[i] == bad;
                balance += b[i] == bad;
            }
        }

        cout << (total_marked-1)*(mn[n][balance] >= 0) + mn[n][balance] << ' ' << ct[n][balance] << '\n';
    }
}

Here, subszu represents the total size only considering children of u that have been processed already. After processing v, this value is to be increased by subszv.

can someone please explain this?