Problem statement
Contest source
Author: Abhinav
Editorialist : Raghav Agarwal
Tester : Miten Shah
DIFFICULTY:
Easy-medium
PREREQUISITES:
DFS
PROBLEM:
Given a tree, whose nodes have values \{ -2,-1,1,2 \}. Value of a simple path between two nodes is product of all the nodes in between them (including those nodes). Calculate the maximum value of any path.
EXPLANATION:
We run a DFS on the tree, for each node we will consider the best path in which the current node is the least common ancestor. To do this we need to choose two different children of the current node that give the maximum product. For example if the current node is positive then we need to select to two child nodes that can give the two largest positive or two largest negative products. While if the current node is negative, we need to select the child nodes such that one of them gives maximum negative and other gives maximum positive (if the child node for max positive and negative are same, then either positive or negative should be taken as the second largest).
DFS call on a node can return the maximum negative and positive path possible for that node to do this in linear time.
#SOLUTION :
c++ Solution (Author's)
#include <bits/stdc++.h>
using namespace std;
vector<vector<int>>adj;
vector<int>val;
int ans;
pair<int,int> dfs(int node,int parent=-1){
vector<pair<int,int>>pos,neg;
for(int &it:adj[node]){
if(it!=parent){
auto p=dfs(it,node);
pos.push_back({p.first,it});
if(p.second!=-1)
neg.push_back({p.second,it});
}
}
vector<pair<pair<int,int>,int>>v={{{0,1},-1},{{0,1},-2}};
for(int c=1;c>=0;c--){
if(pos.size()){
auto a=max_element(pos.begin(),pos.end());
v.push_back({{a->first,c},a->second});
if(pos.size()>1){
pair<int,int>b={-1,-1};
for(auto it=pos.begin();it!=pos.end();it++){
if(it!=a)
b=max(b,*it);
}
v.push_back({{b.first,c},b.second});
}
}
swap(pos,neg);
}
for(auto &x:v){
for(auto &y:v){
if(x.second==y.second)
continue;
if(x.first.second^y.first.second^(val[node]>0))
ans=max(ans,x.first.first+y.first.first+(abs(val[node])>1));
}
}
int max_pos=0,max_neg=-1;
for(auto &it:v){
if(val[node]>0){
if(it.first.second==1)
max_pos=max(max_pos,it.first.first+(val[node]==2));
else
max_neg=max(max_neg,it.first.first+(val[node]==2));
}
else{
if(it.first.second==1)
max_neg=max(max_neg,it.first.first+(val[node]==-2));
else
max_pos=max(max_pos,it.first.first+(val[node]==-2));
}
}
return make_pair(max_pos,max_neg);
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
const int mod=1e9+7;
int t;
cin>>t;
while(t--){
int n;
cin>>n;
adj.clear();adj.resize(n+1);
for(int i=0;i<n-1;i++){
int a,b;
cin>>a>>b;
adj[a].push_back(b);
adj[b].push_back(a);
}
val.resize(n+1);
for(int i=1;i<=n;i++)
cin>>val[i];
if(n==1){
cout<<(val[1]+mod)%mod<<'\n';
continue;
}
ans=0;
dfs(1);
int fans=1;
while(ans--)
fans=fans*2%mod;
cout<<fans<<'\n';
}
}
C++ Tester's Solution
// created by mtnshh
#include<bits/stdc++.h>
using namespace std;
#define ll long long int
#define pb push_back
#define rb pop_back
#define ti tuple<int, int, int>
#define pii pair<int, int>
#define pli pair<ll, int>
#define pll pair<ll, ll>
#define mp make_pair
#define mt make_tuple
#define rep(i,a,b) for(ll i=a;i<b;i++)
#define repb(i,a,b) for(ll i=a;i>=b;i--)
#define err() cout<<"--------------------------"<<endl;
#define errA(A) for(auto i:A) cout<<i<<" ";cout<<endl;
#define err1(a) cout<<#a<<" "<<a<<endl
#define err2(a,b) cout<<#a<<" "<<a<<" "<<#b<<" "<<b<<endl
#define err3(a,b,c) cout<<#a<<" "<<a<<" "<<#b<<" "<<b<<" "<<#c<<" "<<c<<endl
#define err4(a,b,c,d) cout<<#a<<" "<<a<<" "<<#b<<" "<<b<<" "<<#c<<" "<<c<<" "<<#d<<" "<<d<<endl
#define all(A) A.begin(),A.end()
#define allr(A) A.rbegin(),A.rend()
#define ft first
#define sd second
#define V vector<ll>
#define S set<ll>
#define VV vector<V>
#define Vpll vector<pll>
#define endl "\n"
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
// char g = getc(fp);
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
// cerr << x << " " << l << " " << r << endl;
assert(l<=x && x<=r);
return x;
} else {
assert(false);
}
}
}
string readString(int l,int r,char endd){
string ret="";
int cnt=0;
while(true){
char g=getchar();
// char g=getc(fp);
assert(g != -1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
return readString(l,r,'\n');
}
string readStringSp(int l,int r){
return readString(l,r,' ');
}
const int max_q = 1e4;
const int max_n = 1e5;
const int max_sum_n = 2e5;
const ll M = 1000000007;
const ll N = 100005;
const ll INF = 1e12;
ll power(ll a,ll n,ll m=M){
ll ans=1;
while(n){
if(n&1) ans=ans*a;
a=a*a;
n=n>>1;
ans=ans%m;
a=a%m;
}
return ans;
}
ll A[N];
V adj[N];
pll dp[N];
ll ans = 0;
pll change(ll x, pll p){
if(x == 1) return p;
if(x == -1) return {p.sd, p.ft};
if(x == 2) return {p.ft + 1, p.sd + 1};
if(x == -2) return {p.sd + 1, p.ft + 1};
}
void dfs(ll n, ll p){
dp[n] = {0, -INF};
Vpll pos, neg;
for(auto i: adj[n]){
if(i == p) continue;
dfs(i, n);
pos.pb({dp[i].ft, i});
neg.pb({dp[i].sd, i});
}
sort(allr(pos));
sort(allr(neg));
ll sz = pos.size();
if(sz >= 1){
dp[n].ft = max(dp[n].ft, pos[0].ft);
dp[n].sd = max(dp[n].sd, neg[0].ft);
}
dp[n] = change(A[n], dp[n]);
if(sz >= 2){
pll p = {0, -INF};
if(pos[0].sd == neg[0].sd){
p.sd = max(pos[0].ft + neg[1].ft, pos[1].ft + neg[0].ft);
}
else{
p.sd = pos[0].ft + neg[0].ft;
}
p.ft = max(pos[0].ft + pos[1].ft, neg[0].ft + neg[1].ft);
p = change(A[n], p);
ans = max(ans, p.ft);
}
ans = max(ans, dp[n].ft);
}
void solve(ll n){
rep(i,1,n+1) adj[i].clear();
rep(i,0,n-1){
ll u = readIntSp(1, n), v = readIntLn(1, n);
adj[u].pb(v);
adj[v].pb(u);
}
rep(i,1,n+1){
A[i] = i != n ? readIntSp(-2, 2) : readIntLn(-2, 2);
assert(A[i] != 0);
}
if(n == 1 and A[1] < 0){
cout << (A[1] + M) % M << endl;
return;
}
ans = 0;
dfs(1, 0);
cout << power(2, ans) << endl;
}
int main(){
ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
ll q = readIntLn(1, max_q);
ll sum_n = 0;
while(q--){
ll n = readIntLn(1, max_n);
solve(n);
sum_n += n;
}
assert(sum_n <= max_sum_n);
assert(getchar()==-1);
}