‘’’
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using ull = unsigned long long;
using ld = long double;
constexpr int mod = 1000000007;
constexpr ll INF = 2e18;
constexpr int esp = (int)1e-9;
define all(x) x.begin(),x.end()
define rev(x) reverse(all(x))
define int long long
void speed_io(){ios_base::sync_with_stdio(false);cin.tie(0);cout.tie(0);}
int add(int a,int b){ return (a%mod + b%mod)%mod; }
int sub(int a,int b){ return ((a%mod - b%mod) + mod)%mod; }
int mul(int a,int b){ return (a%mod * b%mod)%mod; }
void log(vector&a){for(int x:a)cout<<x<<’ ‘;cout<<’\n’;}
void log(vector<vector >&grid){for(int i=0;i<grid.size();i++){for(int j=0;j<grid[i].size();j++)cout<<grid[i][j]<<’ ‘;cout<<’\n’;}}
void log(set&s){for(auto it:s)cout<<it<<’ ‘;cout<<’\n’;}
void log(vector<pair<int,int> >&a){for(int i=0;i<a.size();i++){cout<<a[i].first<<’ ‘<<a[i].second<<’\n’;}}
void log(map<int,int>&m){for(auto it:m)cout<<it.first<<’ ‘<<it.second<<’\n’;}
void log(set<pair<int,int>>&s){for(auto it:s)cout<<it.first<<’ ‘<<it.second<<’\n’;}
void log(map<int,vector >&m){for(auto it:m){cout<<it.first<<“–>”;for(auto it2:it.second)cout<<it2<<’ ‘;cout<<’\n’;}}
ll power(ll a, ll b)
{
ll res=1;
a=a%mod;
while(b>0)
{
if(b&1){res=mul(res,a);b–;}
a=mul(a,a);
b>>=1;
}
return res;
}
ll fermat_inv(ll y){return power(y,mod-2ll);}
int GCD(int a,int b)
{
while(b)
{
a%=b;
swap(a,b);
}
return a;
}
int LCM(int a,int b)
{
return ((ll)a * b) / GCD(a,b);
}
class SGTree
{
public:
vector<long long>seg;
SGTree(int n)
{
seg.resize(4*n);
}
void build(int idx,int lo,int hi,vector<int>&a)
{
if(lo==hi)
{
seg[idx]=a[lo];
return ;
}
int mid = lo + (hi-lo)/2;
build(2*idx + 1,lo,mid,a);
build(2*idx + 2,mid+1,hi,a);
seg[idx] = seg[2*idx + 1] + seg[2*idx + 2];
}
void update(int idx,int lo,int hi,int i,int val)
{
if(lo==hi)
{
seg[idx]=val;
return ;
}
int mid = lo + (hi-lo)/2;
if(i<=mid)
update(2*idx+1 , lo , mid , i , val);
else
update(2*idx+2 , mid+1 , hi , i , val);
seg[idx] = seg[2*idx+1] + seg[2*idx+2];
}
int query(int idx,int lo,int hi,int l,int r)
{
//[lo hi l r] [l r lo hi] -> No overlapp
if(hi<l || r<lo)
return 0LL;
//[l lo hi r] Complete overlap
if(l<=lo && r>=hi)
return seg[idx];
//Partial overlap
int mid = lo + (hi-lo)/2;
long long left = query(2*idx+1,lo,mid,l,r);
long long right = query(2*idx+2,mid+1,hi,l,r);
return left+right;
}
};
int req[200005];
int moves[200005];
int dp[200005];
vectora;
int fact[2000005];
int factinv[2000005];
void factorial()
{
fact[0]=1;
factinv[0]=1;
for(int i=1;i<2000005;i++)
{
fact[i]=mul(i,fact[i-1]);
factinv[i]=fermat_inv(fact[i]);
}
}
void dfs(int src,vectoradj[],int prev=0)
{
for(auto adj_vertex:adj[src])
{
if(adj_vertex!=prev)
dfs(adj_vertex,adj,src);
}
req[src] = (a[src]>0) ? 1 : 0 ;
for(auto adj_vertex:adj[src])
{
if(adj_vertex!=prev && req[adj_vertex])
req[src]=1;
}
}
void dfs2(int src,vectoradj[],int prev=0)
{
for(auto adj_vertex:adj[src])
{
if(adj_vertex!=prev)
dfs2(adj_vertex,adj,src);
}
if(src==1)
moves[src]=a[src];
else
{
if(req[src])
moves[src]=1+a[src];
}
for(auto adj_vertex:adj[src])
{
if(adj_vertex!=prev)
moves[src]+=moves[adj_vertex];
}
}
void dfs3(int src,vectoradj[],int prev=0)
{
bool flg=false;
for(auto adj_vertex:adj[src])
{
if(adj_vertex!=prev && req[adj_vertex])
{
dfs3(adj_vertex,adj,src);
flg=true;
}
}
if(!flg)
{
dp[src]=1;
return ;
}
int ans = (src==1) ? fact[moves[src]] : fact[moves[src]-1];
int deno = factinv[a[src]];
for(auto adj_vertex:adj[src])
{
if(adj_vertex!=prev && req[adj_vertex])
deno = mul(deno,factinv[moves[adj_vertex]]);
}
ans=mul(ans,deno);
for(auto adj_vertex:adj[src])
{
if(adj_vertex!=prev && req[adj_vertex])
ans=mul(ans,dp[adj_vertex]);
}
dp[src]=ans;
}
void solve(int tc)
{
int n;
cin>>n;
a.resize(n+1,0);
for(int i=1;i<n+1;i++)
cin>>a[i];
vector<int>adj[n+1];
for(int i=0;i<n-1;i++)
{
int u,v;
cin>>u>>v;
adj[u].push_back(v);
adj[v].push_back(u);
}
memset(dp,0,sizeof dp);
memset(req,0,sizeof req);
memset(moves,0,sizeof moves);
dfs(1,adj);
dfs2(1,adj);
dfs3(1,adj);
cout<<dp[1]<<'\n';
}
int32_t main()
{
#ifndef ONLINE_JUDGE
freopen(“input.txt”, “r”, stdin);
freopen(“output.txt”, “w”, stdout);
#else
// online submission
#endif
speed_io();
cout<<setprecision(12)<<fixed;
factorial();
int t;
cin>>t;
for(int tc=1;tc<=t;tc++)
{
solve(tc);
}
return 0;
}
‘’’
Question : BUILDT Problem - CodeChef
The above code passes for all tests except test set 2 where it shows TLE…Please help!!