PROBLEM LINK:
Author: Onkar Ratnaparkhi
Tester: Saptarshi Shome
Editorialist: Onkar Ratnaparkhi
DIFFICULTY:
EASY-MEDIUM
PREREQUISITES:
Euler tour/in-out times in a tree, binary search, Basic number theory
PROBLEM:
You are given a tree, rooted at node 1. Each node is given a value, as described by array A.
There are queries of the form (u,v, X): find which node among u and v has more power.
Power of a node is defined as, the “Count of nodes in its subtree whose value divides X”.
QUICK EXPLANATION:
Store in-times corresponding to each value in the tree and at the time of the query, traverse on all divisors of X and count occurrences of each divisor in the subtree of u and v using their in-out times and binary search.
EXPLANATION:
First of all, we should know that we can store all the divisors of all the numbers from 1 to X_{max} in X_{max}log(X_{max}) time.
Why?
n/1 + n/2 + n/3 + ... + n/n is approximately equal to nlogn.
Now, let’s solve the problem.
We will store the following things while doing the dfs:
- in-out times of all the nodes.
- in-times corresponding to all the values of the nodes. i.e. numTimes[val] is an array having in times corresponding to val.
Note that the “Value” of a node is the one that is given in the input and the “number” of a node is its serial number. eg. node 1 has value 124, node 2 has value 2423, etc.
Now, let’s solve the queries online: (u,v, X)
We can simply traverse on all the stored divisors of X. Let’s say we are currently on divisor di.
We need to count its occurrences in the subtree of u and in the subtree of v. We can do it in this way:
- For u, we need to see how many numbers in numTimes[value[u]] lie in the range inTime[u] and outTime[u].
- First make sure numTimes[value[u]] is in sorted order. This can be done while doing the Euler tour.
- Now just apply binary search on numTimes[value[u]] using inTime and outTime of u.
This can be done in c++ like this
int count(vector<int> &v, int t1, int t2){ // v is vector numTimes[values[u]], t2 is outTime[u] and t1 is inTime[u]. int cnt = upper_bound(all(v) , t2) - lower_bound(all(v) , t1); return cnt; }
Now all that is left is to calculate the power of node u and node v for the battle and compare them.
Code Snippet for queries
while(q--){
int u,v,x;
cin>>u>>v>>x;
int cnt_u=0, cnt_v=0;
for(auto it:divisors[x]){
cnt_u += count(numTimes[it] , in[u] , out[u]);
cnt_v += count(numTimes[it] , in[v] , out[v]);
}
if(cnt_u == cnt_v){
cout<<"Draw"<<endl;
}
else if(cnt_u > cnt_v){
cout<<u<<endl;
}
else{
cout<<v<<endl;
}
}
Time Complexity
O(X_{max} log(X_{max}) + n + q(log n))
X_{max} is the maximum value X in the query can take, N is the number of nodes and q is number of queries.
ALTERNATE EXPLANATION:
We can also solve this problem using DSU-on-tree
Short Explanation
- Store all the queries.
- Store divisors of all possible X.
- Perform DFS (following small to large technique).
- While present at node i, traverse over all divisors of all X's corresponding to queries on this node and store the powers of nodes corresponding to those queries.
- Solve the queries offline, using the stored data.
SOLUTIONS:
Setter's Solution
// I solemnly swear that I am upto no good //
#include <bits/stdc++.h>
using namespace std;
#define sub freopen("input.txt", "r", stdin);//freopen("output.txt", "w", stdout);
#define ll long long
#define ull unsigned long long
#define ld long double
#define ttime {cerr << '\n'<<"Time (in s): " << double(clock() - clk) * 1.0 / CLOCKS_PER_SEC << '\n';}
#define helpUs template<typename T = ll , typename U = ll>
#define helpMe template<typename T = ll>
#define pb push_back
#define sz(x) (int)((x).size())
#define fast ios_base::sync_with_stdio(false);cin.tie(0);
#define all(x) (x).begin(),(x).end()
#define rep(i,a,b) for(ll i=a;i<b;i++)
#define pr(x) cout << #x " = " << (x) << "\n"
#define mp make_pair
#define ff first
#define ss second
#define YY cout<<"Yes"<<endl
#define NN cout<<"No"<<endl
#define ppc __builtin_popcount
#define ppcll __builtin_popcountll
// #include <ext/pb_ds/assoc_container.hpp>
// #include <ext/pb_ds/tree_policy.hpp>
// using namespace __gnu_pbds;
// #define ordered_set tree<int, null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>
// order_of_key and find_by_order
const long long INF=1e18;
const long long N=200005;
const long long mod=1000000007; // 998244353, 2971215073, 1000050131, 433494437
#define endl "\n"
#define int ll
typedef pair<ll,ll> pairll;
typedef map<ll,ll> mapll;
typedef map<char,ll> mapch;
typedef vector<ll> vll;
mt19937_64 rng(std::chrono::steady_clock::now().time_since_epoch().count());
helpUs class comp{public:bool operator()(T a, U b){return a>b;}};
helpUs istream& operator>>(istream& aa, pair<T,U> &p){aa>>p.ff>>p.ss;return aa;}
helpMe ostream& operator<<(ostream& ja, vector<T> &v){for(auto it:v)ja<<it<<" ";return ja;}
helpMe istream& operator>>(istream& aa, vector<T> &v){for(auto &it:v)cin>>it;return aa;}
helpUs ostream& operator<<(ostream& ja, pair<T,U> &p){ja<<p.ff<<" "<<p.ss;return ja;}
ll n,k,tt=1;
struct Q{
Q(){}
Q(int a, int b, int c){
u=a;
v=b;
x=c;
}
int u,v,x;
};
struct Solution{
ll n,k,tt=1;
vector<int> value;
vector<int> sz;
vector<int> cnt;
map<pairll, int> ans;
vector<vll> correspondingX;
vector<vll> adj;
vector<vll> divisors;
vector<Q> query;
void init(){
divisors = vector<vll> (200005);
}
void pre(){
for(int i=1;i<=200000;i++){
int j=i;
while(j<=200000){
divisors[j].pb(i);
j+=i;
}
}
}
void dfs_size(int x, int p){
for(auto it:adj[x]){
if(it != p){
dfs_size(it,x);
sz[x] += sz[it];
}
}
sz[x]++;
}
void add(int x, int p, int val){
for(auto it:adj[x]){
if(it != p){
add(it,x,val);
}
}
cnt[value[x]] += val;
}
void dfs_cnt(int x, int p, int keep){
int mx=-1, bigChild=-1;
for(auto it:adj[x]){
if(it != p){
if(sz[it] > mx)
mx=sz[it], bigChild=it;
}
}
for(auto it:adj[x]){
if(it != p and it != bigChild){
dfs_cnt(it,x,0);
}
}
if(bigChild != -1){
dfs_cnt(bigChild,x,1);
}
cnt[value[x]]++;
for(auto it:adj[x]){
if(it != p and it != bigChild){
add(it,x,1);
}
}
///////////////////////////////////////////////////////////////////////////////////////////////
for(int s:correspondingX[x]){
for(int j:divisors[s]){
ans[{x,s}] += (cnt[j]);
}
}
///////////////////////////////////////////////////////////////////////////////////////////////
if(keep == 0){
add(x,p,-1);
}
}
void solve(){
ll q;
cin>>n>>q;
value = vll(n+1);
sz = vll(n+1);
cnt = vll(200005);
query = vector<Q> ();
correspondingX = vector<vll> (n+1);
adj = vector<vll> (n+1);
for(int i=1;i<=n;i++)
cin>>value[i];
for(int i=1;i<n;i++){
int u,v;
cin>>u>>v;
adj[u].pb(v);
adj[v].pb(u);
}
for(int i=0;i<q;i++){
int u,v,x;
cin>>u>>v>>x;
Q p(u,v,x);
query.push_back(p);
correspondingX[u].pb(x);
correspondingX[v].pb(x);
}
dfs_size(1,0);
dfs_cnt(1,0,0);
for(int i=0;i<q;i++){
int u=query[i].u;
int v=query[i].v;
int x=query[i].x;
int A = ans[{u,x}];
int B = ans[{v,x}];
if(A>B)cout<<u<<endl;
else if(B>A)cout<<v<<endl;
else cout<<"Draw"<<endl;
}
}
};
void solve(){
Solution S;
S.init();
S.pre();
S.solve();
}
signed main(){
fast;
ll t=1;
// sub;
clock_t clk = clock();
while(t--)
solve();
ttime;
return 0;
}
// Mischief Managed //
Tester's Solution
#include<bits/stdc++.h>
using namespace std;
//#include <ext/pb_ds/assoc_container.hpp> // Common file
//#include <ext/pb_ds/tree_policy.hpp> // Including tree_order_statistics_node_update
//using namespace __gnu_pbds; // order of key(keys strictly less than) // find_by_order
//typedef tree<long long,null_type,less<>,rb_tree_tag,tree_order_statistics_node_update> ordered_set;
//typedef tree<long long, null_type, less_equal<>, rb_tree_tag, tree_order_statistics_node_update> indexed_multiset;
//IF WA CHECK FOR : -
// 1 > EDGE CASES LIKE N=1 , N=0
// 2 > SIGNED INTEGER OVERFLOW IN MOD
// 3 > CHECK THE CODE FOR LOGICAL ERRORS AND SEG FAULTS
// 4 > READ THE PS ONCE AGAIN , if having double diff less than 1e-8 is same.
// 5 > You Have got AC .
#define ll long long
#define NUM (ll)998244353
#define inf (long long)(2e18)
#define ff first
#define ss second
#define f(i,a,b) for(ll i=a;(i)<long(b);(i)++)
#define fr(i,a,b) for(ll i=a;(i)>=(long long)(b);(i)--)
#define it(b) for(auto &it:(b))
#define pb push_back
#define mp make_pair
typedef vector<ll> vll;
typedef pair<ll,ll> pll;
ll binpow( ll base , ll ex,ll mod=NUM) {
ll ans = 1;base = base % mod;
if(base==0){
return 0;
}
while (ex > 0) {
if (ex % 2 == 1) {
ans = (ans * base) % mod;
}
base = (base * base) % mod;
ex = ex / 2;
}
return ans;
}
void read(vll &arr,ll n) {
if (arr.size() != n) { arr.assign(n, 0); }for (int i = 0; i < n; i++)cin >> arr[i];
}
inline ll min(ll a,ll b){
if(a>b)return b;return a;
}
inline ll max(ll a, ll b){
if(a>b)return a;return b;
}
inline ll dif(ll a,ll b) {
if (a > b)return a - b;return b - a;
}
long long gcd(long long a,long long b) {
if (b == 0)return a;return gcd(b, a % b);
}
long long lcm(long long a,long long b) {
long long k = gcd(a, b);
return (a * b) / k;
}
vector<vll>adj;
vll val,in,out;
ll tim = 0;
void dfs(ll start,ll par){
tim++;in[start]=tim;
it(adj[start]){
if(it!=par){
dfs(it,start);
}
}
out[start]=tim;
}
ll fun(vll &arr,ll x){
if(arr.empty()){
return 0;
}
ll low = 0;
ll high = arr.size()-1;
ll ans=-1;
while(low<=high){
ll mid = (low+high)/2;
if(arr[mid]>x){
high = mid-1;
}
else{
ans = mid;
low = mid+1;
}
}
return ans;
}
vector<vll>fact(2*1e5+1);
vector<vll>times(2*1e5+1);
void solve() {
int n,q;cin>>n>>q;
assert(n>=1 and n<=5*1e4 and q>=1 and q<=1e5);
adj.resize(n+1);val.resize(n+1);in=val;out=val;
vector<vll>ind(2*1e5+1);
f(i,0,n){
ll a;cin>>a;val[i+1]=a;assert(a<=2*1e5 and a>=1);
ind[a].pb(i+1);
}
f(i,0,n-1){
ll a,b;cin>>a>>b;
adj[a].pb(b);adj[b].pb(a);
}
dfs(1,1);
f(i,1,2*1e5+1) {
it(ind[i]) {
times[i].pb(in[it]);
}
sort(times[i].begin(), times[i].end());
}
set<ll>z;int xx=0;
while(q--){
xx++;
ll u,v,x;cin>>u>>v>>x;
z.insert(x);
assert(u<=n and v<=n and x<=2*1e5);
ll lef = 0;
ll rig =0;
it(fact[x]){
lef += fun(times[it],out[u])-fun(times[it],in[u]-1);
rig += fun(times[it],out[v])-fun(times[it],in[v]-1);
}
if(lef<rig){
cout<<v<<endl;
}
else if(lef==rig){
cout<<"Draw"<<endl;
}
else{
cout<<u<<endl;
}
}
assert(xx==z.size());
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cout << fixed << showpoint;
cout << setprecision(12);
long long test_m = 1;
int k=1;
//cin >> test_m;
//WE WILL WIN .
for(int i=1;i<=2*1e5;i++){
for(ll j=i;j<=2*1e5;j+=i){
fact[j].pb(i);
}
}
while (test_m--) {
//cout<<"Case #"<<k++<<": ";
solve();
}
}
Editorialist's Solution
// I solemnly swear that I am upto no good //
#include <bits/stdc++.h>
using namespace std;
#define sub freopen("input.txt", "r", stdin);//freopen("output.txt", "w", stdout);
#define ll long long
#define ttime {cerr << '\n'<<"Time (in s): " << double(clock() - clk) * 1.0 / CLOCKS_PER_SEC << '\n';}
#define helpUs template<typename T = ll , typename U = ll>
#define helpMe template<typename T = ll>
#define pb push_back
#define sz(x) (int)((x).size())
#define fast ios_base::sync_with_stdio(false);cin.tie(0);
#define all(x) (x).begin(),(x).end()
#define endl "\n"
#define int ll
typedef pair<ll,ll> pairll;
typedef map<ll,ll> mapll;
typedef map<char,ll> mapch;
typedef vector<ll> vll;
mt19937_64 rng(std::chrono::steady_clock::now().time_since_epoch().count());
helpUs class comp{public:bool operator()(T a, U b){return a>b;}};
helpUs istream& operator>>(istream& aa, pair<T,U> &p){aa>>p.ff>>p.ss;return aa;}
helpMe ostream& operator<<(ostream& ja, vector<T> &v){for(auto it:v)ja<<it<<" ";return ja;}
helpMe istream& operator>>(istream& aa, vector<T> &v){for(auto &it:v)cin>>it;return aa;}
helpUs ostream& operator<<(ostream& ja, pair<T,U> &p){ja<<p.ff<<" "<<p.ss;return ja;}
struct Solution{
ll n,k,q;
vector<vll> numTimes;
vector<vll> divisors;
vll values;
vector<vll> adj;
vll in, out;
int tim=0;
ll MAX = 200000;
void pre(){
for(int i=1;i<=MAX;i++){
int j=i;
while(j<=MAX){
divisors[j].pb(i);
j+=i;
}
}
}
void dfs(int x, int p){
in[x] = (++tim);
int val = values[x];
numTimes[val].push_back(in[x]);
for(auto it:adj[x]){
if(it != p){
dfs(it,x);
}
}
out[x] = tim;
}
int count(vector<int> &v, int t1, int t2){
int cnt = upper_bound(all(v) , t2) - lower_bound(all(v) , t1);
return cnt;
}
void solve(){
divisors = vector<vll> (200001);
pre();
cin>>n>>q;
numTimes = vector<vll> (MAX+1);
values = vll(n+1);
in = vll(n+1);
out = vll(n+1);
for(int i=1;i<=n;i++){
cin>>values[i];
}
set<pairll> edges;
set<int> xs;
set<vll> qrs;
adj = vector<vll> (n+1);
for(int i=1;i<n;i++){
int a,b;
cin>>a>>b;
adj[a].pb(b);
adj[b].pb(a);
}
dfs(1,0);
int temp=q;
while(temp--){
int u,v,x;
cin>>u>>v>>x;
int cnt_u=0, cnt_v=0;
for(auto it:divisors[x]){
cnt_u += count(numTimes[it] , in[u] , out[u]);
cnt_v += count(numTimes[it] , in[v] , out[v]);
}
if(cnt_u == cnt_v){
cout<<"Draw"<<endl;
}
else if(cnt_u > cnt_v){
cout<<u<<endl;
}
else{
cout<<v<<endl;
}
}
}
};
signed main(){
fast;
ll t=1;
// freopen("input.txt" , "r" , stdin);
// clock_t clk = clock();
Solution S;
S.solve();
// ttime;
return 0;
}
// Mischief Managed //