GREATSACK - Editorial

PROBLEM LINK:

Practice
Contest

Author: Onkar Ratnaparkhi
Tester: Saptarshi Shome
Editorialist: Onkar Ratnaparkhi

DIFFICULTY:

EASY-MEDIUM

PREREQUISITES:

Euler tour/in-out times in a tree, binary search, Basic number theory

PROBLEM:

You are given a tree, rooted at node 1. Each node is given a value, as described by array A.
There are queries of the form (u,v, X): find which node among u and v has more power.
Power of a node is defined as, the "Count of nodes in its subtree whose value divides X".

QUICK EXPLANATION:

Store in-times corresponding to each value in the tree and at the time of the query, traverse on all divisors of X and count occurrences of each divisor in the subtree of u and v using their in-out times and binary search.

EXPLANATION:

First of all, we should know that we can store all the divisors of all the numbers from 1 to X_{max} in X_{max}log(X_{max}) time.

Why?

n/1 + n/2 + n/3 + ... + n/n is approximately equal to nlogn.

Now, let’s solve the problem.

We will store the following things while doing the dfs:

  • in-out times of all the nodes.
  • in-times corresponding to all the values of the nodes. i.e. numTimes[val] is an array having in times corresponding to val.

Note that the “Value” of a node is the one that is given in the input and the “number” of a node is its serial number. eg. node 1 has value 124, node 2 has value 2423, etc.

Now, let’s solve the queries online: (u,v, X)
We can simply traverse on all the stored divisors of X. Let’s say we are currently on divisor di.
We need to count its occurrences in the subtree of u and in the subtree of v. We can do it in this way:

  • For u, we need to see how many numbers in numTimes[value[u]] lie in the range inTime[u] and outTime[u].
  • First make sure numTimes[value[u]] is in sorted order. This can be done while doing the Euler tour.
  • Now just apply binary search on numTimes[value[u]] using inTime and outTime of u.
    This can be done in c++ like this
      int count(vector<int> &v, int t1, int t2){
          // v is vector numTimes[values[u]], t2 is outTime[u] and t1 is inTime[u].
          int cnt = upper_bound(all(v) , t2) - lower_bound(all(v) , t1);
          return cnt;
      }
    

Now all that is left is to calculate the power of node u and node v for the battle and compare them.

Code Snippet for queries
while(q--){
        
        int u,v,x;
        cin>>u>>v>>x;
        int cnt_u=0, cnt_v=0;

        for(auto it:divisors[x]){
            cnt_u += count(numTimes[it] , in[u] , out[u]);
            cnt_v += count(numTimes[it] , in[v] , out[v]);
        }

        if(cnt_u == cnt_v){
            cout<<"Draw"<<endl;
        }
        else if(cnt_u > cnt_v){
            cout<<u<<endl;
        }
        else{
            cout<<v<<endl;
        }

    }
Time Complexity

O(X_{max} log(X_{max}) + n + q(log n))
X_{max} is the maximum value X in the query can take, N is the number of nodes and q is number of queries.

ALTERNATE EXPLANATION:

We can also solve this problem using DSU-on-tree

Short Explanation
  • Store all the queries.
  • Store divisors of all possible X.
  • Perform DFS (following small to large technique).
  • While present at node i, traverse over all divisors of all X's corresponding to queries on this node and store the powers of nodes corresponding to those queries.
  • Solve the queries offline, using the stored data.

SOLUTIONS:

Setter's Solution
                    //    I solemnly swear that I am upto no good //

#include <bits/stdc++.h>
using namespace std;
#define sub             freopen("input.txt", "r", stdin);//freopen("output.txt", "w", stdout);
#define ll              long long
#define ull             unsigned long long
#define ld              long double
#define ttime           {cerr << '\n'<<"Time (in s): " << double(clock() - clk) * 1.0 / CLOCKS_PER_SEC << '\n';}
#define helpUs          template<typename T = ll , typename U = ll>
#define helpMe          template<typename T = ll>
#define pb              push_back
#define sz(x)           (int)((x).size())
#define fast            ios_base::sync_with_stdio(false);cin.tie(0);
#define all(x)          (x).begin(),(x).end()
#define rep(i,a,b)      for(ll i=a;i<b;i++)
#define pr(x)           cout << #x " = " << (x) << "\n"
#define mp              make_pair
#define ff              first
#define ss              second
#define YY              cout<<"Yes"<<endl
#define NN              cout<<"No"<<endl
#define ppc             __builtin_popcount
#define ppcll           __builtin_popcountll

// #include <ext/pb_ds/assoc_container.hpp>
// #include <ext/pb_ds/tree_policy.hpp>
// using namespace __gnu_pbds;
// #define ordered_set tree<int, null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>

// order_of_key and find_by_order

const long long INF=1e18;
const long long N=200005;
const long long mod=1000000007;                    // 998244353, 2971215073, 1000050131, 433494437

#define endl "\n"
#define int ll

typedef pair<ll,ll> pairll;
typedef map<ll,ll>  mapll;
typedef map<char,ll> mapch;
typedef vector<ll> vll;
mt19937_64 rng(std::chrono::steady_clock::now().time_since_epoch().count());
helpUs class comp{public:bool operator()(T a, U b){return a>b;}};
helpUs istream& operator>>(istream& aa, pair<T,U> &p){aa>>p.ff>>p.ss;return aa;}
helpMe ostream& operator<<(ostream& ja, vector<T> &v){for(auto it:v)ja<<it<<" ";return ja;}
helpMe istream& operator>>(istream& aa, vector<T> &v){for(auto &it:v)cin>>it;return aa;}
helpUs ostream& operator<<(ostream& ja, pair<T,U> &p){ja<<p.ff<<" "<<p.ss;return ja;}

ll n,k,tt=1;

struct Q{
    Q(){}
    Q(int a, int b, int c){
        u=a;
        v=b;
        x=c;
    }
    int u,v,x;
};
struct Solution{
    ll n,k,tt=1;
    vector<int> value;
    vector<int> sz;
    vector<int> cnt;
    map<pairll, int> ans;
    vector<vll> correspondingX;

    vector<vll> adj;
    vector<vll> divisors;
    vector<Q> query;

    void init(){
        divisors = vector<vll> (200005);
    }

    void pre(){
        for(int i=1;i<=200000;i++){
            int j=i;
            while(j<=200000){
                divisors[j].pb(i);
                j+=i;
            }
        }
    }

    void dfs_size(int x, int p){

        for(auto it:adj[x]){
            if(it != p){
                dfs_size(it,x);
                sz[x] += sz[it];
            }
        }

        sz[x]++;

    }

    void add(int x, int p, int val){
        
        for(auto it:adj[x]){
            if(it != p){
                add(it,x,val);
            }
        }

        cnt[value[x]] += val;

    }

    void dfs_cnt(int x, int p, int keep){

        int mx=-1, bigChild=-1;
        for(auto it:adj[x]){
            if(it != p){
                if(sz[it] > mx)
                    mx=sz[it], bigChild=it;
            }
        }

        for(auto it:adj[x]){
            if(it != p and it != bigChild){
                dfs_cnt(it,x,0);
            }
        }

        if(bigChild != -1){
            dfs_cnt(bigChild,x,1);
        }

        cnt[value[x]]++;

        for(auto it:adj[x]){
            if(it != p and it != bigChild){
                add(it,x,1);
            }
        }

        ///////////////////////////////////////////////////////////////////////////////////////////////


        for(int s:correspondingX[x]){
            for(int j:divisors[s]){
                ans[{x,s}] += (cnt[j]);
            }
        }


        ///////////////////////////////////////////////////////////////////////////////////////////////

        if(keep == 0){
            add(x,p,-1);
        }

    }

    void solve(){
        
        ll q;
        cin>>n>>q;
        value  = vll(n+1);
        sz  = vll(n+1);
        cnt  = vll(200005);
        query = vector<Q> ();
        correspondingX = vector<vll> (n+1);
        adj = vector<vll> (n+1);
        
        for(int i=1;i<=n;i++)
            cin>>value[i];

        for(int i=1;i<n;i++){
            int u,v;
            cin>>u>>v;
            adj[u].pb(v);
            adj[v].pb(u);
        }

        for(int i=0;i<q;i++){
            int u,v,x;
            cin>>u>>v>>x;
            Q p(u,v,x);
            query.push_back(p);
            correspondingX[u].pb(x);
            correspondingX[v].pb(x);
        }



        dfs_size(1,0);
        dfs_cnt(1,0,0);

        for(int i=0;i<q;i++){
            int u=query[i].u;
            int v=query[i].v;
            int x=query[i].x;
            int A = ans[{u,x}];
            int B = ans[{v,x}];
            
            if(A>B)cout<<u<<endl;
            else if(B>A)cout<<v<<endl;
            else cout<<"Draw"<<endl;
        }

    }
};


void solve(){

    Solution S;
    S.init();
    S.pre();
    S.solve();

}
           
signed main(){
    
    fast;
    ll t=1;
    // sub;
   
    clock_t clk = clock();
   

    while(t--)
        solve();
    
    ttime;

    return 0;

}        

                                // Mischief Managed //
Tester's Solution
        #include<bits/stdc++.h>
        using namespace std;
        //#include <ext/pb_ds/assoc_container.hpp> // Common file
        //#include <ext/pb_ds/tree_policy.hpp> // Including tree_order_statistics_node_update
        //using namespace __gnu_pbds; //  order of key(keys strictly less than)  // find_by_order
        //typedef tree<long long,null_type,less<>,rb_tree_tag,tree_order_statistics_node_update> ordered_set;
        //typedef tree<long long, null_type, less_equal<>, rb_tree_tag, tree_order_statistics_node_update> indexed_multiset;
        //IF WA CHECK FOR : -
        // 1 > EDGE CASES LIKE N=1 , N=0
        // 2 > SIGNED INTEGER OVERFLOW IN MOD
        // 3 > CHECK THE CODE FOR LOGICAL ERRORS AND SEG FAULTS
        // 4 > READ THE PS ONCE AGAIN , if having double diff less than 1e-8 is same.
        // 5 > You Have got AC .
        #define ll long long
        #define NUM (ll)998244353
        #define inf (long long)(2e18)
        #define ff first
        #define ss second
        #define f(i,a,b) for(ll i=a;(i)<long(b);(i)++)
        #define fr(i,a,b) for(ll i=a;(i)>=(long long)(b);(i)--)
        #define it(b)  for(auto &it:(b))
        #define pb push_back
        #define mp make_pair
        typedef vector<ll> vll;
        typedef pair<ll,ll> pll;
        ll binpow( ll base , ll ex,ll mod=NUM) {
            ll ans = 1;base = base % mod;
            if(base==0){
                return 0;
            }
            while (ex > 0) {
                if (ex % 2 == 1) {
                    ans = (ans * base) % mod;
                }
                base = (base * base) % mod;
                ex = ex / 2;
            }
            return ans;
        }
        void read(vll &arr,ll n) {
            if (arr.size() != n) { arr.assign(n, 0); }for (int i = 0; i < n; i++)cin >> arr[i];
        }
        inline ll min(ll a,ll b){
            if(a>b)return b;return a;
        }
        inline ll max(ll a, ll b){
            if(a>b)return a;return b;
        }
        inline ll dif(ll a,ll b) {
            if (a > b)return a - b;return b - a;
        }
        long long gcd(long long a,long long b) {
            if (b == 0)return a;return gcd(b, a % b);
        }
        long long lcm(long long a,long long b) {
            long long k = gcd(a, b);
            return (a * b) / k;
        }
        vector<vll>adj;
        vll val,in,out;
        ll tim = 0;
        void dfs(ll start,ll par){
            tim++;in[start]=tim;
            it(adj[start]){
                if(it!=par){
                    dfs(it,start);
                }
            }
            out[start]=tim;
        }
        ll fun(vll &arr,ll x){
            if(arr.empty()){
                return 0;
            }
            ll low = 0;
            ll high = arr.size()-1;
            ll ans=-1;
            while(low<=high){
                ll mid = (low+high)/2;
                if(arr[mid]>x){
                    high = mid-1;
                }
                else{
                    ans = mid;
                    low = mid+1;
                }
            }
            return ans;
        }
        vector<vll>fact(2*1e5+1);
        vector<vll>times(2*1e5+1);
        void solve() {
            int n,q;cin>>n>>q;
            assert(n>=1 and n<=5*1e4 and q>=1 and q<=1e5);
            adj.resize(n+1);val.resize(n+1);in=val;out=val;
            vector<vll>ind(2*1e5+1);
            f(i,0,n){
                ll a;cin>>a;val[i+1]=a;assert(a<=2*1e5 and a>=1);
                ind[a].pb(i+1);
            }
            f(i,0,n-1){
                ll a,b;cin>>a>>b;
                adj[a].pb(b);adj[b].pb(a);
            }
            dfs(1,1);
            f(i,1,2*1e5+1) {
                it(ind[i]) {
                    times[i].pb(in[it]);
                }
                sort(times[i].begin(), times[i].end());
            }
            set<ll>z;int xx=0;
            while(q--){
                xx++;
                ll u,v,x;cin>>u>>v>>x;
                z.insert(x);
                assert(u<=n and v<=n and x<=2*1e5);
                ll lef = 0;
                ll rig =0;
                it(fact[x]){
                    lef  += fun(times[it],out[u])-fun(times[it],in[u]-1);
                    rig  += fun(times[it],out[v])-fun(times[it],in[v]-1);
                }
                if(lef<rig){
                    cout<<v<<endl;
                }
                else if(lef==rig){
                    cout<<"Draw"<<endl;
                }
                else{
                    cout<<u<<endl;
                }
            }
            assert(xx==z.size());
        }
        int main() {
            ios_base::sync_with_stdio(false);
            cin.tie(NULL);
            cout << fixed << showpoint;
            cout << setprecision(12);
            long long test_m = 1;
            int k=1;
            //cin >> test_m;
            //WE WILL WIN .
            for(int i=1;i<=2*1e5;i++){
                for(ll j=i;j<=2*1e5;j+=i){
                    fact[j].pb(i);
                }
            }
            while (test_m--) {
                //cout<<"Case #"<<k++<<": ";
                solve();
            }
        }
Editorialist's Solution
                    //    I solemnly swear that I am upto no good //

#include <bits/stdc++.h>
using namespace std;
#define sub             freopen("input.txt", "r", stdin);//freopen("output.txt", "w", stdout);
#define ll              long long
#define ttime           {cerr << '\n'<<"Time (in s): " << double(clock() - clk) * 1.0 / CLOCKS_PER_SEC << '\n';}
#define helpUs          template<typename T = ll , typename U = ll>
#define helpMe          template<typename T = ll>
#define pb              push_back
#define sz(x)           (int)((x).size())
#define fast            ios_base::sync_with_stdio(false);cin.tie(0);
#define all(x)          (x).begin(),(x).end()

#define endl "\n"
#define int ll

typedef pair<ll,ll> pairll;
typedef map<ll,ll>  mapll;
typedef map<char,ll> mapch;
typedef vector<ll> vll;
mt19937_64 rng(std::chrono::steady_clock::now().time_since_epoch().count());
helpUs class comp{public:bool operator()(T a, U b){return a>b;}};
helpUs istream& operator>>(istream& aa, pair<T,U> &p){aa>>p.ff>>p.ss;return aa;}
helpMe ostream& operator<<(ostream& ja, vector<T> &v){for(auto it:v)ja<<it<<" ";return ja;}
helpMe istream& operator>>(istream& aa, vector<T> &v){for(auto &it:v)cin>>it;return aa;}
helpUs ostream& operator<<(ostream& ja, pair<T,U> &p){ja<<p.ff<<" "<<p.ss;return ja;}



struct Solution{

    ll n,k,q;
    vector<vll> numTimes;
    vector<vll> divisors;
    vll values;
    vector<vll> adj;
    vll in, out;

    int tim=0;
    ll MAX = 200000; 

    void pre(){
        for(int i=1;i<=MAX;i++){
            int j=i;
            while(j<=MAX){
                divisors[j].pb(i);
                j+=i;
            }
        }
    }

    void dfs(int x, int p){
        
        in[x] = (++tim);
        int val = values[x];

        numTimes[val].push_back(in[x]);
        
        for(auto it:adj[x]){
            if(it != p){
                dfs(it,x);
            }
        }
        
        out[x] = tim;
    }

    int count(vector<int> &v, int t1, int t2){
        int cnt = upper_bound(all(v) , t2) - lower_bound(all(v) , t1);
        return cnt;
    }

    void solve(){

        divisors = vector<vll> (200001);
        pre();
        cin>>n>>q;

        numTimes = vector<vll> (MAX+1);
        values = vll(n+1);
        in = vll(n+1);
        out = vll(n+1);

        for(int i=1;i<=n;i++){
            cin>>values[i];
        }

        set<pairll> edges;
        set<int> xs;
        set<vll> qrs;

        adj = vector<vll> (n+1);
        for(int i=1;i<n;i++){
            int a,b;
            cin>>a>>b;
            adj[a].pb(b);
            adj[b].pb(a);

        }

        dfs(1,0);

        int temp=q;

        while(temp--){
            
            int u,v,x;

            cin>>u>>v>>x;

            int cnt_u=0, cnt_v=0;

            for(auto it:divisors[x]){
                cnt_u += count(numTimes[it] , in[u] , out[u]);
                cnt_v += count(numTimes[it] , in[v] , out[v]);
            }

            if(cnt_u == cnt_v){
                cout<<"Draw"<<endl;
            }
            else if(cnt_u > cnt_v){
                cout<<u<<endl;
            }
            else{
                cout<<v<<endl;
            }

        }

    }

};

           
signed main(){

    fast;
    ll t=1;

   // freopen("input.txt" , "r" , stdin);
   
    // clock_t clk = clock();
    
    Solution S;
    S.solve();
    
    // ttime;

    return 0;

}        

                                // Mischief Managed //