# TREDIFF - Editorial

Setter: Vivek Chauhan
Editorialist: Taranpreet Singh

Easy-Medium

# PREREQUISITES

Pigeonhole principle

# PROBLEM

Given a tree with N nodes where each node has an associated value, given by array A. Answer Q queries of the following form.
For given nodes u and v, consider the set of nodes S of the nodes lying on the simple path from u to v. Find min(|A_x-A_y|) overall pairs x, y \in S such that x \neq y.

Note that 1 \leq A_i \leq maxA for 1 \leq i \leq N and maxA = 100

# QUICK EXPLANATION

• If the path from u to v contains more than maxA nodes, at least one value of A_i must occur more than once. Hence, the answer is zero in that case.
• Otherwise, we can naively consider all nodes on the path from u to v and find the minimum value of min(|A_x-A_y|) by using a frequency table in O(maxA) per query.

# EXPLANATION

Let’s solve this problem naively. For each query, we shall try to move on the path from u to v and add all values to set and then compute the answer. We can see that each query can take O(N) time in the worst case (Linear tree). This solution would be sufficient for subtask 1, but we need to be smarter than this for subtask 2.

Till now, we had ignored the constraint 1 \leq A_i \leq maxA where maxA = 100

What does this constraint tell us? The values in set can take at most maxA distinct values. So, if the path from u to v has more than maxA nodes, then we can guarantee that at least one value repeats (Pigeonhole principle), right? The answer becomes zero in that case.

Otherwise, we are guaranteed that the path length shall not exceed maxA for each query.

Hence, we can still move over the path from u to v and as soon as we get a repeated value, we shall return zero. Otherwise, we’d have at most maxA values. Using a frequency table, we can answer the query for values in O(maxA), which is fast enough for second subtask.

Implementation

• Root the tree at any node and compute the depths and parent of each node.
• For moving on the path from u to v, first, move the node with higher depth (using parent array) till both nodes have same depth. Then move both nodes up simultaneously till you reach the same node.

Do you have any other solution not depending upon 1 \leq A_i \leq 100? Comment below.

Bonus:

• Solve the same problem, now you need to find the value of max(A_x-A_y) over all pairs x, y \in S where S is the set of nodes from u to v.
• What if 1 \leq A_i \leq 10^6?
• Can you solve this in O(1) per query?

# TIME COMPLEXITY

The time complexity is O(N+Q*maxA) per test case.

# SOLUTIONS

Setter's Solution
#include <bits/stdc++.h>
using namespace std;
typedef int ll;
typedef long double ld;
const ll N = 200005;
char en = '\n';
ll inf = 1e16;
ll mod = 1e9 + 7;
ll power(ll x, ll n, ll mod) {
ll res = 1;
x %= mod;
while (n) {
if (n & 1)
res = (res * x) % mod;
x = (x * x) % mod;
n >>= 1;
}
return res;
}
ll n, q;
ll arr[N];
ll depth[N];
ll parent[N];

void dfs(ll curr, ll prev1 = -1, ll depth1 = 0) {
parent[curr] = prev1;
depth[curr] = depth1;

for (ll &x : adj[curr]) {
if (x != prev1) {
dfs(x, curr, depth1 + 1);
}
}
}
/*
ll oprn = 0;
long long sumDist = 0;
*/
ll solve(ll a, ll b) {
ll freq[105];
memset(freq, 0, sizeof(freq));
while (a != b) {
// oprn++;
if (depth[a] > depth[b]) {
freq[arr[a]]++;
if (freq[arr[a]] > 1)
return 0;
a = parent[a];
} else {
freq[arr[b]]++;
if (freq[arr[b]] > 1)
return 0;
b = parent[b];
}
}
freq[arr[a]]++;
if (freq[arr[a]] > 1)
return 0;

ll prev1 = -200;
ll res = 105;
for (ll i = 1; i <= 100; i++) {
if (freq[i]) {
res = min(res, i - prev1);
prev1 = i;
}
}
return res;
}

/*
// dont declare n in main,clear adj,as init doesnt do it,call init()
ll dp[N][20];
void dfs2(ll curr, ll prev1 = -1, ll depth1 = 0) {
dp[curr][0] = prev1;

for (ll x : adj[curr]) {
if (x != prev1) {
dfs2(x, curr, depth1 + 1);
}
}
}
ll findLca(ll a, ll b) {
if (a == b)
return a;
if (depth[a] < depth[b])
swap(a, b);
ll rem = depth[a] - depth[b];
for (ll i = 19; i >= 0; i--) {
if (rem & (1 << i))
a = dp[a][i];
}

if (a == b)
return a;

for (ll i = 19; i >= 0; i--) {
if (dp[a][i] != dp[b][i])
a = dp[a][i], b = dp[b][i];
}
return dp[a][0];
}

ll distance(ll a, ll b) {
ll lc = findLca(a, b);
ll res = depth[a] + depth[b] - 2 * depth[lc];
res++;
return res;
}

void init(ll n) {
memset(dp, -1, sizeof(dp));
dfs2(1);
for (ll i = 1; i < 20; i++) {
for (ll j = 1; j <= n; j++) {
if (dp[j][i - 1] != -1)
dp[j][i] = dp[dp[j][i - 1]][i - 1];
}
}
}
*/
int32_t main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);

ll t;
cin >> t;
while (t--) {

cin >> n >> q;
for (ll i = 1; i <= n; i++) {
cin >> arr[i];
}

for (ll i = 1; i < n; i++) {
ll x, y;
cin >> x >> y;
}
dfs(1);
/*
oprn = 0;
sumDist = 0;
*/
while (q--) {
ll a, b;
cin >> a >> b;
/*
sumDist += distance(a, b);
solve(a, b);
*/
cout << solve(a, b) << en;
}
}

return 0;
}

Tester's Solution
#include <bits/stdc++.h>
#define endl '\n'

#define SZ(x) ((int)x.size())
#define ALL(V) V.begin(), V.end()
#define L_B lower_bound
#define U_B upper_bound
#define pb push_back

using namespace std;
template<class T, class T1> int chkmin(T &x, const T1 &y) { return x > y ? x = y, 1 : 0; }
template<class T, class T1> int chkmax(T &x, const T1 &y) { return x < y ? x = y, 1 : 0; }
const int MAXN = (1 << 20);
const int MAXLOG = 20;

struct sparse_table_pair
{
pair<int, int> dp[MAXN << 1][MAXLOG];
int prec_lg2[MAXN << 1], n;

sparse_table_pair() { }

void init(vector<pair<int, int> > &a)
{
n = a.size();
for(int i = 2; i < 2 * n; i++) prec_lg2[i] = prec_lg2[i >> 1] + 1;
for(int i = 0; i < n; i++) dp[i][0] = a[i];
for(int j = 1; (1 << j) <= n; j++)
for(int i = 0; i < n; i++)
dp[i][j] = min(dp[i][j - 1], dp[i + (1 << (j - 1))][j - 1]);
}

pair<int, int> query(int l, int r)
{
int k = prec_lg2[r - l + 1];
return min(dp[l][k], dp[r - (1 << k) + 1][k]);
}
};

struct LCA
{
int dep[MAXN];
int pos[MAXN], par[MAXN];

sparse_table_pair rmq;
vector<pair<int, int> > order;

{
}

void pre_dfs(int u, int pr = -1, int d = 0)
{
pos[u] = SZ(order);
dep[u] = d;
order.pb({d, u});

if(v != pr)
{
par[v] = u;
pre_dfs(v, u, d + 1);
order.pb({d, u});
}
}

void clear(int n)
{
order.clear();
for(int i = 0; i <= n; i++)
}

void init(int root)
{
order.clear();
pre_dfs(root);
rmq.init(order);
}

int lca(int u, int v)
{
if(pos[u] > pos[v]) swap(u, v);
return rmq.query(pos[u], pos[v]).second;
}

int dist(int u, int v) { return dep[u] + dep[v] - 2 * dep[lca(u, v)]; }
};

int n, q, a[MAXN];
LCA T;

cin >> n >> q;
for(int i = 1; i <= n; i++) {
cin >> a[i];
}

T.clear(n);
for(int i = 0; i < n - 1; i++) {
int u, v;
cin >> u >> v;
}
}

int linear_solve(int u, int v) {
int anc = T.lca(u, v);

vector<int> vals;
while(u != anc) {
vals.pb(a[u]);
u = T.par[u];
}

while(v != anc) {
vals.pb(a[v]);
v = T.par[v];
}

vals.pb(a[anc]);

// We can use radix/counting sort here
sort(ALL(vals));

int mn = 101, last = -(int)1e9;
for(int x: vals) {
chkmin(mn, x - last);
last = x;
}

return mn;
}

void solve() {
T.init(1);

for(int i = 1; i <= q; i++) {
int u, v;
cin >> u >> v;
if(T.dist(u, v) >= 100) {
cout << 0 << endl;
} else {
cout << linear_solve(u, v) << endl;
}
}
}

int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);

int T;
cin >> T;
while(T--) {
solve();
}

return 0;
}

Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class TREDIF{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
int N = ni(), Q = ni();
int[] A = new int[N];
for(int i = 0; i< N; i++)A[i] = ni();
int[] from = new int[N-1], to = new int[N-1];
for(int i = 0; i< N-1; i++){
from[i] = ni()-1;
to[i] = ni()-1;
}
int[][] g = make(N, from, to);
int[] D = new int[N], par = new int[N];
dfs(g, D, par, 0, -1);
while(Q-->0)pn(query(A, D, par, ni()-1, ni()-1));
}
int[] F = new int[101];
int query(int[] A, int[] D, int[] P, int u, int v){
if(D[v] < D[u]){int tmp = u;u = v; v = tmp;}
Arrays.fill(F, 0);
while(D[v] > D[u]){
if(F[A[v]] > 0)return 0;
F[A[v]]++;
v = P[v];
}
while(u != v){
if(F[A[u]] > 0)return 0;
F[A[u]]++;
if(F[A[v]] > 0)return 0;
F[A[v]]++;
u = P[u];
v = P[v];
}
if(F[A[v]] > 0)return 0;
F[A[v]]++;
int prev = -1, ans = 200;
for(int i = 1; i<= 100; i++){
if(F[i] > 0){
if(prev != -1)ans = Math.min(ans, i-prev);
prev = i;
}
}
return ans;
}
void dfs(int[][] g, int[] d, int[] par, int u, int p){
par[u] = p;
for(int v:g[u]){
if(v != p){
d[v] = d[u]+1;
dfs(g, d, par, v, u);
}
}
}
int[][] make(int N, int[] from, int[] to){
int[] cnt = new int[N];int[][] g = new int[N][];
for(int i:from)cnt[i]++;for(int i:to)cnt[i]++;
for(int i = 0; i< N; i++)g[i] = new int[cnt[i]];
for(int i = 0; i< N-1; i++){
g[from[i]][--cnt[from[i]]] = to[i];
g[to[i]][--cnt[to[i]]] = from[i];
}
return g;
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
DecimalFormat df = new DecimalFormat("0.00000000000");
static boolean multipleTC = true;
void run() throws Exception{
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new TREDIF().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}

StringTokenizer st;
}

}

String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
}catch (IOException  e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}

String nextLine() throws Exception{
String str = "";
try{
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}


Feel free to share your approach. Suggestions are welcomed as always.

38 Likes

Problem TREE DIFFERENCE was same as this problem https://www.codechef.com/problems/CLOSEFAR . Could you please look into this as its solution is just a google search away. Also, your bonus question is also included in the previous question. Question should not be repeated in any sense. @taran_1407 @admin

6 Likes

@codingisfood The question in the link you described contains solution with Q*sqrt(N)*log(N) which no way can pass in this constraints and secondly the complete idea behind main solution is different… In this question, the solution is completely oriented upon constraints of A.

6 Likes

I would really love to see if someone can discuss the approach for bonus part. It looks like a challenging task.

5 Likes

The two problems appear similar, but the intended solutions are really different. I don’t think the intended solution for CLOSEFAR can even fit the time limit.

4 Likes

26 Likes

After seeing the solution I feel like I am so dumb.I had seen similar type of problem in codeforces yet I wasn’t able to solve it.Now I will remember it. Btw @taran_1407 thanks for the editorial. Your editorials are always so good.

4 Likes

My solution involved using LCA to find the path length and if it was less than maxA then traverse from a to b and find the required answer.

1 Like

Another similar problem is well framed by NIT Patna Codecube Team in their contest

1 Like

Can anyone tell me why my solution is giving a WA for a simple BFS per query? Couldn’t even score 30 points.

https://www.codechef.com/viewsolution/33507872

MySolution for this problem.
The complexity is supposed to be {O(100)} per query. Why am I getting TLE? Can anyone help me out? Was this problem set so that this solution shouldn’t pass or should I optimize the code?

2 Likes

See, its nothing to get offended about. I was just presenting my thoughts. You could have been more polite with this

4 Likes

Great Editorial! If anyone knows some similar problems from CC, CF,SPOJ, please attach them in this thread. It would be helpful for everyone. Thanks!

1 Like

I wrote this code with the complexity O(n^2logn) and should work completely fine for 30 points.
But it gives WA. Can someone please tell where I went wrong, I am so upset.
I actually cried during the contest as I was unable to find the mistake.
Also my code isn’t ugly so please do take a look at it.
Thanks!!
https://www.codechef.com/viewsolution/33511662

same story with me! I’ve just posted my link above yours.

Alternate solution:

Observe, dist(a,b) = dist(a - dist(lca(a,b)) + dist(b - dist(lca(a,b))

If this quantity is greater than or equal to 100, the answer is 0.

Otherwise, find the simple path from a to b and calculate the answer

dist(a,b) can be calculated in O(logN) with LCA using Binary Lifting Technique

6 Likes

Yeah I was. Until you started posting everywhere and then used your account @spam_123 to upvote those comments.

6 Likes