# GREATSACK - Editorial

Author: Onkar Ratnaparkhi
Tester: Saptarshi Shome
Editorialist: Onkar Ratnaparkhi

EASY-MEDIUM

# PREREQUISITES:

Euler tour/in-out times in a tree, binary search, Basic number theory

# PROBLEM:

You are given a tree, rooted at node 1. Each node is given a value, as described by array A.
There are queries of the form (u,v, X): find which node among u and v has more power.
Power of a node is defined as, the "Count of nodes in its subtree whose value divides X".

# QUICK EXPLANATION:

Store in-times corresponding to each value in the tree and at the time of the query, traverse on all divisors of X and count occurrences of each divisor in the subtree of u and v using their in-out times and binary search.

# EXPLANATION:

First of all, we should know that we can store all the divisors of all the numbers from 1 to X_{max} in X_{max}log(X_{max}) time.

Why?

n/1 + n/2 + n/3 + ... + n/n is approximately equal to nlogn.

Now, let’s solve the problem.

We will store the following things while doing the dfs:

• in-out times of all the nodes.
• in-times corresponding to all the values of the nodes. i.e. numTimes[val] is an array having in times corresponding to val.

Note that the “Value” of a node is the one that is given in the input and the “number” of a node is its serial number. eg. node 1 has value 124, node 2 has value 2423, etc.

Now, let’s solve the queries online: (u,v, X)
We can simply traverse on all the stored divisors of X. Let’s say we are currently on divisor di.
We need to count its occurrences in the subtree of u and in the subtree of v. We can do it in this way:

• For u, we need to see how many numbers in numTimes[value[u]] lie in the range inTime[u] and outTime[u].
• First make sure numTimes[value[u]] is in sorted order. This can be done while doing the Euler tour.
• Now just apply binary search on numTimes[value[u]] using inTime and outTime of u.
This can be done in c++ like this
int count(vector<int> &v, int t1, int t2){
// v is vector numTimes[values[u]], t2 is outTime[u] and t1 is inTime[u].
int cnt = upper_bound(all(v) , t2) - lower_bound(all(v) , t1);
return cnt;
}

Now all that is left is to calculate the power of node u and node v for the battle and compare them.

Code Snippet for queries
while(q--){

int u,v,x;
cin>>u>>v>>x;
int cnt_u=0, cnt_v=0;

for(auto it:divisors[x]){
cnt_u += count(numTimes[it] , in[u] , out[u]);
cnt_v += count(numTimes[it] , in[v] , out[v]);
}

if(cnt_u == cnt_v){
cout<<"Draw"<<endl;
}
else if(cnt_u > cnt_v){
cout<<u<<endl;
}
else{
cout<<v<<endl;
}

}
Time Complexity

O(X_{max} log(X_{max}) + n + q(log n))
X_{max} is the maximum value X in the query can take, N is the number of nodes and q is number of queries.

# ALTERNATE EXPLANATION:

We can also solve this problem using DSU-on-tree

Short Explanation
• Store all the queries.
• Store divisors of all possible X.
• Perform DFS (following small to large technique).
• While present at node i, traverse over all divisors of all X's corresponding to queries on this node and store the powers of nodes corresponding to those queries.
• Solve the queries offline, using the stored data.

# SOLUTIONS:

Setter's Solution
//    I solemnly swear that I am upto no good //

#include <bits/stdc++.h>
using namespace std;
#define sub             freopen("input.txt", "r", stdin);//freopen("output.txt", "w", stdout);
#define ll              long long
#define ull             unsigned long long
#define ld              long double
#define ttime           {cerr << '\n'<<"Time (in s): " << double(clock() - clk) * 1.0 / CLOCKS_PER_SEC << '\n';}
#define helpUs          template<typename T = ll , typename U = ll>
#define helpMe          template<typename T = ll>
#define pb              push_back
#define sz(x)           (int)((x).size())
#define fast            ios_base::sync_with_stdio(false);cin.tie(0);
#define all(x)          (x).begin(),(x).end()
#define rep(i,a,b)      for(ll i=a;i<b;i++)
#define pr(x)           cout << #x " = " << (x) << "\n"
#define mp              make_pair
#define ff              first
#define ss              second
#define YY              cout<<"Yes"<<endl
#define NN              cout<<"No"<<endl
#define ppc             __builtin_popcount
#define ppcll           __builtin_popcountll

// #include <ext/pb_ds/assoc_container.hpp>
// #include <ext/pb_ds/tree_policy.hpp>
// using namespace __gnu_pbds;
// #define ordered_set tree<int, null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>

// order_of_key and find_by_order

const long long INF=1e18;
const long long N=200005;
const long long mod=1000000007;                    // 998244353, 2971215073, 1000050131, 433494437

#define endl "\n"
#define int ll

typedef pair<ll,ll> pairll;
typedef map<ll,ll>  mapll;
typedef map<char,ll> mapch;
typedef vector<ll> vll;
helpUs class comp{public:bool operator()(T a, U b){return a>b;}};
helpUs istream& operator>>(istream& aa, pair<T,U> &p){aa>>p.ff>>p.ss;return aa;}
helpMe ostream& operator<<(ostream& ja, vector<T> &v){for(auto it:v)ja<<it<<" ";return ja;}
helpMe istream& operator>>(istream& aa, vector<T> &v){for(auto &it:v)cin>>it;return aa;}
helpUs ostream& operator<<(ostream& ja, pair<T,U> &p){ja<<p.ff<<" "<<p.ss;return ja;}

ll n,k,tt=1;

struct Q{
Q(){}
Q(int a, int b, int c){
u=a;
v=b;
x=c;
}
int u,v,x;
};
struct Solution{
ll n,k,tt=1;
vector<int> value;
vector<int> sz;
vector<int> cnt;
map<pairll, int> ans;
vector<vll> correspondingX;

vector<vll> divisors;
vector<Q> query;

void init(){
divisors = vector<vll> (200005);
}

void pre(){
for(int i=1;i<=200000;i++){
int j=i;
while(j<=200000){
divisors[j].pb(i);
j+=i;
}
}
}

void dfs_size(int x, int p){

if(it != p){
dfs_size(it,x);
sz[x] += sz[it];
}
}

sz[x]++;

}

void add(int x, int p, int val){

if(it != p){
}
}

cnt[value[x]] += val;

}

void dfs_cnt(int x, int p, int keep){

int mx=-1, bigChild=-1;
if(it != p){
if(sz[it] > mx)
mx=sz[it], bigChild=it;
}
}

if(it != p and it != bigChild){
dfs_cnt(it,x,0);
}
}

if(bigChild != -1){
dfs_cnt(bigChild,x,1);
}

cnt[value[x]]++;

if(it != p and it != bigChild){
}
}

///////////////////////////////////////////////////////////////////////////////////////////////

for(int s:correspondingX[x]){
for(int j:divisors[s]){
ans[{x,s}] += (cnt[j]);
}
}

///////////////////////////////////////////////////////////////////////////////////////////////

if(keep == 0){
}

}

void solve(){

ll q;
cin>>n>>q;
value  = vll(n+1);
sz  = vll(n+1);
cnt  = vll(200005);
query = vector<Q> ();
correspondingX = vector<vll> (n+1);

for(int i=1;i<=n;i++)
cin>>value[i];

for(int i=1;i<n;i++){
int u,v;
cin>>u>>v;
}

for(int i=0;i<q;i++){
int u,v,x;
cin>>u>>v>>x;
Q p(u,v,x);
query.push_back(p);
correspondingX[u].pb(x);
correspondingX[v].pb(x);
}

dfs_size(1,0);
dfs_cnt(1,0,0);

for(int i=0;i<q;i++){
int u=query[i].u;
int v=query[i].v;
int x=query[i].x;
int A = ans[{u,x}];
int B = ans[{v,x}];

if(A>B)cout<<u<<endl;
else if(B>A)cout<<v<<endl;
else cout<<"Draw"<<endl;
}

}
};

void solve(){

Solution S;
S.init();
S.pre();
S.solve();

}

signed main(){

fast;
ll t=1;
// sub;

clock_t clk = clock();

while(t--)
solve();

ttime;

return 0;

}

// Mischief Managed //
Tester's Solution
#include<bits/stdc++.h>
using namespace std;
//#include <ext/pb_ds/assoc_container.hpp> // Common file
//#include <ext/pb_ds/tree_policy.hpp> // Including tree_order_statistics_node_update
//using namespace __gnu_pbds; //  order of key(keys strictly less than)  // find_by_order
//typedef tree<long long,null_type,less<>,rb_tree_tag,tree_order_statistics_node_update> ordered_set;
//typedef tree<long long, null_type, less_equal<>, rb_tree_tag, tree_order_statistics_node_update> indexed_multiset;
//IF WA CHECK FOR : -
// 1 > EDGE CASES LIKE N=1 , N=0
// 2 > SIGNED INTEGER OVERFLOW IN MOD
// 3 > CHECK THE CODE FOR LOGICAL ERRORS AND SEG FAULTS
// 4 > READ THE PS ONCE AGAIN , if having double diff less than 1e-8 is same.
// 5 > You Have got AC .
#define ll long long
#define NUM (ll)998244353
#define inf (long long)(2e18)
#define ff first
#define ss second
#define f(i,a,b) for(ll i=a;(i)<long(b);(i)++)
#define fr(i,a,b) for(ll i=a;(i)>=(long long)(b);(i)--)
#define it(b)  for(auto &it:(b))
#define pb push_back
#define mp make_pair
typedef vector<ll> vll;
typedef pair<ll,ll> pll;
ll binpow( ll base , ll ex,ll mod=NUM) {
ll ans = 1;base = base % mod;
if(base==0){
return 0;
}
while (ex > 0) {
if (ex % 2 == 1) {
ans = (ans * base) % mod;
}
base = (base * base) % mod;
ex = ex / 2;
}
return ans;
}
if (arr.size() != n) { arr.assign(n, 0); }for (int i = 0; i < n; i++)cin >> arr[i];
}
inline ll min(ll a,ll b){
if(a>b)return b;return a;
}
inline ll max(ll a, ll b){
if(a>b)return a;return b;
}
inline ll dif(ll a,ll b) {
if (a > b)return a - b;return b - a;
}
long long gcd(long long a,long long b) {
if (b == 0)return a;return gcd(b, a % b);
}
long long lcm(long long a,long long b) {
long long k = gcd(a, b);
return (a * b) / k;
}
vll val,in,out;
ll tim = 0;
void dfs(ll start,ll par){
tim++;in[start]=tim;
if(it!=par){
dfs(it,start);
}
}
out[start]=tim;
}
ll fun(vll &arr,ll x){
if(arr.empty()){
return 0;
}
ll low = 0;
ll high = arr.size()-1;
ll ans=-1;
while(low<=high){
ll mid = (low+high)/2;
if(arr[mid]>x){
high = mid-1;
}
else{
ans = mid;
low = mid+1;
}
}
return ans;
}
vector<vll>fact(2*1e5+1);
vector<vll>times(2*1e5+1);
void solve() {
int n,q;cin>>n>>q;
assert(n>=1 and n<=5*1e4 and q>=1 and q<=1e5);
vector<vll>ind(2*1e5+1);
f(i,0,n){
ll a;cin>>a;val[i+1]=a;assert(a<=2*1e5 and a>=1);
ind[a].pb(i+1);
}
f(i,0,n-1){
ll a,b;cin>>a>>b;
}
dfs(1,1);
f(i,1,2*1e5+1) {
it(ind[i]) {
times[i].pb(in[it]);
}
sort(times[i].begin(), times[i].end());
}
set<ll>z;int xx=0;
while(q--){
xx++;
ll u,v,x;cin>>u>>v>>x;
z.insert(x);
assert(u<=n and v<=n and x<=2*1e5);
ll lef = 0;
ll rig =0;
it(fact[x]){
lef  += fun(times[it],out[u])-fun(times[it],in[u]-1);
rig  += fun(times[it],out[v])-fun(times[it],in[v]-1);
}
if(lef<rig){
cout<<v<<endl;
}
else if(lef==rig){
cout<<"Draw"<<endl;
}
else{
cout<<u<<endl;
}
}
assert(xx==z.size());
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cout << fixed << showpoint;
cout << setprecision(12);
long long test_m = 1;
int k=1;
//cin >> test_m;
//WE WILL WIN .
for(int i=1;i<=2*1e5;i++){
for(ll j=i;j<=2*1e5;j+=i){
fact[j].pb(i);
}
}
while (test_m--) {
//cout<<"Case #"<<k++<<": ";
solve();
}
}
Editorialist's Solution
//    I solemnly swear that I am upto no good //

#include <bits/stdc++.h>
using namespace std;
#define sub             freopen("input.txt", "r", stdin);//freopen("output.txt", "w", stdout);
#define ll              long long
#define ttime           {cerr << '\n'<<"Time (in s): " << double(clock() - clk) * 1.0 / CLOCKS_PER_SEC << '\n';}
#define helpUs          template<typename T = ll , typename U = ll>
#define helpMe          template<typename T = ll>
#define pb              push_back
#define sz(x)           (int)((x).size())
#define fast            ios_base::sync_with_stdio(false);cin.tie(0);
#define all(x)          (x).begin(),(x).end()

#define endl "\n"
#define int ll

typedef pair<ll,ll> pairll;
typedef map<ll,ll>  mapll;
typedef map<char,ll> mapch;
typedef vector<ll> vll;
helpUs class comp{public:bool operator()(T a, U b){return a>b;}};
helpUs istream& operator>>(istream& aa, pair<T,U> &p){aa>>p.ff>>p.ss;return aa;}
helpMe ostream& operator<<(ostream& ja, vector<T> &v){for(auto it:v)ja<<it<<" ";return ja;}
helpMe istream& operator>>(istream& aa, vector<T> &v){for(auto &it:v)cin>>it;return aa;}
helpUs ostream& operator<<(ostream& ja, pair<T,U> &p){ja<<p.ff<<" "<<p.ss;return ja;}

struct Solution{

ll n,k,q;
vector<vll> numTimes;
vector<vll> divisors;
vll values;
vll in, out;

int tim=0;
ll MAX = 200000;

void pre(){
for(int i=1;i<=MAX;i++){
int j=i;
while(j<=MAX){
divisors[j].pb(i);
j+=i;
}
}
}

void dfs(int x, int p){

in[x] = (++tim);
int val = values[x];

numTimes[val].push_back(in[x]);

if(it != p){
dfs(it,x);
}
}

out[x] = tim;
}

int count(vector<int> &v, int t1, int t2){
int cnt = upper_bound(all(v) , t2) - lower_bound(all(v) , t1);
return cnt;
}

void solve(){

divisors = vector<vll> (200001);
pre();
cin>>n>>q;

numTimes = vector<vll> (MAX+1);
values = vll(n+1);
in = vll(n+1);
out = vll(n+1);

for(int i=1;i<=n;i++){
cin>>values[i];
}

set<pairll> edges;
set<int> xs;
set<vll> qrs;

for(int i=1;i<n;i++){
int a,b;
cin>>a>>b;

}

dfs(1,0);

int temp=q;

while(temp--){

int u,v,x;

cin>>u>>v>>x;

int cnt_u=0, cnt_v=0;

for(auto it:divisors[x]){
cnt_u += count(numTimes[it] , in[u] , out[u]);
cnt_v += count(numTimes[it] , in[v] , out[v]);
}

if(cnt_u == cnt_v){
cout<<"Draw"<<endl;
}
else if(cnt_u > cnt_v){
cout<<u<<endl;
}
else{
cout<<v<<endl;
}

}

}

};

signed main(){

fast;
ll t=1;

// freopen("input.txt" , "r" , stdin);

// clock_t clk = clock();

Solution S;
S.solve();

// ttime;

return 0;

}

// Mischief Managed //