TREESAREFUN5 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: satyam_343
Testers: apoorv_me, tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Trees, segment trees/fenwick trees

PROBLEM:

Define f(B) for an array B as follows:

  • While B is not empty, choose a value x and delete all occurrences of x from B.
    The cost of this is R-L, where L and R are the indices of the leftmost/rightmost occurrences of x.

f(B) is the minimum possible cost of performing this process on B.

You’re given a tree with values on its vertices. Each value from 1 to N appears exactly twice.
Find the sum of f(P[u, v]) across all u \leq v, where P[u, v] denotes the sequence of values on the unique path from u to v.

EXPLANATION:

First, we figure out how to compute f(B) for a fixed array B.
However, we only need to concern ourselves with arrays that contain each element at most twice, since the given tree satisfies that property.

Let l_x and r_x (l_x \leq r_x) denote the occurrences of x in B. (if there is only one occurrence, we say l_x = r_x.)
Note that this essentially gives us several segments within the array B.
It can be proved that f(B) is the result of repeatedly choosing the shortest such segment and deleting it.

Proof

Suppose the segment corresponding to element x is deleted immediately before the one corresponding to element y.
Then,

  • If these two segments don’t intersect, the order they’re deleted in doesn’t matter.
    So, we can always just delete the smaller one first.
  • If one segment is completely contained inside the other one, it’s clearly better to delete it first: this will reduce the cost of deleting the larger one by 2.
  • If the segments intersect (but don’t satisfy a containment relation), it again doesn’t matter which one is deleted first - the total cost remains the same.
    This is because the segment deleted later will contribute a cost of 1 to the segment deleted earlier.
    Again, we can just delete the smaller segment first.

Note that the above proof also gives us a different characterization of f(B): it equals the number of intervals, plus the number of pairs of intervals that intersect each other (note that containment is not considered intersection here).


Using this characterization, let’s move to computing the answer on a tree.
Root the tree at vertex 1, and let \text{subsz}[u] denote the subtree size of vertex u.
Let u_x and v_x denote the two vertices labelled x.

To compute the answer, we’ll compute contributions.

  • For each value x, count the number of paths that contain both u_x and v_x.
    x adds 1 to the answer of all such paths.
    • This is easy: if u_x and v_x aren’t ancestors/descendants of each other the number of paths is \text{subsz}[u_x] \cdot \text{subsz}[v_x].
      Otherwise, assuming u_x is the lower vertex, it’s \text{subsz}[u_x] \cdot (2N - \text{subsz}[c]), where c is the child of v_x that contains u_x.
  • Then, for each pair of values x and y such that the paths (u_x, v_x) and (u_y, v_y) intersect, we’d like to count the number of paths that contain all four of these vertices.

The second part can be solved in various ways, one of them will be detailed below.

Let’s fix a vertex x, and look at its vertices u_x and v_x.

Observe that for a value y to intersect with x, either u_y or v_y should lie on the path (u_x, v_x), and the other vertex must lie within either the subtree of u_x, or within the subtree of v_x.
In any other configuration, there cannot exist a path that contains all four of these vertices.
Without loss of generality, let’s say u_y lies on the (u_x, v_x) path.
Then,

  • If v_y lies in the subtree of u_x, the number of paths containing all four vertices is \text{subsz}[v_y] \cdot \text{subsz}[v_x].
    • We’d like to sum this across all valid y, so we really want to know the sum of \text{subsz}[v_y] across all such vertices.
  • If v_y lies in the subtree of u_y, the number of paths is \text{subsz}[v_y] \cdot \text{subsz}[u_x].
    • Again, we only need the sum of \text{subsz}[v_y] across all such vertices.

Note that if either u_x or v_x is an ancestor of the other one, the analysis is slightly different - only the subtree of the lower vertex needs to be considered, and the multiplier is a bit different too, being N - \text{subsz}[c], where c is the child of the higher vertex that contains the lower one.

Now, since the (u_x, v_x) path can be broken up into root → vertex paths (with the help of their LCA), we only really need to be able to answer the following query:

  • Given two vertices a and b, find the sum of \text{subsz}[v_y] across all v_y such that v_y lies in the subtree of b, and u_y lies on the root → a path.

Queries of this type are solvable offline with the help of Euler tours and a point-add range-sum data structure (segment trees/fenwick trees), as follows:

  • Let \text{in}[u] and \text{out}[u] denote the DFS in and out times of vertex u.
  • Perform a DFS starting from the root.
    When at vertex u,
    • Let \text{oth}[u] be the other vertex that has the same label as u.
    • Assign the value \text{subsz}[\text{oth}[u]] to position \text{in}[\text{oth}[u]] in the data structure.
    • Then, answer all queries involving the root → u path.
      • For the subtree of vertex v, this simply corresponds to computing a range sum from indices \text{in}[v] to \text{out}[v] in the data structure (\text{out}[v] being excluded).
  • Once all queries at u have been processed, continue the DFS into the children of u.
  • Finally, when leaving u, set the value at position \text{in}[\text{oth}[u]] back to 0.

This way, during the DFS we ensure that only values on the path from the root to the current vertex are considered, which is why a simple subtree sum suffices.

TIME COMPLEXITY:

\mathcal{O}(N\log N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h> 
using namespace std;
#define ll long long
#define nline "\n"
#define all(x) x.begin(),x.end()
const ll MAX=500500;
const ll till=20;
const ll MOD=998244353;
struct FenwickTree{
    vector<ll> bit; 
    ll n;
    FenwickTree(ll n){
        this->n = n;
        bit.assign(n, 0);
    }
    FenwickTree(vector<ll> a):FenwickTree(a.size()){  
        ll x=a.size();
        for(size_t i=0;i<x;i++)
            add(i,a[i]);
    }
    ll sum(ll r) {
        ll ret=0;
        for(;r>=0;r=(r&(r+1))-1){
            ret+=bit[r];
        }
        return ret;
    }
    ll sum(ll l,ll r) {
        if(l>r)
            return 0;
        return sum(r)-sum(l-1);
    }
    void add(ll idx,ll delta) {
        for(;idx<n;idx=idx|(idx+1)){
            bit[idx]+=delta;
        }
    }
};
vector<ll> track;
vector<vector<ll>> adj;
ll subtree[MAX],tin[MAX],tout[MAX],now=1;
ll jump[MAX][till];
void dfs(ll cur,ll par){
	subtree[cur]=1;
	tin[cur]=now++;
	track.push_back(cur);
	jump[cur][0]=par;
	for(ll i=1;i<till;i++){
		jump[cur][i]=jump[jump[cur][i-1]][i-1];
	}
	for(auto chld:adj[cur]){
		if(chld!=par){
			dfs(chld,cur);
			subtree[cur]+=subtree[chld];
		}
	}
	tout[cur]=now++;
	track.push_back(-cur);
}
ll is_ancestor(ll v,ll u){
	if(tin[v]<=tin[u] and tout[v]>=tout[u]){
		return 1;
	}
	else{
		return 0;
	}
}
ll binpow(ll a,ll b,ll MOD){
    ll ans=1;
    a%=MOD;  
    while(b){
        if(b&1)
            ans=(ans*a)%MOD;
        b/=2;
        a=(a*a)%MOD;
    }
    return ans;
}
ll inverse(ll a,ll MOD){
    return binpow(a,MOD-2,MOD);
} 
void solve(){
	ll n; cin>>n;
	now=1;
	track={0};
	ll len=2*n;
	adj=vector<vector<ll>>();
	adj.resize(len+5);
	vector<ll> a(len+5),l(len+5,len),r(len+5);
	for(ll i=1;i<=len;i++){
		cin>>a[i];
		l[a[i]]=min(l[a[i]],i);
		r[a[i]]=i;
	}
	for(ll i=1;i<len;i++){
	    ll u,v; cin>>u>>v;
	    adj[u].push_back(v);
	    adj[v].push_back(u);
	}
	dfs(1,1);
	ll ans=0;
	auto get_almost_lca=[&](ll u,ll anot){
		if(is_ancestor(u,anot)){
			assert(0);
		}
		for(ll b=till-1;b>=0;b--){
			if(!is_ancestor(jump[u][b],anot)){
				u=jump[u][b];
			}
		}
		return u;
	};
	FenwickTree ft(now+5);
	vector<ll> marked(n+5,0);
	vector<ll> from_self(n+5,0),from_anot(n+5,0);
	for(ll i=1;i<now;i++){
		ll node=abs(track[i]);
		ll anot=l[a[node]];
		if(anot==node){
			anot=r[a[node]];
		}
		if(is_ancestor(node,anot)){
		    //cout<<node<<" "<<anot<<nline;
			continue;
		}
		ll mul=1;
		if(is_ancestor(anot,node)){
			mul=len-subtree[get_almost_lca(node,anot)];
		}
		else{
			mul=subtree[anot];
		}
		ll lc=jump[get_almost_lca(node,anot)][0];
		ll sum=2ll*ft.sum(tin[lc],tin[node]-1)+(1+is_ancestor(anot,node))*ft.sum(tin[lc]+1,tin[anot]-1);
		if(anot==lc){
			sum-=2ll*ft.sum(tin[lc],tin[lc]);
		}
		//cout<<a[node]<<" "<<track[i]<<" "<<anot<<" "<<lc<<" "<<sum<<" "<<mul<<" "<<subtree[node]<<nline;
		ll inc=ans;
		ll val=a[node];
		if(track[i]>=1){
			if(!marked[a[node]]){
				ans=(ans+2ll*mul*subtree[node])%MOD;
				marked[a[node]]=1;
				from_self[val]+=2ll*mul*subtree[node];
			}
			ans=(ans-sum*mul)%MOD;
			from_anot[val]-=sum*mul;
			ft.add(tin[anot],subtree[node]);
			ft.add(tout[anot],-subtree[node]);	
		}
		else{
			ans=(ans+sum*mul)%MOD;
			from_anot[val]+=sum*mul;
		}
		// cout<<ans-inc<<nline;
		// cout<<nline;
	}
	
	ans=(ans+MOD)%MOD;
	//cout<<ans<<nline;
	ans=(ans*inverse(2,MOD))%MOD;
	// for(ll i=1;i<=n;i++){
	// 	cout<<from_self[i]<<" "<<from_anot[i]<<nline;
	// }
	cout<<ans<<nline;
}
int main()                                                                                 
{         
  ios_base::sync_with_stdio(false);                         
  cin.tie(NULL);                                  
  ll test_cases=1;                 
  cin>>test_cases;
  while(test_cases--){
      solve();
  }
  cout<<fixed<<setprecision(10);
  cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
} 
Tester's code (apoorv_me, C++)
#include<bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...)
#endif

template<class T>
struct RMQ{
    int n, logn;
    vector<vector<int>> b;
    vector<T> A;
    void build(const vector<T> &a) {
        A = a, n = (int)a.size();
        logn = 32 - __builtin_clz(n);
        b.resize(logn, vector<int>(n));
        iota(b[0].begin(), b[0].end(), 0);
        for(int i = 1; i < logn ; i++){
            for(int j = 0; j < n ; j++){
                b[i][j] = b[i - 1][j];
                if(j + (1 << (i - 1)) < n && A[b[i - 1][j + (1 << (i - 1))]] < A[b[i][j]])
                    b[i][j] = b[i - 1][j + (1 << (i - 1))];
            }
        }
    }
    int rangeMin(int x, int y){
        int k = 31 - __builtin_clz(y - x + 1);
        return min(A[b[k][x]], A[b[k][y - (1 << k) + 1]]);
    }
    int minIndx(int x, int y){
        int k = 31 - __builtin_clz(y - x + 1);
        return A[b[k][x]] < A[b[k][y - (1 << k) + 1]] ? b[k][x] : b[k][y - (1 << k) + 1];
    }
};

namespace mint_ns {
template<auto P>
struct Modular {
    using value_type = decltype(P);
    value_type value;
 
    Modular(long long k = 0) : value(norm(k)) {}
 
    friend Modular<P>& operator += (      Modular<P>& n, const Modular<P>& m) { n.value += m.value; if (n.value >= P) n.value -= P; return n; }
    friend Modular<P>  operator +  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r += m; }
 
    friend Modular<P>& operator -= (      Modular<P>& n, const Modular<P>& m) { n.value -= m.value; if (n.value < 0)  n.value += P; return n; }
    friend Modular<P>  operator -  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r -= m; }
    friend Modular<P>  operator -  (const Modular<P>& n)                      { return Modular<P>(-n.value); }
 
    friend Modular<P>& operator *= (      Modular<P>& n, const Modular<P>& m) { n.value = n.value * 1ll * m.value % P; return n; }
    friend Modular<P>  operator *  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r *= m; }
 
    friend Modular<P>& operator /= (      Modular<P>& n, const Modular<P>& m) { return n *= m.inv(); }
    friend Modular<P>  operator /  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r /= m; }
 
    Modular<P>& operator ++ (   ) { return *this += 1; }
    Modular<P>& operator -- (   ) { return *this -= 1; }
    Modular<P>  operator ++ (int) { Modular<P> r = *this; *this += 1; return r; }
    Modular<P>  operator -- (int) { Modular<P> r = *this; *this -= 1; return r; }
 
    friend bool operator == (const Modular<P>& n, const Modular<P>& m) { return n.value == m.value; }
    friend bool operator != (const Modular<P>& n, const Modular<P>& m) { return n.value != m.value; }
 
    explicit    operator       int() const { return value; }
    explicit    operator      bool() const { return value; }
    explicit    operator long long() const { return value; }
 
    constexpr static value_type mod()      { return     P; }
 
    value_type norm(long long k) {
        if (!(-P <= k && k < P)) k %= P;
        if (k < 0) k += P;
        return k;
    }
 
    Modular<P> inv() const {
        value_type a = value, b = P, x = 0, y = 1;
        while (a != 0) { value_type k = b / a; b -= k * a; x -= k * y; swap(a, b); swap(x, y); }
        return Modular<P>(x);
    }
 
    friend void __print(Modular<P> v) {
        cerr << v.value;
    }
};
template<auto P> Modular<P> pow(Modular<P> m, long long p) {
    Modular<P> r(1);
    while (p) {
        if (p & 1) r *= m;
        m *= m;
        p >>= 1;
    }
    return r;
}
 
template<auto P> ostream& operator << (ostream& o, const Modular<P>& m) { return o << m.value; }
template<auto P> istream& operator >> (istream& i,       Modular<P>& m) { long long k; i >> k; m.value = m.norm(k); return i; }
template<auto P> string   to_string(const Modular<P>& m) { return to_string(m.value); }
 
}
constexpr int mod = 998244353;
using mod_int = mint_ns::Modular<mod>;


struct LCA {
    vector<int> tour, Findx, dpth;
    RMQ<int> rmq;
    void build(const vector<vector<int>> &adj, int src = 0) {
        vector<bool> vis((int)adj.size());
        vector<int> dpth1((int)adj.size());
        function<void(int, int)> dfs = [&](int i, int d) {
            tour.push_back(i);
            vis[i] = 1;
            dpth1[i] = d;
            for(auto &u: adj[i]) if(!vis[u])    dfs(u, d + 1), tour.push_back(i);
        };
        dfs(src, 0);
        Findx.resize((int)adj.size());
        dpth.resize((int)tour.size());
        for(int i = (int)tour.size() - 1 ; i >= 0 ; i--) {
            dpth[i] = dpth1[tour[i]], Findx[tour[i]] = i;
        }
        rmq.build(dpth);
    }
    int lca(int x, int y) {
        x = Findx[x], y = Findx[y];
        if(x > y)     swap(x, y);
        return tour[rmq.minIndx(x, y)];
    }
    int dist(int x, int y) {
        x = Findx[x], y = Findx[y];
        if(x > y)     swap(x, y);
        return dpth[x] + dpth[y] - 2 * rmq.rangeMin(x, y);
    }
};

template <typename T>
struct BIT {
  int N;
  vector<T> tree;

  BIT(int n_) : N(n_ + 1) {
    tree.resize(N);
  }

  void update(int ind, T val) {
    for(++ind ; ind < N ; ind += ind & -ind)	tree[ind] += val;
  }

  T query(int in) {
    T sum = 0;
    for(++in ; in > 0 ; in -= in & -in)	sum += tree[in];
    return sum;
  }

  T query(int l, int r) {
    return query(r) - query(l - 1);
  }
};

struct HLD{
    vector<int> parent, depth, heavy, head, pos;
    int cur_pos;
    BIT<mod_int> bit;

    int dfs(int v, vector<vector<int>> const& adj) {
        int size = 1;
        int max_c_size = 0;
        // dbg(v, parent[v]);
        assert(parent[0] == -1);
        for (int c : adj[v]) {
            if (c != parent[v]) {
                parent[c] = v, depth[c] = depth[v] + 1;
                int c_size = dfs(c, adj);
                size += c_size;
                if (c_size > max_c_size)
                    max_c_size = c_size, heavy[v] = c;
            }
        }
        return size;
    }

    void decompose(int v, int h, vector<vector<int>> const& adj) {
        head[v] = h, pos[v] = cur_pos++;
        if (heavy[v] != -1)
            decompose(heavy[v], h, adj);
        for (int c : adj[v]) {
            if (c != parent[v] && c != heavy[v])
                decompose(c, c, adj);
        }
    }

    HLD(vector<vector<int>> const& adj) : bit(adj.size()){
        int n = adj.size();
        parent = vector<int>(n, -1);
        depth = vector<int>(n);
        heavy = vector<int>(n, -1);
        head = vector<int>(n);
        pos = vector<int>(n);
        cur_pos = 0;
        dbg(parent);
        dfs(0, adj);
        decompose(0, 0, adj);
    }

    mod_int query(int a) {
      return bit.query(pos[a]);
    }

    void update(int a, int b, mod_int v) {
        for (; head[a] != head[b]; b = parent[head[b]]) {
            if (depth[head[a]] > depth[head[b]])
                swap(a, b);
            bit.update(pos[head[b]], v);
            bit.update(pos[b] + 1, -v);
        }
        if (depth[a] > depth[b])
            swap(a, b);
        bit.update(pos[a], v);
        bit.update(pos[b] + 1, -v);
    }
};

struct BLD{
  int N;
  vector<vector<int>> adj;
  LCA lca;
  vector<mod_int> A;
  BLD(vector<vector<int>> &g) {
    adj = g;
    N = g.size();
    lca.build(adj);
    A.resize(g.size());
  }

  void update(int u, int v, mod_int val) {
    for(int i = 0 ; i < N ; ++i) {
      if(lca.dist(i, u) + lca.dist(i, v) == lca.dist(u, v))
        A[i] += val;
    }
  }
  mod_int query(int v) {
    return A[v];
  }
};

int32_t main() {
  ios_base::sync_with_stdio(0);
  cin.tie(0);

  auto __solve_testcase = [&](int test) {
    int N;  cin >> N;
    vector<vector<int>> adj(2 * N), col(N);
    for(int i = 0 ; i < 2 * N ; ++i) {
      int x;  cin >> x;
      col[x - 1].push_back(i);
    }
    for(int i = 1 ; i < 2 * N ; ++i) {
      int u, v; cin >> u >> v;
      adj[u - 1].push_back(v - 1);
      adj[v - 1].push_back(u - 1);
    }

    LCA lca;  lca.build(adj);

    vector<int> other(2 * N);
    for(int i = 0 ; i < N ; ++i) {
      for(int j = 0 ; j < 2 ; ++j) {
        other[col[i][j]] = col[i][j ^ 1];
      }
    }

    vector<int> sub(2 * N, 1), dep(2 * N), in(2 * N), out(2 * N);
    int cnt = 0;
    auto getd = [&](auto &&getd, int node, int par) -> void {
      in[node] = cnt++;
      for(auto &u: adj[node]) if(u != par) {
        dep[u] = dep[node] + 1;
        getd(getd, u, node);
        sub[node] += sub[u];
      }
      out[node] = cnt;
    };
    getd(getd, 0, -1);

    mod_int ans = 0;
    vector<bool> vis(2 * N);

    HLD hld(adj);
    N *= 2;
    vector<int> path;
    path.reserve(N);
    auto dfs = [&](auto &&dfs, int node) -> void {
      vis[node] = 1;
      bool upd = false;
      path.push_back(node);
      if(vis[other[node]]) {
        dbg(node, other[node], hld.A);
        upd = true;
        ans += sub[node] * hld.query(other[node]);
        dbg(ans);
        if(lca.lca(node, other[node]) == other[node]) {
          ans += sub[node] * 1ll * (N - sub[path[dep[other[node]] + 1]]);
          hld.update(node, other[node], N - sub[path[dep[other[node]] + 1]]);
        } else {
          ans += sub[node] * 1ll * sub[other[node]];
          hld.update(node, other[node], sub[other[node]]);
        }
        dbg(ans);
      }
      for(auto &u: adj[node]) if(!vis[u]) {
        dfs(dfs, u);
        if(in[other[node]] >= in[u] && out[other[node]] <= out[u]) {
          if(lca.lca(node, other[node]) == node) {
            hld.update(node, other[node], sub[other[node]]);        
          }
        }
      }
      if(upd) {
        if(lca.lca(node, other[node]) == other[node]) {
          hld.update(node, other[node], sub[path[dep[other[node]] + 1]] - N);
        } else {
          hld.update(node, other[node], -sub[other[node]]);
        }
      }
      path.pop_back();
    };
    dfs(dfs, 0);

    cout << ans << '\n';

  };
  
  int NumTest = 1;
  cin >> NumTest;
  for(int testno = 1; testno <= NumTest ; ++testno) {
    __solve_testcase(testno);
  }
  
  return 0;
}

Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

/**
 * Integers modulo p, where p is a prime
 * Source: Aeren (modified from tourist?)
 *         Modmul for 64-bit mod from kactl:ModMulLL
 * Works with p < 7.2e18 with x87 80-bit long double, and p < 2^52 ~ 4.5e12 with 64-bit
 */
template<typename T>
struct Z_p{
    using Type = typename decay<decltype(T::value)>::type;
    static vector<Type> MOD_INV;
    constexpr Z_p(): value(){ }
    template<typename U> Z_p(const U &x){ value = normalize(x); }
    template<typename U> static Type normalize(const U &x){
        Type v;
        if(-mod() <= x && x < mod()) v = static_cast<Type>(x);
        else v = static_cast<Type>(x % mod());
        if(v < 0) v += mod();
        return v;
    }
    const Type& operator()() const{ return value; }
    template<typename U> explicit operator U() const{ return static_cast<U>(value); }
    constexpr static Type mod(){ return T::value; }
    Z_p &operator+=(const Z_p &otr){ if((value += otr.value) >= mod()) value -= mod(); return *this; }
    Z_p &operator-=(const Z_p &otr){ if((value -= otr.value) < 0) value += mod(); return *this; }
    template<typename U> Z_p &operator+=(const U &otr){ return *this += Z_p(otr); }
    template<typename U> Z_p &operator-=(const U &otr){ return *this -= Z_p(otr); }
    Z_p &operator++(){ return *this += 1; }
    Z_p &operator--(){ return *this -= 1; }
    Z_p operator++(int){ Z_p result(*this); *this += 1; return result; }
    Z_p operator--(int){ Z_p result(*this); *this -= 1; return result; }
    Z_p operator-() const{ return Z_p(-value); }
    template<typename U = T>
    typename enable_if<is_same<typename Z_p<U>::Type, int>::value, Z_p>::type &operator*=(const Z_p& rhs){
        #ifdef _WIN32
        uint64_t x = static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value);
        uint32_t xh = static_cast<uint32_t>(x >> 32), xl = static_cast<uint32_t>(x), d, m;
        asm(
            "divl %4; \n\t"
            : "=a" (d), "=d" (m)
            : "d" (xh), "a" (xl), "r" (mod())
        );
        value = m;
        #else
        value = normalize(static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value));
        #endif
        return *this;
    }
    template<typename U = T>
    typename enable_if<is_same<typename Z_p<U>::Type, int64_t>::value, Z_p>::type &operator*=(const Z_p &rhs){
        uint64_t ret = static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value) - static_cast<uint64_t>(mod()) * static_cast<uint64_t>(1.L / static_cast<uint64_t>(mod()) * static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value));
        value = normalize(static_cast<int64_t>(ret + static_cast<uint64_t>(mod()) * (ret < 0) - static_cast<uint64_t>(mod()) * (ret >= static_cast<uint64_t>(mod()))));
        return *this;
    }
    template<typename U = T>
    typename enable_if<!is_integral<typename Z_p<U>::Type>::value, Z_p>::type &operator*=(const Z_p &rhs){
        value = normalize(value * rhs.value);
        return *this;
    }
    template<typename U>
    Z_p &operator^=(U e){
        if(e < 0) *this = 1 / *this, e = -e;
        Z_p res = 1;
        for(; e; *this *= *this, e >>= 1) if(e & 1) res *= *this;
        return *this = res;
    }
    template<typename U>
    Z_p operator^(U e) const{
        return Z_p(*this) ^= e;
    }
    Z_p &operator/=(const Z_p &otr){
        Type a = otr.value, m = mod(), u = 0, v = 1;
        if(a < (int)MOD_INV.size()) return *this *= MOD_INV[a];
        while(a){
            Type t = m / a;
            m -= t * a; swap(a, m);
            u -= t * v; swap(u, v);
        }
        assert(m == 1);
        return *this *= u;
    }
    template<typename U> friend const Z_p<U> &abs(const Z_p<U> &v){ return v; }
    Type value;
};
template<typename T> bool operator==(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value == rhs.value; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(const Z_p<T>& lhs, U rhs){ return lhs == Z_p<T>(rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) == rhs; }
template<typename T> bool operator!=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(const Z_p<T> &lhs, U rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(U lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T> bool operator<(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value < rhs.value; }
template<typename T> bool operator>(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value > rhs.value; }
template<typename T> bool operator<=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value <= rhs.value; }
template<typename T> bool operator>=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value >= rhs.value; }
template<typename T> Z_p<T> operator+(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(const Z_p<T> &lhs, U rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T> Z_p<T> operator-(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T> Z_p<T> operator*(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T> Z_p<T> operator/(const Z_p<T> &lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(const Z_p<T>& lhs, U rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(U lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T> istream &operator>>(istream &in, Z_p<T> &number){
    typename common_type<typename Z_p<T>::Type, int64_t>::type x;
    in >> x;
    number.value = Z_p<T>::normalize(x);
    return in;
}
template<typename T> ostream &operator<<(ostream &out, const Z_p<T> &number){ return out << number(); }

/*
using ModType = int;
struct VarMod{ static ModType value; };
ModType VarMod::value;
ModType &mod = VarMod::value;
using Zp = Z_p<VarMod>;
*/

// constexpr int mod = 1e9 + 7; // 1000000007
constexpr int mod = (119 << 23) + 1; // 998244353
// constexpr int mod = 1e9 + 9; // 1000000009
using Zp = Z_p<integral_constant<decay<decltype(mod)>::type, mod>>;

template<typename T> vector<typename Z_p<T>::Type> Z_p<T>::MOD_INV;
template<typename T = integral_constant<decay<decltype(mod)>::type, mod>>
void precalc_inverse(int SZ){
    auto &inv = Z_p<T>::MOD_INV;
    if(inv.empty()) inv.assign(2, 1);
    for(; inv.size() <= SZ; ) inv.push_back((mod - 1LL * mod / (int)inv.size() * inv[mod % (int)inv.size()]) % mod);
}

template<typename T>
vector<T> precalc_power(T base, int SZ){
    vector<T> res(SZ + 1, 1);
    for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * base;
    return res;
}

template<typename T>
vector<T> precalc_factorial(int SZ){
    vector<T> res(SZ + 1, 1); res[0] = 1;
    for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * i;
    return res;
}

template<class T>
struct RMQ {
    vector<vector<T>> jmp;
    RMQ(const vector<T>& V) : jmp(1, V) {
        for (int pw = 1, k = 1; pw * 2 <= (int)size(V); pw *= 2, ++k) {
            jmp.emplace_back(size(V) - pw * 2 + 1);
            for (int j = 0; j < (int)size(jmp[k]); ++j)
                jmp[k][j] = min(jmp[k - 1][j], jmp[k - 1][j + pw]);
        }
    }
    T query(int a, int b) {
        assert(a < b); // or return inf if a == b
        int dep = 31 - __builtin_clz(b - a);
        return min(jmp[dep][a], jmp[dep][b - (1 << dep)]);
    }
};

struct LCA {
    int T = 0;
    vector<int> time, path, ret, depth;
    RMQ<int> rmq;

    LCA(vector<vector<int>>& C) : time(size(C)), depth(size(C)), rmq((dfs(C,0,-1), ret)) {}
    void dfs(vector<vector<int>>& C, int v, int par) {
        time[v] = T++;
        for (int y : C[v]) if (y != par) {
            depth[y] = 1 + depth[v];
            path.push_back(v), ret.push_back(time[v]);
            dfs(C, y, v);
        }
    }

    int lca(int a, int b) {
        if (a == b) return a;
        tie(a, b) = minmax(time[a], time[b]);
        return path[rmq.query(a, b)];
    }
    int dist(int a, int b) {
        return depth[a] + depth[b] - 2*depth[lca(a,b)];
    }
};

/**
 * Point-update Segment Tree
 * Source: kactl
 * Description: Iterative point-update segment tree, ranges are half-open i.e [L, R).
 *              f is any associative function.
 * Time: O(logn) update/query
 */

template<class T, T unit = T()>
struct SegTree {
    T f(T a, T b) { return a+b; }
    vector<T> s; int n;
    SegTree(int _n = 0, T def = unit) : s(2*_n, def), n(_n) {}
    void update(int pos, T val) {
        for (s[pos += n] = val; pos /= 2;)
            s[pos] = f(s[pos * 2], s[pos * 2 + 1]);
    }
    T query(int b, int e) {
        T ra = unit, rb = unit;
        for (b += n, e += n; b < e; b /= 2, e /= 2) {
            if (b % 2) ra = f(ra, s[b++]);
            if (e % 2) rb = f(s[--e], rb);
        }
        return f(ra, rb);
    }
};

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector<int> a(2*n);
        vector<array<int, 2>> pos(n, array{-1, -1});
        for (int i = 0; i < 2*n; ++i) {
            cin >> a[i];
            --a[i];
            if (pos[a[i]][0] == -1) pos[a[i]][0] = i;
            else pos[a[i]][1] = i;
        }

        n *= 2;
        vector adj(n, vector<int>());
        for (int i = 0; i < n - 1; ++i) {
            int u, v; cin >> u >> v;
            adj[--u].push_back(--v);
            adj[v].push_back(u);
        }

        vector<ll> subsz(n), in(n), out(n);
        vector<int> par(n);
        int timer = 0;
        auto dfs = [&] (const auto &self, int u, int p) -> void {
            in[u] = timer++;
            subsz[u] = 1;
            for (int v : adj[u]) {
                if (v == p) continue;
                par[v] = u;
                self(self, v, u);
                subsz[u] += subsz[v];
            }
            out[u] = timer;
        };
        dfs(dfs, 0, 0);
        LCA L(adj);
        for (int u = 0; u < n; ++u)
            sort(begin(adj[u]), end(adj[u]), [&] (int x, int y) {return in[x] < in[y];});

        using Query = tuple<int, int, Zp>;
        vector queries(n, vector<Query>());
        array<Zp, 2> ans = {0, 0};
        for (int c = 0; c < n/2; ++c) {
            auto [u, v] = pos[c];
            int l = L.lca(u, v);

            if (l != u and l != v) {
                queries[u].push_back({u, 0, subsz[v]});
                queries[v].push_back({v, 0, subsz[u]});

                if (l) {
                    queries[par[l]].push_back({u, 0, -subsz[v]});
                    queries[par[l]].push_back({v, 0, -subsz[u]});
                }

                queries[par[u]].push_back({v, 1, subsz[u]});
                queries[par[v]].push_back({u, 1, subsz[v]});
                queries[l].push_back({v, 1, -subsz[u]});
                queries[l].push_back({u, 1, -subsz[v]});
                
                ans[0] += Zp(1) * subsz[u] * subsz[v];
            }
            else {
                int bot = u + v - l;
                if (par[u] == v or par[v] == u) {
                    ans[0] += Zp(1) * subsz[bot] * (n - subsz[bot]);
                    continue;
                }
                int child = 0;
                {
                    int lo = 0, hi = adj[l].size() - 1;
                    while (lo < hi) {
                        int mid = (lo + hi + 1) / 2;
                        int x = adj[l][mid];
                        if (in[x] > in[bot]) hi = mid - 1;
                        else lo = mid;
                    }
                    child = adj[l][lo];
                }
                assert(par[child] == l);
                assert(in[bot] >= in[child] and out[bot] <= out[child]);
                queries[bot].push_back({bot, 0, n - subsz[child]});
                queries[l].push_back({bot, 0, subsz[child] - n});
                
                ans[0] += Zp(1) * subsz[bot] * (n - subsz[child]);
            }
        }

        SegTree<Zp> seg(n);
        auto solve = [&] (const auto &self, int u, int p) -> void {
            int oth = pos[a[u]][0] + pos[a[u]][1] - u;
            seg.update(in[oth], subsz[oth]);
            for (auto [v, id, mul] : queries[u]) {
                ans[id] += mul * seg.query(in[v], out[v]);
            }
            
            for (int v : adj[u]) {
                if (v != p) self(self, v, u);
            }
            seg.update(in[oth], 0);
        };
        solve(solve, 0, 0);
        cout << ans[0] + ans[1]/2 << '\n';
    }
}