CNTSEQ - Editorial

Cnt Number of Sequences:

PROBLEM LINK:

Practice
Division 1
Division 2

Author: Nitin Gupta

Testers: Lavish Gupta, Takuki Kurokawa, Nishant

DIFFICULTY:

MEDIUM

PREREQUISITES:

Centroid Composition and Combinatorics

PROBLEM:

f(x) is equal to largest divisor of x which is not equal to x and f(1)=1.

You are given a tree T containing N nodes. You are also given an array B of size N where B[i] is the value of node i where 1 \leq i \leq N.You have to answer Q queries of the form.

X R

Where X denotes some node. You have to find the number of distinct non-empty sequences of nodes, which should satisfy these three restriction.

Lets assume sequence of node is array A (1 base indexing ) of length l.

1. A[1] = X

2. For every i \gt 1 distance between node A[i] and node A[1] is less then equal to R.
Distance between two nodes is defined as the number of edges between simple path from node A[i] to node A[1].

3. For every i \gt 1 either of two should be true.

  • B[A[i]]=f(B[A[i-1]])
  • B[A[i]]=B[A[i-1]]

Nodes should be distinct in sequence.

QUICK EXPLANATION:

So as we can see we need frequency of elements that comes in sequence in a radius R of tree, for that we can use centroid decomposition.

EXPLANATION:

For a query X,R where X is a Node.
First lets count the Number of sequences if we have required information from Tree:-
First value of sequence is fixed that is given node X itself, i am going to explain finding number of sequences using a example, lets take 18 as B[X], so f(18)=9, f(9)=3, f(3)=1.
Lets assume frequency of 18 as a , 9 as b , 3 as c, 1 as d among the nodes that are atmost R distance far from given node X.

There are four type of sequence ( please note that we need to find the number of sequences of node ,not of values but i am writing values down below because restriction is on Values of Node ) :-
18…18,18
18…18,18, 9,9,9…9
18…18,18, 9,9,9…9,3,3,…3
18…18,18, 9,9,9…9,3,3,…3,1,1…1

Lets assume P(n)=\binom{n}{1}*(1!)+\binom{n}{2}*(2!)+\binom{n}{3}*(3!)+...\binom{n}{n}(n!).

  • Number of First type of sequences containing only nodes having value 18 are P(a-1)+1 , as first node is fixed that is the given node and remaining are a-1 nodes.
  • Number of Second type of sequeces are :- (P(a-1)+1)*P(b).
  • Number of Third type of sequeces are :- (P(a-1)+1)*P(b)*P(c).
  • Number of Fourth type of sequeces are:- (P(a-1)+1)*P(b)*P(c)*P(d).
    So total number of sequences are:- (P(a-1)+1)+((P(a-1)+1)*P(b))+((P(a-1)+1)*P(b)*P(c))+((P(a-1)+1)*P(b)*P(c)*P(d)).

P(n) follow a recursive relation:- P(n)=n*(1+P(n-1)).
This can be easily generalised for any value.

Now we need to find the frequency of values required in a particular sequence,that part can be done with centroid decomposition.
Lets assume you have centroid tree and you have the find the frequency of value V from nodes atmost distance R from given node X.
Now at every node in centeroid tree we are going to store the distances of all values in subtree of centeroid tree. Now to get the the number of nodes atmost distance R having values V we can binary search over the distances of V of node X and find the count, but this only gives the contribution of nodes that lies in the subtree of X in subtree of centroid tree. Lets assume Par(node) gives the parent of node in centroid tree. We will go to the parent of X that is Par(X) and lets assume the distance between X and Par(X) is D in original tree, so we need to find the number of nodes having value V at atmost distance R-D from Par(X) but we may overcount so we need to remove the contribution of node X that we have already counted, for that we can also store the distances of nodes in the subtree of a particular node from Par(node), after removing the contribution of X we can repeat this process until we reach the Root of centroid tree. So Overall it can be done in O(Q*log(Max(B[i]))*log(N)*log(N)).

SOLUTIONS:

Setter's Solution
     #include <bits/stdc++.h>
    
    #define int long long
    #define endl           "\n"
    #define mod            1000000007
    #define all(x)         x.begin(), x.end()
    #define s              second
    #define sz(x)          (int)(x).size()
    #define nitin          ios_base::sync_with_stdio(false); cin.tie(nullptr)
    using namespace std;
    
    
    const int N = 100001;
    const int M = 22;
    const int MV=100001;
    bool del[N];
    vector<int> eg[N];
    int B[N];
    int u[N],v[N];
    int dist[M][N];
    int sz[N], c_par[N], level[N];
    map<int,vector<int>>d[N];
    map<int,vector<int>>cp_d[N];
    
    int ncr[N];
    
    int prime[MV];
    void pre() {
        for(int i=1;i<MV;i++) prime[i]=i;
        for (int i = 2; i<MV; i++) {
            for (int j = i; j < MV; j += i) {
                prime[j] = min(prime[j],i);
            }
        }
        ncr[0]=0;
        for(int i=1;i<N;i++){
            ncr[i]=(i*(1+ncr[i-1]))%mod;
        }
    }
    int nodes = 0;
    int adj(int e,int x){
        if(u[e]==x) return v[e];
        else return u[e];
    }
    void sz_cal(int node,int par=-1){
        nodes++;
        sz[node]=1;
        for(auto &e:eg[node]){
            int c=adj(e,node);
            if(del[e] || c==par) continue;
            sz_cal(c,node);
            sz[node]+=sz[c];
        }
    }
    int cent(int node,int par=-1){
        for(auto &e:eg[node]){
            int c=adj(e,node);
            if(del[e] || c==par) continue;
            if(sz[c]>(nodes)/2){
                return cent(c,node);
            }
        }
        return node;
    }
    void dfs_dist(int node,int par,int lvl,int d){
        dist[lvl][node]=d;
        for(auto &e:eg[node]){
            int c=adj(e,node);
            if(del[e] || c==par) continue;
            dfs_dist(c,node,lvl,d+1);
        }
    }
    void decompose(int vert, int cent_par = 0) {
        nodes = 0;
        sz_cal(vert);
        int child = cent(vert);
        c_par[child] = cent_par;
        level[child] = level[cent_par] + 1;
        dfs_dist(child,-1,level[child],0);
        for (auto e:eg[child]) {
            int c=adj(e,child);
            if(del[e]) continue;
            del[e]=true;
            decompose(c, child);
        }
    }
    int how_many(vector<int>&t,int l){
        return upper_bound(t.begin(),t.end(),l)-t.begin();
    }
    int qry(int node,int r,int val){
        int org_node=node;
        int tot=how_many(d[node][val],r);
        while(c_par[node]!=0){
            int child=node;
            node=c_par[node];
            int distance=r-dist[level[node]][org_node];
            int fir_val=how_many(d[node][val],distance);
            int sec_val=how_many(cp_d[child][val],distance);
            tot+= fir_val-sec_val ;
        }
        return tot;
    }
    void solve() {
        int n;
        cin>>n;
        assert(n>=1 && n<=100000);
        for(int i=0;i<n;i++){
            cin>>B[i+1];
            assert(B[i+1]>=1 && B[i+1]<=100000);
        }
        for(int i=0;i<n-1;i++){
            cin>>u[i]>>v[i];
            assert(u[i]>=1 && u[i]<=n);
            assert(v[i]>=1 && v[i]<=n);
            eg[u[i]].push_back(i);
            eg[v[i]].push_back(i);
        }
        decompose(1);
        for(int i=1;i<=n;i++){
            int node=i;
            while(node!=0){
                d[node][B[i]].push_back(dist[level[node]][i]);
                int par=c_par[node];
                if(par!=0){
                    cp_d[node][B[i]].push_back(dist[level[par]][i]);
                }
                node=par;
            }
        }
        for(int i=1;i<=n;i++){
            for(auto &c:d[i]){
                sort(all(c.second));
            }
            for(auto &c:cp_d[i]){
                sort(all(c.second));
            }
        }
        int q;
        cin>>q;
        assert(q<=10000 && q>=1);
        while(q--){
            int node,r;
            cin>>node>>r;
            assert(node>=1 && node<=n);
            assert(r>=1 && r<=n);
            int val=B[node];
            vector<int>seq;
            while(val!=1){
                seq.push_back(val);
                val/=prime[val];
            }
            seq.push_back(1);
            vector<int>freq;
            for(auto &c:seq){
                freq.push_back(qry(node,r,c));
            }
            int ans=ncr[freq[0]-1]+1;
            int tot=ans;
            for(int i=1;i<freq.size();i++){
                tot*=ncr[freq[i]];
                tot%=mod;
                ans+=tot;
                ans%=mod;
            }
            cout<<ans<<endl;
        }
    
    }
    
    int32_t main() {
        nitin;
        pre();
        solve();
    }
Tester's Solution (tabr)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

template <long long mod>
struct modular {
    long long value;
    modular(long long x = 0) {
        value = x % mod;
        if (value < 0) value += mod;
    }
    modular& operator+=(const modular& other) {
        if ((value += other.value) >= mod) value -= mod;
        return *this;
    }
    modular& operator-=(const modular& other) {
        if ((value -= other.value) < 0) value += mod;
        return *this;
    }
    modular& operator*=(const modular& other) {
        value = value * other.value % mod;
        return *this;
    }
    modular& operator/=(const modular& other) {
        long long a = 0, b = 1, c = other.value, m = mod;
        while (c != 0) {
            long long t = m / c;
            m -= t * c;
            swap(c, m);
            a -= t * b;
            swap(a, b);
        }
        a %= mod;
        if (a < 0) a += mod;
        value = value * a % mod;
        return *this;
    }
    friend modular operator+(const modular& lhs, const modular& rhs) { return modular(lhs) += rhs; }
    friend modular operator-(const modular& lhs, const modular& rhs) { return modular(lhs) -= rhs; }
    friend modular operator*(const modular& lhs, const modular& rhs) { return modular(lhs) *= rhs; }
    friend modular operator/(const modular& lhs, const modular& rhs) { return modular(lhs) /= rhs; }
    modular& operator++() { return *this += 1; }
    modular& operator--() { return *this -= 1; }
    modular operator++(int) {
        modular res(*this);
        *this += 1;
        return res;
    }
    modular operator--(int) {
        modular res(*this);
        *this -= 1;
        return res;
    }
    modular operator-() const { return modular(-value); }
    bool operator==(const modular& rhs) const { return value == rhs.value; }
    bool operator!=(const modular& rhs) const { return value != rhs.value; }
    bool operator<(const modular& rhs) const { return value < rhs.value; }
};
template <long long mod>
string to_string(const modular<mod>& x) {
    return to_string(x.value);
}
template <long long mod>
ostream& operator<<(ostream& stream, const modular<mod>& x) {
    return stream << x.value;
}
template <long long mod>
istream& operator>>(istream& stream, modular<mod>& x) {
    stream >> x.value;
    x.value %= mod;
    if (x.value < 0) x.value += mod;
    return stream;
}

constexpr long long mod = (long long) 1e9 + 7;
using mint = modular<mod>;

mint power(mint a, long long n) {
    mint res = 1;
    while (n > 0) {
        if (n & 1) {
            res *= a;
        }
        a *= a;
        n >>= 1;
    }
    return res;
}

vector<mint> fact(1, 1);
vector<mint> finv(1, 1);

mint C(int n, int k) {
    if (n < k || k < 0) {
        return mint(0);
    }
    while ((int) fact.size() < n + 1) {
        fact.emplace_back(fact.back() * (int) fact.size());
        finv.emplace_back(mint(1) / fact.back());
    }
    return fact[n] * finv[k] * finv[n - k];
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    const int m = (int) 1e5 + 10;
    vector<int> f(m, 1);
    for (int i = 2; i < m; i++) {
        if (f[i] != 1) {
            continue;
        }
        for (int j = 2; i * j < m; j++) {
            f[i * j] = max(f[i * j], j);
        }
    }
    vector<mint> ways(m);
    mint inv_facts = 1;
    C(m, 0);
    for (int i = 1; i < m; i++) {
        inv_facts += finv[i];
        ways[i] = fact[i] * inv_facts - 1;
    }
    debug(ways[0]);
    debug(ways[1]);
    debug(ways[2]);
    debug(ways[3]);
    debug(ways[4]);
    debug(ways[5]);
    int n;
    cin >> n;
    vector<int> b(n);
    for (int i = 0; i < n; i++) {
        cin >> b[i];
    }
    vector<vector<int>> g(n);
    for (int i = 0; i < n - 1; i++) {
        int x, y;
        cin >> x >> y;
        x--, y--;
        g[x].emplace_back(y);
        g[y].emplace_back(x);
    }
    int q;
    cin >> q;
    vector<int> x(q), r(q);
    vector<map<int, int>> cnts(q);
    for (int i = 0; i < q; i++) {
        cin >> x[i] >> r[i];
        x[i]--;
    }
    vector<int> used(n);
    vector<int> child(n);
    vector<int> sz(n);
    vector<int> depth(n);
    vector<int> pv(n, -1);
    vector<map<int, int>> mp_solo(n);
    queue<int> que;
    que.emplace(0);
    vector<vector<int>> who(n);
    for (int i = 0; i < q; i++) {
        who[0].emplace_back(i);
    }
    while (!que.empty()) {
        int root = que.front();
        que.pop();
        int c = -1;
        vector<int> vs;
        function<void(int)> find_centroid = [&](int v) -> void {
            vs.emplace_back(v);
            sz[v] = 1;
            for (int to : g[v]) {
                if (to == pv[v] || used[to]) {
                    continue;
                }
                pv[to] = v;
                find_centroid(to);
                sz[v] += sz[to];
            }
        };
        pv[root] = -1;
        find_centroid(root);
        for (int v : vs) {
            if (sz[root] - sz[v] <= sz[root] / 2) {
                int mx = 0;
                for (int to : g[v]) {
                    if (to == pv[v] || used[to]) {
                        continue;
                    }
                    mx = max(mx, sz[to]);
                }
                if (mx <= sz[root] / 2) {
                    c = v;
                }
            }
        }
        assert(c != -1);
        function<void(int, int)> set_child = [&](int v, int p) {
            for (int to : g[v]) {
                if (to == p || used[to]) {
                    continue;
                }
                depth[to] = depth[v] + 1;
                child[to] = (v == c ? to : child[v]);
                set_child(to, v);
            }
        };
        depth[c] = 0;
        child[c] = -1;
        set_child(c, -1);
        auto wr = who[root];
        who[root].clear();
        sort(wr.begin(), wr.end(), [&](int i, int j) { return r[i] - depth[x[i]] < r[j] - depth[x[j]]; });
        sort(vs.begin(), vs.end(), [&](int i, int j) { return depth[i] < depth[j]; });
        int id = 0;
        map<int, int> mp_all;
        for (int v : vs) {
            mp_solo[v].clear();
        }
        debug(c, wr);
        for (int v : vs) {
            debug(v, depth[v]);
            while (id < (int) wr.size() && r[wr[id]] - depth[x[wr[id]]] < depth[v]) {
                int cc = b[x[wr[id]]];
                debug(c, v, wr[id], r[wr[id]] - depth[x[wr[id]]], depth[v]);
                while (true) {
                    cnts[wr[id]][cc] += mp_all[cc];
                    if (x[wr[id]] != c) {
                        cnts[wr[id]][cc] -= mp_solo[child[x[wr[id]]]][cc];
                        if (cc == 1) {
                            who[child[x[wr[id]]]].emplace_back(wr[id]);
                        }
                    }
                    if (cc == 1) {
                        break;
                    }
                    cc = f[cc];
                }
                id++;
            }
            mp_all[b[v]]++;
            if (v != c) {
                mp_solo[child[v]][b[v]]++;
            }
        }
        while (id < (int) wr.size()) {
            int cc = b[x[wr[id]]];
            debug(c, wr[id]);
            while (true) {
                cnts[wr[id]][cc] += mp_all[cc];
                if (cc == 1) {
                    break;
                }
                cc = f[cc];
            }
            id++;
        }
        used[c] = 1;
        for (int to : g[c]) {
            if (!used[to]) {
                que.emplace(to);
            }
        }
    }
    for (int i = 0; i < q; i++) {
        int c = b[x[i]];
        vector<int> a;
        while (true) {
            int cnt = cnts[i][c];
            if (cnt == 0) {
                break;
            }
            a.emplace_back(cnt);
            if (c == 1) {
                break;
            }
            c = f[c];
        }
        debug(a);
        a[0]--;
        mint ans = 1 + ways[a[0]];
        mint now = 1 + ways[a[0]];
        for (int j = 1; j < (int) a.size(); j++) {
            now *= ways[a[j]];
            ans += now;
        }
        cout << ans << '\n';
    }
    return 0;
}
Tester's Solution (lavish315)

#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define pii pair<int , int>
 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/
 
const int MAX_T = 10;
const int MAX_N = 100000;
const int MAX_Q = 10000;
const int MAX_val = 100000;
const int MAX_SUM_N = 100000;
 
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
 
int sum_n = 0;
int max_n = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;
ll z = 1000000007;
ll sum_nk = 0 ;

vector<pii> cent[MAX_N] ;
vector<vector<int> > adj(MAX_N , vector<int>()) ;
vector<int> par(MAX_N , -1) ;
vector<int> par_dist(MAX_N) ;
vector<int> val(MAX_N) ;
vector<pii> queries ;
vector<ll> p(MAX_N+5) ;
vector<ll> f(MAX_val+5) ;
map<int , vector<int> > val_dist[MAX_N] ;
map<pii , vector<int> > child_val_dist[MAX_N] ;
map<int , int> orig_dist[MAX_N] ;

map<int , int> f_used ;
ll tot_chain_length = 0 ;
ll max_chain_length = 0 ;
map<int , int> nodes_query ;
ll tot_chain_sum = 0 ;
ll max_chain_sum = 0 ;
ll upto_1 = 0 ;
map<int , int> overall_freq ;
ll tot_chain_sum_left = 0 ;
ll max_chain_sum_left = 0 ;


pii OneCentroid(int root, const vector<vector<int>> &g, const vector<bool> &dead) {
        static vector<int> sz(g.size());
        function<void (int, int)> get_sz = [&](int u, int prev) {
                sz[u] = 1;
                for (auto v : g[u]) if (v != prev && !dead[v]) {
                        get_sz(v, u);
                        sz[u] += sz[v];
                }
        };
        get_sz(root, -1);
        int n = sz[root];
        function<int (int, int)> dfs = [&](int u, int prev) {
                for (auto v : g[u]) if (v != prev && !dead[v]) {
                        if (sz[v] > n / 2) {
                                return dfs(v, u);
                        }
                }
                return u;
        };

        function<void (int , int , int , int)> get_dist = [&](int u , int prev , int root , int dist)
        {
            orig_dist[root][u] = dist ;
            for(auto v : g[u])
            {
                if(v != prev && !dead[v])
                    get_dist(v , u , root , dist+1) ;
            }
        };
        int d = 0 ;
        int c = dfs(root, -1);
        get_dist(c , -1 , c , 0) ;
        return {c , d} ;
}

void CentroidDecomposition(const vector<vector<int>> &g) {
        int n = (int) g.size();
        vector<bool> dead(n, false);
        function<pii (int)> rec = [&](int start) {
                pii cd = OneCentroid(start, g, dead);          
                int c = cd.first ;
                dead[c] = true;                                
                for (auto u : g[c]) if (!dead[u]) {
                        pii g = rec(u);  
                        int v = g.first , w = g.second+1 ;
                        cent[c].push_back({v , w}) ;
                        cent[v].push_back({c , w}) ;
                        par[v] = c ;
                        par_dist[v] = w ;
                }
                dead[c] = false; 
                return cd ;                              
        };
        rec(0);                                                
}

void initialize()
{
    p[0] = 0 ;
    for(ll i = 1 ; i < MAX_N+5 ; i++)
    {
        p[i] = (i * (p[i-1] + 1))%z ;
    }

    f[1] = 1 ;
    for(int i = 2 ; i < MAX_val+5 ; i++)
    {
        if(f[i] == 0)
        {
            for(int j = i ; j < MAX_val+5 ; j += i)
                if(f[j] == 0)
                    f[j] = i ;
        }
    }
    return ;
}

void dfs_cnt(int u , int p , int child)
{
    assert(u >= 0 && u < MAX_N) ;
    int curr_val = val[u] ;
    val_dist[p][curr_val].push_back(orig_dist[p][u]) ;

    if(u != p)
        child_val_dist[p][{child , curr_val}].push_back(orig_dist[p][u]) ;

    for(int i = 0 ; i < cent[u].size() ; i++)
    {
        pii g = cent[u][i] ;
        int v = g.first , w = g.second ;
        if(v != par[u])
        {
            if(u == p)
                child = v ;
            dfs_cnt(v , p , child) ;
        }
    }
    return ;
}

void get_curr_freq(vector<int> &all_val , vector<int> &val_freq , int x , int r , int mult)
{   
    if(mult == 1)
    {
        for(int i = 0 ; i < all_val.size() ; i++)
        {
            int curr_val = all_val[i] ;
            int curr_freq = upper_bound(val_dist[x][curr_val].begin() , val_dist[x][curr_val].end() , r) - val_dist[x][curr_val].begin() ;
            val_freq[i] += (curr_freq) ;
        }
    }
    else
    {
        for(int i = 0 ; i < all_val.size() ; i++)
        {
            int curr_val = all_val[i] ;
            int curr_freq = upper_bound(child_val_dist[x][{-mult, curr_val}].begin() , child_val_dist[x][{-mult, curr_val}].end() , r) - child_val_dist[x][{-mult, curr_val}].begin() ;
            val_freq[i] -= (curr_freq) ;
            assert(val_freq[i] >= 0) ;
        }
    }
    return ;
}

vector<int> get_freq(int x , int r)
{
    int curr_val = val[x] ;
    vector<int> all_val; 

    all_val.push_back(curr_val) ;
    while(curr_val > 1)
    {
        curr_val /= f[curr_val] ;
        all_val.push_back(curr_val) ;
    }

    for(int i = 0 ; i < all_val.size() ; i++)
        f_used[all_val[i]]++ ;

    if(curr_val == 1)
        upto_1++ ;

    ll k = all_val.size() ;
    tot_chain_length += k;
    max_chain_length = max(max_chain_length , k) ;


    vector<int> val_freq(k) ;

    get_curr_freq(all_val , val_freq , x , r , 1) ;

    assert(x >= 0 && x < MAX_N) ;
    int np = par[x] , nx = x ;
    
    while(np != -1)
    {
        assert(np >= 0 && np < MAX_N) ;
        int dist = orig_dist[np][x];
        // cerr << "nx = " << nx << " np = " << np << '\n' ;
        get_curr_freq(all_val , val_freq , np , r-dist , 1) ;
        get_curr_freq(all_val , val_freq , np , r-dist , -nx) ;

        nx = np ;
        assert(nx >= 0 && nx < MAX_N) ;
        np = par[nx] ;
    }
    return val_freq ;
}

int get_ans(int x , int r)
{
    
    vector<int> freq = get_freq(x , r) ;

    ll curr_freq_sum = 0 ;
    ll tot_freq_sum = 0 ;
    for(int i = 0 ; i < freq.size() ; i++)
    {
        tot_freq_sum += overall_freq[val[i]] ;
        curr_freq_sum += freq[i] ;
    }

    tot_chain_sum += curr_freq_sum ;
    max_chain_sum = max(max_chain_sum , curr_freq_sum) ;

    tot_chain_sum_left += (tot_freq_sum - curr_freq_sum) ;
    max_chain_sum_left = max(max_chain_sum_left ,tot_freq_sum - curr_freq_sum ) ;


    freq[0]-- ;

    ll ans = 0 ;
    ll curr_val = 1 ;

    for(int i = 0 ; i < freq.size() ; i++)
    {
        assert(freq[i] >= 0 && freq[i] < MAX_N) ;
        ll curr_contri = p[freq[i]] ;
        curr_contri += (i == 0) ;

        curr_val = (curr_val * curr_contri)%z ;
        ans += curr_val ;
        ans %= z ;
    }
    return ans ;

}

void solve()
{   
    int n = readIntLn(1 , MAX_N) ;
    for(int i = 0 ; i < n-1 ; i++)
    {
        val[i] = readIntSp(1 , MAX_val) ;
        overall_freq[val[i]]++ ;
    }
    val[n-1] = readIntLn(1 , MAX_val) ;
    overall_freq[val[n-1]]++ ;

    for(int i = 0 ; i < n-1 ; i++)
    {
        int u , v ;
        u = readIntSp(1 , n) ;
        v = readIntLn(1 , n) ;

        u-- ; v-- ;
        adj[u].push_back(v) ;
        adj[v].push_back(u) ;
    }

    int q = readIntLn(1 , MAX_Q) ;

    for(int i = 0 ; i < q ; i++)
    {
        int x , r ;
        x = readIntSp(1 , n) ;
        r = readIntLn(1 , n) ;
        x-- ;
        queries.push_back({x , r}) ;
        nodes_query[x]++ ;
    }

    // cout << "Input Verified" << endl ;
    /**************************** Input Verified ****************************/

    CentroidDecomposition(adj) ;
    initialize() ;

    // cout << "root size: " << cent[3].size() << endl ;

    // cout << "Decomposed" << endl ;
    // for(int i = 0 ; i < n ; i++)
    //     for(int j = 0 ; j < n ; j++)
    //         cout << "dist : " << i << ' ' << j << " " << orig_dist[i][j] << endl ;

    for(int i = 0 ; i < n ; i++)
    {
        //cout << "i = " << i << endl ;
        dfs_cnt(i , i , 0) ;
        for(auto itr = val_dist[i].begin() ; itr != val_dist[i].end() ; itr++)
        {
            int curr_val = (*itr).first ;
            sort(val_dist[i][curr_val].begin() , val_dist[i][curr_val].end()) ;
        }
        for(auto itr = child_val_dist[i].begin() ; itr != child_val_dist[i].end() ; itr++)
        {
            pii curr_val = (*itr).first ;
            sort(child_val_dist[i][curr_val].begin() , child_val_dist[i][curr_val].end()) ;
        }
    }

    //cout << "Values stored" << endl ;

    for(int i = 0 ; i < q ; i++)
    {
        //cout << "query num = " << i << endl ;
        int x = queries[i].first , r = queries[i].second ;
        assert(x >= 0 && x < n) ;
        int ans = get_ans(x , r) ;
        cout << ans << '\n' ;
    }


    cerr << "chains reaching upto 1 = " << upto_1 << endl ;
    cerr << "Total Chain length = " << tot_chain_length << endl ;
    cerr << "Max Chain length = " << max_chain_length << endl ;
    cerr << "Total Chain freq sum = " << tot_chain_sum << endl ;
    cerr << "Max Chain freq sum = " << max_chain_sum << endl ;
    cerr << "total distinc f values used = " << f_used.size() << endl ;
    cerr << "distinct nodes used = " << nodes_query.size() << endl ;
    cerr << "total favourable nodes exceeding dist = " << tot_chain_sum_left << endl ;
    cerr << "max favourable nodes exceeding dist = " << max_chain_sum_left << endl ;

    return ;

}
 
signed main()
{
    fast;
    #ifndef ONLINE_JUDGE
    freopen("inputf.txt" , "r" , stdin) ;
    freopen("outputf.txt" , "w" , stdout) ;
    freopen("error.txt" , "w" , stderr) ;
    #endif
    
    int t = 1;
    
    // t = readIntLn(1,MAX_T);

    for(int i=1;i<=t;i++)
    {    
        solve() ;
    }
    
    assert(getchar() == -1);
    // assert(sum_n <= MAX_SUM_N);
 
    cerr<<"SUCCESS\n";
    cerr<<"Tests : " << t << '\n';
    // cerr<<"Sum of lengths : " << sum_n << '\n';
    // cerr<<"Maximum length : " << max_n << '\n';
    // cerr << "Sum of product : " << sum_nk << '\n' ;
    // cerr<<"Total operations : " << total_ops << '\n';
    // cerr<<"Answered yes : " << yess << '\n';
    // cerr<<"Answered no : " << nos << '\n';
}
 
Tester's Solution (nishant403)
#include <bits/stdc++.h>
using namespace std;

#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
   
#define int long long 
#define pb push_back
#define S second
#define F first
#define f(i,n) for(int i=0;i<n;i++)
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define vi vector<int>
#define pii pair<int,int>
#define all(x) x.begin(),x.end()

#define ordered_set tree<pii, null_type,less<pii>, rb_tree_tag,tree_order_statistics_node_update>

const int MOD = 1e9 + 7;

const int N = 1e5 + 10;
const int Q = 1e4 + 10;

vector<int> g[N];
bool vis[N];
int n,q;
int b[N];

int s[N];

int last_div[N]; // f(x) : largest divisor not equal to x 

vector< pair<pii,pii> > go[N]; //go[i] contains queries ; [[distance , value],[index,count]]

vector<pii> c_vals[Q]; //For each query, store {value , count} after computing go

int sz[N]; //size of the centroid subtree
ordered_set dep[N]; //dep[i] is multiset of depths w.r.t centroid having value i 
int nn; //number of nodes in the centroid subtree
 
void pre()
{
    //compute largest divisor
    for(int i=1;i<N;i++)
    {
        for(int j=i+i;j<N;j+=i)
        {
            last_div[j] = max(last_div[j] , i);
        }
    }
    
    last_div[1] = 1;
    
    //compute s array
    s[0] = 1;
    s[1] = 2;
    
    for(int i=2;i<N;i++) 
    {
        s[i] = (1 + (i*s[i-1]))%MOD;
    }    

    for(int i=0;i<N;i++) assert(s[i] != 0);
    
    for(int i=0;i<N;i++) s[i]--;
}
    
 
//---------------Centroid Decomposition-------------------------
//reference : https://www.codechef.com/viewsolution/7402256
 
//To find sizes of nodes
void dfs1(int node,int par)
{
   sz[node] = 1;
   nn++;
    
   for(auto x : g[node])
       if(!vis[x] && x != par) dfs1(x,node),sz[node]+=sz[x];
}
 
//Find centroid from the sizes
int dfs2(int node,int par)
{
    for(auto x : g[node])
        if(!vis[x] && x != par && sz[x] > (nn/2)) return dfs2(x,node);
        
	return node;
}
 
//Add/Remove the data of the distances  
void dfs3(int node,int par,int dist,int ad)
{
    nn++;
    
    if(ad == 1)
    {
        dep[b[node]].insert({dist,node});
    }
    else
    {
        dep[b[node]].erase({dist,node});
    }
    
    for(auto x : g[node])
        if(!vis[x] && x != par) dfs3(x,node,dist+1,ad);
}
 
//Add/Remove the answers to the nodes
void dfs4(int node,int par,int dist,int ad)
{
    for(auto & x : go[node])
    {
        x.S.S += dep[x.F.S].order_of_key({x.F.F + 1 - dist,-1})*ad;
    }
    
    for(auto x : g[node])
        if(!vis[x] && x != par) dfs4(x,node,dist+1,ad);
}
 
//Main centroid decomposition function
void decompose(int root)
{
	nn = 0;
    dfs1(root,root);
    
    int centroid = dfs2(root,root);
   
    //add for all and calculate for all
    nn = 0;
    dfs3(centroid,-1,0,1);
    dfs4(centroid,-1,0,1);
    dfs3(centroid,-1,0,-1);
    
    //now remove the repeated things for subtrees
    for(auto x : g[centroid])
    {
        if(vis[x]) continue;
        
        nn = 1;
        dfs3(x,centroid,1,1);
        dfs4(x,centroid,1,-1);
        dfs3(x,centroid,1,-1);
    }
    
    vis[centroid] = 1;
    
    //remove centroid from the tree and recurse
    for(auto x : g[centroid])
        if(!vis[x]) decompose(x);
}

//number of ways to form valid sequence and where we can select 0 elements as well
int get_ways(vector<pii> & c_vals,int ee)
{
   if(c_vals.empty()) return 1; 
    
   if(ee != c_vals.back().F && c_vals.back().S == 0) return 1;
   else if(ee == c_vals.back().F && c_vals.back().S == 0)
   {
       c_vals.pop_back();
       return get_ways(c_vals,ee);
   }
    
   int res = s[c_vals.back().S];

   if(ee == c_vals.back().F)
   {
       c_vals.pop_back();
       res = (res + 1)*get_ways(c_vals,ee);
   }
   else
   {
       c_vals.pop_back();
       res = 1 + (res*get_ways(c_vals,ee));
   }
    
   res %= MOD;
    
   return res;
}

signed main()
{
    fast;
    
    pre();
    
    cin >> n;
    
    for(int i=1;i<=n;i++) cin >> b[i];
    
    int u,v;
    
    f(i,n-1)
    {
        cin >> u >> v;
        g[u].pb(v);
        g[v].pb(u);
    }
    
    cin >> q;
    
    f(i,q)
    {
        cin >> u >> v;
        
        int val = b[u];
        
        while(val != 1)
        {
            go[u].push_back({{v,val},{i,0}});
            val = last_div[val];
        }
        
        go[u].push_back({{v,val},{i,0}});
    }
    
    decompose(1);
    
    for(int i=1;i<=n;i++)
    {
        for(auto x : go[i])
        {
            c_vals[x.S.F].pb({x.F.S,x.S.S});
        }
    }
    
    for(int i=0;i<q;i++)
    {
        sort(all(c_vals[i]));
        
        c_vals[i].back().S--;
        
        cout << get_ways(c_vals[i] , c_vals[i].back().F)  << '\n';
    }
}

2 Likes

can you explain what the map<int,vector> d[n] and cd_p[n] is storing and the fashion in which they are storing it