PROBLEM LINK:
EXPLANATION:
Let’s make few observations when the tree is rooted at some node -
- At most 1 operation is required to be performed on every node of the tree.
- Before performing an operation on some node, all of its ancestors must have been made perfect squares.
- Except for the root node, an operation needs to be performed on some node if and only if the product of A[node] and A[ parent of that node ] is not a perfect square.
- Operation must be performed on the root node if it is not a perfect square.
Let’s denote the vertices connected by the ith edge as X_i and Y_i. Also, let’s define function check(P) = 1 if P is not a perfect square, otherwise 0.
Then, from observation 3 and 4, we can deduce that -
F(x)=check(A[x])+Σ check(A[X_i]*A[Y_i]) ; ∀ edge i ∈ [1,N-1].
From the above equation, we can see that F(i) ; ∀ node i ∈ [1,N] can take at most 2 values. We can also observe that those 2 values will be consecutive. So, these 2 numbers are coprime to each other.
Let’s denote the smallest of those 2 numbers as a, then the other number must be a+1 if it exists. If a=0, it is obvious that the answer must be 0, otherwise to get the maximum answer, we have to equally divide a and a+1 across sets S_1 and S_2.
SOLUTION:
Setter's Solution
#include <bits/stdc++.h>
using namespace std;
long long int A[100005], inf = pow(10,8);
long long int hell = pow(10,9) + 7;
// Function to check if x is a perfect square
bool check (long long int x){
long long int l = 1, r = inf;
while(l <= r){
long long int m = (l + r)/2;
if(m*m <= x){
l = m + 1;
}
else{
r = m - 1;
}
}
if(r*r == x){
return true;
}
else{
return false;
}
}
int main(){
int T;
cin >> T;
while(T){
int n;
cin >> n;
for(int i = 1; i <= n; i++){
cin >> A[i];
}
int sum = 0;
for(int i = 1; i <= n-1; i++){
int a, b;
cin >> a >> b;
if(!check(A[a]*A[b])){
sum += 1;
}
}
map<int,int> m;
// Checking if root node is a perfect square
for(int i = 1; i <= n; i++){
if(!check(A[i])){
m[sum+1] += 1;
}
else{
m[sum] += 1;
}
}
long long int res = 1;
for(auto it = m.begin(); it != m.end(); it++){
int x = it->first, cnt = it->second, cur = 0;
while(cur<cnt/2){
res = (res*x)%hell;
cur++;
}
}
res=max(res,(long long int)1);
cout<<res<<"\n";
T--;
}
}