WARTLND - Editorial

PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Setter: Anik Sarker, Ezio Auditore

Tester: Teja Vardhan Reddy

Editorialist: Taranpreet Singh

DIFFICULTY:

PREREQUISITES:

Inclusion-Exculsion, Union-find, Divisors.

PROBLEM:

Given a tree with N nodes connected with weighted edges, consider f(u, v) as the gcd of values on the path from u to v, we need to find the value of summation \sum_{i = 1}^{N}\sum_{j = i+1}^{N}f(i, j)

EXPLANATION

Let us consider a simpler problem first. For each x in the range [1, 10^5], we shall calculate the number of pairs of cities such that all edges on the path between each pair have weight divisible by x.

Suppose we are calculating this for a fixed value x. We can see, that we only need to consider edges such that the weight of the edge is divisible by x. So, we need to count the number of pair of cities which can be reached using only these edges. We have cities and a set of edges, we need to count the number of pairs of cities reachable from each other. We can see that for each connected component with z cities, there are (z*(z-1))/2 pair of cities reachable from each other in this component.

So, all we need is the sizes of the connected components. Also, for connected components of size 1, the pair of cities is zero, so we can ignore these components.

Let us use Union disjoint to initialized with N components and for each edge, we merge the components. Now, an important thing to notice is that all connected components with size greater than one are rooted at one of the vertices in endpoints of these edges. So, we can consider all vertices considered and if its the root of its component, we increase the count of pairs of cities by (z*(z-1))/2 where z is the size of the component. After that, we also need to reset the union disjoint components for the next value of z in time proportional to the number of edges. Both of these can be done by maintaining a separate boolean array. Refer implementation if not getting this.

This pretty much sums it up. We consider all values of x and for each x, consider only edges whose weight is multiple of x, use Union disjoint to obtain the number of pairs with gcd of values on the path being multiple of v.

Coming back to the original problem, Suppose f(x) denote the number of pairs (u, v) such that gcd on the path from u to v is x and g(x) denote the number of pairs (u, v) such that gcd on the path from u to v is a multiple of x. Using above process, we calculated g(x) for all values of x.

We can see, that g(x) = \sum_{k = 1} f(k*x) = f(x) + \sum_{k = 2}f(k*x). Also, f(x) = 0 for x > MX. So, we can rearrange this to form f(x) = g(x) - \sum_{k = 2}f(k*x). We can actually calculate this from MX to 1 since f(x) is dependent on g(x) and f(y) with y > x. Secondly, iterating over all multiples of x below MX takes N/x iterations, resulting in total N*(1+1/2+1/3 \ldots 1/N) iterations which is of the order of N*ln(N).

After knowing f(x) for each x in range [1, 10^5], the final answer is just \sum_{i = 1}^{MX} i*f(i) since each pair of city contribute value equal to gcd of path between them to the summation.

That wraps up this problem.

TIME COMPLEXITY

Since each edge is considered only the number of times equal to number of divisors of weight, and No value in range [1, 10^5] have more than 128 divisors (see here), this part doesn’t take more than 128*(N-1) iterations.

In each edge, we consider two vertices, so total 256*(N-1) vertices were considered (including same vertices many times) over whole process. So, we can reset the union disjoint in time proportional to this.

So, the overall Time complexity comes out to be O(256*N). (The inclusion-exclusion’s complexity is dominated by this.)

SOLUTIONS:

Setter 1 Solution
#include<bits/stdc++.h>
using namespace std;
#define ll long long int
const int MAX = 100005;
int mobius[MAX];
vector<int> D[MAX];
ll g[MAX], f[MAX];
int Par[MAX], Sz[MAX];
vector< pair<int,int> > Edge[MAX];

int Find(int u){
	if(Par[u] == u) return u;
	return Par[u] = Find(Par[u]);
}

int main(){
	mobius[1] = 1;
	for(int i=1; i<MAX; i++){
	    D[i].push_back(i);
	    for(int j=i+i; j<MAX; j+=i){
	        mobius[j] -= mobius[i];
	        D[j].push_back(i);
	    }
	}


	int t;
	scanf("%d",&t);
	assert(1<=t && t<=10);

	for(int cs=1; cs<=t; cs++) {
	    int n;
	    scanf("%d",&n);
	    assert(1<=n && n<=100000);

	    for(int i=1; i<MAX; i++) Edge[i].clear();
	    for(int i=1; i<=n; i++) Par[i] = i, Sz[i] = 1;
	    for(int i=1; i<n; i++){
	        int u,v,c;
	        scanf("%d %d %d",&u,&v,&c);

	        assert(1<=u && u<=n);
	        assert(1<=v && v<=n);
	        assert(1<=c && c<=100000);

	        for(int d : D[c]) Edge[d].push_back({u,v});
	    }

	    for(int i=1; i<MAX; i++){
	        g[i] = 0;
	        vector<int> rollBack;

	        for(auto edge : Edge[i]){
	            int u = edge.first;
	            int v = edge.second;
	            rollBack.push_back(u);
	            rollBack.push_back(v);

	            u = Find(u);
	            v = Find(v);
	            assert(u != v);

	            g[i] += Sz[u] * 1LL * Sz[v];
	            Par[v] = u;
	            Sz[u] += Sz[v];
	        }
	        for(int x : rollBack) Par[x] = x, Sz[x] = 1;
	    }
	    assert(g[1] + g[1] == n * 1LL * (n-1));

	    for(int i=1; i<MAX; i++){
	        f[i] = 0;
	        for(int y=i; y<MAX; y+=i){
	            f[i] += mobius[y/i] * g[y];
	        }
	    }

	    ll Sum = 0;
	    for(int i=1; i<MAX; i++) Sum += f[i] * i;
	    printf("%lld\n",Sum);
	}
}
Setter 2 Solution
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll maxn = 100009;

ll t, cs = 1;
vector < pair < ll , ll > > perval[maxn];
ll par[maxn], sz[maxn];
ll F[maxn];

ll fndpar(ll pos)
{
	if(par[pos] == pos) return pos;
	return par[pos] = fndpar(par[pos]);
}

int main()
{
//    freopen("input05.txt", "r", stdin);
//    freopen("output05.txt", "w", stdout);

	cin >> t;
	if(t < 1 || t > 10) assert(false);

	while(t--){
	    ll n;
	    scanf("%lld", &n);
	    if(n < 1 || n > 100000) assert(false);
	    for(ll i = 1; i < maxn; i++) perval[i].clear();
	    for(ll i = 1; i <= n; i++) par[i] = i, sz[i] = 1;

//        cout << "yo " << endl;

	    for(ll i = 1; i < n; i++){
	        ll x, y, z;
	        scanf("%lld %lld %lld", &x, &y, &z);
	        if(x < 1 || x > n || y < 1 || y > n) assert(false);
	        if(z < 1 || z > 100000) assert(false);

	        perval[z].push_back({x, y});

	        x = fndpar(x);
	        y = fndpar(y);

	        if(x == y){
	            assert(false);
	        }

	        if(x < 1 || x > n || y < 1|| y > n) assert(false);

	        par[y] = x;
	        sz[x] += sz[y];
	    }
	    ll xx = fndpar(1);
	    if(sz[xx] != n){
	        assert(false);
	    }


	    for(ll i = 1; i <= n; i++) par[i] = i, sz[i] = 1;

	    vector < ll > curpars;
	    ll ans = 0;

	    for(ll i = 100000; i >= 1; i--){
	        curpars.clear();
	        F[i] = 0;
	        for(ll j = i; j <= 100000; j += i){

	            for(auto e : perval[j]){
	                ll x = e.first;
	                ll y = e.second;
	                curpars.push_back(x);
	                curpars.push_back(y);
	                x = fndpar(x);
	                y = fndpar(y);

	                if(x == y) continue;
	                par[y] = x;
	                F[i] += sz[x] * sz[y];
	                sz[x] += sz[y];
	            }
	        }

	        for(ll j = i + i; j <= 100000; j += i) F[i] -= F[j];
	        for(ll p : curpars) par[p] = p, sz[p] = 1;

	        ans += i * F[i];

	    }

	    printf("%lld\n", ans);

	}

	return 0;
}
Tester's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp> 
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val

using namespace std;
using namespace __gnu_pbds;

#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout) 
#define primeDEN 727999983
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

// find_by_order()  // order_of_key
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;
#define int ll
int ans[123456];
int x[123456],y[123456],z[123456];
vector<vi> adj(123456),facto(123456);
vi vec;
int dsu[123456];
int paren(int u){
	//cout<<"dsa"<<endl;
	if(dsu[u]<0)
		return u;
	dsu[u]=paren(dsu[u]);
	return dsu[u];
}
int merge(int u,int v){
	u=paren(u);
	v=paren(v);
	if(dsu[u]==-1)
		vec.pb(u);
	if(dsu[v]==-1)
		vec.pb(v);
	if(u==v){
		return 0;
	}
	if(dsu[u]<dsu[v])
		swap(u,v);
	dsu[v]+=dsu[u];
	dsu[u]=v;
	return 0;
}
main(){
	std::ios::sync_with_stdio(false); cin.tie(NULL);
	int t;
	cin>>t;
	int i,j;
	for(i=1;i<=1e5;i++){
		for(j=i;j<=1e5;j+=i){
			facto[j].pb(i);
		}
	}
	rep(i,1e5+5){
		dsu[i]=-1;
	}
	while(t--){
		int n;
		cin>>n;
		int i;
		rep(i,1e5+5){
			adj[i].clear();
		}
		int j,k;
		rep(i,n-1){
			cin>>x[i]>>y[i]>>z[i];
			x[i]--;
			y[i]--;
			adj[z[i]].pb(i);
		}
		int ind;
		int tot=0;
		for(i=1e5;i>=1;i--){
			for(j=i;j<=1e5;j+=i){
				rep(k,adj[j].size()){
					ind = adj[j][k];
					merge(x[ind],y[ind]);
				}
			}
			// /return 0;
			rep(j,vec.size()){
				if(dsu[vec[j]]<0)
					ans[i]+=dsu[vec[j]]*(dsu[vec[j]]+1)/2;
				dsu[vec[j]]=-1;
			}
			vec.clear();
			tot+=ans[i]*i;
			rep(j,facto[i].size()){
				ans[facto[i][j]]-=ans[i];
			}
		}
		cout<<tot<<endl;
	}
	return 0;   
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class WARTLND{
	//SOLUTION BEGIN
	//Into the Hardware Mode
	void pre() throws Exception{}
	void solve(int TC) throws Exception{
	    int n = ni(), mx = (int)1e5;
	    int[][][] e = new int[1+mx][][];int[] cnt = new int[1+mx];
	    int[][] ee = new int[n-1][];
	    for(int i = 0; i< n-1; i++){
	        ee[i] = new int[]{ni()-1, ni()-1, ni()};
	        cnt[ee[i][2]]++;
	    }
	    for(int i = 1; i<= mx; i++)e[i] = new int[cnt[i]][];
	    for(int[] i:ee)e[i[2]][--cnt[i[2]]] = new int[]{i[0], i[1]};
	    int[][] set = new int[n][];
	    for(int i = 0; i< n; i++)set[i] = new int[]{i, 1};
	    long[] count = new long[1+mx];
	    int[] q = new int[2*n];
	    int qptr = 0;
	    boolean[] inc = new boolean[n];
	    for(int i = 1; i<= mx; i++){
	        qptr = 0;
	        for(int j = i; j<= mx; j+=i){
	            for(int[] edge:e[j]){
	                int u = find(set, edge[0]), v = find(set, edge[1]);
	                if(u == v)continue;
	                set[u][1] += set[v][1];
	                set[v][0] = u;
	                q[qptr++] = u;
	                q[qptr++] = v;
	            }
	        }
	        count[i] = 0;
	        for(int x = 0; x< qptr; x++){
	            int k = q[x];
	            if(find(set, k) == k && !inc[k])count[i] += ((set[k][1]-1)*(long)set[k][1])/2;
	            inc[k] = true;
	        }
	        for(int x = 0; x< qptr; x++){
	            int k = q[x];
	            set[k] = new int[]{k, 1};
	            inc[k] = false;
	        }
	    }
	    for(int i = mx; i>= 1; i--)
	        for(int j = i+i; j<= mx; j+=i)
	            count[i] -= count[j];
	    long ans = 0;
	    for(int i = 1; i<= mx; i++)
	        ans += i*count[i];
	    pn(ans);
	}
	int find(int[][] set, int u){return set[u][0] = (set[u][0] == u?u:find(set, set[u][0]));}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	long IINF = (long)1e18, mod = (long)1e9+7;
	final int INF = (int)1e9, MX = (int)2e5+5;
	DecimalFormat df = new DecimalFormat("0.00000000000");
	double PI = 3.141592653589793238462643383279502884197169399, eps = 1e-6;
	static boolean multipleTC = true, memory = false, fileIO = false;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    if(fileIO){
	        in = new FastReader("input.txt");
	        out = new PrintWriter("output.txt");
	    }else {
	        in = new FastReader();
	        out = new PrintWriter(System.out);
	    }
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    if(memory)new Thread(null, new Runnable() {public void run(){try{new WARTLND().run();}catch(Exception e){e.printStackTrace();}}}, "1", 1 << 28).start();
	    else new WARTLND().run();
	}
	long gcd(long a, long b){return (b==0)?a:gcd(b,a%b);}
	int gcd(int a, int b){return (b==0)?a:gcd(b,a%b);}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}
 
	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }
 
	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }
 
	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }
 
	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to share your approach, if you want to. (even if its same :stuck_out_tongue: ) . Suggestions are welcomed as always had been. :slight_smile:

2 Likes

Corporate wants to know the difference between this editorial and editorial of https://codeforces.com/contest/990/problem/G

15 Likes

Can you elaborate this part?

Yeah, it could be added in similar problems list… :laughing:

2 Likes

It could be added in “same problem list”… :stuck_out_tongue:

2 Likes

Initially all components are of size 1. Think when any component is merged. Only when there’s an edge connecting them. So, the components with size > 1 are the components to which endpoints of edges belong.

1 Like

Please update setter’s and tester’s solution.

Updated solutions now.

hey,I am not able to get the solution …just trying it from along time… please help…