BEAUTY_SUM - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: piyush_2007
Tester: yash_daga
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Centroid decomposition, point-update range-query data structures

PROBLEM:

You’re given a tree on N vertices, vertex i has the value A_i.
For two vertices x and y, let f(x, y) denote the minimum value on the path between them, and g(x, y) denote the maximum.
Find \sum_{x\lt y} f(x, y)\cdot g(x, y).

EXPLANATION:

First, let’s try to solve easier versions of this problem.

For a fixed root vertex r, computing \sum_{x = 1}^N f(r, x)\cdot g(r, x) can be done easily in \mathcal{O}(N) time using a simple DFS to compute the required values.
However, there doesn’t seem to be a nice way to quickly recompute this value when changing roots.

Instead, we’ll try to do a bit more: we’ll try to compute the sum of f(x, y)\cdot g(x, y) for all pairs of vertices (x, y) whose paths pass through r. Let’s see how we can do this quickly.

First, compute f(r, x) and g(r, x) for all vertices; as mentioned above, this can be done in \mathcal{O}(N) using a DFS.
For convenience, let B_x = f(r, x) and C_x = g(r, x).

Notice that for a path (x, y) that passes through r, it holds that f(x, y) = \min(B_x, B_y) and g(x, y) = \max(C_x, C_y).
So, we essentially want to compute \sum \min(B_x, B_y) \cdot \max(C_x, C_y) across all valid pairs (x, y).

Once again, let’s relax the problem a bit, ignore the “valid pairs (x, y)” condition, and just attempt to quickly compute that summation for all pairs.

How?

Let’s sort the pairs of (B_x, C_x), so that B_1 \leq B_2 \leq \ldots B_N.

Then, for each i from 1 to N:

  • Let’s only consider i \lt j \leq N
  • The sorting guarantees that \min(B_i, B_j) = B_i, so we only need to deal with C_j values.
  • There are two cases here: C_i \geq C_j and C_i \lt C_j.
    • If \max(C_i, C_j) = C_i, we add B_i\cdot C_i to the answer. Overall, the answer increases by k\cdot B_i\cdot C_i, where k is the number of indices such that C_i \geq C_j.
      Processing them all quickly can be done if we can just find k.
    • If C_i \lt C_j, we add B_i\cdot C_j to the answer. The answer increases by B_i\cdot S, where S is the sum of all C_j that are larger than C_i.
      Once again, we can process them all quickly if we can quickly compute S.

So, we need to be able to quickly answer the following questions about the suffix from i+1:

  • How elements are \leq C_i?
  • What’s the sum of elements that are \gt C_i?

If i is iterated in decreasing order from N down to 1, both of these can be answered in \mathcal{O}(\log N) by a fenwick/segment tree built on values.

This allows us to compute the entire sum in \mathcal{O}(N\log N).

Now that we’ve done this, we have some excess that needs to be removed.
In particular, we need to remove the sum of \min(B_x, B_y)\cdot \max(B_x, B_y) for all pairs vertices whose paths don’t pass through r.

In fact, such pairs can be found quite easily.
Let’s root the tree at r, and let c_1, c_2, \ldots, c_m be its children.

What we need to remove is then the value of all pairs that lie in the subtree of c_1, all pairs in the subtree of c_2, \ldots, all pairs in the subtree of c_m.
This is quite easy to do: after all, we already have a way to compute the sum of all pairs of vertices, given the list of their values. So, simply run the above on each of the respective sets, and subtract it from the answer.
This also takes \mathcal{O}(N\log N) time in total, since the sum of the sizes of the sets is N-1.

Now that we’ve done this, we can move on to a solution to the original problem.

Notice that if we root the tree at some r and compute this value, we are left with several smaller trees (one corresponding to each neighbor of r); which can be solved with the same algorithm.

It’s well-known from the centroid decomposition algorithm that choosing r to be the centroid each time leads to a depth of \mathcal{O}(\log N), and at each level of this recursion we do a total of \mathcal{O}(N\log N) work leading to a solution in \mathcal{O}(N\log^2 N) overall.

TIME COMPLEXITY

\mathcal{O}(N \log^2 N) per test case.

CODE:

Setter's code (C++)
                                    //  ॐ
#include <bits/stdc++.h>
using namespace std;
#define PI 3.14159265358979323846
#define ll long long int


const int MOD = 1e9+7;  // check mod
struct mod_int {
    int val;
 
    mod_int(long long v = 0) {
        if (v < 0)
            v = v % MOD + MOD;
 
        if (v >= MOD)
            v %= MOD;
 
        val = v;
    }
 
    static int mod_inv(int a, int m = MOD) {
        int g = m, r = a, x = 0, y = 1;
 
        while (r != 0) {
            int q = g / r;
            g %= r; swap(g, r);
            x -= q * y; swap(x, y);
        }
 
        return x < 0 ? x + m : x;
    }
 
    explicit operator int() const {
        return val;
    }
 
    mod_int& operator+=(const mod_int &other) {
        val += other.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
 
    mod_int& operator-=(const mod_int &other) {
        val -= other.val;
        if (val < 0) val += MOD;
        return *this;
    }
 
    static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
           #if !defined(_WIN32) || defined(_WIN64)
                return x % m;
           #endif
           unsigned x_high = x >> 32, x_low = (unsigned) x;
           unsigned quot, rem;
           asm("divl %4\n"
            : "=a" (quot), "=d" (rem)
            : "d" (x_high), "a" (x_low), "r" (m));
           return rem;
    }
 
    mod_int& operator*=(const mod_int &other) {
        val = fast_mod((uint64_t) val * other.val);
        return *this;
    }
 
    mod_int& operator/=(const mod_int &other) {
        return *this *= other.inv();
    }
 
    friend mod_int operator+(const mod_int &a, const mod_int &b) { return mod_int(a) += b; }
    friend mod_int operator-(const mod_int &a, const mod_int &b) { return mod_int(a) -= b; }
    friend mod_int operator*(const mod_int &a, const mod_int &b) { return mod_int(a) *= b; }
    friend mod_int operator/(const mod_int &a, const mod_int &b) { return mod_int(a) /= b; }
 
    mod_int& operator++() {
        val = val == MOD - 1 ? 0 : val + 1;
        return *this;
    }
 
    mod_int& operator--() {
        val = val == 0 ? MOD - 1 : val - 1;
        return *this;
    }
 
    mod_int operator++(int32_t) { mod_int before = *this; ++*this; return before; }
    mod_int operator--(int32_t) { mod_int before = *this; --*this; return before; }
 
    mod_int operator-() const {
        return val == 0 ? 0 : MOD - val;
    }
 
    bool operator==(const mod_int &other) const { return val == other.val; }
    bool operator!=(const mod_int &other) const { return val != other.val; }
 
    mod_int inv() const {
        return mod_inv(val);
    }
 
    mod_int pow(long long p) const {
        assert(p >= 0);
        mod_int a = *this, result = 1;
 
        while (p > 0) {
            if (p & 1)
                result *= a;
 
            a *= a;
            p >>= 1;
        }
 
        return result;
    }
 
    friend ostream& operator<<(ostream &stream, const mod_int &m) {
        return stream << m.val;
    }
    friend istream& operator >> (istream &stream, mod_int &m) {
        return stream>>m.val;   
    }
};

#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
struct custom_hash {
    static uint64_t splitmix64(uint64_t x) {
        // http://xorshift.di.unimi.it/splitmix64.c
        x += 0x9e3779b97f4a7c15;
        x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
        x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
        return x ^ (x >> 31);
    }

    size_t operator()(uint64_t x) const {
        static const uint64_t FIXED_RANDOM = chrono::steady_clock::now().time_since_epoch().count();
        return splitmix64(x + FIXED_RANDOM);
    }
};

gp_hash_table<int, int, custom_hash> mp;

const int N = 2e5+5; 
vector<int> adj[N];
int subtr[N],a[N];
mod_int ans=0;
vector<pair<int,int>> vec;
bool vis[N];

struct FenwickTree {
    vector<mod_int> bit;  // binary indexed tree
    int n;
 
    FenwickTree(int n) {
        this->n = n;
        bit=vector<mod_int>(n,0);
    }
 
    mod_int sum(int r) {
        mod_int ret = 0;
        for (; r >= 0; r = (r & (r + 1)) - 1)
            ret += bit[r];
        return ret;
    }
 
    mod_int sum(int l, int r) {
        return sum(r) - sum(l - 1);
    }
 
    void add(int idx,int delta) {
        for (; idx < n; idx = idx | (idx + 1))
            bit[idx] += delta;
    }
};


int getsz_cd(int v, int p) {
    subtr[v] = 1;
    for (int u : adj[v]) {
        if (vis[u] || u == p)  continue;
        subtr[v] += getsz_cd(u, v);
    }
    return subtr[v];
}
 
int findct_cd(int v, int p, int n) {
    for (int u : adj[v]) {
        if (!vis[u] &&  u!= p && subtr[u] * 2 > n)  return findct_cd(u, v, n);
    }
    return v;
}

void dfs(int v,int p,int mx,int mi){
    
       vec.push_back({mx,mi});
       for(auto u : adj[v]){
           if(u==p || vis[u]){
             continue;
           }
           dfs(u,v,max(a[u],mx),min(mi,a[u]));
       }  
}

inline mod_int solve(vector<pair<int,int>> &v){
      
      mod_int ret=0;
      sort(v.rbegin(),v.rend());

      vector<int> ord;

      for(auto [mx,mi] : v){
            ord.push_back(mi);
      }

      sort(ord.begin(),ord.end());
      mp.clear();

      int curr=0;

      for(int i=0;i<(int)ord.size();i++){

           int temp=i;
           while(temp+1<(int)ord.size() && ord[temp+1]==ord[i])
               temp++;

           mp[ord[i]]=(curr++); 
           i=temp;     
      }

      FenwickTree ft_mi_sum(curr+1),ft_mi_cnt(curr+1);
 

      for(auto [mx,mi] : v){
           ft_mi_cnt.add(mp[mi],1);
           ft_mi_sum.add(mp[mi],mi);
      }

      for(auto [mx,mi] : v){
          ft_mi_cnt.add(mp[mi],-1);
          ft_mi_sum.add(mp[mi],-mi);
          mod_int temp=ft_mi_cnt.sum(mp[mi]+1,curr);
          temp*=mi;
          temp+=ft_mi_sum.sum(mp[mi]);
          temp*=mx;
          ret+=temp;
      }

      return ret;
}

 
void decompose_cd(int u, int p) {
    int n = getsz_cd(u, p); 
    int ct = findct_cd(u, p, n);
    vector<pair<int,int>> tot;
    tot.push_back({a[ct],a[ct]});
    vis[ct]=1;

    
    for(auto chi : adj[ct]){
        if(vis[chi])
           continue;
        vec.clear();
        dfs(chi,chi,max(a[ct],a[chi]),min(a[ct],a[chi]));
        ans-=solve(vec);
        for(auto u : vec){
            tot.push_back(u);
        }
    }

    ans+=solve(tot);

    for(auto chi : adj[ct]){
         if(!vis[chi])
           decompose_cd(chi,ct);
    }
}

            
int main(){
   
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
 
    int test = 1;
    cin>>test;

    assert(test<=1000);
    int sum_n=0;

    
    while(test--){
                          
                          int n;
                          cin>>n;

                          for(int i=0;i<n;i++){
                              cin>>a[i];
                              adj[i].clear();
                              vis[i]=0;
                          }

                          for(int i=1;i<n;i++){
                              int a,b;
                              cin>>a>>b;
                              --a,--b;
                              adj[a].push_back(b);
                              adj[b].push_back(a);
                          }

                          ans=0;
                          decompose_cd(0,-1);
                          cout<<ans;
                         
                          cout<<'\n';
        
        
   }

        return 0;
}
Tester's code (C++)
//clear adj and visited vector declared globally after each test case
//check for long long overflow   
//Mod wale question mein last mein if dalo ie. Ans<0 then ans+=mod;
//Incase of close mle change language to c++17 or c++14  
//Check ans for n=1 
// #pragma GCC target ("avx2")    
// #pragma GCC optimize ("O3")  
// #pragma GCC optimize ("unroll-loops")
#include <bits/stdc++.h>                   
#include <ext/pb_ds/assoc_container.hpp>  
#define int long long     
#define IOS std::ios::sync_with_stdio(false); cin.tie(NULL);cout.tie(NULL);cout.precision(dbl::max_digits10);
#define pb push_back 
#define mod 1000000007ll //998244353ll
#define lld long double
#define mii map<int, int> 
#define pii pair<int, int>
#define ll long long 
#define ff first
#define ss second 
#define all(x) (x).begin(), (x).end()
#define rep(i,x,y) for(int i=x; i<y; i++)    
#define fill(a,b) memset(a, b, sizeof(a))
#define vi vector<int>
#define setbits(x) __builtin_popcountll(x)
#define print2d(dp,n,m) for(int i=0;i<=n;i++){for(int j=0;j<=m;j++)cout<<dp[i][j]<<" ";cout<<"\n";}
typedef std::numeric_limits< double > dbl;
using namespace __gnu_pbds;
using namespace std;
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> indexed_set;
//member functions :
//1. order_of_key(k) : number of elements strictly lesser than k
//2. find_by_order(k) : k-th element in the set
const long long N=200005, INF=2000000000000000000;
const int inf=2e9 + 5;
lld pi=3.1415926535897932;
int lcm(int a, int b)
{
    int g=__gcd(a, b);
    return a/g*b;
}
int power(int a, int b, int p)
    {
        if(a==0)
        return 0;
        int res=1;
        a%=p;
        while(b>0)
        {
            if(b&1)
            res=(1ll*res*a)%p;
            b>>=1;
            a=(1ll*a*a)%p;
        }
        return res;
    }
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

int getRand(int l, int r)
{
    uniform_int_distribution<int> uid(l, r);
    return uid(rng);
}

const int MOD=mod;
struct Mint {
    int val;
 
    Mint(long long v = 0) {
        if (v < 0)
            v = v % MOD + MOD;
        if (v >= MOD)
            v %= MOD;
        val = v;
    }
 
    static int mod_inv(int a, int m = MOD) {
        int g = m, r = a, x = 0, y = 1;
        while (r != 0) {
            int q = g / r;
            g %= r; swap(g, r);
            x -= q * y; swap(x, y);
        } 
        return x < 0 ? x + m : x;
    } 
    explicit operator int() const {
        return val;
    }
    Mint& operator+=(const Mint &other) {
        val += other.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    Mint& operator-=(const Mint &other) {
        val -= other.val;
        if (val < 0) val += MOD;
        return *this;
    }
    static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
           #if !defined(_WIN32) || defined(_WIN64)
                return x % m;
           #endif
           unsigned x_high = x >> 32, x_low = (unsigned) x;
           unsigned quot, rem;
           asm("divl %4\n"
            : "=a" (quot), "=d" (rem)
            : "d" (x_high), "a" (x_low), "r" (m));
           return rem;
    }
    Mint& operator*=(const Mint &other) {
        val = fast_mod((uint64_t) val * other.val);
        return *this;
    }
    Mint& operator/=(const Mint &other) {
        return *this *= other.inv();
    }
    friend Mint operator+(const Mint &a, const Mint &b) { return Mint(a) += b; }
    friend Mint operator-(const Mint &a, const Mint &b) { return Mint(a) -= b; }
    friend Mint operator*(const Mint &a, const Mint &b) { return Mint(a) *= b; }
    friend Mint operator/(const Mint &a, const Mint &b) { return Mint(a) /= b; }
    Mint& operator++() {
        val = val == MOD - 1 ? 0 : val + 1;
        return *this;
    }
    Mint& operator--() {
        val = val == 0 ? MOD - 1 : val - 1;
        return *this;
    }
    // friend Mint operator<=(const Mint &a, const Mint &b) { return (int)a <= (int)b; }
    Mint operator++(int32_t) { Mint before = *this; ++*this; return before; }
    Mint operator--(int32_t) { Mint before = *this; --*this; return before; }
    Mint operator-() const {
        return val == 0 ? 0 : MOD - val;
    }
    bool operator==(const Mint &other) const { return val == other.val; }
    bool operator!=(const Mint &other) const { return val != other.val; }
    Mint inv() const {
        return mod_inv(val);
    }
    Mint power(long long p) const {
        assert(p >= 0);
        Mint a = *this, result = 1;
        while (p > 0) {
            if (p & 1)
                result *= a;
 
            a *= a;
            p >>= 1;
        }
        return result;
    }
    friend ostream& operator << (ostream &stream, const Mint &m) {
        return stream << m.val;
    }
    friend istream& operator >> (istream &stream, Mint &m) {
        return stream>>m.val;   
    }
};


#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
struct custom_hash {
    static uint64_t splitmix64(uint64_t x) {
        // http://xorshift.di.unimi.it/splitmix64.c
        x += 0x9e3779b97f4a7c15;
        x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
        x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
        return x ^ (x >> 31);
    }

    size_t operator()(uint64_t x) const {
        static const uint64_t FIXED_RANDOM = chrono::steady_clock::now().time_since_epoch().count();
        return splitmix64(x + FIXED_RANDOM);
    }
};

gp_hash_table<int, int, custom_hash> mp;

vector <int> v[N];
int a[N], siz[N];
Mint ans=0;
vector<pair<int,int>> vec;
bool vis[N];

struct FenwickTree {
    vector<Mint> bit;  // binary indexed tree
    int n;
 
    FenwickTree(int n) {
        this->n = n;
        bit=vector<Mint>(n,0);
    }
 
    Mint sum(int r) {
        Mint ret = 0;
        for (; r >= 0; r = (r & (r + 1)) - 1)
            ret += bit[r];
        return ret;
    }
 
    Mint sum(int l, int r) {
        return sum(r) - sum(l - 1);
    }
 
    void add(int idx,int delta) {
        for (; idx < n; idx = idx | (idx + 1))
            bit[idx] += delta;
    }
};


int getsz_cd(int u, int p)
{
    siz[u]=1;
    for (int to : v[u])
    {
        if(vis[to] || to==p)
            continue;
        siz[u]+=getsz_cd(to, u);
    }
    return siz[u];
}
 
int findct_cd(int u, int p, int n)
{
    for(int to:v[u])
    {
        if(!vis[to] &&  to!=p && (siz[to]*2)>n) 
            return findct_cd(to, u, n);
    }
    return u;
}

void dfs(int u, int p, int mx, int mi)
{
    vec.pb({mx, mi});
    for(auto to : v[u])
    {
        if(to==p || vis[to])
            continue;
        dfs(to, u, max(a[to], mx), min(mi, a[to]));
    }  
}

Mint solve(vector <pii> &cur)
{
    Mint res=0;
    sort(cur.rbegin(), cur.rend());
    vi ord;

    for(pii p : cur)
        ord.push_back(p.ss);

    sort(all(ord));
    mp.clear();
    int curr=0;
    for(int i=0;i<(int)ord.size();i++){
        int temp=i;
        while(temp+1<(int)ord.size() && ord[temp+1]==ord[i])
            temp++;
        mp[ord[i]]=(curr++); 
        i=temp;     
    }
    FenwickTree ft_mi_sum(curr+1), ft_mi_cnt(curr+1);

    for(pii p:cur){
        ft_mi_cnt.add(mp[p.ss],1);
        ft_mi_sum.add(mp[p.ss], p.ss);
    }

    for(pii p:cur){
        int mx=p.ff, mi=p.ss;
        ft_mi_cnt.add(mp[mi],-1);
        ft_mi_sum.add(mp[mi],-mi);
        Mint temp=ft_mi_cnt.sum(mp[mi]+1,curr);
        temp*=mi;
        temp+=ft_mi_sum.sum(mp[mi]);
        temp*=mx;
        res+=temp;
    }
    return res;
}

 
void decompose_cd(int u, int p)
{
    int n = getsz_cd(u, p); 
    int ct = findct_cd(u, p, n);
    vector <pii> tot;
    tot.pb({a[ct],a[ct]});
    vis[ct]=1;

    for(auto to:v[ct])
    {
        if(vis[to])
           continue;
        vec.clear();
        dfs(to,to,max(a[ct],a[to]),min(a[ct],a[to]));
        ans-=solve(vec);
        for(auto u : vec)
            tot.push_back(u);
    }
    ans+=solve(tot);
    for(auto to:v[ct])
    {
        if(!vis[to])
            decompose_cd(to,ct);
    }
}

            
int32_t main()
{   
    IOS;
    int t;
    cin>>t;
    while(t--)
    {    
        int n;
        cin>>n;
        fill(vis, false);
        rep(i,1,n+1)
        {
            cin>>a[i];
            v[i].clear();
            // vis[i]=false;
        }
        rep(i,1,n)
        {
            int a, b;
            cin>>a>>b;
            v[a].pb(b);
            v[b].pb(a);
        }
        ans=0;
        decompose_cd(1, 0);
        cout<<ans<<"\n";
   }
}