TREETREE - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Abhinav Sharma
Testers: Nishank Suresh, Nishant Shah
Editorialist: Nishank Suresh

DIFFICULTY:

2266

PREREQUISITES:

Depth-first search, rerooting

PROBLEM:

You have N edge-weighted trees, and N-1 edges of weights A_1, \ldots, A_{N-1}. Use these N-1 edges to connect the trees into a single large tree while minimizing the sum of distances between all vertices of the large tree.

EXPLANATION:

This problem has several parts, so let’s go over them one by one.

First, for each input tree, compute the sum of distances between vertices belonging only to that tree. This can be done with plan dfs in \mathcal{O}(M_i) for a tree with M_i vertices, utilizing the fact that an edge of weight w between u and v contributes w\cdot sz_u \cdot (M_i - sz_u) to the sum. Here v is the parent of u in the dfs, and sz_u is the number of vertices in the subtree of u.

Now let’s look at what happens when we connect the trees with edges. We have to compute two new values:

  • The contribution of each newly added edge
  • The extra contribution of existing edges in terms of paths to all the other vertices

Subtask 1 (N = 2)

Let the two trees have M_1 and M_2 vertices. The newly added edge, with weight A_1, clearly contributes (M_1 \cdot M_2 \cdot A_1) to the answer. This takes care of the first part.

Now, suppose this edge joins vertex u in the first tree and vertex v in the second. How much do existing edges contribute to the answer?

If you think about it,

  • Every path in T_1 with one endpoint at u is added M_2 times more to the answer
  • Every path in T_2 with one endpoint at v is added M_1 times more to the answer

So, if D_1 denotes the sum of all path lengths in T_1 starting at u, and D_2 denotes the same for v in T_2, the answer increases by D_1 \cdot M_2 + D_2 \cdot M_1.

Now note that D_1 and D_2 are entirely independent. To minimize this sum, it is enough to pick the minimum possible value of D_1 across all u \in T_1, and the minimum possible value of D_2 across all v \in T_2.

To compute this minimum for a given tree, one can use the technique of rerooting. Compute the length of paths starting at a given vertex using a single dfs, then ‘reroot’ the tree at each vertex using another dfs, computing the change in the value each time. This can be implemented in linear time. An explanation of this technique can be found in the first part of this blog.

This solves the problem for N = 2.

Subtask 2

Now, we generalize this solution further. Let S = M_1 + M_2 + \ldots + M_N be the total number of vertices.

Let us compute the values D_1, D_2, \ldots, D_N for each of the N trees, i.e, the minimum value of (sum of all paths starting at a vertex).

It can be seen that the existing edges of the i-th tree contribute D_i \cdot (S - M_i) to the answer.

Proof

Consider some tree T. Suppose it has k external edges connected to it, to vertices u_1, u_2, \ldots, u_k. Let the total number of vertices on the other side of the i-th edge be s_i.

Let w_i be the sum of path lengths in T starting from vertex u_i. Then, the contribution of this configuration is

w_1\cdot s_1 + w_2\cdot s_2 + \ldots + w_k\cdot s_k

Clearly this is at least as large as D \cdot (s_1 + \ldots + s_k) where D is minimum possible sum of path lengths in T.

s_1 + \ldots + s_k is just the sum of number of vertices in all trees other than T, and so is a constant. Achieving D \cdot (s_1 + \ldots + s_k) is possible by connecting everything to a single vertex with that value of D.

All that remains is to compute the minimum contribution of the newly added edges A_1, A_2, \ldots, A_{N-1}.

To do this, let’s simplify the problem a little. Compress tree T_i into a single vertex, with a weight of M_i (i.e, its weight is its number of vertices). Now, adding the new N-1 edges is the same as creating a tree out of this graph, such that the contribution of an edge with weight A_i is as follows:

Let the sum of vertex weights on one side of this edge be x. Then, the sum of vertex weights on the other side is S - x, and the contribution of this edge is A_i \cdot x \cdot (S - x).

The question now is how to create a tree that minimizes the sum this value across all edges.

It turns out that the optimal way to do this is to create a ‘star’ graph, i.e, a tree with one ‘central’ vertex and every other vertex directly connected to this vertex.

Proof

This property can be proved by noticing the following fact: a tree is a star if and only if every path in the tree has length at most 2 edges.

Let’s take the example of four vertices. Suppose they were connected as

a \xleftrightarrow{w_1} b \xleftrightarrow{w_2} c \xleftrightarrow{w_3} d

where a, b, c, d are the vertex weights and w_1, w_2, w_3 are the edge weights.

The contribution in this setup is the sum of:

  • w_1 \cdot a \cdot (b + c + d)
  • w_2 \cdot (a + b) \cdot (c + d)
  • w_3 \cdot (a + b + c) \cdot d

Now suppose we remove the c \leftrightarrow d edge and make it a b \leftrightarrow d edge instead. The contribution is the sum of

  • w_1 \cdot a \cdot (b + c + d)
  • w_2 \cdot c \cdot (a + b + d)
  • w_3 \cdot d \cdot (a + b + c)

The first and third terms are the same, so the difference is only in the second term. Cancelling out terms, you can see that the first sum is strictly smaller if and only if a + b \lt c.

So when c \geq a+b, moving this edge to form a star gives us a better answer.

Now, consider the case when a + b \lt c. Let’s instead move the a \leftrightarrow b edge to form a star centered at c.

Similar calculation will tell you that the non-star graph has a lower weight if and only if c+d \lt b.

However, since all values are non-negative, it is not possible to have both c+d \lt b and a+b \lt c, because that tells us b \lt c and c \lt b: a contradiction.

So, in the case of 4 vertices, a non-star graph is never optimal.

For larger than 4 vertices, this proof can be mimicked by choosing a diameter of the tree, and moving one of its endpoints closer: at least one of these movements will always be possible without making the answer worse.

Note that each time this operation of edge-shifting is performed,

  • If there is exactly one diameter, the length of the diameter decreases by 1
  • otherwise, the number of diameters decreases

So, it is always possible to reach a star tree within a finite number of moves, hence completing the proof.

Once this property has been noted, calculating the minimum is simple. Note that we essentially want to match the N-1 edges to some N-1 vertices. If an edge of weight A_i is matched to a vertex of weight M_j, the contribution of such a matching is A_i \cdot M_j \cdot (S - M_j).

Note that the M_j \cdot (S - M_j) term depends only on M_j and is independent of A_i. So, if we create an array B such that B_j = M_j \cdot (S - M_j), we would like to minimize the sum of pairwise product of elements of A and B.

By the rearrangement inequality, solving this is simple:

  • Take the smallest N-1 elements of B, and sort them so that B_1 \leq B_2 \leq \ldots \leq B_{N-1}
  • Sort A so that A_1 \leq A_2 \leq \ldots \leq A_{N-1}
  • The answer is then A_1B_{N-1} + A_2B_{N-2} + \ldots + A_{N-1}B_1

TIME COMPLEXITY

\mathcal{O}(N\log N) per test case.

CODE:

Setter's Code (C++)
#include <bits/stdc++.h>
using namespace std;
#define ll long long
ll max(ll l, ll r){ if(l > r) return l ; return r;}
ll min(ll l, ll r){ if(l < r) return l ; return r;}

 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/
 
const int MAX_T = 1;
const int MAX_N = 100;
const int SUM_N = 300000;
const int MAX_VAL = 100; 
const int SUM_VAL = 20005 ;
const int OFFSET = 10000 ;

#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define ff first
#define ss second
#define mp make_pair
#define ll long long
#define rep(i,n) for(int i=0;i<n;i++)
#define rev(i,n) for(int i=n;i>=0;i--)
#define rep_a(i,a,n) for(int i=a;i<n;i++)
#define pb push_back
#define int ll

ll sum_n = 0, sum_m = 0, sum_nm = 0;
int max_n = 0, max_m = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;
ll mod = 998244353;

using ii = pair<ll,ll>;

vector<vector<vector<pair<int,ll> > > > adj;
vector<ll> sub_sz;

ll comp_wt(int c, int p, int t, ll &wi, ll &di, ll e_wt, int n){
    ll tmp = 1;
    for(auto h:adj[t][c]){
        if(h.ff!=p){
            tmp += comp_wt(h.ff, c, t, wi, di, h.ss, n);
        }
    }

    sub_sz[c] = tmp;

    wi += (tmp*(n-tmp)*e_wt)%mod;
    wi %= mod;
    di += tmp*e_wt;
    return tmp;
}

void comp_mn(int c, int p, int t, ll curr, ll &di, int n){
    for(auto h:adj[t][c]){
        if(h.ff!=p){
            ll tmp = curr+(n-2*sub_sz[h.ff])*h.ss;
            di = min(di, tmp);
            comp_mn(h.ff, c, t, tmp, di, n);
        }
    }
}


void solve()
{   
    int n = readIntLn(2,10000);

    adj.resize(n);
    vector<pair<ll,pair<ll,ll> > > z;
    ll tot = 0;
    rep(i,n){
        int m = readIntLn(1,5e4);
        sum_n += m;
        adj[i].assign(m, vector<pair<int,ll> >());

        rep(j, m-1){
            int x = readIntSp(1,m);
            int y = readIntSp(1,m);
            ll w = readIntLn(1,1e8);
            x--, y--;

            adj[i][x].pb(mp(y,w));
            adj[i][y].pb(mp(x,w));
        }

        tot+=m;

        ll wi = 0, di = 0;
        sub_sz.assign(m,0);
        assert(comp_wt(0,-1,i,wi,di,0,m)==m); // computes sum of distance of each pair of nodes within a tree

        comp_mn(0,-1,i,di,di,m); // computed minimum sum of distances from a particular node
        di%=mod;

        z.pb(mp(m, mp(wi,di)));
    }


    ll a[n-1];
    rep(i,n-1){
        if(i<n-2) a[i] = readIntSp(1,1e8);
        else a[i] = readIntLn(1,1e8);
    }

    sort(a,a+n-1);
    ll ans = 0;

    rep(i,n){
        ans += z[i].ss.ff;
        ans += (z[i].ss.ss*(tot-z[i].ff))%mod;
        ans %= mod;
        z[i].ff = (z[i].ff*(tot-z[i].ff));
    }

     sort(z.begin(), z.end());
     rep(i,n-1){
        z[i].ff%=mod;
        ans += (z[i].ff*a[n-2-i])%mod;
        ans %= mod;
     }

     cout<<ans<<'\n';
}
 
signed main()
{
    fast;
    #ifndef ONLINE_JUDGE
    freopen("input.txt" , "r" , stdin) ;
    freopen("output.txt" , "w" , stdout) ;
    #endif
    
    int t = 1;

    for(int i=1;i<=t;i++)
    {    
        solve() ;
    }
    
    assert(getchar() == -1);
    assert(sum_n<=1e5);
 
    cerr<<"SUCCESS\n";
}
Tester's Code (C++)
/*
   - Check file formatting
   - Assert every constraint
   - Analyze testdata
*/

#include <bits/stdc++.h>
using namespace std;

/*
---------Input Checker(ref : https://pastebin.com/Vk8tczPu )-----------
*/

long long readInt(long long l, long long r, char endd)
{
    long long x = 0;
    int cnt = 0;
    int fi = -1;
    bool is_neg = false;
    while (true)
    {
        char g = getchar();
        if (g == '-')
        {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if ('0' <= g && g <= '9')
        {
            x *= 10;
            x += g - '0';
            if (cnt == 0)
            {
                fi = g - '0';
            }
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);

            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if (g == endd)
        {
            if (is_neg)
            {
                x = -x;
            }

            if (!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        }
        else
        {
            assert(false);
        }
    }
}
string readString(int l, int r, char endd)
{
    string ret = "";
    int cnt = 0;
    while (true)
    {
        char g = getchar();
        assert(g != -1);
        if (g == endd)
        {
            break;
        }
        cnt++;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}
long long readIntSp(long long l, long long r)
{
    return readInt(l, r, ' ');
}
long long readIntLn(long long l, long long r)
{
    return readInt(l, r, '\n');
}

/*
-------------Main code starts here------------------------
*/

// Note here all the constants from constraints
const int MAX_N = 1e4;
const int MAX_M = 5e4;
const int MAX_A = 1e8;
const int SUM_M = 1e5;
const int MOD = 998244353;

// Variables to measure some parameters on test-data
long long sum_m = 0;
long long max_m = 0;

vector<pair<long long, long long>> g[MAX_M];

long long A[MAX_N];
long long M[MAX_N];
long long best_node_val[MAX_N];
long long sz[MAX_M];
long long dp[MAX_M];
long long dp2[MAX_M];

long long res = 0;
int cur_id;

void dfs(int node, int par = 0)
{
    sz[node] = 1;

    for (auto x : g[node])
    {
        if (x.first != par)
        {
            dfs(x.first, node);
            sz[node] += sz[x.first];

            dp[node] += dp[x.first];
            dp[node] += x.second * sz[x.first];

            dp2[node] += dp2[x.first];
            dp2[node] += x.second;
        }
    }

    for (auto x : g[node])
    {
        if (x.first != par)
        {
            res += x.second * sz[x.first] * (M[cur_id] - sz[x.first]);
            res %= MOD;
        }
    }
}

void dfs2(int node, int par = 0)
{
    best_node_val[cur_id] = min(best_node_val[cur_id], dp[node]);

    for (auto x : g[node])
    {
        if (x.first != par)
        {
            dp[x.first] = dp[node] + x.second * (M[cur_id] - sz[x.first] - sz[x.first]);
            dfs2(x.first, node);
        }
    }
}

void solve()
{
    int n;
    n = readIntLn(2, MAX_N);

    cerr << "N : " << n << '\n';

    for (int i = 0; i < n; i++)
    {
        M[i] = readIntLn(1, MAX_M);

        sum_m += M[i];
        max_m = max(max_m, M[i]);

        assert(sum_m <= SUM_M);

        for (int j = 0; j < M[i] - 1; j++)
        {
            int u, v, w;

            u = readIntSp(1, M[i]);
            v = readIntSp(1, M[i]);
            w = readIntLn(1, MAX_A);

            g[u].push_back({v, w});
            g[v].push_back({u, w});
        }

        cur_id = i;
        best_node_val[i] = 1e18;

        dfs(1);
        dfs2(1);

        for (int j = 1; j <= M[i]; j++)
        {
            sz[j] = 0;
            dp[j] = 0;
            dp2[j] = 0;
            g[j].clear();
        }
    }

    for (int i = 0; i < n; i++)
    {
        best_node_val[i] %= MOD;
        res += best_node_val[i] * (sum_m - M[i]);
    }

    sort(M, M + n);

    for (int i = 0; i < n - 1; i++)
    {
        if (i != n - 2)
        {
            A[i] = readIntSp(1, MAX_A);
        }
        else
        {
            A[i] = readIntLn(1, MAX_A);
        }
    }

    sort(A, A + n - 1);
    reverse(A, A + n - 1);

    for (int i = 0; i < n - 1; i++)
    {
        res += A[i] * M[i] * (sum_m - M[i]);
    }

    res %= MOD;

    cout << res;
}

signed main()
{
    solve();

    // Make sure there are no extra characters at the end of input
    assert(getchar() == -1);
    cerr << "SUCCESS\n";
    cerr << "Sum M : " << sum_m << '\n';
    cerr << "MAX M : " << max_m << '\n';

    // Some important parameters which can help identify weakness in testdata
}
Editorialist's Code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

/**
 * Integers modulo p, where p is a prime
 * Source: Aeren (modified from tourist?)
 *         Modmul for 64-bit mod from kactl:ModMulLL
 * Works with p < 7.2e18 with x87 80-bit long double, and p < 2^52 ~ 4.5e12 with 64-bit
 */
template<typename T>
struct Z_p{
	using Type = typename decay<decltype(T::value)>::type;
	static vector<Type> MOD_INV;
	constexpr Z_p(): value(){ }
	template<typename U> Z_p(const U &x){ value = normalize(x); }
	template<typename U> static Type normalize(const U &x){
		Type v;
		if(-mod() <= x && x < mod()) v = static_cast<Type>(x);
		else v = static_cast<Type>(x % mod());
		if(v < 0) v += mod();
		return v;
	}
	const Type& operator()() const{ return value; }
	template<typename U> explicit operator U() const{ return static_cast<U>(value); }
	constexpr static Type mod(){ return T::value; }
	Z_p &operator+=(const Z_p &otr){ if((value += otr.value) >= mod()) value -= mod(); return *this; }
	Z_p &operator-=(const Z_p &otr){ if((value -= otr.value) < 0) value += mod(); return *this; }
	template<typename U> Z_p &operator+=(const U &otr){ return *this += Z_p(otr); }
	template<typename U> Z_p &operator-=(const U &otr){ return *this -= Z_p(otr); }
	Z_p &operator++(){ return *this += 1; }
	Z_p &operator--(){ return *this -= 1; }
	Z_p operator++(int){ Z_p result(*this); *this += 1; return result; }
	Z_p operator--(int){ Z_p result(*this); *this -= 1; return result; }
	Z_p operator-() const{ return Z_p(-value); }
	template<typename U = T>
	typename enable_if<is_same<typename Z_p<U>::Type, int>::value, Z_p>::type &operator*=(const Z_p& rhs){
		#ifdef _WIN32
		uint64_t x = static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value);
		uint32_t xh = static_cast<uint32_t>(x >> 32), xl = static_cast<uint32_t>(x), d, m;
		asm(
			"divl %4; \n\t"
			: "=a" (d), "=d" (m)
			: "d" (xh), "a" (xl), "r" (mod())
		);
		value = m;
		#else
		value = normalize(static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value));
		#endif
		return *this;
	}
	template<typename U = T>
	typename enable_if<is_same<typename Z_p<U>::Type, int64_t>::value, Z_p>::type &operator*=(const Z_p &rhs){
		uint64_t ret = static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value) - static_cast<uint64_t>(mod()) * static_cast<uint64_t>(1.L / static_cast<uint64_t>(mod()) * static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value));
		value = normalize(static_cast<int64_t>(ret + static_cast<uint64_t>(mod()) * (ret < 0) - static_cast<uint64_t>(mod()) * (ret >= static_cast<uint64_t>(mod()))));
		return *this;
	}
	template<typename U = T>
	typename enable_if<!is_integral<typename Z_p<U>::Type>::value, Z_p>::type &operator*=(const Z_p &rhs){
		value = normalize(value * rhs.value);
		return *this;
	}
	template<typename U>
	Z_p &operator^=(U e){
		if(e < 0) *this = 1 / *this, e = -e;
		Z_p res = 1;
		for(; e; *this *= *this, e >>= 1) if(e & 1) res *= *this;
		return *this = res;
	}
	template<typename U>
	Z_p operator^(U e) const{
		return Z_p(*this) ^= e;
	}
	Z_p &operator/=(const Z_p &otr){
		Type a = otr.value, m = mod(), u = 0, v = 1;
		if(a < (int)MOD_INV.size()) return *this *= MOD_INV[a];
		while(a){
			Type t = m / a;
			m -= t * a; swap(a, m);
			u -= t * v; swap(u, v);
		}
		assert(m == 1);
		return *this *= u;
	}
	template<typename U> friend const Z_p<U> &abs(const Z_p<U> &v){ return v; }
	Type value;
};
template<typename T> bool operator==(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value == rhs.value; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(const Z_p<T>& lhs, U rhs){ return lhs == Z_p<T>(rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) == rhs; }
template<typename T> bool operator!=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(const Z_p<T> &lhs, U rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(U lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T> bool operator<(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value < rhs.value; }
template<typename T> bool operator>(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value > rhs.value; }
template<typename T> bool operator<=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value <= rhs.value; }
template<typename T> bool operator>=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value >= rhs.value; }
template<typename T> Z_p<T> operator+(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(const Z_p<T> &lhs, U rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T> Z_p<T> operator-(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T> Z_p<T> operator*(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T> Z_p<T> operator/(const Z_p<T> &lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(const Z_p<T>& lhs, U rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(U lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T> istream &operator>>(istream &in, Z_p<T> &number){
	typename common_type<typename Z_p<T>::Type, int64_t>::type x;
	in >> x;
	number.value = Z_p<T>::normalize(x);
	return in;
}
template<typename T> ostream &operator<<(ostream &out, const Z_p<T> &number){ return out << number(); }

/*
using ModType = int;
struct VarMod{ static ModType value; };
ModType VarMod::value;
ModType &mod = VarMod::value;
using Zp = Z_p<VarMod>;
*/

// constexpr int mod = 1e9 + 7; // 1000000007
constexpr int mod = (119 << 23) + 1; // 998244353
// constexpr int mod = 1e9 + 9; // 1000000009
using Zp = Z_p<integral_constant<decay<decltype(mod)>::type, mod>>;

template<typename T> vector<typename Z_p<T>::Type> Z_p<T>::MOD_INV;
template<typename T = integral_constant<decay<decltype(mod)>::type, mod>>
void precalc_inverse(int SZ){
	auto &inv = Z_p<T>::MOD_INV;
	if(inv.empty()) inv.assign(2, 1);
	for(; inv.size() <= SZ; ) inv.push_back((mod - 1LL * mod / (int)inv.size() * inv[mod % (int)inv.size()]) % mod);
}

template<typename T>
vector<T> precalc_power(T base, int SZ){
	vector<T> res(SZ + 1, 1);
	for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * base;
	return res;
}

template<typename T>
vector<T> precalc_factorial(int SZ){
	vector<T> res(SZ + 1, 1); res[0] = 1;
	for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * i;
	return res;
}

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	int n; cin >> n;
	Zp ans = 0;
	vector<array<ll, 2>> vals;
	Zp m = 0;
	for (int i = 0; i < n; ++i) {
		int s; cin >> s;
		m += s;
		vector<vector<array<int, 2>>> adj(s);
		for (int j = 0; j < s-1; ++j) {
			int u, v, w; cin >> u >> v >> w;
			adj[--u].push_back({--v, w});
			adj[v].push_back({u, w});
		}

		vector<int> subsz(s);
		auto dfs = [&] (const auto &self, int u, int p) -> ll {
			subsz[u] = 1;
			ll ret = 0;
			for (auto [v, w] : adj[u]) {
				if (v == p) continue;
				auto res = self(self, v, u);
				subsz[u] += subsz[v];
				ans += Zp(subsz[v])*(s-subsz[v])*w;
				ret += res + 1LL*w*subsz[v];
			}
			return ret;
		};
		auto reroot = [&] (const auto &self, int u, int p, ll cur) -> ll {
			ll ret = cur;
			for (auto [v, w] : adj[u]) {
				if (v == p) continue;
				ll res = self(self, v, u, cur - 1LL*w*subsz[v] + 1LL*w*(s-subsz[v]));
				ret = min(ret, res);
			}
			return ret;
		};
		ll mn = dfs(dfs, 0, 0);
		mn = min(mn, reroot(reroot, 0, 0, mn));
		vals.push_back({s, mn});
	}
	vector<ll> vals2;
	for (auto [s, mn] : vals) {
		ans += (m-s)*mn;
		vals2.push_back(1LL*s*(m-s).value);
	}
	sort(begin(vals2), end(vals2));
	vector<int> a(n-1);
	for (int &x : a) cin >> x;
	sort(rbegin(a), rend(a));
	for (int i = 0; i < n-1; ++i) ans += Zp(a[i])*vals2[i];

	cout << ans << '\n';
}
3 Likes

When N > 2, you don’t seem to show why it is optimal to attach the edges to the same vertex (the one that minimizes the sum for each tree). You showed the ‘condensed’ tree should be a star. Why could the edges not be attached to different vertices from the center tree in the star? I see that in this case the distance between two leaves of the condensed tree may be larger, but there may be shorter distances involving the vertices from the center tree, so it doesn’t seem obvious to me.

Ah, you’re right, I indeed did not prove that part.

However, the idea follows almost immediately from what was done for N = 2, and in fact does not even depend on the fact that the compressed tree is a star.

Proof

Consider some tree T. Suppose it has k external edges connected to it, to vertices u_1, u_2, \ldots, u_k. Let the total number of vertices on the other side of the i-th edge be s_i.

Let w_i be the sum of path lengths in T starting from vertex u_i. Then, the contribution of this configuration is

w_1\cdot s_1 + w_2\cdot s_2 + \ldots + w_k\cdot s_k

Clearly this is at least as large as D \cdot (s_1 + \ldots + s_k) where D is minimum possible sum of path lengths in T.

s_1 + \ldots + s_k is just the sum of number of vertices in all trees other than T, and so is a constant. Achieving D \cdot (s_1 + \ldots + s_k) is possible by connecting everything to a single vertex with that value of D.

I’ll add this to the editorial.

2 Likes

Let the sum of vertex weights on one side of this edge be x. Then, the sum of vertex weights on the other side is S−x, and the contribution of this edge is A_i⋅x⋅(A_i - x).

Shouldn’t the contribution of the edge be A_i⋅x⋅(S - x)? Please correct me if I’m wrong.

Correct, thanks for pointing it out. I’ve fixed it.