Author: Takuki Kurokawa

Tester: Aryan Choudhary




Dynamic Programming.


You are given a positive integer K and a tree with N vertices, rooted at vertex 1.

For all integer i (2 \leq i \leq N), P_i is a parent of vertex i.

Let’s call an array A consisting of N positive integers brilliant if the following constraints are met;

  • A_{P_i}\bmod A_i = 0, for all integer i (2 \leq i \leq N).

  • \prod_{i=1}^N A_i \leq K.

Compute remainder modulo 998244353 of the number of possible brilliant arrays.


Hint 1

What if the product of array is restricted to power of 2?

Hint 2

Use tree DP and solve Hint 1.

Hint 3

Solve Hint 1 for each prime and combine them. Can you solve with K = 100000?

Hint 4

Hint 3 can be solved using DP. The key should be K / (product of A), not product of A. By the way, when K = 100, can A[2] be 11?

Hint 5

A[1] is the only one that can be multiple of prime more than sqrt(K). If you can exclude these primes, this DP can be computed with K = 1e8. But how?

Hint 6

Add new constraint: A[1] = lcm(A[2], A[3], …, A[N]).

Hint 7

With the constraint of Hint 6, you can compute this DP using two arrays.

Hint 8

With the constraint of Hint 6, there are values whose product of A cannot be. For example, 2, 12, 42.

Hint 9

If product of A has a prime factor p, it also should be divided by square of p. These number are called powerful number.

Hint 10

Let M be the number of powerful number not more than K. With K = 1e12, M is about 2e6 which is small enough.

Hint 11

Use associative array instead of array to compute DP. The number of transitions of DP will be at most M.


Consider the case that each element of A can be expressed as power of a prime p, so A_i = p^j.

Let dp^1_{x,y,z} be the number of arrays A which satisfy

  • A_x = p^y

  • \prod_{i \in \textrm{subtree}(x)} A_i = p^z

You can calculate this DP using DFS on tree. Since y, z \leq \log K, the time complexity is O(N \log^4 K) (setter’s solution). You can improve this to O(N \log^3 K) using prefix sum and inclusive-exclusive principle (tester’s solution).

Next, let’s combine primes. There are O(K/\log K) primes less than or equals to K, which is too many.

Notice that if p>\sqrt K, A_1 is the only element which can be multiple of p. Let’s add a new constraint: A_1 = \textrm{lcm}_{i\neq 1}(A_i). In other words, A_1 will be the smallest possible value. With this constraint, you can put off primes more than \sqrt K. You can multiply to A_1 between 1 and \lfloor K/\prod_{i\neq 1}A_i\rfloor later.

Let F_z be the number of arrays A which satisfy

  • \prod_{i=1}^N A_i = p^z

  • A_1 = \textrm{lcm}_{i\neq 1}(A_i) = \max_{i\neq 1}(A_i)

Let dp^2_{x,y} be number of arrays A which satisfy

  • \prod_{i=1}^N A_i has only prime factors between P_1 and P_x. (P_i denotes the i-th prime.)

  • \frac{K}{\prod_{i=1}^N A_i} = y

  • A_1=\textrm{lcm}_{i\neq 1}(A_i)

This DP’s transitions are dp^2_{x+1,\lfloor y/ P_x^i\rfloor}+=dp^2_{x,y}\times F_i. The initial value is dp^2_{0,K}=1.

There are O(\sqrt K / \log K) different primes so O(\sqrt K) both primes and perfect power of them. O(\sqrt K) keys in dp^2, therefore using two arrays whose boundary is \sqrt K, this DP runs in O(K).

What happen when you use associative array like std::map ? Let M be the number of positive integers which is not more than K and only have prime factors whose exponent are at least 2. Since F_1=0, the number of DP’s transitions is at most M. It can be shown that M=O(\sqrt K) (details are below), so this DP runs in O(\sqrt K\log K).

You can also use DFS and enumerate all possible these M values. This solution runs in O(\sqrt K) (tester’s solution).

Finally, consider vertex 1. You can multiply to A_1 between 1 to x, so you can just sum up the multiply of each key and value of dp^2.

The overall time complexity is O(N\log^4K +\sqrt K \log K).

The number X is called powerful number if for each prime factor p of X, p^2 divides X. X can be expressed as X=A^2B^3 with two positive integer A,B. Let’s count M', the number of pair (A,B) which satisfies A^2B^3\leq K.

From the discussion above, you can get

M\leq M'= K^{\frac{1}{3}}\sum_{i=1}^{\sqrt K}i^{-\frac{2}{3}}

The sigma part can be evaluated as

\int_{1}^{\sqrt K}x^{-\frac{2}{3}}\ dx=\left[3x^\frac{1}{3} \right]_1^{\sqrt K}\approx K^\frac{1}{6}

So M=O(\sqrt K).


dp^1 and F are multiplicative function.

Prefix sum of multiplicative function can be computed in sublinear time.


Setter's Solution
#include <bits/stdc++.h>
using namespace std;

int main() {
    const long long mod = 998244353;
    const int LOG_K = 42;
    const int SQRT_K = (int) 1e6 + 16;
    int n;
    long long k;
    cin >> n >> k;
    vector<int> p(n, -1);
    for (int i = 1; i < n; i++) {
        cin >> p[i];
    vector<vector<int>> g(n);
    for (int i = 1; i < n; i++) {
    vector dp1(n, vector(LOG_K, vector<long long>(LOG_K)));
    function<void(int)> dfs = [&](int v) {
        dp1[v][0][0] = 1;
        for (int to : g[v]) {
            vector new_dp1(LOG_K, vector<long long>(LOG_K));
            for (int i = 0; i < LOG_K; i++) {
                for (int j = i; j < LOG_K; j++) {
                    for (int ni = 0; ni < LOG_K; ni++) {
                        for (int nj = ni; j + nj < LOG_K; nj++) {
                            new_dp1[max(i, ni)][j + nj] = (new_dp1[max(i, ni)][j + nj] + dp1[v][i][j] * dp1[to][ni][nj]) % mod;
            swap(dp1[v], new_dp1);
        vector new_dp1(LOG_K, vector<long long>(LOG_K));
        for (int i = 0; i < LOG_K; i++) {
            for (int j = i; j < LOG_K; j++) {
                for (int ni = i; j + ni < LOG_K; ni++) {
                    if (v == 0 && ni != i) {
                    new_dp1[ni][j + ni] += dp1[v][i][j];
                    if (new_dp1[ni][j + ni] >= mod) {
                        new_dp1[ni][j + ni] -= mod;
        swap(dp1[v], new_dp1);
    vector<long long> f(LOG_K);
    for (int i = 0; i < LOG_K; i++) {
        for (int j = 0; j < LOG_K; j++) {
            f[j] += dp1[0][i][j];
            if (f[j] >= mod) {
                f[j] -= mod;
    vector<bool> is_prime(SQRT_K, true);
    is_prime[0] = is_prime[1] = false;
    map<long long, long long> dp2;
    dp2[k] = 1;
    for (int x = 2; x <= k / x; x++) {
        if (!is_prime[x]) {
        for (int y = x * 2; y < SQRT_K; y += x) {
            is_prime[y] = false;
        vector<long long> pows(1, 1);
        while (pows.back() <= k / x) {
            pows.emplace_back(pows.back() * x);
        for (auto iter = dp2.lower_bound(pows[2]); iter != dp2.end(); iter++) {
            auto [remain, value] = *iter;
            for (int i = 2; i < (int) pows.size(); i++) {
                if (remain < pows[i]) {
                dp2[remain / pows[i]] = (dp2[remain / pows[i]] + value * f[i]) % mod;
    long long ans = 0;
    for (auto [remain, value] : dp2) {
        ans = (ans + remain % mod * value) % mod;
    cout << ans << '\n';
    return 0;
Tester's Solution
  Compete against Yourself.
  Author - Aryan (@aryanc403)
  Credits -
  Atcoder library - (namespace atcoder)
  Github source code of library -
#ifdef ARYANC403
    #include <header.h>
    #pragma GCC optimize ("Ofast")
    #pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx")
    #pragma GCC optimize ("-ffloat-store")
    #include <ext/pb_ds/assoc_container.hpp>
    #include <ext/pb_ds/tree_policy.hpp>
    #define dbg(args...) 42;
    #define endl "\n"
// y_combinator from @neal template
template<class Fun> class y_combinator_result {
    Fun fun_;
    template<class T> explicit y_combinator_result(T &&fun): fun_(std::forward<T>(fun)) {}
    template<class ...Args> decltype(auto) operator()(Args &&...args) { return fun_(std::ref(*this), std::forward<Args>(args)...); }
template<class Fun> decltype(auto) y_combinator(Fun &&fun) { return y_combinator_result<std::decay_t<Fun>>(std::forward<Fun>(fun)); }
using namespace std;
#define fo(i,n)   for(i=0;i<(n);++i)
#define repA(i,j,n)   for(i=(j);i<=(n);++i)
#define repD(i,j,n)   for(i=(j);i>=(n);--i)
#define all(x) begin(x), end(x)
#define sz(x) ((lli)(x).size())
#define pb push_back
#define mp make_pair
#define X first
#define Y second
typedef long long int lli;
typedef long double mytype;
typedef pair<lli,lli> ii;
typedef vector<ii> vii;
typedef vector<lli> vi;
template <class T>
using ordered_set =  __gnu_pbds::tree<T,__gnu_pbds::null_type,less<T>,__gnu_pbds::rb_tree_tag,__gnu_pbds::tree_order_statistics_node_update>;
// X.find_by_order(k) return kth element. 0 indexed.
// X.order_of_key(k) returns count of elements strictly less than k.
const auto start_time = std::chrono::high_resolution_clock::now();
void aryanc403()
#ifdef ARYANC403
auto end_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end_time-start_time;
    cerr<<"Time Taken : "<<diff.count()<<"\n";
const lli SEED=chrono::steady_clock::now().time_since_epoch().count();
mt19937 rng(SEED);
inline lli rnd(lli l=0,lli r=INF)
{return uniform_int_distribution<lli>(l,r)(rng);}
class CMP
bool operator()(ii a , ii b) //For min priority_queue .
{    return ! ( a.X < b.X || ( a.X==b.X && a.Y <= b.Y ));   }};
void add( map<lli,lli> &m, lli x,lli cnt=1)
    auto jt=m.find(x);
    if(jt==m.end())         m.insert({x,cnt});
    else                    jt->Y+=cnt;
void del( map<lli,lli> &m, lli x,lli cnt=1)
    auto jt=m.find(x);
    if(jt->Y<=cnt)            m.erase(jt);
    else                      jt->Y-=cnt;
bool cmp(const ii &a,const ii &b)
    return a.X<b.X||(a.X==b.X&&a.Y<b.Y);
const lli mod = 1000000007L;
// const lli maxN = 1000000007L;
using namespace atcoder;
using mint = modint998244353;
//using mint = modint1000000007;
using vm = vector<mint>;
std::ostream& operator << (std::ostream& out, const mint& rhs) {
        return out<<rhs.val();
const int N = 43;
using dpType = array<array<mint,N>,N>;
const int LEN = 1<<20;
using bt = bitset<LEN>;
void primesUptoSqrt(lli n,vi &primes){
    const lli sq=(lli)(sqrt(n+1))+10;
    vector<bool> vis(sq+1);
    for(lli i=2;i*i<=n;++i){
        for(lli j=i*i;j*j<=n;j+=i)
void primesUptoN(const lli n,vi &primes){
    bt vis;
    for(lli l=0;l<n;l+=LEN){
        for(const auto &p:primes){
            for(lli i=p-(l%p);i<LEN;i+=p)
        for(lli i=1;i<LEN;++i){
int main(void) {
    // freopen("", "r", stdin);
    // freopen("txt.out", "w", stdout);
// cout<<std::fixed<<std::setprecision(35);
// cin>>T;while(T--)
    lli n,k;
    cerr<<"k "<<k<<endl;
    vector<vi> e(n);
    for(int i=1;i<n;++i){
        int p;
auto merge=[&](dpType &a,dpType b){
        dpType res;
        for(int am=0;am<N;++am)
            for(int ai=0;ai<N;++ai)
                for(int bi=0;bi+ai<N;++bi)
        return res;
    auto addValue=[&](dpType a,bool fl){
        dpType res;
        for(int am=N-1;am>0;--am)
            for(int ai=0;ai<N;++ai)
            return a;
        for(int pm=0;pm<N;++pm)
            for(int ai=0;ai+pm<N;++ai)
                for(int am=0;am<=pm&&am<=ai;++am)
        for(int ai=0;ai<N;++ai)
            for(int am=1;am<N;++am){
        return res;
    const auto dpWithMax=y_combinator([&](const auto &self,const int u)->dpType{
        dpType cur;
        for(int i=0;i<N;++i)
        for(auto x:e[u])
        return cur;
    vm ways(N-2);
    for(int at=2;at<N;++at)
        for(int am=0;2*am<=at;++am)
    vi primes;
    mint ans=k;
    // lli dfsCnt=0;
    y_combinator([&](const auto &self,lli val,lli pidx,mint fac)->void{
        // dfsCnt++;
        const lli p=primes[pidx];
        for(int i=0;val;++i,val/=p){
    // cerr<<"dfsCnt : "<<dfsCnt<<endl;
}   aryanc403();
    return 0;