APTREE - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: yash5507
Testers: IceKnight1093, tabr
Editorialist: IceKnight1093

DIFFICULTY:

3244

PREREQUISITES:

Binary lifting or heavy-light decomposition

PROBLEM:

Given a tree on N vertices where the i-th vertex has A_i written on it, answer Q queries of the following form:

  • Given u and v, find the longest arithmetic progression on the simple path from u to v.
    Note that this arithmetic progression should itself be a path.

EXPLANATION:

This is pretty much a pure data structure problem. There are a couple of different ways to solve it, I’ll explain one below.

A simpler version

First, let’s solve a simpler version of the queries: instead of arbitrary u and v, let’s assume u is an ancestor of v (when the tree is rooted at 1).

To solve this, let’s transform the problem a bit:

Root the tree at 1 and let p_u denote the parent of u.
Write the value A_u - A_{p_u} on the edge between u and p_u.
Then, notice that “longest arithmetic progression on path” is now asking for “longest contiguous set of equal values on path”, which is a bit easier to deal with.

Several different data structures can solve this query: for example, binary lifting or heavy-light decomposition.

The devil is in the details here: you’ll need to maintain several different quantities and merge them correctly when doing binary lifting/segtree merging.
For instance, the editorialist’s code linked below uses HLD, and maintains the following quantities for each segment tree node that represents some path:

  • The answer for the path, i.e, the longest set of equal values
  • The length of the path
  • The first edge value seen on this path, and how many of them form the prefix
  • The last edge value seen on this path, and how many of them form the suffix
  • Merging two adjacent nodes requires merging these values appropriately, which takes a bit of casework (or represent the whole state as a matrix and use matrix multiplication, which is probably a bit less typing)

Binary lifting will require you to maintain several similar quantities in your lifting table.

Either way, this allows us to solve for a single query in \mathcal{O}(\log^2 N) time, given that u is an ancestor of v.

The original problem

Let’s now deal with arbitrary (u, v) queries.

Let L = lca(u, v). Finding L can be done in \mathcal{O}(\log N) or \mathcal{O}(1) in various ways.

First, let’s apply our above solution independently to (L, u) and (L, v).
Now, note that the only thing we’re missing is paths that pass through L and don’t have it as one endpoint.

However, such paths can be considered by just merging the answers for (L, u) and (L, v) appropriately: for example, keep the (L, u) answer as it is, and negate the differences of the (L, v) answer; then merge them.

Our merge function is \mathcal{O}(1), so each query is now answered in \mathcal{O}(\log^2 N).

Once again, this problem’s difficulty mainly lies in correctly working out the details: for example, if you’re using HLD, you need to be careful about the order in which merges are done, since the merge operation is not associative. I recommend looking through the code below if you’re stuck and can’t debug.

TIME COMPLEXITY:

\mathcal{O}((N+Q)\log^2 N) per testcase.

CODE:

Setter's code (C++, binary lifting)
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long 
#define pb(e) push_back(e)
#define sv(a) sort(a.begin(),a.end())
#define sa(a,n) sort(a,a+n)
#define mp(a,b) make_pair(a,b)
#define vf first
#define vs second
#define ar array
#define all(x) x.begin(),x.end()
const int inf = 0x3f3f3f3f;
const int mod = 998244353; 
const double PI=3.14159265358979323846264338327950288419716939937510582097494459230;
bool remender(ll a , ll b){return a%b;}

//freopen("problemname.in", "r", stdin);
//freopen("problemname.out", "w", stdout);

struct item {
	int down , vald ,  up , valu , vald1 , valu1 , best , full; 
};

int isap(int a , int b , int c , int d){
	if(d == -1 && a == -1)return 1;
	if(a == -1){
		if(c - b == d - c)return 1;
		return 0;
	}
	if(d == -1){
		if(b - a == c - b)return 1;
		return 0;
	}
	if(b - a == c - b && c - b == d - c)return 1;
	int cnt = 0;
	if(c - b == b - a)cnt = 2;
	if(c - b == d - c){
		if(cnt == 1)cnt = 4;
		else cnt = 3;
	}
	return cnt;
}

const int N = 200003 , L = 22;

vector<int> adj[N];
int arr[N];
int timer , tin[N] , tout[N];
item up[N][L];
int p[N][L];

item merge(item a , item b , int pr = 0){
	item ans;
	ans.best = max(a.best , b.best);
	ans.valu = a.valu;
	ans.vald = b.vald;
	ans.valu1 = a.valu1;
	if(a.valu1 == -1){
		ans.valu1 = b.valu;
	}
	ans.vald1 = b.vald1;
	ans.full = 0;
	if(b.vald1 == -1){
		ans.vald1 = a.vald;
	}
	ans.up = a.up;
	ans.down = b.down;
	int x = isap(b.valu1 , b.valu , a.vald , a.vald1);
	if(x > 0){
		if(x == 1){
			ans.best = max(ans.best , a.down + b.up);
			if(a.full && b.full){
				ans.full = 1;
			}
			if(b.full){
				ans.down = a.down + b.up;
			}
			if(a.full){
				ans.up = a.down + b.up;
			}
		}
		else if(x == 2){
			ans.best = max(ans.best , b.up + 1);
			if(b.full){
				ans.down++;
			}
		}
		else if(x == 3){
			ans.best = max(ans.best , a.down + 1);
			if(a.full)ans.up++;
		}
		else {
			ans.best = max({ans.best , a.down + 1 , b.up + 1});
			if(b.full){
				ans.down++;
			}
			if(a.full)ans.up++;
		}
	}
	if(ans.full){
		ans.up = ans.down = ans.best;
	}
	ans.down = max(ans.down , 2);
	ans.up = max(ans.up , 2);
	return ans;
}

void dfs(int node , int par , int dis){
	tin[node] = timer++;
	up[node][0] = {1 , arr[node] , 1 , arr[node] , -1 , -1 , 1 , 1};
	p[node][0] = par;
	for(int i = 1; i < L; i++){
		if(dis < (1 << i)){
			up[node][i] = up[node][i-1];
			p[node][i] = p[p[node][i-1]][i-1];
			continue;
		}
		up[node][i] = merge(up[p[node][i-1]][i-1] , up[node][i-1]);
		p[node][i] = p[p[node][i-1]][i-1];
	}
	for(int i : adj[node]){
		if(i != par){
			dfs(i , node , dis + 1);
		}
	}
	tout[node] = timer++;
}

bool islca(int x , int y){
	return tin[x] <= tin[y] && tout[x] >= tout[y];
}

int find(int u , int v){
	if(islca(u , v))return u;
	else if(islca(v , u))return v;
	for(int i = L - 1; i >= 0; i--){
		if(!islca(p[u][i],v))u = p[u][i];
	}
	return p[u][0];
}

item corner(int lca , int x , int todo = 0){
	item cur = {1 , arr[x] , 1 , arr[x] , -1 , -1 , 1 ,1};
	x = p[x][0];
	for(int i = L - 1; i >= 0; i--){
		if(!islca(p[x][i] , lca)){
			cur = merge(up[x][i] , cur);
			x = p[x][i];
		}
	}
	if(x != lca){
		cur = merge(up[x][0], cur);
		x = p[x][0];
	}
	if(todo == 0)cur = merge(up[x][0], cur);
	return cur;
}

void solve(){
	int n;
	cin >> n;
	for(int i = 1; i <= n; i++)cin >> arr[i];	
	for(int i = 0; i < n-1; i++){
		int u , v;
		cin >> u >> v;
		adj[u].pb(v);
		adj[v].pb(u);
	}
	dfs(1 , 1 , 1);
	int q;
	cin >> q;
	while(q--){
		int u , v;
		cin >> u >> v;
		if(u == v){
			cout << 1 << '\n';
			continue;
		}
		int lca = find(u , v);
		if(lca == u){
			cout << corner(lca , v).best << '\n';
		}
		else if(lca == v){
			cout << corner(lca , u).best << '\n';
		}
		else {
			item x = corner(lca ,u);
			item y = corner(lca,  v,  1);
			swap(x.valu , x.vald);
			swap(x.valu1,x.vald1);
			swap(x.up , x.down);
			cout << merge(x,y).best << '\n';
		}
	}
}

int main(){
ios_base::sync_with_stdio(false);
cin.tie(NULL);
	//int t;cin >> t;while(t--)
	solve();
	return 0;
}
Editorialist's code (C++, HLD)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

template<class T>
struct RMQ {
    vector<vector<T>> jmp;
    RMQ(const vector<T>& V) : jmp(1, V) {
        for (int pw = 1, k = 1; pw * 2 <= (int)size(V); pw *= 2, ++k) {
            jmp.emplace_back(size(V) - pw * 2 + 1);
            for (int j = 0; j < (int)size(jmp[k]); ++j)
                jmp[k][j] = min(jmp[k - 1][j], jmp[k - 1][j + pw]);
        }
    }
    T query(int a, int b) {
        assert(a < b); // or return inf if a == b
        int dep = 31 - __builtin_clz(b - a);
        return min(jmp[dep][a], jmp[dep][b - (1 << dep)]);
    }
};

struct LCA {
    int T = 0;
    vector<int> time, path, ret;
    RMQ<int> rmq;

    LCA(vector<vector<int>>& C) : time(size(C)), rmq((dfs(C,0,-1), ret)) {}
    void dfs(vector<vector<int>>& C, int v, int par) {
        time[v] = T++;
        for (int y : C[v]) if (y != par) {
            path.push_back(v), ret.push_back(time[v]);
            dfs(C, y, v);
        }
    }

    int lca(int a, int b) {
        if (a == b) return a;
        tie(a, b) = minmax(time[a], time[b]);
        return path[rmq.query(a, b)];
    }
};

struct Node {
	int ans, preflen, suflen, len;
	int prefval, sufval;
	Node() : ans(0), preflen(0), suflen(0), len(0), prefval(INT_MAX), sufval(INT_MAX) {}
	Node(int x) {
		preflen = suflen = len = ans = 1;
		prefval = sufval = x;
	}
};
Node unit;

/**
 * Point-update Segment Tree
 * Source: kactl
 * Description: Iterative point-update segment tree, ranges are half-open i.e [L, R).
 *              f is any associative function.
 * Time: O(logn) update/query
 */

struct SegTree {
	Node f(Node a, Node b) {
		if (b.prefval == INT_MAX) return a;
		if (a.prefval == INT_MAX) return b;
		Node c;
		c.len = a.len + b.len;
		c.prefval = a.prefval; c.preflen = a.preflen;
		c.sufval = b.sufval; c.suflen = b.suflen;
		if (a.sufval == b.prefval) {
			if (a.suflen == a.len) c.preflen += b.preflen;
			if (b.preflen == b.len) c.suflen += a.suflen;
			c.ans = a.suflen + b.preflen;
		}
		c.ans = max({c.ans, a.ans, b.ans, c.preflen, c.suflen});
		return c;
	}
	vector<Node> s; int n;
	SegTree(int _n = 0) : s(2*_n), n(_n) {}
	void update(int pos, int val) {
		for (s[pos += n] = Node(val); pos /= 2;)
			s[pos] = f(s[pos * 2], s[pos * 2 + 1]);
	}
	Node query(int b, int e) {
		Node ra = unit, rb = unit;
		for (b += n, e += n; b < e; b /= 2, e /= 2) {
			if (b % 2) ra = f(ra, s[b++]);
			if (e % 2) rb = f(s[--e], rb);
		}
		return f(ra, rb);
	}
};

template <bool VALS_EDGES> struct HLD {
	int N, tim = 0;
	vector<vector<int>> adj;
	vector<int> par, siz, depth, rt, pos;
	SegTree seg;
	HLD(vector<vector<int>> adj_)
		: N(size(adj_)), adj(adj_), par(N, -1), siz(N, 1), depth(N),
		  rt(N),pos(N){ dfsSz(0); dfsHld(0); seg = SegTree(N);}
	void dfsSz(int v) {
		if (par[v] != -1) adj[v].erase(find(begin(adj[v]), end(adj[v]), par[v]));
		for (int& u : adj[v]) {
			par[u] = v, depth[u] = depth[v] + 1;
			dfsSz(u);
			siz[v] += siz[u];
			if (siz[u] > siz[adj[v][0]]) swap(u, adj[v][0]);
		}
	}
	void dfsHld(int v) {
		pos[v] = tim++;
		for (int u : adj[v]) {
			rt[u] = (u == adj[v][0] ? rt[v] : u);
			dfsHld(u);
		}
	}
	template <class B> void process(int u, int v, B op) {
		for (; rt[u] != rt[v]; v = par[rt[v]]) {
			if (depth[rt[u]] > depth[rt[v]]) swap(u, v);
			op(pos[rt[v]], pos[v] + 1);
		}
		if (depth[u] > depth[v]) swap(u, v);
		op(pos[u] + VALS_EDGES, pos[v] + 1);
	}
	void modifyPath(int u, int v, int val) {
		process(u, v, [&](int l, int r) {
			if (l < r) seg.update(l, val); 
		});
	}
	Node queryPath(int u, int v) { // Modify depending on problem
		Node res = unit;
		process(u, v, [&](int l, int r) {
			auto cur = seg.query(l, r);
			swap(cur.prefval, cur.sufval);
			swap(cur.preflen, cur.suflen);
			res = seg.f(res, cur);
		});
		return res;
	}
};

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);
	
	int n; cin >> n;
	vector<int> a(n);
	for (auto &x : a) cin >> x;
	vector<vector<int>> g(n);
	for (int i = 0; i < n-1; ++i) {
		int u, v; cin >> u >> v;
		g[--u].push_back(--v);
		g[v].push_back(u);
	}
	HLD<true> hld(g);
	LCA lca(g);
	for (int u = 1; u < n; ++u) {
		int p = hld.par[u];
		hld.modifyPath(u, p, a[u] - a[p]);
	}
	int q; cin >> q;
	while (q--) {
		int u, v; cin >> u >> v; --u, --v;
		int l = lca.lca(u, v);
		auto left = hld.queryPath(u, l);
		auto right = hld.queryPath(v, l);
		auto actual = left;
		if (left.prefval == INT_MAX) actual = right;
		else if (right.prefval != INT_MAX) {
			right.prefval *= -1;
			right.sufval *= -1;
			swap(right.prefval, right.sufval);
			swap(right.suflen, right.preflen);
			actual = SegTree().f(left, right);
		}
		cout << actual.ans+1 << '\n';
	}
}