TREESFUN343 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: satyam_343
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

DFS, Small-to-large merging

PROBLEM:

You have a tree on N vertices. Vertex i has value A_i.
The following operation can be performed:

  • Pick two vertices u and v such that d(u, v) = K, and an integer x, and add x to both A_u and A_v.

It is guaranteed that K is odd.
Find the minimum possible value of \sum_{i=1}^N |A_i| after some operations are performed.

EXPLANATION:

We’ll start out with a few easier variants to get intuition, and work our way up to the full solution.
Subtasks, if you will.

Case 1: The tree is a line (so we’re just working with an array A), and K = 1, meaning we operate on adjacent elements.
This version is not all that hard to solve.

Solution

Let S_1 = A_1 + A_3 + A_5 + \ldots denote the sum of values at odd indices, and S_2 = A_2 + A_4 + \ldots denote the sum of values at even indices.
Notice that each operation changes both S_1 and S_2 by x, since we choose one even- and one odd-indexed element.

In particular, let D = |S_1 - S_2|. This difference doesn’t change no matter what the operations are.
D is a lower bound for \sum_{i=1}^N |A_i|, and it’s not hard to attain exactly D either - for example, keep setting elements to zero from index 1, 2, 3, \ldots, N-1 at which point the last remaining element is either D or -D.


Case 2: The tree is a line, arbitrary K.
This follows fairly easily from the previous case.

Solution

For each i = 1, 2, 3, \ldots, K, let B^{(i)} = [A_i, A_{i+K}, A_{i+2K}, \ldots] denote the slice of A starting at i and with separation K.
Each operation on A now corresponds to an operation on adjacent elements of some B^{(i)}, and they’re all independent.

So, simply solve for each array B^{(i)} separately using the solution from case 1, and add up the answers.


Case 3: K = 1, but the tree can be anything.
This case also follows from the first relatively simply.

Solution

Rather than even and odd indices, we now look at depths: each move affects one vertex at even depth and one at odd depth.
Once again, if we let S_1 denote the sum of values of odd-depth vertices, and S_2 the sum of everything else, the answer is just |S_1 - S_2|.

Note that separating vertices of a tree by even/odd depth gives us a bipartite coloring of the tree.


We’re now ready to tackle the original problem: an arbitrary tree, and an arbitrary odd K.
Let’s use the ideas from cases 2 and 3 above:

  • We’ll partition the vertices into several “groups”, where u and v belong to the same group if d(u, v) = K.
    You can think of this as creating a new graph on N vertices with an edge between u and v iff d(u, v) = K, and then looking at connected components of this new graph.
  • Since K is odd, each such group inherits the bipartite coloring of the tree with respect to how the moves are performed.
    That is, each move will change the sum of two vertices on different sides of the bipartition.
  • So, for each group, find S_1 and S_2 as per the bipartition, and add |S_1 - S_2| to the answer. ^\dagger
    This needs a bit of care: we’re not on an array, so the simple greedy strategy to attain |S_1 - S_2| no longer works.
    A proof will be included at the bottom.

All we need to do now is figure out how to partition the vertices into their groups.

One way is to use small-to-large merging.
Let’s perform a DFS on the tree. Suppose we’re at vertex u, and its children are v_1, v_2, \ldots, v_m.
Let L_{u, d} be a list of vertices that are at depth d from u and in its subtree (so L_{u, 1} contains only the children of u, for example).

For each v_i, we’ll attempt to merge L_{v_i} into L_u.
When doing so, note that some new pairs of vertices with distance K form.
Specifically, for each 0 \leq d \lt K, each vertex in L_{v_i, d} wants to be in the same component as each vertex in (the current) L_{u, K-1-d}.

Now, notice that small-to-large merging allows us to maintain the L_u lists quickly enough; the only issue is that we may have \mathcal{O}(N^2) merges.

However, this can be optimized by noting that we don’t actually need to make every such merge.
In particular, suppose the vertices of L_{v_i, d} are merged with the vertices of L_{u, K-1-d}.
Then,

  • Rather than doing |L_{v_i, d}|\times |L_{u, K-1-d}| merges, it suffices to do |L_{v_i, d}| + |L_{u, K-1-d}| of them - merge every vertex in each set first, then do a single merge between the sets.
  • Further, after this, all but one element of L_{u, K-1-d} can be removed; since they’re all in the same set anyway.

This simple-looking optimization brings the number of merges down to \mathcal{O}(N\log N).
To see why:

  • The number of “cross-set” merges is \mathcal{O}(N\log N), because of the small-to-large part (there’s at most one such merge for each element of the smaller set).
  • As for “within-set” merges, the ones in the smaller set also total \mathcal{O}(N\log N) since they’re bounded by the size of the smaller set.
  • The larger set’s merges can be looked at as follows.
    For each vertex u, draw an edge connecting it to the next vertex in dfs order with the same depth as it (imagine this as a linked list).
    Then, each ‘deletion’ from the larger set corresponds to deleting one node from this linked list; which can of course happen only \mathcal{O}(N) times at most.

A nice way to implement this is to make L_u a deque of lists.
This allows for all distances to be “shifted” by 1 in constant time (with just a push_front operation).


^\dagger Here’s a proof of the claim that |S_1 - S_2| is always attainable within a group.

Proof

As noted, we have a bipartition based on depth.
First, trivially, all vertices of one side of the bipartition can be made to have value 0.

Now, observe that if we have vertices u and v on the same side of the bipartition, it’s always possible to do A_u \gets A_u + x and A_v \gets A_v - x for any integer x, because there exists a path from u to v of even length and we can alternate +x and -x along that path (so the values of intermediate vertices don’t change).

This way, it’s possible to set all but one value within the group to 0, and the final remaining value will of course be |S_1 - S_2|.

TIME COMPLEXITY:

\mathcal{O}(N\log N) per testcase.

CODE:

Author's code (C++)
//#pragma GCC optimize("O3,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;   
using namespace std;
#define ll long long
const ll INF_ADD=1e18;
#define pb push_back                  
#define mp make_pair          
#define nline "\n"                            
#define f first                                            
#define s second                                             
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()      
#define vl vector<ll>           
#define vvl vector<vector<ll>>      
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);  
#endif     
void _print(ll x){cerr<<x;}  
void _print(char x){cerr<<x;}   
void _print(string x){cerr<<x;}    
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());   
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------  
const ll MOD=998244353;
const ll MAX=500500;
vector<ll> adj[MAX];
ll n,k;
struct dsu{
    vector<ll> parent,height;
    ll n,len;
    dsu(ll n){
        this->n=n;
        parent.resize(n);
        height.resize(n);
        len=n;
        for(ll i=0;i<n;i++){
            parent[i]=i;
            height[i]=1;
        }
    }
    ll find_set(ll x){   
        return find_set(x,x); 
    } 
    ll find_set(ll x,ll orig){
        if(parent[x]==x){
            return x;   
        }
        parent[orig]=find_set(parent[x]);
        return parent[orig]; 
    }
    void union_set(ll u,ll v){
        //debug(mp(u,v));
        u=find_set(u),v=find_set(v);
        if(u==v){
            return;
        }
        len--; 
        if(height[u]<height[v]){
            swap(u,v); 
        }
        parent[v]=u;
        height[u]+=height[v]; 
    }
    ll getv(ll l){
        l=find_set(l);
        return height[l]; 
    }
}; 
map<ll,vector<ll>> track[MAX];
vector<ll> depth(MAX,0),use(MAX,0);
vector<ll> total_size(MAX,0);
void dfs(ll cur,ll par, dsu &groups){
    track[cur].clear();
    use[cur]=depth[cur]%k;
    //debug(mp(cur,use[cur]));
    track[cur][use[cur]].push_back(cur);
    total_size[cur]=1;
    for(auto chld:adj[cur]){
        if(chld!=par){
            depth[chld]=depth[cur]+1;
            dfs(chld,cur,groups);  
            //debug(cur);
            if(track[cur].size()<track[chld].size()){
                swap(track[chld],track[cur]);
                swap(total_size[chld],total_size[cur]);
            }
            for(auto &it:track[chld]){
                ll need=(k-(it.f-use[cur]))%k;
                need=(use[cur]+need)%k;
                //debug(mp(need,it.f));
                //assert(!it.s.empty());
                if(track[cur].find(need)==track[cur].end()){
                    continue;
                } 
                ll l=track[cur][need][0];
                auto &vec=it.s; 
                ll r=vec[0];
                groups.union_set(l,r);
                while(track[cur][need].size()!=1){
                    total_size[cur]--;
                    groups.union_set(track[cur][need].back(),r);
                    track[cur][need].pop_back();
                }
                while(vec.size()!=1){
                    total_size[chld]--;
                    groups.union_set(l,vec.back());
                    vec.pop_back();
                }
                //assert(it.s.size()==1);
            }
            if(total_size[cur]<total_size[chld]){
                swap(track[chld],track[cur]);
                swap(total_size[chld],total_size[cur]);
            }
            for(auto it:track[chld]){
                for(auto d:it.s){
                    track[cur][it.f].push_back(d);
                }
            }
        }
    }
}
void solve(){        
    cin>>n>>k; 
    vector<ll> a(n+5);
    for(ll i=1;i<=n;i++){
        cin>>a[i]; 
    } 
    for(ll i=1;i<n;i++){ 
        ll u,v; cin>>u>>v; 
        adj[u].push_back(v);
        adj[v].push_back(u);
    }  
    dsu groups(n+5);
    dfs(1,-1,groups);  
    vector<ll> check(n+5,0);
    ll freq=0;    
    for(ll i=1;i<=n;i++){
        if(groups.find_set(i)==i){
            freq++;  
        }
        // debug(i);
        adj[i].clear();
        ll p=groups.find_set(i);
        if(depth[i]&1){
            check[p]+=a[i]; 
        }
        else{
            check[p]-=a[i];
        }
    }
    debug(mp(k,freq));
    ll ans=0;
    for(auto it:check){
        ans+=abs(it);
    }
    cout<<ans<<nline;
    return;      
}                                           
int main()                                                                               
{     
    ios_base::sync_with_stdio(false);                         
    cin.tie(NULL);                               
    #ifndef ONLINE_JUDGE                 
    freopen("input.txt", "r", stdin);                                           
    freopen("output.txt", "w", stdout);      
    freopen("error.txt", "w", stderr);                        
    #endif     
    ll test_cases=1;                 
    cin>>test_cases;
    while(test_cases--){
        solve();
    }
    cout<<fixed<<setprecision(10);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
}  

Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

struct DSU {
private:
	std::vector<int> parent_or_size;
public:
	DSU(int n = 1): parent_or_size(n, -1) {}
	int get_root(int u) {
		if (parent_or_size[u] < 0) return u;
		return parent_or_size[u] = get_root(parent_or_size[u]);
	}
	int size(int u) { return -parent_or_size[get_root(u)]; }
	bool same_set(int u, int v) {return get_root(u) == get_root(v); }
	bool merge(int u, int v) {
		u = get_root(u), v = get_root(v);
		if (u == v) return false;
		if (parent_or_size[u] > parent_or_size[v]) std::swap(u, v);
		parent_or_size[u] += parent_or_size[v];
		parent_or_size[v] = u;
		return true;
	}
	std::vector<std::vector<int>> group_up() {
		int n = parent_or_size.size();
		std::vector<std::vector<int>> groups(n);
		for (int i = 0; i < n; ++i) {
			groups[get_root(i)].push_back(i);
		}
		groups.erase(std::remove_if(groups.begin(), groups.end(), [&](auto &s) { return s.empty(); }), groups.end());
		return groups;
	}
};

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n, k; cin >> n >> k;
        vector<int> a(n);
        for (int &x : a) cin >> x;
        vector adj(n, vector<int>());
        for (int i = 0; i < n-1; ++i) {
            int u, v; cin >> u >> v;
            adj[--u].push_back(--v);
            adj[v].push_back(u);
        }

        DSU D(n);
        vector<int> path, depth(n), sz(n);
        vector<deque<vector<int>>> lst(n);

        auto merge = [&] (int u, int v) {
            if ((int)lst[u].size() == k) {
                sz[u] -= lst[u].back().size();
                lst[u].pop_back();
            }
            if ((int)lst[v].size() == k) {
                sz[v] -= lst[v].back().size();
                lst[v].pop_back();
            }

            if (sz[u] < sz[v]) {
                swap(sz[u], sz[v]);
                swap(lst[u], lst[v]);
            }
            for (int i = 0; i < (int)lst[v].size(); ++i) {
                // This is depth i+1 from v
                // To merge with depth k-i-1 of u -> index k-i-2
                if (k-i-2 >= (int)lst[u].size()) continue;
                for (int y : lst[v][i]) {
                    while (true) {
                        int x = lst[u][k-i-2].back();
                        D.merge(x, y);
                        if (lst[u][k-i-2].size() == 1) break;
                        lst[u][k-i-2].pop_back();
                        --sz[u];
                    }
                }
            }
            for (int i = 0; i < (int)lst[v].size(); ++i) {
                sz[u] += lst[v][i].size();
                if (i == (int)lst[u].size()) lst[u].push_back(lst[v][i]);
                else {
                    for (int y : lst[v][i]) lst[u][i].push_back(y);
                }
            }
        };
        auto dfs = [&] (const auto &self, int u, int p) -> void {
            path.push_back(u);
            if (path.size() > k) D.merge(u, path[path.size()-k-1]);
            for (int v : adj[u]) {
                if (v == p) continue;
                depth[v] = 1 + depth[u];
                self(self, v, u);
                merge(u, v);
            }
            lst[u].push_front(vector(1, u));
            ++sz[u];
            path.pop_back();
        };
        dfs(dfs, 0, 0);
        ll ans = 0;
        for (auto g : D.group_up()) {
            ll s1 = 0, s2 = 0;
            for (int u : g) {
                if (depth[u]%2) s1 += a[u];
                else s2 += a[u];
            }
            ans += abs(s1 - s2);
        }
        cout << ans << '\n';
    }
}

author’s code isn’t visible. please add it

Oops, not sure how I missed that — thanks for pointing it out!
I’ve added it now.