PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Setter: Vivek Chauhan
Tester: Radoslav Dimitrov
Editorialist: Taranpreet Singh
DIFFICULTY
Easy-Medium
PREREQUISITES
PROBLEM
Given a tree with N nodes where each node has an associated value, given by array A. Answer Q queries of the following form.
For given nodes u and v, consider the set of nodes S of the nodes lying on the simple path from u to v. Find min(|A_x-A_y|) overall pairs x, y \in S such that x \neq y.
Note that 1 \leq A_i \leq maxA for 1 \leq i \leq N and maxA = 100
QUICK EXPLANATION
- If the path from u to v contains more than maxA nodes, at least one value of A_i must occur more than once. Hence, the answer is zero in that case.
- Otherwise, we can naively consider all nodes on the path from u to v and find the minimum value of min(|A_x-A_y|) by using a frequency table in O(maxA) per query.
EXPLANATION
Let’s solve this problem naively. For each query, we shall try to move on the path from u to v and add all values to set and then compute the answer. We can see that each query can take O(N) time in the worst case (Linear tree). This solution would be sufficient for subtask 1, but we need to be smarter than this for subtask 2.
Till now, we had ignored the constraint 1 \leq A_i \leq maxA where maxA = 100
What does this constraint tell us? The values in set can take at most maxA distinct values. So, if the path from u to v has more than maxA nodes, then we can guarantee that at least one value repeats (Pigeonhole principle), right? The answer becomes zero in that case.
Otherwise, we are guaranteed that the path length shall not exceed maxA for each query.
Hence, we can still move over the path from u to v and as soon as we get a repeated value, we shall return zero. Otherwise, we’d have at most maxA values. Using a frequency table, we can answer the query for values in O(maxA), which is fast enough for second subtask.
Implementation
- Root the tree at any node and compute the depths and parent of each node.
- For moving on the path from u to v, first, move the node with higher depth (using parent array) till both nodes have same depth. Then move both nodes up simultaneously till you reach the same node.
Do you have any other solution not depending upon 1 \leq A_i \leq 100? Comment below.
Bonus:
- Solve the same problem, now you need to find the value of max(A_x-A_y) over all pairs x, y \in S where S is the set of nodes from u to v.
- What if 1 \leq A_i \leq 10^6?
- Can you solve this in O(1) per query?
- How about updates on the values of A_i?
TIME COMPLEXITY
The time complexity is O(N+Q*maxA) per test case.
SOLUTIONS
Setter's Solution
#include <bits/stdc++.h>
using namespace std;
typedef int ll;
typedef long double ld;
const ll N = 200005;
char en = '\n';
ll inf = 1e16;
ll mod = 1e9 + 7;
ll power(ll x, ll n, ll mod) {
ll res = 1;
x %= mod;
while (n) {
if (n & 1)
res = (res * x) % mod;
x = (x * x) % mod;
n >>= 1;
}
return res;
}
ll n, q;
vector<ll> adj[N];
ll arr[N];
ll depth[N];
ll parent[N];
void dfs(ll curr, ll prev1 = -1, ll depth1 = 0) {
parent[curr] = prev1;
depth[curr] = depth1;
for (ll &x : adj[curr]) {
if (x != prev1) {
dfs(x, curr, depth1 + 1);
}
}
}
/*
ll oprn = 0;
long long sumDist = 0;
*/
ll solve(ll a, ll b) {
ll freq[105];
memset(freq, 0, sizeof(freq));
while (a != b) {
// oprn++;
if (depth[a] > depth[b]) {
freq[arr[a]]++;
if (freq[arr[a]] > 1)
return 0;
a = parent[a];
} else {
freq[arr[b]]++;
if (freq[arr[b]] > 1)
return 0;
b = parent[b];
}
}
freq[arr[a]]++;
if (freq[arr[a]] > 1)
return 0;
ll prev1 = -200;
ll res = 105;
for (ll i = 1; i <= 100; i++) {
if (freq[i]) {
res = min(res, i - prev1);
prev1 = i;
}
}
return res;
}
/*
// dont declare n in main,clear adj,as init doesnt do it,call init()
ll dp[N][20];
void dfs2(ll curr, ll prev1 = -1, ll depth1 = 0) {
dp[curr][0] = prev1;
for (ll x : adj[curr]) {
if (x != prev1) {
dfs2(x, curr, depth1 + 1);
}
}
}
ll findLca(ll a, ll b) {
if (a == b)
return a;
if (depth[a] < depth[b])
swap(a, b);
ll rem = depth[a] - depth[b];
for (ll i = 19; i >= 0; i--) {
if (rem & (1 << i))
a = dp[a][i];
}
if (a == b)
return a;
for (ll i = 19; i >= 0; i--) {
if (dp[a][i] != dp[b][i])
a = dp[a][i], b = dp[b][i];
}
return dp[a][0];
}
ll distance(ll a, ll b) {
ll lc = findLca(a, b);
ll res = depth[a] + depth[b] - 2 * depth[lc];
res++;
return res;
}
void init(ll n) {
memset(dp, -1, sizeof(dp));
dfs2(1);
for (ll i = 1; i < 20; i++) {
for (ll j = 1; j <= n; j++) {
if (dp[j][i - 1] != -1)
dp[j][i] = dp[dp[j][i - 1]][i - 1];
}
}
}
*/
int32_t main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
ll t;
cin >> t;
while (t--) {
cin >> n >> q;
for (ll i = 1; i <= n; i++) {
cin >> arr[i];
adj[i].clear();
}
for (ll i = 1; i < n; i++) {
ll x, y;
cin >> x >> y;
adj[x].push_back(y);
adj[y].push_back(x);
}
dfs(1);
/*
oprn = 0;
sumDist = 0;
*/
while (q--) {
ll a, b;
cin >> a >> b;
/*
sumDist += distance(a, b);
solve(a, b);
*/
cout << solve(a, b) << en;
}
}
return 0;
}
Tester's Solution
#include <bits/stdc++.h>
#define endl '\n'
#define SZ(x) ((int)x.size())
#define ALL(V) V.begin(), V.end()
#define L_B lower_bound
#define U_B upper_bound
#define pb push_back
using namespace std;
template<class T, class T1> int chkmin(T &x, const T1 &y) { return x > y ? x = y, 1 : 0; }
template<class T, class T1> int chkmax(T &x, const T1 &y) { return x < y ? x = y, 1 : 0; }
const int MAXN = (1 << 20);
const int MAXLOG = 20;
struct sparse_table_pair
{
pair<int, int> dp[MAXN << 1][MAXLOG];
int prec_lg2[MAXN << 1], n;
sparse_table_pair() { }
void init(vector<pair<int, int> > &a)
{
n = a.size();
for(int i = 2; i < 2 * n; i++) prec_lg2[i] = prec_lg2[i >> 1] + 1;
for(int i = 0; i < n; i++) dp[i][0] = a[i];
for(int j = 1; (1 << j) <= n; j++)
for(int i = 0; i < n; i++)
dp[i][j] = min(dp[i][j - 1], dp[i + (1 << (j - 1))][j - 1]);
}
pair<int, int> query(int l, int r)
{
int k = prec_lg2[r - l + 1];
return min(dp[l][k], dp[r - (1 << k) + 1][k]);
}
};
struct LCA
{
int dep[MAXN];
int pos[MAXN], par[MAXN];
sparse_table_pair rmq;
vector<pair<int, int> > order;
vector<int> adj[MAXN];
void add_edge(int u, int v)
{
adj[u].pb(v);
adj[v].pb(u);
}
void pre_dfs(int u, int pr = -1, int d = 0)
{
pos[u] = SZ(order);
dep[u] = d;
order.pb({d, u});
for(int v: adj[u])
if(v != pr)
{
par[v] = u;
pre_dfs(v, u, d + 1);
order.pb({d, u});
}
}
void clear(int n)
{
order.clear();
for(int i = 0; i <= n; i++)
adj[i].clear();
}
void init(int root)
{
order.clear();
pre_dfs(root);
rmq.init(order);
}
int lca(int u, int v)
{
if(pos[u] > pos[v]) swap(u, v);
return rmq.query(pos[u], pos[v]).second;
}
int dist(int u, int v) { return dep[u] + dep[v] - 2 * dep[lca(u, v)]; }
};
int n, q, a[MAXN];
LCA T;
void read() {
cin >> n >> q;
for(int i = 1; i <= n; i++) {
cin >> a[i];
}
T.clear(n);
for(int i = 0; i < n - 1; i++) {
int u, v;
cin >> u >> v;
T.add_edge(u, v);
}
}
int linear_solve(int u, int v) {
int anc = T.lca(u, v);
vector<int> vals;
while(u != anc) {
vals.pb(a[u]);
u = T.par[u];
}
while(v != anc) {
vals.pb(a[v]);
v = T.par[v];
}
vals.pb(a[anc]);
// We can use radix/counting sort here
sort(ALL(vals));
int mn = 101, last = -(int)1e9;
for(int x: vals) {
chkmin(mn, x - last);
last = x;
}
return mn;
}
void solve() {
T.init(1);
for(int i = 1; i <= q; i++) {
int u, v;
cin >> u >> v;
if(T.dist(u, v) >= 100) {
cout << 0 << endl;
} else {
cout << linear_solve(u, v) << endl;
}
}
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int T;
cin >> T;
while(T--) {
read();
solve();
}
return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class TREDIF{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
int N = ni(), Q = ni();
int[] A = new int[N];
for(int i = 0; i< N; i++)A[i] = ni();
int[] from = new int[N-1], to = new int[N-1];
for(int i = 0; i< N-1; i++){
from[i] = ni()-1;
to[i] = ni()-1;
}
int[][] g = make(N, from, to);
int[] D = new int[N], par = new int[N];
dfs(g, D, par, 0, -1);
while(Q-->0)pn(query(A, D, par, ni()-1, ni()-1));
}
int[] F = new int[101];
int query(int[] A, int[] D, int[] P, int u, int v){
if(D[v] < D[u]){int tmp = u;u = v; v = tmp;}
Arrays.fill(F, 0);
while(D[v] > D[u]){
if(F[A[v]] > 0)return 0;
F[A[v]]++;
v = P[v];
}
while(u != v){
if(F[A[u]] > 0)return 0;
F[A[u]]++;
if(F[A[v]] > 0)return 0;
F[A[v]]++;
u = P[u];
v = P[v];
}
if(F[A[v]] > 0)return 0;
F[A[v]]++;
int prev = -1, ans = 200;
for(int i = 1; i<= 100; i++){
if(F[i] > 0){
if(prev != -1)ans = Math.min(ans, i-prev);
prev = i;
}
}
return ans;
}
void dfs(int[][] g, int[] d, int[] par, int u, int p){
par[u] = p;
for(int v:g[u]){
if(v != p){
d[v] = d[u]+1;
dfs(g, d, par, v, u);
}
}
}
int[][] make(int N, int[] from, int[] to){
int[] cnt = new int[N];int[][] g = new int[N][];
for(int i:from)cnt[i]++;for(int i:to)cnt[i]++;
for(int i = 0; i< N; i++)g[i] = new int[cnt[i]];
for(int i = 0; i< N-1; i++){
g[from[i]][--cnt[from[i]]] = to[i];
g[to[i]][--cnt[to[i]]] = from[i];
}
return g;
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
DecimalFormat df = new DecimalFormat("0.00000000000");
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new TREDIF().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Feel free to share your approach. Suggestions are welcomed as always.