Problem statement
Contest source
Author, Editorialist : Dhananjay Raghuwanshi
Tester : Miten Shah
DIFFICULTY:
Easy
PREREQUISITES:
Tree Flattening, prefix sums, DFS and similar
PROBLEM:
Given a tree consisting of N nodes. Each node is given some value. We have to delete one node from the tree such that the sum of GCD of node values in all connected components formed after deleting that node is as maximum as possible.
EXPLANATION:
- One way is to delete each node one by one and check answer for each node and select maximum answer from them. But doing this lamely obviously will give TLE. So we will use tree flattening and prefix sum technique to do above task more optimally.
- First root the tree let at node 1 then flatten the tree by using tree flattening technique and along with it also store in and out time for each node and also store the GCD of each subtree and parent of each node by running a simple DFS from node 1. Now lastly precompute the GCD from both forward and backward side of the array obtained from tree flattening . Now we can easily calculate for each node the possible sum of GCD of node values in each connected component after removing that node, and can output maximum of them.
How to find answer for a particular node x
- To find answer for a node x, first let’s see what happens by deleting the node x. If we delete node x, then the total number of connected components formed will be Number of children of node x +1.
- Also for each child of node x we already have calculated the GCD of node values of all the node in subtree of it’s child as earlier we have precomputed GCD of each subtree. Now the part left is the part formed after removing the whole subtree of node x from the tree.
- To calculate GCD of all nodes values in this part we will make use of in and out time of node x and of prefix and suffix arrays of GCD’s we have earlier calculated.
- So this part is equal to the GCD(forward[in[x]-1],backward[out[x]+1]), the array forward[i] (prefix array) stores GCD of all the nodes from index 1 to i of the array obtained from tree flattening and backward[j] (suffix array) stores GCD of all the nodes from last index of array to index j of the array obtained from tree flattening.
- The answer for node x is equal to the sum all above parts we have calculated above.
SOLUTION :
c++ Solution
#include <bits/stdc++.h>
using namespace std;
#define int int long long
#define fr(i, n) for (int i = 0; i < n; i++)
#define fr1(i, n) for (int i = 1; i <= n; i++)
#define S second
#define F first
#define pb(n) push_back(n)
#define endl "\n"
vector<int> v[1000001];
int in[1000001] = {0};
int out[1000001] = {0};
int timer = 1, n, m, t, k, eq = 0;
int flatentree[1000001];
int parent[1000001];
int gcd_subtree[1000001];
int value[1000001];
int gcd(int a, int b)
{
if (b == 0)
return a;
else
return (gcd(b, a % b));
}
int dfs(int x, int p)
{
parent[x] = p;
flatentree[timer] = x;
in[x] = timer++;
gcd_subtree[x] = value[x];
for (auto child : v[x])
{
if (child != p)
{
gcd_subtree[x] = gcd(dfs(child, x), gcd_subtree[x]);
}
}
flatentree[timer] = x;
out[x] = timer++;
return (gcd_subtree[x]);
}
signed main()
{
cin >> t;
while (t--)
{
timer = 1;
cin >> n;
fr(i, n + 1)
{
v[i].clear();
}
fr1(i, n - 1)
{
int a, b;
cin >> a >> b;
v[a].pb(b);
v[b].pb(a);
}
fr1(i, n)
{
cin >> value[i];
}
dfs(1, -1);
int si = 2 * n;
int forward[si + 1], backward[si + 1];
forward[1] = backward[si] = value[1];
for (int i = 2; i <= si; i++)
{
forward[i] = gcd(forward[i - 1], value[flatentree[i]]);
}
for (int i = si - 1; i >= 1; i--)
{
backward[i] = gcd(backward[i + 1], value[flatentree[i]]);
}
int ans = 1;
int temp;
fr1(i, n)
{
vector<int> children;
for (auto child : v[i])
{
if (child != parent[i])
{
children.pb(gcd_subtree[child]);
}
}
if (children.size() != 0)
{
temp = 0;
for (auto j : children)
{
temp += j;
}
int temp1;
if (i != 1)
{
temp1 = gcd(forward[in[i] - 1], backward[out[i] + 1]);
ans = max(ans, temp + temp1);
}
else
{
ans = max(ans, temp);
}
}
else
{
temp = gcd(forward[in[i] - 1], backward[out[i] + 1]);
ans = max(ans, temp);
}
}
cout << ans << endl;
}
}
C++ Tester's solution
// created by mtnshh
#include<bits/stdc++.h>
#include<sys/resource.h>
using namespace std;
#define ll long long int
#define pb push_back
#define rb pop_back
#define ti tuple<int, int, int>
#define pii pair<int, int>
#define pli pair<ll, int>
#define pll pair<ll, ll>
#define mp make_pair
#define mt make_tuple
#define rep(i,a,b) for(ll i=a;i<b;i++)
#define repb(i,a,b) for(ll i=a;i>=b;i--)
#define err() cout<<"--------------------------"<<endl;
#define errA(A) for(auto i:A) cout<<i<<" ";cout<<endl;
#define err1(a) cout<<#a<<" "<<a<<endl
#define err2(a,b) cout<<#a<<" "<<a<<" "<<#b<<" "<<b<<endl
#define err3(a,b,c) cout<<#a<<" "<<a<<" "<<#b<<" "<<b<<" "<<#c<<" "<<c<<endl
#define err4(a,b,c,d) cout<<#a<<" "<<a<<" "<<#b<<" "<<b<<" "<<#c<<" "<<c<<" "<<#d<<" "<<d<<endl
#define all(A) A.begin(),A.end()
#define allr(A) A.rbegin(),A.rend()
#define ft first
#define sd second
#define V vector<ll>
#define S set<ll>
#define VV vector<V>
#define Vpll vector<pll>
// #define endl "\n"
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
// char g = getc(fp);
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
// cerr << x << " " << l << " " << r << endl;
assert(l<=x && x<=r);
return x;
} else {
assert(false);
}
}
}
string readString(int l,int r,char endd){
string ret="";
int cnt=0;
while(true){
char g=getchar();
// char g=getc(fp);
assert(g != -1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
return readString(l,r,'\n');
}
string readStringSp(int l,int r){
return readString(l,r,' ');
}
const int max_q = 1e3;
const int max_n = 1e5;
const int max_sum_n = 1e6;
const int max_ai = 1e9;
const ll N = 100005;
const ll INF = 1e12;
V adj[N];
ll A[N];
ll in[N], out[N], cnt = 0, pref[N], suff[N], ans = 0, ord[N];
void dfs1(ll n, ll p){
in[n] = cnt;
ord[cnt] = n;
cnt += 1;
for(auto i: adj[n]){
if(i == p) continue;
dfs1(i, n);
}
out[n] = cnt - 1;
}
ll dfs2(ll n, ll p){
ll sum = 0, gcd = A[n];
for(auto i: adj[n]){
if(i == p) continue;
ll q = dfs2(i, n);
gcd = __gcd(gcd, q);
sum += q;
}
sum += __gcd(pref[in[n]-1], suff[out[n]+1]);
ans = max(ans, sum);
return gcd;
}
void solve(ll n){
rep(i,0,n+1) adj[i].clear();
rep(i,0,n-1){
ll u = readIntSp(1, n), v = readIntLn(1, n);
adj[u].pb(v);
adj[v].pb(u);
}
rep(i,1,n+1) A[i] = (i != n) ? readIntSp(1, max_ai) : readIntLn(1, max_ai);
cnt = 1;
dfs1(1, 0);
pref[0] = 0;
suff[n+1] = 0;
rep(i,1,n+1){
pref[i] = __gcd(pref[i-1], A[ord[i]]);
}
repb(i,n,1){
suff[i] = __gcd(suff[i+1], A[ord[i]]);
}
ans = 0;
dfs2(1, 0);
cout << ans << endl;
}
int main(){
ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
ll q = readIntLn(1, max_q);
ll sum_n = 0;
rlimit R;
getrlimit(RLIMIT_STACK, &R);
R.rlim_cur = R.rlim_max;
setrlimit(RLIMIT_STACK, &R);
while(q--){
ll n = readIntLn(2, max_n);
solve(n);
sum_n += n;
}
assert(sum_n <= max_sum_n);
assert(getchar()==-1);
}