Practice

# EXPLANATION:

Let’s make few observations when the tree is rooted at some node -

1. At most 1 operation is required to be performed on every node of the tree.
2. Before performing an operation on some node, all of its ancestors must have been made perfect squares.
3. Except for the root node, an operation needs to be performed on some node if and only if the product of A[node] and A[ parent of that node ] is not a perfect square.
4. Operation must be performed on the root node if it is not a perfect square.

Let’s denote the vertices connected by the ith edge as X_i and Y_i. Also, let’s define function check(P) = 1 if P is not a perfect square, otherwise 0.

Then, from observation 3 and 4, we can deduce that -

F(x)=check(A[x])+Σ check(A[X_i]*A[Y_i]) ; ∀ edge i[1,N-1].

From the above equation, we can see that F(i) ; ∀ node i[1,N] can take at most 2 values. We can also observe that those 2 values will be consecutive. So, these 2 numbers are coprime to each other.

Let’s denote the smallest of those 2 numbers as a, then the other number must be a+1 if it exists. If a=0, it is obvious that the answer must be 0, otherwise to get the maximum answer, we have to equally divide a and a+1 across sets S_1 and S_2.

# SOLUTION:

Setter's Solution
``````    #include <bits/stdc++.h>
using namespace std;

long long int A, inf = pow(10,8);
long long int hell = pow(10,9) + 7;

// Function to check if x is a perfect square
bool check (long long int x){
long long int l = 1, r = inf;

while(l <= r){
long long int m = (l + r)/2;

if(m*m <= x){
l = m + 1;
}
else{
r = m - 1;
}
}

if(r*r == x){
return true;
}
else{
return false;
}
}

int main(){
int T;
cin >> T;

while(T){
int n;
cin >> n;

for(int  i = 1; i <= n; i++){
cin >> A[i];
}

int sum = 0;

for(int i = 1; i <= n-1; i++){
int a, b;
cin >> a >> b;

if(!check(A[a]*A[b])){
sum += 1;
}
}

map<int,int> m;

// Checking if root node is a perfect square
for(int i = 1; i <= n; i++){
if(!check(A[i])){
m[sum+1] += 1;
}
else{
m[sum] += 1;
}
}

long long int res = 1;

for(auto it = m.begin(); it != m.end(); it++){

int x = it->first, cnt = it->second, cur = 0;
while(cur<cnt/2){
res = (res*x)%hell;
cur++;
}
}

res=max(res,(long long int)1);

cout<<res<<"\n";

T--;
}
}
``````