PROBLEM LINKS:
Practice
Contest Link
Author and Editorialist: Abhishek Garg
DIFFICULTY:
Medium-Hard
PREREQUISITES:
Lowest Common ancestor, Heavy-Light decomposition, Megre-sort tree
PROBLEM:
You are given a tree with distinct values attached to every node of that tree. You need to answer queries of the type U V L R i.e. for the given 2 nodes U and V, you need to tell the number of nodes on the simple path between U and V such that the value attached to those nodes lie between L to R inclusive. The problem needs to be solved online.
QUICK EXPLANATION:
Using Heavy-LIght decomposition, we can split the tree into several paths so that we can reach the root vertex from any node v by traversing at most log n paths and answer for these log n paths can be calculated by pre-computing a merge-sort tree on these paths. Before moving to the explanation, it is adviced that you should be well known to the pre-requisites.
EXPLANATION:
First of all, You can break the query Q : U V L R into 2 parts i.e.
q1 : U V R :- Tell the number of nodes on the simple path from U to V such that their value is less than or equal to R.
q2 : U V L-1 :- Tell the number of nodes on the simple path from U to V such that their value is less than or equal to L-1.
Then the answer to the original query will be Ans_{Q} = Ans_{q1} - Ans_{q2}.
So, As of now, we are left with the query of the type q : U V X i.e. to tell the number of nodes on the path from U to V having value less than or equal to X.
The simple path from U to V can be broken down into 2 different paths i.e. path_{(U, V)} = path_{(U, LCA(U, V))} + path_{(V, LCA(U,V))}. The benefit of using these two paths is that while traversing from the 1st node to the 2nd, we only go in the direction of the root. Let us denote Set(A, B) to be the set of all the nodes on the path from A to node B. Then,
Set(U, LCA(U, V)) = Set(root, U) - Set(root, parent_{LCA(U,V)}) and
Set(V, LCA(U, V)) = Set(root, V) - Set(root, parent_{LCA(U,V)})
where root is the arbitrarily choosen root of the tree and parent_{X} is the 1st ancestor of X in the direction of root.
From the above explanation, the query q : U V X can be broken down into 3 different queries :
que1 : U X :- Tell the number of nodes on the path from root to U having value less than or equal to X.
que2 : V X :- Tell the number of nodes on the path from root to V having value less than or equal to X.
que3 : parent_{LCA(U,V)} X :- Tell the number of nodes on the path from root to parent_{LCA(U,V)} having value less than or equal to X and answer to the query q can be calculated as
Ans_{q} = Ans_{que1} + Ans_{que2}- 2*Ans_{que3}
Now, we are left with a simpler query que : Node Value i.e. tell the number of nodes on the path from root to Node having value less than or equal to Value.
To answer this query, we can split the tree into several paths such that we can reach any node v from the root traversing at most log n paths. In addition, these paths need to be distinct. We can use Heavy-Light decomposition to do so. Refer to the link mentioned above to know more about Heavy-Light decomposition. Now, each of the distinct path can be thought of as an array and we need to tell the number of elements in that array having value less than or equal to Value. This can be done using another efficient data-structure i.e. Merge-sort tree. Again refer to the link mentioned above if you are not clear about a Merge-sort tree. This data structure answers such query in log^2n time. So, in total you need to traverse log n distinct paths from root to Node and log^2n time is needed to tell the number of nodes on every path having value less than or equal to Value.
Thus, the total time to answer 1 query in the worst case is log^3n.
The tree can be decomposed in O(n)
Merge-sort tree can be built in O(n*log n)
Hence, The total time complexity of the algorithm is O(n*log n + q*log^3n)
SOLUTION:
Author Solution
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define fill(arr,x) memset(arr,x,sizeof(arr))
#define all(x) x.begin(),x.end()
#define sz(x) (int)x.size()
#define lb lower_bound
#define ub upper_bound
#define mex 100005
#define lgn 18
vector<int> adj[mex];
int in[mex],out[mex],par[mex],sub[mex],heavy[mex],tym;
int dp[mex][lgn],pos[mex],head[mex],curr;
vector<int> tree[5*mex];
inline int dfs(int u,int p)
{
par[u]=dp[u][0]=p;
in[u]=++tym;
sub[u]++;
int max_val=-1;
for(int i=1;i<lgn;i++) dp[u][i]=dp[dp[u][i-1]][i-1];
for(int i=0;i<sz(adj[u]);i++){
int v=adj[u][i];
if(v==p) continue;
int x=dfs(v,u);
if(x>max_val){
max_val=x;
heavy[u]=v;
}
sub[u]+=x;
}
out[u]=tym;
return sub[u];
}
inline bool check(int u,int v)
{
return (in[u]<=in[v] && out[u]>=out[v]);
}
inline int lca(int u,int v)
{
if(check(u,v)) return u;
if(check(v,u)) return v;
for(int i=lgn-1;i>=0;i--){
if(!check(dp[u][i],v)) u=dp[u][i];
}
return dp[u][0];
}
inline void decompose(int u,int h)
{
pos[u]=curr++;
head[u]=h;
if(heavy[u]!=-1) decompose(heavy[u],h);
for(int i=0;i<sz(adj[u]);i++){
int v=adj[u][i];
if(v==par[u] || v==heavy[u]) continue;
decompose(v,v);
}
}
inline vector<int> build(int u,int s,int e,int aux[])
{
if(s==e){
tree[u].pb(aux[s]);
return tree[u];
}
int m=(s+e)/2;
vector<int> x=build(2*u+1,s,m,aux);
vector<int> y=build(2*u+2,m+1,e,aux);
vector<int> v(sz(x)+sz(y));
merge(all(x),all(y),v.begin());
return tree[u]=v;
}
inline int que(int u,int s,int e,int l,int r,int x)
{
if(s>r || e<l) return 0;
if(s>=l && e<=r) return distance(tree[u].begin(),ub(all(tree[u]),x));
int m=(s+e)/2;
return que(2*u+1,s,m,l,r,x)+que(2*u+2,m+1,e,l,r,x);
}
inline int query(int u,int v,int l,int r,int n)
{
int x=0,y=0;
while(head[u]!=head[v]){
x+=que(0,0,n-1,pos[head[v]],pos[v],l-1);
y+=que(0,0,n-1,pos[head[v]],pos[v],r);
v=par[head[v]];
}
x+=que(0,0,n-1,pos[u],pos[v],l-1);
y+=que(0,0,n-1,pos[u],pos[v],r);
return y-x;
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL); cout.tie(NULL);
int n,q; cin>>n>>q;
int arr[n];
for(int i=0;i<n;i++) cin>>arr[i];
for(int i=0;i<n-1;i++){
int u,v; cin>>u>>v;
u--; v--;
adj[u].pb(v);
adj[v].pb(u);
}
fill(heavy,-1);
dfs(0,0);
decompose(0,0);
int aux[n];
for(int i=0;i<n;i++) aux[pos[i]]=arr[i];
build(0,0,n-1,aux);
int ans=0;
while(q--){
int u,v,l,r;
cin>>u>>v>>l>>r;
u=1+((u^ans)%n);
v=1+((v^ans)%n);
u--; v--;
int x=lca(u,v);
int a=query(x,u,l,r,n),b=query(x,v,l,r,n);
if(arr[x]>=l && arr[x]<=r) ans=a+b-1;
else ans=a+b;
cout<<ans<<endl;
}
}
Bonus : Try solving this question offline (assuming that you are asked U V directly and you don’t need to xor the values with the previous answer) without using HLD and Merge-sort tree.
If you have any doubts regarding the problem, feel free to ask them in the comments