TREESUB - Editorial

PROBLEM LINK:

Contest

Author: Naman Jain
Tester: Raja Vardhan Reddy
Editorialist: Rajarshi Basu

DIFFICULTY:

Medium

PREREQUISITES:

Tree DP, Depth First Search

PROBLEM:

We are given a rooted tree with N nodes (numbered 1 through N). Node 1 is the root. You are also given integer sequences x_1, x_2, \ldots, x_N and v_1, v_2, \ldots, v_N.

Let S be a subset of nodes. It is called valid if it is non-empty and the following conditions hold:

  • There is no pair of nodes (i, j) such that i,j \in S and i is an ancestor of j.
  • The greatest common divisor of the values x_i for all nodes i \in S (let’s denote it by G) is greater than 1.

Next, let’s define the value of S as G \cdot V, where G is defined above and V = \sum_{i \in S}{v_i}.

You need to find a valid subset of nodes with the maximum value.

  • 1 \le T \le 100,000
  • 1 \le N \le 100,000
  • 1 \le x_i, v_i \le 100,000 for each valid i
  • at least one valid subset exists
  • the sum of N over all test cases does not exceed 1,000,000

QUICK EXPLANATION:

We do a DFS, and maintain a global array A. We maintain the invariant that whenever we reach a node (say p) for the first time, A[i] = maximum sum of V[i] from a valid subset, not containing any ancestor of p, and the gcd of the subset is divisible by i. On visiting p, we store the values corresponding to the factors of X[p] in node p and call DFS on p's children’s. We make the updates, when backtracking from p [using the values we stored before], all while maintaining the invariant. We also keep updating the best ans and the gcd which gives rise to it during this process.

EXPLANATION:

Observation 1

This problem has something to do with factors. Specifically, the gcd G can be a factor of any of the numbers. We know that the number of factors is approximately X_{max}^\frac{1}{3}. N can also be as large as 10^6 overall across all test cases. Hence it is fair to assume the intended complexity is O(NX_{max}^\frac{1}{3}). We cannot really introduce a log factor in there since that would definitely be TLE.

Observation 2

If we did not have to worry about G, and just had to maximise V, it would have been a TreeDP problem. Now, if we consider a separate Auxiliary Tree for each factor, the overall number of nodes would still be O(NX_{max}^\frac{1}{3}), since every node is present in at most
X_{max}^\frac{1}{3} different such Trees. After that, we could have just conducted a Treedp to maximise V for each tree separately.

Details of the DP?
// this is to get the best answer in one of the auxiliary trees
#define ll long long int
ll dfs1(int node,int p = -1){
	ll sum = 0;
	for(auto e : gg[node]){
		if(e != p)sum += dfs1(e,node);
	}
	return max(sum,V[node]);
}

Unoptimal Solution 1

What if, we just separated each of the “factor trees”? This can be done using a stack in a DFS. Then, for each of the separate trees created, we do a simple DP, just as mentioned in observation 2. Time Complexity: O(NX_{max}^\frac{1}{3}). For more details on how to construct the individual factor trees, or “Auxiliary Trees” using stacks, see the below code.

Code
#include <iostream>
#include <vector>
#include <algorithm>
#include <fstream>
#include <queue>
#include <deque>
#include <iomanip>
#include <cmath>
#include <set>
#include <stack>
#include <map>
#include <unordered_map>

#define FOR(i,n) for(int i=0;i<n;i++)
#define FORE(i,a,b) for(int i=a;i<=b;i++)
#define ll long long 
//#define int long long
#define ld long double
#define vi vector<ll>
#define pb push_back
#define ff first
#define ss second
#define ii pair<int,int>
#define iii pair<int,ii>
#define vv vector
#define endl '\n'

using namespace std;

const int MAXN = (100*1000 + 5);

vi g[MAXN];
vi facs[MAXN];
int x[MAXN];
int v[MAXN];

ll v2[MAXN];

vv<ii> allGraphs[MAXN];
vi lastOcc[MAXN];

vi gg[MAXN];
int mapValue[MAXN];
int revMap[MAXN];
vi nextNodes[MAXN];

// this is to find each of the auxiliary trees
// allGraphs[e] stores the Aux tree for factor e. 
void dfs(int node,int p = -1){
	if(node != 0)
		for(auto e : facs[x[node]])
			// we store as {node, parent} pairs. We will later retrieve the tree from this.
			allGraphs[e].pb({node,lastOcc[e].back()}); 
	
	if(node != 0)
		for(auto e : facs[x[node]])
			lastOcc[e].push_back(node);

	for(auto e : g[node]){
		if(e != p)dfs(e,node);
	}

	if(node!=0)
		for(auto e : facs[x[node]])
			lastOcc[e].pop_back();
}
// this is to get the best answer in one of the auxiliary trees
ll dfs1(int node,int p = -1){
	ll sum = 0;
	for(auto e : gg[node]){
		if(e != p)sum += dfs1(e,node);
	}
	return max(sum,v2[node]);
}
// this is to construct the answer in the best auxiliary tree
ll dfs2(int node,int p = -1){
	ll sum = 0;
	for(auto e : gg[node]){
		if(e != p)sum += dfs2(e,node);
	}
	if(node != 0){
		nextNodes[p].pb(node);
		if(sum <= v2[node]){
			nextNodes[node].clear();
		}
	}
	return max(sum,v2[node]);
}

void precalc(){
	for(int i = 1;i < MAXN;i++){
		for(int j = i;j <MAXN;j+=i){
			facs[j].pb(i);
		}
	}
}

void solve(){

	int n;
	cin >> n;
	FOR(i,n+1){
		g[i].clear();
		gg[i].clear();
		
		nextNodes[i].clear();
	}
	FOR(i,n-1){
		int a,b;
		cin >> a >> b;
		g[a].pb(b);
		g[b].pb(a);
	}

	vi usedFactors;
	FOR(i,n){
		cin >> x[i+1] >> v[i+1];
		for(auto e : facs[x[i+1]])usedFactors.pb(e);
	}

	g[0].pb(1);
	for(auto i : usedFactors){
		lastOcc[i].push_back(0);
		allGraphs[i].clear();
	}

	dfs(0);
	for(auto i : usedFactors)lastOcc[i].pop_back();


	ll best = 0;
	int bestid = 0;
	// this is to loop over all the factor trees.
	for(auto i : usedFactors){
		if(i == 1)continue;

		mapValue[0] = 0;
		int id = 1;
		if(allGraphs[i].size() == 0)continue;
		for(auto e : allGraphs[i]){
			// we make a map of the values to smaller values so as to avoid using a map. 
			mapValue[e.ff] = id;
			revMap[id] = e.ff;
			v2[id] = v[e.ff];
			gg[mapValue[e.ss]].pb(mapValue[e.ff]);
			id++;
		}
		ll val = dfs1(0);
		if(val*i > best){
			best = val*i;
			bestid = i;
		}
		FOR(j,id)gg[j].clear();
	}

	
	// recreate 
	mapValue[0] = 0;
	revMap[0] = 0;
	int id = 1;
	for(auto e : allGraphs[bestid]){
		mapValue[e.ff] = id;
		revMap[id] = e.ff;
		v2[id] = v[e.ff];
		gg[mapValue[e.ss]].pb(mapValue[e.ff]);
		id++;
	}

	dfs2(0);
	FOR(j,id)gg[j].clear();
	vi allNodes;
	queue<int> q;
	q.push(0);
	ll V = 0;
	int G = bestid;
	while(!q.empty()){
		int nextNode = q.front();q.pop();
		
		if(nextNodes[nextNode].size() == 0){
			V += v2[nextNode];
			allNodes.pb(revMap[nextNode]);
		}
		for(auto e : nextNodes[nextNode])q.push(e);
	}


	cout << G*V << " " << G << endl;
	cout << allNodes.size() << endl;
	for(auto e : allNodes)if(e> 0)cout << e << " ";cout << endl;

}

signed main(){
	precalc();

	ios_base::sync_with_stdio(0);
	cin.tie(0);cout.tie(0);
	int t;
	cin >> t;
	while(t--){
		solve();
	}
	return 0;
}

However, this TLEs by a large margin, probably due to the high constant factor of calling so many DFS’s, as well as using push_back and pop_back on vectors so many times.

Observation 3

Instead of constructing each tree separately, if we could do all the process simultaneously, it would be awesome right?

Unoptimal Solution 2

The first thought that comes to our mind is maybe maintain a map of factors (obviously we cannot maintain array of factors, due to memory constraints) for each node, and then while doing the DFS, if we just know for every node p and its factor f its closest ancestor which also has a factor f, we can do our dp simultaneously. But maps are costly in terms of efficiency as it contributes an additional log X_{max}, and this TLEs.

Full Solution
hint:

Instead of maintaining maps at each node, why not maintain a global array?

In Detail:

We will have a global array A[.] with the invariant that when we reach a node p, A[f] contains the best possible answer without having any ancestor of p, for the factor f. Next, we call dfs on all of p's children. Now, A[f] contains the best possible answer from the subtree of p. Now as in the normal DP, we have to either choose the sum of values from p's subtree, or p itself, for every factor f of X[p]. This is easy to do.

Even more details

Maintain the previous values when we had entered p for the first time for every factor f. Let it be called val_{f:prev}. After the DFS to all the children are complete, let the value in A[f] be val_{f:curr}. Thus we only need to compare val_{f:prev} and val_{f:curr} and see which to keep in A[f] when we exit node p.
While reconstructing the solution, we also need to keep track of the nodes, but it can also be done in a similar fashion. See Setter’s code for clarity.


Time Complexity:

As discussed, it is O(NX_{max}^\frac{1}{3}).

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
 
template<class T> ostream& operator<<(ostream &os, vector<T> V) {
 os << "[ "; for(auto v : V) os << v << " "; return os << "]";}
template<class L, class R> ostream& operator<<(ostream &os, pair<L,R> P) {
	return os << "(" << P.first << "," << P.second << ")";}
 
#define TRACE
#ifdef TRACE
#define trace(...) __f(#__VA_ARGS__, __VA_ARGS__)
template <typename Arg1>
void __f(const char* name, Arg1&& arg1){
	cout << name << " : " << arg1 << std::endl;
}
template <typename Arg1, typename... Args>
void __f(const char* names, Arg1&& arg1, Args&&... args){
	const char* comma = strchr(names + 1, ',');cout.write(names, comma - names) << " : " << arg1<<" | ";__f(comma+1, args...);
}
#else
#define trace(...) 1
#endif
 
 
#define ll long long
#define ld long double
#define vll vector<ll>
#define pll pair<ll,ll>
#define vpll vector<pll>
#define I insert 
#define pb push_back
#define F first
#define S second
#define endl "\n"
#define vi vector<int>
#define pii pair<int, int>
#define vpii vector< pii >
 
 
// const int mod=1e9+7;
// inline int mul(int a,int b){return (a*1ll*b)%mod;}
// inline int add(int a,int b){a+=b;if(a>=mod)a-=mod;return a;}
// inline int sub(int a,int b){a-=b;if(a<0)a+=mod;return a;}
// inline int power(int a,int b){int rt=1;while(b>0){if(b&1)rt=mul(rt,a);a=mul(a,a);b>>=1;}return rt;}
// inline int inv(int a){return power(a,mod-2);}
// inline void modadd(int &a,int &b){a+=b;if(a>=mod)a-=mod;} 
 
const int M = 1e5+5;
vi fac[M];
 
void pre(){
	for(int i=1;i<M;i++){
		for(int j=i;j<M;j+=i){
			fac[j].pb(i);
		}
	}
}
 
int ty = 0;
int cur_ty[M];
ll sumV[M];
pii lst[M];
vi g[M];
int v[M], x[M];
vll incV[M];
vpii incLst[M];
bool vis[M];
pll Ans; 
 
inline void check(int z){
	if(cur_ty[z]!=ty){
		cur_ty[z] = ty; sumV[z] = 0; lst[z] = {-1, -1};
	}
}
 
 
void dfs(int c){
	vis[c]= 1;
	incV[c].clear();
	for(auto z:fac[x[c]]){
		check(z);
		incV[c].pb(sumV[z]+v[c]);
		incLst[c].pb(lst[z]);
	}
	for(auto z:g[c]){
		if(!vis[z]) dfs(z);
	}
	for(int i=0;i<fac[x[c]].size();i++){
		int a = fac[x[c]][i];
		if(incV[c][i] > sumV[a]){
			sumV[a] = incV[c][i];
			lst[a] = {c, i};
			if(a!=1) Ans = max(Ans, make_pair(sumV[a]*1ll*a, (ll)a) );
		}
	}
}
 
void solve(){
	int N; cin>>N; ty++;
	for(int i=0;i<=N;i++){ 
		g[i].clear();
		incV[i].clear();
		incLst[i].clear();
		vis[i] = 0;
	}
	for(int i=0;i<N-1;i++){
		int a, b; cin>>a>>b;
		g[a].pb(b); g[b].pb(a);
	}
	for(int i=1;i<=N;i++){
		cin>>x[i]>>v[i];
	}
	Ans = {0, 0};
	dfs(1);
	cout<<Ans.F<<" "<<Ans.S<<"\n";
	assert(Ans.S > 1 && Ans.F > 0);
	pii cur = lst[Ans.S];
	vi subset;
	while(cur.F !=-1){
		subset.pb(cur.F);
		cur = incLst[cur.F][cur.S];
	}
	cout<<subset.size()<<"\n";
	for(auto z:subset) cout<<z<<" ";
	cout<<"\n";
}
 
 
int main(){
	ios_base::sync_with_stdio(false);cin.tie(0);cout.tie(0);cout<<setprecision(25);
	pre();
	int T; cin>>T;
	while(T--){
		solve();
	}
}
Tester's Solution
//#pragma comment(linker, "/stack:200000000")
//#pragma GCC optimize("Ofast")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,avx,avx2")
 
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string> 
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip> 
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp> 
//setbase - cout << setbase (16)a; cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 <<endl;prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val
 
using namespace std;
using namespace __gnu_pbds;
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
//#define int ll
 
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;
 
 
//std::ios::sync_with_stdio(false);
 
int par[100005],x[100005],v[100005],divi;
ll maxx;
pii last[100005];
vector<vi>adj(100005),divisor(100005);
vector<vl>sum(100005);
int res[100005];
int ans=0;
int N=100002;
int act[100005];
int sel[100005];
int dp[100005];
int dfs(int u,int pa){
	int i;
	vii temp;
	temp.resize(divisor[x[u]].size());
	sum[u].resize(divisor[x[u]].size());
	rep(i,divisor[x[u]].size()){
		sum[u][i]=0;
		temp[i]=last[divisor[x[u]][i]];
		last[divisor[x[u]][i]]=mp(u,i);
	}
	rep(i,adj[u].size()){
		if(adj[u][i]!=pa){
			dfs(adj[u][i],u);
			//child++;
		}
	}
	pii p;
	int q;
	rep(i,divisor[x[u]].size()){
		ll val=max((ll)v[u],sum[u][i]);
		q=divisor[x[u]][i];
		last[q]=temp[i];
		sum[last[q].ff][last[q].ss]+=val;
	}
	return 0;
}
int solve(int u,int p,int d){
	int i;
	dp[u]=0;
	sel[u]=0;
	rep(i,adj[u].size()){
		if(adj[u][i]!=p){
			solve(adj[u][i],u,d);
			dp[u]+=dp[adj[u][i]];
		}
	}
	if(x[u]%d==0&&v[u]>=dp[u]){
		dp[u]=v[u];
		sel[u]=1;
	}
	return 0;
}
int dfs1(int u,int p){
	if(sel[u]==1){
		res[ans++]=u;
		return 0;
	}
	int i;
	rep(i,adj[u].size()){
		if(adj[u][i]!=p){
			dfs1(adj[u][i],u);
		}
	}
 
}
int main(){
	//std::ios::sync_with_stdio(false); cin.tie(NULL);
	int t,i,j,iter=0;
	sum[N].resize(N+2);
	for(i=2;i<N;i++){
		for(j=i;j<N;j+=i){
			divisor[j].pb(i);
		}
		last[i].ff=N;
		last[i].ss=i;
	}
	scanf("%d",&t);
	while(t--){
		int n,u,vv;
		iter++;
		scanf("%d",&n);
		rep(i,n){
			adj[i].clear();
		}
		ans=0;
		rep(i,n-1){
			scanf("%d %d",&u,&vv);
			u--;
			vv--;
			adj[u].pb(vv);
			adj[vv].pb(u);
		}
		rep(i,n){
			scanf("%d %d",&x[i],&v[i]);
		}
		ll val;
		maxx=0;
		dfs(0,-1);
		rep(i,n){
			if(act[x[i]]==iter){
				continue;
			}
			act[x[i]]=iter;
			rep(j,divisor[x[i]].size()){
				val=sum[N][divisor[x[i]][j]];
				if(val*divisor[x[i]][j]>maxx){
					maxx=val*divisor[x[i]][j];
					divi=divisor[x[i]][j];
				}
				sum[N][divisor[x[i]][j]]=0;
			}
		}
		//return 0;
		solve(0,-1,divi);
		//return 0;
		dfs1(0,-1);
		//return 0;
		printf("%lld %d\n",maxx,divi);
		printf("%d\n",ans);
		rep(i,ans){
			printf("%d ",res[i]+1);
		}
		printf("\n");
	}
	return 0;
} 
	

Please give me suggestions if anything is unclear so that I can improve. Thanks :slight_smile:

4 Likes

The tough part was to print the nodes too. I was unable to come up how to do that (facepalm).

very well written editorial

Another excellent question! Enjoyed solving it post contest.

@rajarshi_basu you are the best editorialist!!
Hope you do write editorials for future contests too

1 Like

Haha, thanks.