TREESAREFUN7 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: satyam_343
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Heavy-light decomposition, segment trees

PROBLEM:

There’s a tree with N vertices. Vertex u has value A_u.
There are Q updates of the following form: given u and v, swap A_u and A_v.
After each update, do the following:

  • Let P_u denote the sorted set of values in the subtree of u.
  • Let L = \min(P_1, P_2, \ldots, P_N) when ordered lexicographically.
  • Compute the sum of all u such that P_u = L.

EXPLANATION:

Our first order of business is to analyze the query structure to figure out what we need to maintain.

Observe that the root (vertex 1) contains all the values in the tree, and so P_1 is almost lexicographically smallest.
The only way that some P_u can be even smaller than it, is if P_u forms a prefix of the sorted values.
That is, if \text{sz}_u denotes the size of the subtree of u, then the subtree of u should contain the smallest \text{sz}_u values.

Let B denote the sorted set of values of A. Then, it’s easy to see that a vertex u should contain the first \text{sz}_u elements of B (and in particular, this condition
is nice because it doesn’t change no matter what swaps are made).

Here’s one way of checking for this condition.
For each vertex u, define D_u to be the difference between the current sum of values in its subtree, and the values we want to be in its subtree.
That is,

D_u = \sum_{v \in S(u)} A_v - \sum_{i=1}^{\text{sz}_u} B_i

where S(u) denotes the set of all vertices in the subtree of u.
Clearly, D_u \geq 0 for every vertex, and D_u = 0 if and only if u contains all the required values in its subtree (and hence P_u is lexicographically smallest).
Further, note that as we observed earlier, D_1 = 0 always.

Among all the vertices with D_u = 0, we want the ones with shortest P_u: in other words, the ones with smallest \text{sz}_u.


Let’s first precompute the D_u values for each vertex, which is easy to do with a DFS.
Now, suppose an update swaps the values of vertices x and y.
Notice that the only changes are:

  • For all ancestors of x, their D_u values change by A_y - A_x.
  • For all ancestors of y, their D_u values change by A_x - A_y.

So, we need to support the following:

  • Add a value to a root → vertex path
  • Find the sum of all u such that D_u = 0 and \text{sz}_u is minimum.
    For this, note that D_u \geq 0 and D_1 = 0 means that we really just want the sum of all u such that the pair (D_u, \text{sz}_u) is minimum.

This is a rather standard problem that can be solved with the help of heavy-light decomposition and a lazy segment tree.

TIME COMPLEXITY:

\mathcal{O}((N + Q)\log^2 N) per testcase.

CODE:

Author's code (C++)
#pragma GCC optimize("O3,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;   
using namespace std;
#define ll long long
#define pb push_back                  
#define mp make_pair          
#define nline "\n"                            
#define f first                                            
#define s second                                             
#define pll pair<ll,ll>  
#define all(x) x.begin(),x.end()     
#define vl vector<ll>           
#define vvl vector<vector<ll>>      
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);  
#endif     
void _print(ll x){cerr<<x;}  
void _print(char x){cerr<<x;}    
void _print(string x){cerr<<x;}    
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());   
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------  
const ll MOD=1e9+7;
const ll MAX=500500;
struct HLD{
    ll n,now;
    vector<vector<ll>> adj;
    vector<ll> position,top,order;
    vector<ll> depth,subtree_size,parent;
    HLD(ll n){
        this->n=n;
        subtree_size.resize(n);
        position.resize(n);
        top.resize(n);
        order.resize(n);
        depth.resize(n);
        parent.resize(n);
        adj.assign(n,{});
        now=0;
    }
    void addEdge(ll u,ll v){
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    void work(ll root=0){
        top[root]=root;
        depth[root]=0;
        parent[root]=-1;
        dfs_1(root);
        dfs_2(root);
    }
    void dfs_1(ll cur){
        if(parent[cur]!=-1){
            adj[cur].erase(find(adj[cur].begin(),adj[cur].end(),parent[cur]));
        }
        subtree_size[cur]=1;
        for(auto &chld:adj[cur]){
            depth[chld]=depth[cur]+1;
            parent[chld]=cur;
            dfs_1(chld);
            subtree_size[cur]+=subtree_size[chld];
            if(subtree_size[chld]>subtree_size[adj[cur][0]]){
                swap(adj[cur][0],chld);
            }
        }
    }
    void dfs_2(ll cur){
        order[now]=cur;
        position[cur]=now++;
        for(auto &chld:adj[cur]){
            if(chld==adj[cur][0]){
                top[chld]=top[cur];
            }
            else{  
                top[chld]=chld;
            }
            dfs_2(chld);
        }
    }
    void dfs_3(ll cur,vector<ll> &sum,vector<ll> &a){
        sum[cur]+=a[cur];
        for(auto chld:adj[cur]){
            dfs_3(chld,sum,a);
            sum[cur]+=sum[chld];
        }
    }
};
struct Lazy{
    vector<ll> lazy;
    vector<array<ll,3>> segt;
    ll n,inf_add;
    Lazy(ll n){
        this->n=n;
        inf_add=1e18;
        segt.assign(4*n,{inf_add,inf_add,0});
        lazy.assign(4*n,0);
    }
    array<ll,3> merge(array<ll,3> l,array<ll,3> r){
        if(l>r){
            swap(l,r);
        }
        if(l[0]==r[0] and l[1]==r[1]){
            l[2]+=r[2];
        }
        return l;
    }
    void push(ll v){
        segt[2*v][0]+=lazy[v];
        segt[2*v+1][0]+=lazy[v];
        lazy[2*v]+=lazy[v];
        lazy[2*v+1]+=lazy[v];
        lazy[v]=0;
    }
    void set_upd(ll pos,array<ll,3> info){
        set_upd(1,1,n,pos,info);
    }
    void set_upd(ll v,ll l,ll r,ll pos,array<ll,3> info){
        if(l==r){
            segt[v]=info;
        }
        else{
            push(v);
            ll m=(l+r)/2;
            if(pos<=m){
                set_upd(2*v,l,m,pos,info);
            }
            else{
                set_upd(2*v+1,m+1,r,pos,info);
            }
            segt[v]=merge(segt[2*v],segt[2*v+1]);
        }
    }
    void add_upd(ll l,ll r,ll delta){
        add_upd(1,1,n,l,r,delta);
    }
    void add_upd(ll v,ll tl,ll tr,ll l,ll r,ll delta){
        if(l>r){
            return;
        }
        if((tl==l) and (tr==r)){
            segt[v][0]+=delta;
            lazy[v]+=delta;
        }
        else{
            push(v);
            ll tm=(tl+tr)/2;
            add_upd(2*v,tl,tm,l,min(r,tm),delta);
            add_upd(2*v+1,tm+1,tr,max(l,tm+1),r,delta);
            segt[v]=merge(segt[2*v],segt[2*v+1]);
        }
    }
    ll query(ll l,ll r){
        auto it=query(1,1,n,l,r);
        assert(it[0]==0);
        return it[2];
    }
    array<ll,3> query(ll v,ll l,ll r,ll ql,ll qr){
        if(ql>qr){
            return {inf_add,inf_add,-1};
        }
        if((l==ql) and (r==qr)){
            return segt[v];
        }
        push(v);
        ll m=(l+r)/2;
        auto it=merge(query(2*v,l,m,ql,min(m,qr)),query(2*v+1,m+1,r,max(m+1,ql),qr));
        return it;
    }
}; 
void solve(){ 
    ll n,q; cin>>n>>q;
    vector<ll> a(n);
    for(auto &it:a){
        cin>>it;
    }
    vector<ll> pref=a;
    sort(all(pref));
    for(ll i=1;i<n;i++){
        pref[i]+=pref[i-1];
    }
    HLD tree(n);
    for(ll i=1;i<=n-1;i++){
        ll u,v; cin>>u>>v;
        u--,v--;
        tree.addEdge(u,v);
    }
    tree.work();
    Lazy track(n+5);
    auto index_no=tree.position;
    for(auto &it:index_no){
        it++;
    }
    for(ll i=0;i<n;i++){
        ll consider=tree.subtree_size[i];
        track.set_upd(index_no[i],{-pref[consider-1],consider,i+1});
    }
    auto upd=[&](ll node,ll val){
        while(node!=-1){
            ll till=tree.top[node];
            track.add_upd(index_no[till],index_no[node],val);
            node=tree.parent[till];
        }
    };
    for(ll i=0;i<n;i++){
        upd(i,a[i]);
    }
    while(q--){
        ll u,v; cin>>u>>v;
        u--,v--;  
        upd(u,-a[u]); 
        upd(v,-a[v]);
        swap(a[u],a[v]);
        upd(u,a[u]);
        upd(v,a[v]);
        cout<<track.query(1,n)<<nline;
    }
    return;  
}                                       
int main()                                                                               
{     
    ios_base::sync_with_stdio(false);                         
    cin.tie(NULL);                               
    #ifndef ONLINE_JUDGE                 
    freopen("input.txt", "r", stdin);                                           
    freopen("output.txt", "w", stdout);      
    freopen("error.txt", "w", stderr);                        
    #endif     
    ll test_cases=1;                 
    cin>>test_cases;
    while(test_cases--){
        solve();
    }
    cout<<fixed<<setprecision(10);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
}  
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

//#define IGNORE_CR

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
#ifdef IGNORE_CR
            if (buffer[pos] == '\r') {
                pos++;
                continue;
            }
#endif
            assert(!isspace(buffer[pos]));
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

template <typename T>
struct forest {
    struct edge {
        int from;
        int to;
        T cost;
        edge(int _from, int _to, T _cost) : from(_from), to(_to), cost(_cost) {}
    };

    int n;
    vector<edge> edges;
    vector<vector<int>> g;
    vector<int> pv;
    vector<int> pe;
    vector<int> depth;
    vector<int> root;
    vector<int> sz;
    vector<int> order;
    vector<int> beg;
    vector<int> end;
    vector<T> dist;

    forest(int _n) : n(_n) {
        g = vector<vector<int>>(n);
        init();
    }

    void init() {
        pv = vector<int>(n, -1);
        pe = vector<int>(n, -1);
        depth = vector<int>(n, -1);
        root = vector<int>(n, -1);
        sz = vector<int>(n, 0);
        order = vector<int>();
        beg = vector<int>(n, -1);
        end = vector<int>(n, -1);
        dist = vector<T>(n, 0);
    }

    int add(int from, int to, T cost = 1) {
        int id = (int) edges.size();
        g[from].emplace_back(id);
        g[to].emplace_back(id);
        edges.emplace_back(from, to, cost);
        return id;
    }

    void do_dfs(int v) {
        beg[v] = (int) order.size();
        order.emplace_back(v);
        sz[v] = 1;
        for (int id : g[v]) {
            if (id == pe[v]) {
                continue;
            }
            edge e = edges[id];
            int to = e.from ^ e.to ^ v;
            pv[to] = v;
            pe[to] = id;
            depth[to] = depth[v] + 1;
            root[to] = (root[v] != -1 ? root[v] : to);
            dist[to] = dist[v] + e.cost;
            do_dfs(to);
            sz[v] += sz[to];
        }
        end[v] = (int) order.size();
    }

    void dfs(int v) {
        pv[v] = -1;
        pe[v] = -1;
        depth[v] = 0;
        root[v] = v;
        order.clear();
        dist[v] = 0;
        do_dfs(v);
    }

    void dfs_all() {
        init();
        for (int v = 0; v < n; v++) {
            if (depth[v] == -1) {
                dfs(v);
            }
        }
    }

    int h = -1;
    vector<vector<int>> p;

    inline void build_lca() {
        int max_depth = *max_element(depth.begin(), depth.end());
        h = 1;
        while ((1 << h) <= max_depth) {
            h++;
        }
        p = vector<vector<int>>(h, vector<int>(n));
        p[0] = pv;
        for (int i = 1; i < h; i++) {
            for (int j = 0; j < n; j++) {
                p[i][j] = (p[i - 1][j] == -1 ? -1 : p[i - 1][p[i - 1][j]]);
            }
        }
    }

    inline bool anc(int x, int y) {  // return x is y's ancestor or not
        return (beg[x] <= beg[y] && end[y] <= end[x]);
    }

    inline int go_up(int x, int up) {
        assert(h != -1);
        up = min(up, (1 << h) - 1);
        for (int i = h - 1; i >= 0; i--) {
            if (up & (1 << i)) {
                x = p[i][x];
                if (x == -1) {
                    break;
                }
            }
        }
        return x;
    }

    inline int lca(int x, int y) {
        assert(h != -1);
        if (anc(x, y)) {
            return x;
        }
        if (anc(y, x)) {
            return y;
        }
        for (int i = h - 1; i >= 0; i--) {
            if (p[i][x] != -1 && !anc(p[i][x], y)) {
                x = p[i][x];
            }
        }
        return p[0][x];
    }

    inline T distance(int x, int y) {
        return dist[x] + dist[y] - 2 * dist[lca(x, y)];
    }

    vector<int> head;

    void build_hld() {
        dfs_all();
        for (int i = 0; i < n; i++) {
            if (g[i].empty()) {
                continue;
            }
            int best = -1, bid = 0;
            for (int j = 0; j < (int) g[i].size(); j++) {
                int id = g[i][j];
                int v = edges[id].from ^ edges[id].to ^ i;
                if (pv[v] != i) {
                    continue;
                }
                if (sz[v] > best) {
                    best = sz[v];
                    bid = j;
                }
            }
            swap(g[i][0], g[i][bid]);
        }
        init();
        dfs_all();
        build_lca();
        head.resize(n);
        for (int i = 0; i < n; i++) {
            head[i] = i;
        }
        for (int i = 0; i < n - 1; i++) {
            int x = order[i];
            int y = order[i + 1];
            if (pv[y] == x) {
                head[y] = head[x];
            }
        }
    }

    void apply(int x, int y, bool with_lca, function<void(int, int, bool)> f) {
        int z = lca(x, y);
        int v = x;
        while (v != z) {
            if (depth[head[v]] <= depth[z]) {
                f(beg[z] + 1, beg[v], true);
                break;
            }
            f(beg[head[v]], beg[v], true);
            v = pv[head[v]];
        }
        if (with_lca) {
            f(beg[z], beg[z], false);
        }
        v = y;
        vector<int> visited;
        while (v != z) {
            if (depth[head[v]] <= depth[z]) {
                f(beg[z] + 1, beg[v], false);
                break;
            }
            visited.emplace_back(v);
            v = pv[head[v]];
        }
        for (int i = (int) visited.size() - 1; i >= 0; i--) {
            v = visited[i];
            f(beg[head[v]], beg[v], false);
        }
    }
};

struct segtree {
    using T = tuple<long long, int, long long>;
    using F = long long;

    T e() {
        return make_tuple((long long) 1e18, 0, 0);
    }

    F id() {
        return 0;
    }

    T op(T a, T b) {
        auto a0 = a;
        auto b0 = b;
        std::get<2>(a0) = 0;
        std::get<2>(b0) = 0;
        if (a0 < b0) {
            return a;
        }
        if (a0 > b0) {
            return b;
        }
        std::get<2>(a) += std::get<2>(b);
        return a;
    }

    T mapping(F f, T x) {
        std::get<0>(x) += f;
        return x;
    }

    F composition(F f, F g) {
        return f + g;
    }

    int n;
    int size;
    int log_size;
    vector<T> node;
    vector<F> lazy;

    segtree() : segtree(0) {}
    segtree(int _n) {
        build(vector<T>(_n, e()));
    }
    segtree(const vector<T>& v) {
        build(v);
    }

    void build(const vector<T>& v) {
        n = (int) v.size();
        if (n <= 1) {
            log_size = 0;
        } else {
            log_size = 32 - __builtin_clz(n - 1);
        }
        size = 1 << log_size;
        node.resize(2 * size, e());
        lazy.resize(size, id());
        for (int i = 0; i < n; i++) {
            node[i + size] = v[i];
        }
        for (int i = size - 1; i > 0; i--) {
            pull(i);
        }
    }

    void push(int x) {
        node[2 * x] = mapping(lazy[x], node[2 * x]);
        node[2 * x + 1] = mapping(lazy[x], node[2 * x + 1]);
        if (2 * x < size) {
            lazy[2 * x] = composition(lazy[x], lazy[2 * x]);
            lazy[2 * x + 1] = composition(lazy[x], lazy[2 * x + 1]);
        }
        lazy[x] = id();
    }

    void pull(int x) {
        node[x] = op(node[2 * x], node[2 * x + 1]);
    }

    void set(int p, T v) {
        assert(0 <= p && p < n);
        p += size;
        for (int i = log_size; i >= 1; i--) {
            push(p >> i);
        }
        node[p] = v;
        for (int i = 1; i <= log_size; i++) {
            pull(p >> i);
        }
    }

    T get(int p) {
        assert(0 <= p && p < n);
        p += size;
        for (int i = log_size; i >= 1; i--) {
            push(p >> i);
        }
        return node[p];
    }

    T get(int l, int r) {
        assert(0 <= l && l <= r && r <= n);
        l += size;
        r += size;
        for (int i = log_size; i >= 1; i--) {
            if (((l >> i) << i) != l) {
                push(l >> i);
            }
            if (((r >> i) << i) != r) {
                push((r - 1) >> i);
            }
        }
        T vl = e();
        T vr = e();
        while (l < r) {
            if (l & 1) {
                vl = op(vl, node[l++]);
            }
            if (r & 1) {
                vr = op(node[--r], vr);
            }
            l >>= 1;
            r >>= 1;
        }
        return op(vl, vr);
    }

    void apply(int p, F f) {
        assert(0 <= p && p < n);
        p += size;
        for (int i = log_size; i >= 1; i--) {
            push(p >> i);
        }
        node[p] = mapping(f, node[p]);
        for (int i = 1; i <= log_size; i++) {
            pull(p >> i);
        }
    }

    void apply(int l, int r, F f) {
        assert(0 <= l && l <= r && r <= n);
        l += size;
        r += size;
        for (int i = log_size; i >= 1; i--) {
            if (((l >> i) << i) != l) {
                push(l >> i);
            }
            if (((r >> i) << i) != r) {
                push((r - 1) >> i);
            }
        }
        int ll = l;
        int rr = r;
        while (l < r) {
            if (l & 1) {
                node[l] = mapping(f, node[l]);
                if (l < size) {
                    lazy[l] = composition(f, lazy[l]);
                }
                l++;
            }
            if (r & 1) {
                r--;
                node[r] = mapping(f, node[r]);
                if (l < size) {
                    lazy[r] = composition(f, lazy[r]);
                }
            }
            l >>= 1;
            r >>= 1;
        }
        l = ll;
        r = rr;
        for (int i = 1; i <= log_size; i++) {
            if (((l >> i) << i) != l) {
                pull(l >> i);
            }
            if (((r >> i) << i) != r) {
                pull((r - 1) >> i);
            }
        }
    }
};

struct dsu {
    vector<int> p;
    vector<int> sz;
    int n;

    dsu(int _n) : n(_n) {
        p = vector<int>(n);
        iota(p.begin(), p.end(), 0);
        sz = vector<int>(n, 1);
    }

    inline int get(int x) {
        if (p[x] == x) {
            return x;
        } else {
            return p[x] = get(p[x]);
        }
    }

    inline bool unite(int x, int y) {
        x = get(x);
        y = get(y);
        if (x == y) {
            return false;
        }
        p[x] = y;
        sz[y] += sz[x];
        return true;
    }

    inline bool same(int x, int y) {
        return (get(x) == get(y));
    }

    inline int size(int x) {
        return sz[get(x)];
    }

    inline bool root(int x) {
        return (x == get(x));
    }
};

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e5);
    in.readEoln();
    int sn = 0, sq = 0;
    while (tt--) {
        int n = in.readInt(1, 2e5);
        in.readSpace();
        int q = in.readInt(1, 2e5);
        in.readEoln();
        sn += n;
        sq += q;
        auto a = in.readLongs(n, 1, 1e9);
        in.readEoln();
        dsu uf(n);
        forest<int> g(n);
        for (int i = 0; i < n - 1; i++) {
            int x = in.readInt(1, n);
            in.readSpace();
            int y = in.readInt(1, n);
            in.readEoln();
            assert(x != y);
            x--;
            y--;
            g.add(x, y);
            assert(uf.unite(x, y));
        }
        g.build_hld();
        auto b = a;
        sort(b.begin(), b.end());
        for (int i = 1; i < n; i++) {
            b[i] += b[i - 1];
        }
        vector<tuple<long long, int, long long>> sdef(n);
        for (int i = 0; i < n; i++) {
            int j = g.order[i];
            sdef[i] = make_tuple(-b[g.sz[j] - 1], g.sz[j], j + 1);
        }
        segtree seg(sdef);
        auto Add = [&](int x, long long t) {
            function<void(int, int, bool)> f = [&](int i, int j, bool) {
                seg.apply(min(i, j), max(i, j) + 1, t);
            };
            g.apply(0, x, true, f);
        };
        for (int i = 0; i < n; i++) {
            Add(i, a[i]);
        }
        while (q--) {
            int x = in.readInt(1, n);
            in.readSpace();
            int y = in.readInt(1, n);
            in.readEoln();
            // assert(x != y);
            x--;
            y--;
            long long t;
            function<void(int, int, bool)> f = [&](int i, int j, bool) {
                seg.apply(min(i, j), max(i, j) + 1, t);
            };
            Add(x, -a[x]);
            Add(y, -a[y]);
            swap(a[x], a[y]);
            Add(x, a[x]);
            Add(y, a[y]);
            cout << get<2>(seg.node[1]) << "\n";
        }
    }
    assert(sn <= 2e5);
    assert(sq <= 2e5);
    in.readEof();
    return 0;
}
Editorialist's code (C++)
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

const ll inf = 1e18;
struct Node {
	using T = array<ll, 3>;
	T unit = {inf, inf, 0};
	T f(T a, T b) {
        if (a[0] != b[0] or a[1] != b[1]) return min(a, b);
        a[2] += b[2];
        return a;
    }
 
	Node *l = 0, *r = 0;
	int lo, hi;
	ll madd = 0;
	T val = unit;
	Node(int _lo,int _hi):lo(_lo),hi(_hi){}
    Node(vector<array<ll, 3>>& v, int _lo, int _hi) : lo(_lo), hi(_hi) {
		if (lo + 1 < hi) {
			int mid = lo + (hi - lo)/2;
			l = new Node(v, lo, mid); r = new Node(v, mid, hi);
			val = f(l->val, r->val);
		}
		else val = v[lo];
	}
	T query(int L, int R) {
		if (R <= lo || hi <= L) return unit;
		if (L <= lo && hi <= R) return val;
		push();
		return f(l->query(L, R), r->query(L, R));
	}
	void add(int L, int R, ll x) {
		if (R <= lo || hi <= L) return;
		if (L <= lo && hi <= R) {
			madd += x;
			val[0] += x;
		}
		else {
			push(), l->add(L, R, x), r->add(L, R, x);
			val = f(l->val, r->val);
		}
	}
	void push() {
		if (!l) {
			int mid = lo + (hi - lo)/2;
			l = new Node(lo, mid); r = new Node(mid, hi);
		}
		if (madd)
			l->add(lo,hi,madd), r->add(lo,hi,madd), madd = 0;
	}
};

template <bool VALS_EDGES> struct HLD {
	int N, tim = 0;
	vector<vector<int>> adj;
	vector<int> par, siz, depth, rt, pos;
	Node *seg;
	HLD(vector<vector<int>> &adj_)
		: N(size(adj_)), adj(adj_), par(N, -1), siz(N, 1), depth(N),
		  rt(N),pos(N){ dfsSz(0); dfsHld(0); }
	void dfsSz(int v) {
		if (par[v] != -1) adj[v].erase(find(begin(adj[v]), end(adj[v]), par[v]));
		for (int& u : adj[v]) {
			par[u] = v, depth[u] = depth[v] + 1;
			dfsSz(u);
			siz[v] += siz[u];
			if (siz[u] > siz[adj[v][0]]) swap(u, adj[v][0]);
		}
	}
	void dfsHld(int v) {
		pos[v] = tim++;
		for (int u : adj[v]) {
			rt[u] = (u == adj[v][0] ? rt[v] : u);
			dfsHld(u);
		}
	}
	template <class B> void process(int u, int v, B op) {
		for (; rt[u] != rt[v]; v = par[rt[v]]) {
			if (depth[rt[u]] > depth[rt[v]]) swap(u, v);
			op(pos[rt[v]], pos[v] + 1);
		}
		if (depth[u] > depth[v]) swap(u, v);
		op(pos[u] + VALS_EDGES, pos[v] + 1);
	}
	void modifyPath(int u, int v, ll val) {
		process(u, v, [&](int l, int r) {
            seg -> add(l, r, val);
        });
	}
	ll querySubtree(int v) { // modifySubtree is similar
        return (seg -> query(pos[v] + VALS_EDGES, pos[v] + siz[v]))[2];
	}
};


int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n, q; cin >> n >> q;
        vector<int> a(n);
        vector<ll> D(n);
        for (int &x : a) cin >> x;
        auto sorted_a = a;
        sort(begin(sorted_a), end(sorted_a));
        for (int i = 1; i < n; ++i) sorted_a[i] += sorted_a[i-1];
        vector adj(n, vector<int>());
        for (int i = 0; i < n-1; ++i) {
            int u, v;  cin >> u >> v;
            adj[--u].push_back(--v);
            adj[v].push_back(u);
        }
        auto dfs = [&] (const auto &self, int u, int p) -> void {
            D[u] = a[u];
            for (int v : adj[u]) if (v != p) {
                self(self, v, u);
                D[u] += D[v];
            }
        };
        dfs(dfs, 0, 0);
        HLD<false> H(adj);
        for (int i = 0; i < n; ++i) D[i] -= sorted_a[H.siz[i]-1];
        {
            vector<array<ll, 3>> v(n);
            for (int i = 0; i < n; ++i) v[H.pos[i]] = {D[i], H.siz[i], i+1};
            H.seg = new Node(v, 0, n);
        }

        while (q--) {
            int u, v; cin >> u >> v; --u, --v;
            int x = a[u] - a[v], y = a[v] - a[u];
            H.modifyPath(0, v, x);
            H.modifyPath(0, u, y);
            swap(a[u], a[v]);
            cout << H.querySubtree(0) << '\n';
        }
    }
}
1 Like