 # TREESUB - Editorial

Contest

Author: Naman Jain
Tester: Raja Vardhan Reddy
Editorialist: Rajarshi Basu

Medium

# PREREQUISITES:

Tree DP, Depth First Search

# PROBLEM:

We are given a rooted tree with N nodes (numbered 1 through N). Node 1 is the root. You are also given integer sequences x_1, x_2, \ldots, x_N and v_1, v_2, \ldots, v_N.

Let S be a subset of nodes. It is called valid if it is non-empty and the following conditions hold:

• There is no pair of nodes (i, j) such that i,j \in S and i is an ancestor of j.
• The greatest common divisor of the values x_i for all nodes i \in S (let’s denote it by G) is greater than 1.

Next, let’s define the value of S as G \cdot V, where G is defined above and V = \sum_{i \in S}{v_i}.

You need to find a valid subset of nodes with the maximum value.

• 1 \le T \le 100,000
• 1 \le N \le 100,000
• 1 \le x_i, v_i \le 100,000 for each valid i
• at least one valid subset exists
• the sum of N over all test cases does not exceed 1,000,000

# QUICK EXPLANATION:

We do a DFS, and maintain a global array A. We maintain the invariant that whenever we reach a node (say p) for the first time, A[i] = maximum sum of V[i] from a valid subset, not containing any ancestor of p, and the gcd of the subset is divisible by i. On visiting p, we store the values corresponding to the factors of X[p] in node p and call DFS on p's children’s. We make the updates, when backtracking from p [using the values we stored before], all while maintaining the invariant. We also keep updating the best ans and the gcd which gives rise to it during this process.

# EXPLANATION:

Observation 1

This problem has something to do with factors. Specifically, the gcd G can be a factor of any of the numbers. We know that the number of factors is approximately X_{max}^\frac{1}{3}. N can also be as large as 10^6 overall across all test cases. Hence it is fair to assume the intended complexity is O(NX_{max}^\frac{1}{3}). We cannot really introduce a log factor in there since that would definitely be TLE.

Observation 2

If we did not have to worry about G, and just had to maximise V, it would have been a TreeDP problem. Now, if we consider a separate Auxiliary Tree for each factor, the overall number of nodes would still be O(NX_{max}^\frac{1}{3}), since every node is present in at most
X_{max}^\frac{1}{3} different such Trees. After that, we could have just conducted a Treedp to maximise V for each tree separately.

Details of the DP?
// this is to get the best answer in one of the auxiliary trees
#define ll long long int
ll dfs1(int node,int p = -1){
ll sum = 0;
for(auto e : gg[node]){
if(e != p)sum += dfs1(e,node);
}
return max(sum,V[node]);
}


Unoptimal Solution 1

What if, we just separated each of the “factor trees”? This can be done using a stack in a DFS. Then, for each of the separate trees created, we do a simple DP, just as mentioned in observation 2. Time Complexity: O(NX_{max}^\frac{1}{3}). For more details on how to construct the individual factor trees, or “Auxiliary Trees” using stacks, see the below code.

Code
#include <iostream>
#include <vector>
#include <algorithm>
#include <fstream>
#include <queue>
#include <deque>
#include <iomanip>
#include <cmath>
#include <set>
#include <stack>
#include <map>
#include <unordered_map>

#define FOR(i,n) for(int i=0;i<n;i++)
#define FORE(i,a,b) for(int i=a;i<=b;i++)
#define ll long long
//#define int long long
#define ld long double
#define vi vector<ll>
#define pb push_back
#define ff first
#define ss second
#define ii pair<int,int>
#define iii pair<int,ii>
#define vv vector
#define endl '\n'

using namespace std;

const int MAXN = (100*1000 + 5);

vi g[MAXN];
vi facs[MAXN];
int x[MAXN];
int v[MAXN];

ll v2[MAXN];

vv<ii> allGraphs[MAXN];
vi lastOcc[MAXN];

vi gg[MAXN];
int mapValue[MAXN];
int revMap[MAXN];
vi nextNodes[MAXN];

// this is to find each of the auxiliary trees
// allGraphs[e] stores the Aux tree for factor e.
void dfs(int node,int p = -1){
if(node != 0)
for(auto e : facs[x[node]])
// we store as {node, parent} pairs. We will later retrieve the tree from this.
allGraphs[e].pb({node,lastOcc[e].back()});

if(node != 0)
for(auto e : facs[x[node]])
lastOcc[e].push_back(node);

for(auto e : g[node]){
if(e != p)dfs(e,node);
}

if(node!=0)
for(auto e : facs[x[node]])
lastOcc[e].pop_back();
}
ll dfs1(int node,int p = -1){
ll sum = 0;
for(auto e : gg[node]){
if(e != p)sum += dfs1(e,node);
}
return max(sum,v2[node]);
}
// this is to construct the answer in the best auxiliary tree
ll dfs2(int node,int p = -1){
ll sum = 0;
for(auto e : gg[node]){
if(e != p)sum += dfs2(e,node);
}
if(node != 0){
nextNodes[p].pb(node);
if(sum <= v2[node]){
nextNodes[node].clear();
}
}
return max(sum,v2[node]);
}

void precalc(){
for(int i = 1;i < MAXN;i++){
for(int j = i;j <MAXN;j+=i){
facs[j].pb(i);
}
}
}

void solve(){

int n;
cin >> n;
FOR(i,n+1){
g[i].clear();
gg[i].clear();

nextNodes[i].clear();
}
FOR(i,n-1){
int a,b;
cin >> a >> b;
g[a].pb(b);
g[b].pb(a);
}

vi usedFactors;
FOR(i,n){
cin >> x[i+1] >> v[i+1];
for(auto e : facs[x[i+1]])usedFactors.pb(e);
}

g.pb(1);
for(auto i : usedFactors){
lastOcc[i].push_back(0);
allGraphs[i].clear();
}

dfs(0);
for(auto i : usedFactors)lastOcc[i].pop_back();

ll best = 0;
int bestid = 0;
// this is to loop over all the factor trees.
for(auto i : usedFactors){
if(i == 1)continue;

mapValue = 0;
int id = 1;
if(allGraphs[i].size() == 0)continue;
for(auto e : allGraphs[i]){
// we make a map of the values to smaller values so as to avoid using a map.
mapValue[e.ff] = id;
revMap[id] = e.ff;
v2[id] = v[e.ff];
gg[mapValue[e.ss]].pb(mapValue[e.ff]);
id++;
}
ll val = dfs1(0);
if(val*i > best){
best = val*i;
bestid = i;
}
FOR(j,id)gg[j].clear();
}

// recreate
mapValue = 0;
revMap = 0;
int id = 1;
for(auto e : allGraphs[bestid]){
mapValue[e.ff] = id;
revMap[id] = e.ff;
v2[id] = v[e.ff];
gg[mapValue[e.ss]].pb(mapValue[e.ff]);
id++;
}

dfs2(0);
FOR(j,id)gg[j].clear();
vi allNodes;
queue<int> q;
q.push(0);
ll V = 0;
int G = bestid;
while(!q.empty()){
int nextNode = q.front();q.pop();

if(nextNodes[nextNode].size() == 0){
V += v2[nextNode];
allNodes.pb(revMap[nextNode]);
}
for(auto e : nextNodes[nextNode])q.push(e);
}

cout << G*V << " " << G << endl;
cout << allNodes.size() << endl;
for(auto e : allNodes)if(e> 0)cout << e << " ";cout << endl;

}

signed main(){
precalc();

ios_base::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
int t;
cin >> t;
while(t--){
solve();
}
return 0;
}


However, this TLEs by a large margin, probably due to the high constant factor of calling so many DFS’s, as well as using push_back and pop_back on vectors so many times.

Observation 3

Instead of constructing each tree separately, if we could do all the process simultaneously, it would be awesome right?

Unoptimal Solution 2

The first thought that comes to our mind is maybe maintain a map of factors (obviously we cannot maintain array of factors, due to memory constraints) for each node, and then while doing the DFS, if we just know for every node p and its factor f its closest ancestor which also has a factor f, we can do our dp simultaneously. But maps are costly in terms of efficiency as it contributes an additional log X_{max}, and this TLEs.

Full Solution
hint:

Instead of maintaining maps at each node, why not maintain a global array?

In Detail:

We will have a global array A[.] with the invariant that when we reach a node p, A[f] contains the best possible answer without having any ancestor of p, for the factor f. Next, we call dfs on all of p's children. Now, A[f] contains the best possible answer from the subtree of p. Now as in the normal DP, we have to either choose the sum of values from p's subtree, or p itself, for every factor f of X[p]. This is easy to do.

Even more details

Maintain the previous values when we had entered p for the first time for every factor f. Let it be called val_{f:prev}. After the DFS to all the children are complete, let the value in A[f] be val_{f:curr}. Thus we only need to compare val_{f:prev} and val_{f:curr} and see which to keep in A[f] when we exit node p.
While reconstructing the solution, we also need to keep track of the nodes, but it can also be done in a similar fashion. See Setter’s code for clarity.

Time Complexity:

As discussed, it is O(NX_{max}^\frac{1}{3}).

# SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;

template<class T> ostream& operator<<(ostream &os, vector<T> V) {
os << "[ "; for(auto v : V) os << v << " "; return os << "]";}
template<class L, class R> ostream& operator<<(ostream &os, pair<L,R> P) {
return os << "(" << P.first << "," << P.second << ")";}

#define TRACE
#ifdef TRACE
#define trace(...) __f(#__VA_ARGS__, __VA_ARGS__)
template <typename Arg1>
void __f(const char* name, Arg1&& arg1){
cout << name << " : " << arg1 << std::endl;
}
template <typename Arg1, typename... Args>
void __f(const char* names, Arg1&& arg1, Args&&... args){
const char* comma = strchr(names + 1, ',');cout.write(names, comma - names) << " : " << arg1<<" | ";__f(comma+1, args...);
}
#else
#define trace(...) 1
#endif

#define ll long long
#define ld long double
#define vll vector<ll>
#define pll pair<ll,ll>
#define vpll vector<pll>
#define I insert
#define pb push_back
#define F first
#define S second
#define endl "\n"
#define vi vector<int>
#define pii pair<int, int>
#define vpii vector< pii >

// const int mod=1e9+7;
// inline int mul(int a,int b){return (a*1ll*b)%mod;}
// inline int add(int a,int b){a+=b;if(a>=mod)a-=mod;return a;}
// inline int sub(int a,int b){a-=b;if(a<0)a+=mod;return a;}
// inline int power(int a,int b){int rt=1;while(b>0){if(b&1)rt=mul(rt,a);a=mul(a,a);b>>=1;}return rt;}
// inline int inv(int a){return power(a,mod-2);}
// inline void modadd(int &a,int &b){a+=b;if(a>=mod)a-=mod;}

const int M = 1e5+5;
vi fac[M];

void pre(){
for(int i=1;i<M;i++){
for(int j=i;j<M;j+=i){
fac[j].pb(i);
}
}
}

int ty = 0;
int cur_ty[M];
ll sumV[M];
pii lst[M];
vi g[M];
int v[M], x[M];
vll incV[M];
vpii incLst[M];
bool vis[M];
pll Ans;

inline void check(int z){
if(cur_ty[z]!=ty){
cur_ty[z] = ty; sumV[z] = 0; lst[z] = {-1, -1};
}
}

void dfs(int c){
vis[c]= 1;
incV[c].clear();
for(auto z:fac[x[c]]){
check(z);
incV[c].pb(sumV[z]+v[c]);
incLst[c].pb(lst[z]);
}
for(auto z:g[c]){
if(!vis[z]) dfs(z);
}
for(int i=0;i<fac[x[c]].size();i++){
int a = fac[x[c]][i];
if(incV[c][i] > sumV[a]){
sumV[a] = incV[c][i];
lst[a] = {c, i};
if(a!=1) Ans = max(Ans, make_pair(sumV[a]*1ll*a, (ll)a) );
}
}
}

void solve(){
int N; cin>>N; ty++;
for(int i=0;i<=N;i++){
g[i].clear();
incV[i].clear();
incLst[i].clear();
vis[i] = 0;
}
for(int i=0;i<N-1;i++){
int a, b; cin>>a>>b;
g[a].pb(b); g[b].pb(a);
}
for(int i=1;i<=N;i++){
cin>>x[i]>>v[i];
}
Ans = {0, 0};
dfs(1);
cout<<Ans.F<<" "<<Ans.S<<"\n";
assert(Ans.S > 1 && Ans.F > 0);
pii cur = lst[Ans.S];
vi subset;
while(cur.F !=-1){
subset.pb(cur.F);
cur = incLst[cur.F][cur.S];
}
cout<<subset.size()<<"\n";
for(auto z:subset) cout<<z<<" ";
cout<<"\n";
}

int main(){
ios_base::sync_with_stdio(false);cin.tie(0);cout.tie(0);cout<<setprecision(25);
pre();
int T; cin>>T;
while(T--){
solve();
}
}

Tester's Solution
//#pragma comment(linker, "/stack:200000000")
//#pragma GCC optimize("Ofast")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,avx,avx2")

#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
//setbase - cout << setbase (16)a; cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 <<endl;prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val

using namespace std;
using namespace __gnu_pbds;
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
//#define int ll

typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;

//std::ios::sync_with_stdio(false);

int par,x,v,divi;
ll maxx;
pii last;
vector<vl>sum(100005);
int res;
int ans=0;
int N=100002;
int act;
int sel;
int dp;
int dfs(int u,int pa){
int i;
vii temp;
temp.resize(divisor[x[u]].size());
sum[u].resize(divisor[x[u]].size());
rep(i,divisor[x[u]].size()){
sum[u][i]=0;
temp[i]=last[divisor[x[u]][i]];
last[divisor[x[u]][i]]=mp(u,i);
}
//child++;
}
}
pii p;
int q;
rep(i,divisor[x[u]].size()){
ll val=max((ll)v[u],sum[u][i]);
q=divisor[x[u]][i];
last[q]=temp[i];
sum[last[q].ff][last[q].ss]+=val;
}
return 0;
}
int solve(int u,int p,int d){
int i;
dp[u]=0;
sel[u]=0;
}
}
if(x[u]%d==0&&v[u]>=dp[u]){
dp[u]=v[u];
sel[u]=1;
}
return 0;
}
int dfs1(int u,int p){
if(sel[u]==1){
res[ans++]=u;
return 0;
}
int i;
}
}

}
int main(){
//std::ios::sync_with_stdio(false); cin.tie(NULL);
int t,i,j,iter=0;
sum[N].resize(N+2);
for(i=2;i<N;i++){
for(j=i;j<N;j+=i){
divisor[j].pb(i);
}
last[i].ff=N;
last[i].ss=i;
}
scanf("%d",&t);
while(t--){
int n,u,vv;
iter++;
scanf("%d",&n);
rep(i,n){
}
ans=0;
rep(i,n-1){
scanf("%d %d",&u,&vv);
u--;
vv--;
}
rep(i,n){
scanf("%d %d",&x[i],&v[i]);
}
ll val;
maxx=0;
dfs(0,-1);
rep(i,n){
if(act[x[i]]==iter){
continue;
}
act[x[i]]=iter;
rep(j,divisor[x[i]].size()){
val=sum[N][divisor[x[i]][j]];
if(val*divisor[x[i]][j]>maxx){
maxx=val*divisor[x[i]][j];
divi=divisor[x[i]][j];
}
sum[N][divisor[x[i]][j]]=0;
}
}
//return 0;
solve(0,-1,divi);
//return 0;
dfs1(0,-1);
//return 0;
printf("%lld %d\n",maxx,divi);
printf("%d\n",ans);
rep(i,ans){
printf("%d ",res[i]+1);
}
printf("\n");
}
return 0;
}



Please give me suggestions if anything is unclear so that I can improve. Thanks 4 Likes

The tough part was to print the nodes too. I was unable to come up how to do that (facepalm).

very well written editorial

Another excellent question! Enjoyed solving it post contest.

@rajarshi_basu you are the best editorialist!!
Hope you do write editorials for future contests too

1 Like

Haha, thanks.