MPTREE0 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: versiansmart
Tester: mexomerf
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

DFS

PROBLEM:

The \text{MEX-P} of a set of integers is defined to be the smallest prime number that doesn’t divide at least one element of the set.

You’re given a tree with N vertices. Each vertex has a value written on it.
For each vertex u of the tree, compute the sum of \text{MEX-P} values of all paths starting at u.

N \leq 2000 in this version.

EXPLANATION:

The constraint of N \leq 2000 affords us a solution in \mathcal{O}(N^2), meaning we can attempt to solve for each u independently.
So, let’s fix a vertex u and try to compute all the f(u, v) values, after which we can add them up.

To do this, we make the following observation about the \text{MEX-P} function:

\text{MEX-P}(\{x_1, x_2, \ldots, x_k\}) = \min(\text{MEX-P}(\{x_1\}), \text{MEX-P}(\{x_2\}), \ldots, \text{MEX-P}(\{x_k\}))

That is, the \text{MEX-P} of a set of elements is simply the minimum of the \text{MEX-P}'s of its individual elements.

So, suppose we calculate m_i = \text{MEX-P}(A_i) for every i.
Then, by our earlier observation, f(u, v) simply equals the minimum m-value on the u\to v path, so our aim is to fix some u and sum this up for all v.

This turns to be a simple task with the help of a DFS.
Let’s perform a DFS starting at u.
Suppose during the DFS we’re at vertex v (so we know f(u, v)), and we’re trying to use an edge to move to vertex w.
Then, the earlier observation tells us that f(u, w) = \min(f(u, v), m_w) so we can move to w while knowing f(u, w) and then repeat the process.

So, as long as all the m_i values are known, the answer for a single u can be found using a simple DFS in linear time.
All that remains is to actually compute the m_i values.

However, this is not too hard: note that a number can’t have too many distinct prime divisors, in particular values \leq 10^9 can’t have more than 10 prime divisors (because the product of the first 10 primes exceeds 10^9).
So, it’s enough to check for divisibility of A_i against the first 10 primes (which is 2, 3, 5, 7, 11, 13, 17, 19, 23, 29), one of which is guaranteed to be m_i.

TIME COMPLEXITY:

\mathcal{O}(N^2) per testcase.

CODE:

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
long long readInt(long long l, long long r, char endd) {
    long long x = 0;
    int cnt = 0;
    int fi = -1;
    bool is_neg = false;
    while (true) {
        char g = getchar();
        if (g == '-') {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if ('0' <= g && g <= '9') {
            x *= 10;
            x += g - '0';
            if (cnt == 0) {
                fi = g - '0';
            }
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);
 
            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if (g == endd) {
            assert(cnt > 0);
            if (is_neg) {
                x = -x;
            }
            assert(l <= x && x <= r);
            return x;
        }
        else {
            assert(false);
        }
    }
}
void dfs(int x, int pr, vector<int> tr[], int b[], int temp, int &ans){
    ans += min(temp, b[x - 1]);
    for(int i = 0; i < (int)tr[x].size(); i++){
        int y = tr[x][i];
        if(y != pr){
            dfs(y, x, tr, b, min(temp, b[x - 1]), ans);
        }
    }
}
int32_t main() {
	int t;
	t = readInt(1, 1000, '\n');
	int ns = 0;
	vector<int> p = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37};
	int ps = (int)p.size();
	while(t--){
	    int n;
	    n = readInt(2, 2000, '\n');
	    ns += n;
	    assert(ns <= 2000);
	    int a[n];
	    int b[n];
	    for(int i = 0; i < n; i++){
	        if(i != n - 1){
	            a[i] = readInt(1, 1000000000, ' ');
	        }else{
	            a[i] = readInt(1, 1000000000, '\n');
	        }
	        for(int j = 0; j < ps; j++){
	            if(a[i] % p[j]){
	                b[i] = p[j];
	                break;
	            }
	        }
	    }
	    vector<int> tr[n + 1];
	    for(int i = 0; i < n - 1; i++){
	        int u, v;
	        u = readInt(1, n, ' ');
	        v = readInt(1, n, '\n');
	        tr[u].push_back(v);
	        tr[v].push_back(u);
	    }
	    int ans[n]={0};
	    for(int i = 0; i < n; i++){
	        dfs(i + 1, 0, tr, b, INT_MAX, ans[i]);
	        cout<<ans[i]<<" ";
	    }
	    cout<<"\n";
	}
}
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    vector<int> primes = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37};

    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector<int> a(n);
        for (int &x : a) cin >> x;

        vector adj(n, vector<int>());
        for (int i = 0; i < n-1; ++i) {
            int u, v; cin >> u >> v;
            adj[--u].push_back(--v);
            adj[v].push_back(u);
        }

        vector<int> ans(n), curcomp(n, n);
        ll prod = 1;
        for (auto p : primes) {
            prod *= p;
            
            vector<int> mark(n), component;
            auto dfs = [&] (const auto &self, int u) -> void {
                if (a[u] % prod) return;
                
                mark[u] = 1;
                component.push_back(u);
                for (int v : adj[u]) if (!mark[v]) {
                    self(self, v);
                }
            };

            for (int i = 0; i < n; ++i) {
                if (mark[i]) continue;
                if (a[i] % prod) {
                    ans[i] += curcomp[i] * p;
                    curcomp[i] = 0;
                    continue;
                }
                
                dfs(dfs, i);
                for (int u : component) {
                    ans[u] += (curcomp[u] - component.size()) * p;
                    curcomp[u] = component.size();
                }
                component.clear();
            }
        }

        for (int x : ans) cout << x << ' ';
        cout << '\n';
    }
}