MPTREE - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: versiansmart
Tester: mexomerf
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

DFS

PROBLEM:

The \text{MEX-P} of a set of integers is defined to be the smallest prime number that doesn’t divide at least one element of the set.

You’re given a tree with N vertices. Each vertex has a value written on it.
For each vertex u of the tree, compute the sum of \text{MEX-P} values of all paths starting at u.

EXPLANATION:

If the \text{MEX-P} of a set of integers equals p, that would mean every prime \lt p divides every element of the set.
If an integer x is divisible by distinct primes p_1, p_2, \ldots, p_k then x will also be divisible by their product p_1\cdot p_2\cdot\ldots\cdot p_k

Now, note that the product of primes grows extremely fast.
The first few primes are 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, and their product is already 6469693230 \gt 10^9.
So, given that we’re working with numbers \leq 10^9, it’s impossible for any subset we consider to have a \text{MEX-P} that’s more than 29 at all.


Let p_i denote the i-th prime number (so p_1 = 2, p_2 = 3, \ldots).
The maximum possible \text{MEX-P} is 29, so we only need to care about i \leq 10.

We define c_{u, i} to be the number of paths starting at u with \text{MEX-P} equal to p_i.
If we can compute all these values quickly, then the final answer for u is just the sum of p_i\cdot c_{u, i} across 1 \leq i \leq 10.

Let’s look at how c_{u, i} can be computed.
For the \text{MEX-P} to equal p_i, every vertex on the path must be a multiple of every prime \lt p_i, and at least one of them must not be a multiple of p_i.
The second condition is a bit annoying to deal with, so let’s relax it a bit, and not care whether there are any non-multiples of p_i on the path. Note that this might lead us to consider paths whose \text{MEX-P} is greater than p_i as well.

So, define d_{u, i} to be the number of paths starting at u, such that their \text{MEX-P} is at least p_i.
If we can compute d_{u, i} for all i, then c_{u, i} = d_{u, i} - d_{u, i+1} so we’ll be done.

To compute d_{u, i}, the only condition now is that all the vertices on the path from u should be multiples of every prime \lt p_i - which as noted above means that they should all be multiples of the product of all primes \lt p_i.
So, if P_{i-1} = p_1\cdot p_2 \cdot\ldots\cdot p_{i-1},

  • If A_u is not divisible by P_{i-1}, then d_{u, i} = 0.
  • Otherwise, we can start at u, and then move to a neighbor of u that’s divisible by P_{i-1}, then a neighbor of that node that’s divisible by P_{i-1}, and so on.

For the second case, observe that each edge (x, y) is usable if and only if both x and y are multiples of P.
So, essentially we’re keeping only usable edges, and looking at which vertices u can reach using them: which is just the size of the connected component containing u.

This gives us a way to compute d_{u, i} for all u quickly once i is fixed: keep only usable edges, and then find the size of the connected component containing u (which can be done in linear time using DFS for example).
Each connected component only needs to be processed once since its size is common for all vertices it, so the overall complexity remains linear per i.

As noted at the start, only i \leq 10 matters, so repeating this process 10 times gives us all the d_{u, i} values we need.
From there, all c_{u, i} can be computed, and from them the answer so we’re done.

TIME COMPLEXITY:

\mathcal{O}(10\cdot N) per testcase.

CODE:

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
long long readInt(long long l, long long r, char endd) {
    long long x = 0;
    int cnt = 0;
    int fi = -1;
    bool is_neg = false;
    while (true) {
        char g = getchar();
        if (g == '-') {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if ('0' <= g && g <= '9') {
            x *= 10;
            x += g - '0';
            if (cnt == 0) {
                fi = g - '0';
            }
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);
 
            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if (g == endd) {
            assert(cnt > 0);
            if (is_neg) {
                x = -x;
            }
            assert(l <= x && x <= r);
            return x;
        }
        else {
            assert(false);
        }
    }
}
void dfs(int x, int vis[], vector<int> tr[], int a[], int temp, vector<int> &v){
    if(a[x - 1] % temp){
        return;
    }
    vis[x - 1]++;
    v.push_back(x - 1);
    for(int i = 0; i < (int)tr[x].size(); i++){
        int y = tr[x][i];
        if(vis[y - 1] == 0){
            //cout<<y<<"\n";
            dfs(y, vis, tr, a, temp, v);
        }
    }
}
int32_t main() {
	int t;
	t = readInt(1, 10000, '\n');
	int ns = 0;
	vector<int> p = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37};
	int ps = (int)p.size();
	while(t--){
	    int n;
	    n = readInt(2, 200000, '\n');
	    ns += n;
	    assert(ns <= 200000);
	    int a[n];
	    for(int i = 0; i < n; i++){
	        if(i != n - 1){
	            a[i] = readInt(1, 1000000000, ' ');
	        }else{
	            a[i] = readInt(1, 1000000000, '\n');
	        }
	    }
	    vector<int> tr[n + 1];
	    for(int i = 0; i < n - 1; i++){
	        int u, v;
	        u = readInt(1, n, ' ');
	        v = readInt(1, n, '\n');
	        tr[u].push_back(v);
	        tr[v].push_back(u);
	    }
	    int ans[n]={0};
	    int temp = 1;
	    int b[n];
	    for(int i = 0; i < n; i++){
	        b[i] = n;
	    }
	    vector<int> v;
	    for(int i = 0; i < ps; i++){
	        temp *= p[i];
	        int vis[n] = {};
	        for(int j = 0; j < n; j++){
	            if(vis[j] == 0){
	                if (a[j] % temp) {
                        ans[j] += b[j] * p[i];
                        b[j] = 0;
                        continue;
                    }
                    v.clear();
                    dfs(j+1,vis,tr,a,temp,v);
                    for (int u : v) {
                        ans[u] += (b[u] - (int)v.size()) * p[i];
                        b[u] = (int)v.size();
                    }
	            }
	        }
	    }
	    for(int i = 0; i < n; i++){
	        cout<<ans[i]<<" ";
	    }
	    cout<<"\n";
	}
}
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    vector<int> primes = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37};

    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector<int> a(n);
        for (int &x : a) cin >> x;

        vector adj(n, vector<int>());
        for (int i = 0; i < n-1; ++i) {
            int u, v; cin >> u >> v;
            adj[--u].push_back(--v);
            adj[v].push_back(u);
        }

        vector<int> ans(n), curcomp(n, n);
        ll prod = 1;
        for (auto p : primes) {
            prod *= p;
            
            vector<int> mark(n), component;
            auto dfs = [&] (const auto &self, int u) -> void {
                if (a[u] % prod) return;
                
                mark[u] = 1;
                component.push_back(u);
                for (int v : adj[u]) if (!mark[v]) {
                    self(self, v);
                }
            };

            for (int i = 0; i < n; ++i) {
                if (mark[i]) continue;
                if (a[i] % prod) {
                    ans[i] += curcomp[i] * p;
                    curcomp[i] = 0;
                    continue;
                }
                
                dfs(dfs, i);
                for (int u : component) {
                    ans[u] += (curcomp[u] - component.size()) * p;
                    curcomp[u] = component.size();
                }
                component.clear();
            }
        }

        for (int x : ans) cout << x << ' ';
        cout << '\n';
    }
}

2 Likes

okay, really nice solution!
I just shut my brain and did reroot dp.

1 Like

Can you explain the solution how reroot dp works here. I mean there are combinations that would require removing some data in the process of dfs if we start from only one node, how do you handle that?

just maintain an array of 10 elements, where each i is the number of paths from node to any of its child what has mex-p equal to pi

then transitions are simple. if you have any questions just msg me on LinkedIn, i have the same handle, parag776