APTREE - Editorial

Author: yash5507
Testers: IceKnight1093, tabr
Editorialist: IceKnight1093

3244

PROBLEM:

Given a tree on N vertices where the i-th vertex has A_i written on it, answer Q queries of the following form:

• Given u and v, find the longest arithmetic progression on the simple path from u to v.
Note that this arithmetic progression should itself be a path.

EXPLANATION:

This is pretty much a pure data structure problem. There are a couple of different ways to solve it, I’ll explain one below.

A simpler version

First, let’s solve a simpler version of the queries: instead of arbitrary u and v, let’s assume u is an ancestor of v (when the tree is rooted at 1).

To solve this, let’s transform the problem a bit:

Root the tree at 1 and let p_u denote the parent of u.
Write the value A_u - A_{p_u} on the edge between u and p_u.
Then, notice that “longest arithmetic progression on path” is now asking for “longest contiguous set of equal values on path”, which is a bit easier to deal with.

Several different data structures can solve this query: for example, binary lifting or heavy-light decomposition.

The devil is in the details here: you’ll need to maintain several different quantities and merge them correctly when doing binary lifting/segtree merging.
For instance, the editorialist’s code linked below uses HLD, and maintains the following quantities for each segment tree node that represents some path:

• The answer for the path, i.e, the longest set of equal values
• The length of the path
• The first edge value seen on this path, and how many of them form the prefix
• The last edge value seen on this path, and how many of them form the suffix
• Merging two adjacent nodes requires merging these values appropriately, which takes a bit of casework (or represent the whole state as a matrix and use matrix multiplication, which is probably a bit less typing)

Binary lifting will require you to maintain several similar quantities in your lifting table.

Either way, this allows us to solve for a single query in \mathcal{O}(\log^2 N) time, given that u is an ancestor of v.

The original problem

Let’s now deal with arbitrary (u, v) queries.

Let L = lca(u, v). Finding L can be done in \mathcal{O}(\log N) or \mathcal{O}(1) in various ways.

First, let’s apply our above solution independently to (L, u) and (L, v).
Now, note that the only thing we’re missing is paths that pass through L and don’t have it as one endpoint.

However, such paths can be considered by just merging the answers for (L, u) and (L, v) appropriately: for example, keep the (L, u) answer as it is, and negate the differences of the (L, v) answer; then merge them.

Our merge function is \mathcal{O}(1), so each query is now answered in \mathcal{O}(\log^2 N).

Once again, this problem’s difficulty mainly lies in correctly working out the details: for example, if you’re using HLD, you need to be careful about the order in which merges are done, since the merge operation is not associative. I recommend looking through the code below if you’re stuck and can’t debug.

TIME COMPLEXITY:

\mathcal{O}((N+Q)\log^2 N) per testcase.

CODE:

Setter's code (C++, binary lifting)
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
#define pb(e) push_back(e)
#define sv(a) sort(a.begin(),a.end())
#define sa(a,n) sort(a,a+n)
#define mp(a,b) make_pair(a,b)
#define vf first
#define vs second
#define ar array
#define all(x) x.begin(),x.end()
const int inf = 0x3f3f3f3f;
const int mod = 998244353;
const double PI=3.14159265358979323846264338327950288419716939937510582097494459230;
bool remender(ll a , ll b){return a%b;}

//freopen("problemname.in", "r", stdin);
//freopen("problemname.out", "w", stdout);

struct item {
int down , vald ,  up , valu , vald1 , valu1 , best , full;
};

int isap(int a , int b , int c , int d){
if(d == -1 && a == -1)return 1;
if(a == -1){
if(c - b == d - c)return 1;
return 0;
}
if(d == -1){
if(b - a == c - b)return 1;
return 0;
}
if(b - a == c - b && c - b == d - c)return 1;
int cnt = 0;
if(c - b == b - a)cnt = 2;
if(c - b == d - c){
if(cnt == 1)cnt = 4;
else cnt = 3;
}
return cnt;
}

const int N = 200003 , L = 22;

int arr[N];
int timer , tin[N] , tout[N];
item up[N][L];
int p[N][L];

item merge(item a , item b , int pr = 0){
item ans;
ans.best = max(a.best , b.best);
ans.valu = a.valu;
ans.vald = b.vald;
ans.valu1 = a.valu1;
if(a.valu1 == -1){
ans.valu1 = b.valu;
}
ans.vald1 = b.vald1;
ans.full = 0;
if(b.vald1 == -1){
ans.vald1 = a.vald;
}
ans.up = a.up;
ans.down = b.down;
int x = isap(b.valu1 , b.valu , a.vald , a.vald1);
if(x > 0){
if(x == 1){
ans.best = max(ans.best , a.down + b.up);
if(a.full && b.full){
ans.full = 1;
}
if(b.full){
ans.down = a.down + b.up;
}
if(a.full){
ans.up = a.down + b.up;
}
}
else if(x == 2){
ans.best = max(ans.best , b.up + 1);
if(b.full){
ans.down++;
}
}
else if(x == 3){
ans.best = max(ans.best , a.down + 1);
if(a.full)ans.up++;
}
else {
ans.best = max({ans.best , a.down + 1 , b.up + 1});
if(b.full){
ans.down++;
}
if(a.full)ans.up++;
}
}
if(ans.full){
ans.up = ans.down = ans.best;
}
ans.down = max(ans.down , 2);
ans.up = max(ans.up , 2);
return ans;
}

void dfs(int node , int par , int dis){
tin[node] = timer++;
up[node][0] = {1 , arr[node] , 1 , arr[node] , -1 , -1 , 1 , 1};
p[node][0] = par;
for(int i = 1; i < L; i++){
if(dis < (1 << i)){
up[node][i] = up[node][i-1];
p[node][i] = p[p[node][i-1]][i-1];
continue;
}
up[node][i] = merge(up[p[node][i-1]][i-1] , up[node][i-1]);
p[node][i] = p[p[node][i-1]][i-1];
}
if(i != par){
dfs(i , node , dis + 1);
}
}
tout[node] = timer++;
}

bool islca(int x , int y){
return tin[x] <= tin[y] && tout[x] >= tout[y];
}

int find(int u , int v){
if(islca(u , v))return u;
else if(islca(v , u))return v;
for(int i = L - 1; i >= 0; i--){
if(!islca(p[u][i],v))u = p[u][i];
}
return p[u][0];
}

item corner(int lca , int x , int todo = 0){
item cur = {1 , arr[x] , 1 , arr[x] , -1 , -1 , 1 ,1};
x = p[x][0];
for(int i = L - 1; i >= 0; i--){
if(!islca(p[x][i] , lca)){
cur = merge(up[x][i] , cur);
x = p[x][i];
}
}
if(x != lca){
cur = merge(up[x][0], cur);
x = p[x][0];
}
if(todo == 0)cur = merge(up[x][0], cur);
return cur;
}

void solve(){
int n;
cin >> n;
for(int i = 1; i <= n; i++)cin >> arr[i];
for(int i = 0; i < n-1; i++){
int u , v;
cin >> u >> v;
}
dfs(1 , 1 , 1);
int q;
cin >> q;
while(q--){
int u , v;
cin >> u >> v;
if(u == v){
cout << 1 << '\n';
continue;
}
int lca = find(u , v);
if(lca == u){
cout << corner(lca , v).best << '\n';
}
else if(lca == v){
cout << corner(lca , u).best << '\n';
}
else {
item x = corner(lca ,u);
item y = corner(lca,  v,  1);
swap(x.valu , x.vald);
swap(x.valu1,x.vald1);
swap(x.up , x.down);
cout << merge(x,y).best << '\n';
}
}
}

int main(){
ios_base::sync_with_stdio(false);
cin.tie(NULL);
//int t;cin >> t;while(t--)
solve();
return 0;
}

Editorialist's code (C++, HLD)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

template<class T>
struct RMQ {
vector<vector<T>> jmp;
RMQ(const vector<T>& V) : jmp(1, V) {
for (int pw = 1, k = 1; pw * 2 <= (int)size(V); pw *= 2, ++k) {
jmp.emplace_back(size(V) - pw * 2 + 1);
for (int j = 0; j < (int)size(jmp[k]); ++j)
jmp[k][j] = min(jmp[k - 1][j], jmp[k - 1][j + pw]);
}
}
T query(int a, int b) {
assert(a < b); // or return inf if a == b
int dep = 31 - __builtin_clz(b - a);
return min(jmp[dep][a], jmp[dep][b - (1 << dep)]);
}
};

struct LCA {
int T = 0;
vector<int> time, path, ret;
RMQ<int> rmq;

LCA(vector<vector<int>>& C) : time(size(C)), rmq((dfs(C,0,-1), ret)) {}
void dfs(vector<vector<int>>& C, int v, int par) {
time[v] = T++;
for (int y : C[v]) if (y != par) {
path.push_back(v), ret.push_back(time[v]);
dfs(C, y, v);
}
}

int lca(int a, int b) {
if (a == b) return a;
tie(a, b) = minmax(time[a], time[b]);
return path[rmq.query(a, b)];
}
};

struct Node {
int ans, preflen, suflen, len;
int prefval, sufval;
Node() : ans(0), preflen(0), suflen(0), len(0), prefval(INT_MAX), sufval(INT_MAX) {}
Node(int x) {
preflen = suflen = len = ans = 1;
prefval = sufval = x;
}
};
Node unit;

/**
* Point-update Segment Tree
* Source: kactl
* Description: Iterative point-update segment tree, ranges are half-open i.e [L, R).
*              f is any associative function.
* Time: O(logn) update/query
*/

struct SegTree {
Node f(Node a, Node b) {
if (b.prefval == INT_MAX) return a;
if (a.prefval == INT_MAX) return b;
Node c;
c.len = a.len + b.len;
c.prefval = a.prefval; c.preflen = a.preflen;
c.sufval = b.sufval; c.suflen = b.suflen;
if (a.sufval == b.prefval) {
if (a.suflen == a.len) c.preflen += b.preflen;
if (b.preflen == b.len) c.suflen += a.suflen;
c.ans = a.suflen + b.preflen;
}
c.ans = max({c.ans, a.ans, b.ans, c.preflen, c.suflen});
return c;
}
vector<Node> s; int n;
SegTree(int _n = 0) : s(2*_n), n(_n) {}
void update(int pos, int val) {
for (s[pos += n] = Node(val); pos /= 2;)
s[pos] = f(s[pos * 2], s[pos * 2 + 1]);
}
Node query(int b, int e) {
Node ra = unit, rb = unit;
for (b += n, e += n; b < e; b /= 2, e /= 2) {
if (b % 2) ra = f(ra, s[b++]);
if (e % 2) rb = f(s[--e], rb);
}
return f(ra, rb);
}
};

template <bool VALS_EDGES> struct HLD {
int N, tim = 0;
vector<int> par, siz, depth, rt, pos;
SegTree seg;
rt(N),pos(N){ dfsSz(0); dfsHld(0); seg = SegTree(N);}
void dfsSz(int v) {
for (int& u : adj[v]) {
par[u] = v, depth[u] = depth[v] + 1;
dfsSz(u);
siz[v] += siz[u];
}
}
void dfsHld(int v) {
pos[v] = tim++;
for (int u : adj[v]) {
rt[u] = (u == adj[v][0] ? rt[v] : u);
dfsHld(u);
}
}
template <class B> void process(int u, int v, B op) {
for (; rt[u] != rt[v]; v = par[rt[v]]) {
if (depth[rt[u]] > depth[rt[v]]) swap(u, v);
op(pos[rt[v]], pos[v] + 1);
}
if (depth[u] > depth[v]) swap(u, v);
op(pos[u] + VALS_EDGES, pos[v] + 1);
}
void modifyPath(int u, int v, int val) {
process(u, v, [&](int l, int r) {
if (l < r) seg.update(l, val);
});
}
Node queryPath(int u, int v) { // Modify depending on problem
Node res = unit;
process(u, v, [&](int l, int r) {
auto cur = seg.query(l, r);
swap(cur.prefval, cur.sufval);
swap(cur.preflen, cur.suflen);
res = seg.f(res, cur);
});
return res;
}
};

int main()
{
ios::sync_with_stdio(false); cin.tie(0);

int n; cin >> n;
vector<int> a(n);
for (auto &x : a) cin >> x;
vector<vector<int>> g(n);
for (int i = 0; i < n-1; ++i) {
int u, v; cin >> u >> v;
g[--u].push_back(--v);
g[v].push_back(u);
}
HLD<true> hld(g);
LCA lca(g);
for (int u = 1; u < n; ++u) {
int p = hld.par[u];
hld.modifyPath(u, p, a[u] - a[p]);
}
int q; cin >> q;
while (q--) {
int u, v; cin >> u >> v; --u, --v;
int l = lca.lca(u, v);
auto left = hld.queryPath(u, l);
auto right = hld.queryPath(v, l);
auto actual = left;
if (left.prefval == INT_MAX) actual = right;
else if (right.prefval != INT_MAX) {
right.prefval *= -1;
right.sufval *= -1;
swap(right.prefval, right.sufval);
swap(right.suflen, right.preflen);
actual = SegTree().f(left, right);
}
cout << actual.ans+1 << '\n';
}
}