TREEQUER - Editorial

PROBLEM LINK:

Practice
Div-1 Contest

Author: Jatin Yadav
Tester: gamegame
Editorialist: Jatin Yadav

DIFFICULTY:

3575

PREREQUISITES:

Trees, DFS

PROBLEM:

You are given a tree with N vertices, and an array of start times in its dfs. You have to process Q queries, each having a set of disjoint intervals T = [L_1, R_1] \cup [L_2, R_2], \ldots [L_k, R_k], and you have to tell the number of connected components in the graph spanned by the set S of vertices whose start times lie in T.

EXPLANATION:

Since, in all subtasks the bound on N is the same as that on the sum of k, we’ll loosely use N to denote the problem size everywhere. Also, for simplicity, we will assume that the vertices are indexed according to start times.

Let’s denote by p_i the parent of i, and en_i the end time in dfs of node i (or equivalently the greatest node in the subtree of i).


Subtask 1

The required number of connected components equals the number of vertices in S - the number of edges with both endpoints in S. Counting the number of vertices is easy. The number of edges equals the number of vertices v \in S for which p_v \in S. This can be found in O(n) by just iterating over all the vertices.


Subtask 2

Consider one query [L_1, R_1], [L_2, R_2], \ldots [L_k, R_k]. We will do square root decomposition. Consider a value B (to be decided later). If k > B, just run the solution of subtask 1. This takes O \left (\dfrac{N^2}{B} \right). We use a different approach if k is < B.

We need to find the number of v with v \in S, p_v \in S. This is equivalent to iterating over 1 \leq i \leq j \leq k and finding then number of v with p_v \in [L_j, R_j], p_v \in [L_i, R_i] (note that p_v < v). This is equivalent to adding the number of v with v \leq R_j and p_v \in [L_i, R_i] and subtracting the number of v with v < L_j, p_v \in [L_i, R_i].

To do the above, we do offline processing (that is, we don’t process queries in order). For each query q, for each [L_j, R_j] in that query, store (j, q, 1) at index R_j and (j, q, -1) at index L_j - 1.

Let’s iterate on v from 1 to n, and for every v, add 1 to the index p_v in a fenwick tree. Suppose we are at index v = x, right now. Then consider all stored values (j, q, z) at index x. For each such value, iterate over all intervals [L_i, R_i] in the q- th query and add z \times the sum in the range [L_i, R_i] to the answer for query q.

This takes O(N B \log N) time. So, our overall complexity is O(N \sqrt{N \log N}) if we choose B = \sqrt{\frac{N}{\log N}}, which easily passes subtask 2. But we can improve this further.

Notice that we did O(N) point update operations and O(N B) range queries in the fenwick tree. But both update and query take O(\log N) time. So, can we perhaps do queries faster at the cost of doing updates slower for a better total complexity? There is a simple solution for that. Divide the range [1,N] in \sqrt{N} blocks. Maintain the prefix sum in each block (sum of first j values inside a block), and also the prefix sums of blocks (sum of the first i blocks for each i \leq \sqrt N). Clearly, we can do a point update at point i by just changing the prefix sums in i's block(say b) and also the prefix sums of blocks in O(\sqrt N). Also, it is really simple to get a prefix sum in O(1) time, as the prefix sum of blocks < b and some prefix sum of block b. This leads to a total complexity O(N \sqrt N + \frac{N^2}{B} + N B) which equals O(N \sqrt N) for B = \sqrt N


Subtasks 3 and 4

Let us first try to solve the k = 1 case. Here, we want to find the number of connected components for a single interval [L, R]. Let’s define J_v = en_v + 1 to be the smallest node > v that we move to, after visiting all nodes in the subtree of v. Also, let’s define J_{n+1} = n + 1. Note that J_v is either a sibling of v, or some node whose parent is an ancestor of v.

Let’s consider the sequence l_0 = L, l_1 = J_{l_0}, l_2 = J_{l_1}, \ldots and let c be the smallest index for which l_c > R. Then, we claim that the answer is c. For this, one can observe that for j < c - 1, [l_j, l_{j+1}) is the subtree rooted at l_j, and is by default a connected component. The last range [l_{c-1}, R] is also connected, as p_v \in [l_{c-1}, R] for any v \in (l_{c-1}, R]. Also, note that for any i, j there is no edge between [l_i, l_{i+1}) and [l_j, l_{j+1}).

So, we can just iterate over this sequence, but that is takes O(N) per query (for example in the case of a star graph, where J_i = i + 1 for all i \neq 1). Instead, we now store jump pointers, J_{v}^{(i)} being J(J( \ldots 2^i \text{times} (v)) \ldots ). Now, we can jump in powers of 2 to get to the first index > R, yielding an O(\log N) per query algorithm.

Now, let’s solve for k > 1. Consider two disjoint intervals [L_i, R_i] and [L_j, R_j] for some i < j. Also, let X be the sequence l_0 = L_j, l_1 = J_{l_0}, l_2 = J_{l_1} \ldots, as we did in the last subtask. Then, notice that for any v \in [L_j, R_j] for which p_v \in [L_i, R_i], v must be in X. Also, p_v is one of the ancestors of L_j (including L_j itself).

Let’s break the line [1, n] in 2k+1 alternating intervals (one interval belongs to S, then the next doesn’t and so on). For example if the query with n = 12 had intervals [2, 3], [6, 7], [9, 9], we break into [1, 1], [2, 3],[4, 5], [6, 7], [8, 8], [9, 9], [10, 12].

Note that the parents of the nodes in the sequence X; l_0, l_1, \ldots are non-increasing. So, the index of the interval they lie in, is also non-increasing. Say currently, we are at node v in the sequence whose parent lies in the range [L_i, R_i] (where 1 \leq i \leq 2k + 1), then we can find the smallest node in the sequence whose parent is in [L_r, R_r] for some r < i, in O(\log N) time using binary jumping. This way, we only iterate over those i such that there is some edge connecting [L_i, R_i] and [L_j, R_j]. Also, if [L_i, R_i] \in S, we can add some appropriate value (the number of jumps) to our answer.

Consider a graph G on 2k + 1 nodes, where there is an edge between (i, j), if there is a tree edge, connecting some vertex in [L_i, R_i] to some vertex in [L_j, R_j]. Let E be the set of edges of G. Then the time complexity is O(|E| \log N). We note the following property in order to bound the number of edges:

Non-crossing-property: For any p < q < r < s, either (p, r) \notin E, or (q, s) \notin E.

This is to say that there are no sortof cross edges. Let a be a node in [L_p, R_p], b be a node in [L_q, R_q], c be a node in [L_r, R_r], d be a node in [L_s, R_s], such that there is a tree edge between a and c, and between b and d. Clearly a < b < c < d. Then c must be a child of a (as nodes are indexed according to start times), and d is a child of b. But b must also lie in the subtree of a, as a < b < c, and since d is a child of b, we can’t move to c before covering d in the dfs, which means d < c, a contradiction.

So, now we’ll prove that the number of edges is O(k). There are multiple ways of proving this (for example, induction). One simple way is to see that G is planar, for we can draw all the nodes in a circle in clockwise order, and no two chords will intersect using the non-crossing property.

Therefore, we’ve solved each query in O(k \log N), leading to a total complexity of O(N \log N).

SOLUTIONS

Setter's Solution
#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define pii pair<int, int>
#define all(c) ((c).begin()), ((c).end())
#define sz(x) ((int)(x).size())

#ifdef LOCAL
#include <print.h>
#else
#define trace(...)
#endif

const int LOGN = 20;
const int N = 5e5 + 10;
int jump[LOGN][N];
int st[N], par[N];
void precompute(vector<vector<int>>& adj){
    int n = adj.size();
    const int LOG = __lg(n) + 1;
    fill(par, par + N, -1); 
    for(int i = 0; i < LOG; i++) fill(jump[i], jump[i] + N, n);
    jump[0][n] = n;

    int timer = 0;
    for(int i = 0; i < n; i++) sort(all(adj[i]));
    // after calling dfs, indexing is start time based
    function<void(int, int)> dfs = [&](int s, int p){
        sort(all(adj[s]));
        st[s] = timer++;
        for(int v: adj[s]) if(v != p){
            par[timer] = st[s]; 
            dfs(v, s);
        }
        jump[0][st[s]] = timer;
    };

    dfs(0, 0);
    
    for(int i = 1; i < LOG; i++){
        for(int j = 0; j < n; j++){
            jump[i][j] = jump[i - 1][jump[i - 1][j]];
        }
    }
}
int getFast(int n, const vector<pair<int, int>>& intervals){
    vector<int> L;
    pii top = intervals[0];
    for(int i = 1; i < sz(intervals); i++){
        auto it = intervals[i];
        if(top.second != it.first - 1){
            L.push_back(top.first);
            L.push_back(top.second + 1);
            top = it;
        }
        else top.second = it.second;
    }
    L.push_back(top.first); L.push_back(top.second + 1);
    int ans = 0;
    const int LOG = __lg(n) + 1;
    for(int i = 0; i < sz(L); i += 2){
        int l = L[i], r = L[i + 1] - 1;
        int x = l;
        while(x <= r){
            int pos = (upper_bound(all(L), par[x]) - L.begin()) - 1;
            int take = (pos + 2) % 2;
            int y = pos == -1 ? -2 : L[pos];
            ans += take;
            // go till par[x] <= y, and x <= r
            for(int u = LOG - 1; u >= 0; u--){
                int z = jump[u][x];
                if(par[z] >= y && z <= r){
                    ans += take << u;
                    x = z;
                }
            }

            x = jump[0][x];
        }
    }
    return ans;
}

int main(){
    ios_base::sync_with_stdio(false); 
    cin.tie(NULL); // Remove in interactive problems   
    
    int n, q, b; cin >> n >> q >> b;
    vector<vector<int>> adj(n);
    for(int i = 1; i < n; i++){
        int u, v; cin >> u >> v;
        u--; v--;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    precompute(adj);
    int ans = 0;
    while(q--){
        int l; cin >> l;
        vector<pair<int, int>> intervals(l);
        for(int j = 0; j < l; j++){
            cin >> intervals[j].first >> intervals[j].second;
            intervals[j].first ^= (b * ans);
            intervals[j].second ^= (b * ans);
            intervals[j].first--; intervals[j].second--;
        }
        
        cout << (ans = getFast(n, intervals)) << endl;
    }
}
Tester's Solution
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define fi first
#define se second
const int N=5e5+1;
int n,q;
vector<int>adj[N];
int par[N],h[N];
int st[N],en[N],od[N];
int lg[N],sp[N][19];
int ptr=0;

vector<int>ch[N];
void dfs(int id,int p){
	sort(adj[id].begin(),adj[id].end());
	st[id]=++ptr;od[ptr]=id;
	par[id]=p;h[id]=h[p]+1;
	ch[p].push_back(id);
	for(auto c:adj[id]){
		if(c==p) continue;
		dfs(c,id);
	}
	en[id]=ptr;
}
int boss(int l,int r){
	int k=lg[r-l+1];
	int lc=sp[l][k],rc=sp[r-(1<<k)+1][k];
	return (h[lc]<h[rc])?lc:rc;
}
int lca(int l,int r){
	int id=boss(l,r);
	return par[id];
}
int k;
int ql[N],qr[N],qm[N];
int ans[N];
vector<pair<int,int> >qs[N];


int rt[N];
int sz;
const int ts=1e7;
int s[ts];
int lc[ts],rc[ts];
int upd(int id,int l,int r,int p,int v){
	//if(l==1 && r==n) cout << "upd " << id << ' ' << p << ' ' << v << endl;
	if(l==r){
		s[++sz]=s[id]+1;return sz;
	}
	int mid=(l+r)/2;
	int cur=++sz;
	lc[cur]=lc[id];rc[cur]=rc[id];
	if(p<=mid) lc[cur]=upd(lc[id],l,mid,p,v);
	else rc[cur]=upd(rc[id],mid+1,r,p,v);
	s[cur]=s[lc[cur]]+s[rc[cur]];
	return cur;
}
int qry(int id,int l,int r,int ql,int qr){
	//cout << "qry " << id << ' ' << l  << ' ' << r << ' ' << s[id] << endl;
	if(id==0) return 0;
	if(l>qr || r<ql) return 0;
	if(ql<=l && r<=qr) return s[id];
	int mid=(l+r)/2;
	return qry(lc[id],l,mid,ql,qr)+qry(rc[id],mid+1,r,ql,qr);
}
void stamp(int id,int pf,int sf,int qt){
	//cout << "stamp " << id << ' ' << pf << ' ' << sf << ' ' << qt << endl;
	if(sf-1>1) ans[qt]-=qry(rt[sf-1],1,n,pf+1,sf-1);
	//cout << "funny " << qry(rt[sf-1],1,n,pf+1,sf-1) << ' ' << qry(rt[en[id]],1,n,pf+1,sf-1) << endl;
	ans[qt]+=qry(rt[en[id]],1,n,pf+1,sf-1);/*
	qs[sf-1].push_back({pf,-qt});
	qs[en[id]].push_back({pf,qt});
	qs[sf-1].push_back({sf-1,qt});
	qs[en[id]].push_back({sf-1,-qt});*/
}
void deal(int bm,int bl,int br,int qt){
	//cout << "deal " << bm << ' ' << bl << ' ' << br << ' ' << qt << endl;
	int king=qm[bm];
	bool dad=bl!=0 && (st[king]<=qr[bl]);
	int cl,cr;
	{
		int l=0,r=ch[king].size()-1;
		while(l!=r){
			int mid=(l+r+1)/2;
			if(st[ch[king][mid]]>ql[bm]) r=mid-1;
			else l=mid;
		}
		cl=l;
		if(st[ch[king][l]]==ql[bm]) cl--;
	}
	{
		int l=0,r=ch[king].size()-1;
		while(l!=r){
			int mid=(l+r+1)/2;
			if(st[ch[king][mid]]>qr[bm]) r=mid-1;
			else l=mid;
		}
		cr=l;
	}
	if(!dad) ans[qt]+=cr-cl;
	//if(!dad) cout << "add " << cr-cl << ' ' << qt << endl;
	if(cl!=-1){
		int tl=ch[king][cl];
		if(bl!=0) stamp(tl,max(st[tl]-1,qr[bl]),ql[bm],qt);
		else stamp(tl,st[tl]-1,ql[bm],qt);
	}
	{
		int tr=ch[king][cr];
		if(br!=0 && ql[br]<=en[tr]){
			stamp(tr,qr[bm],ql[br],qt);
			ql[br]=en[tr]+1;//should not change qm but i might be clown
		}
		else ;//stamp(tr,qr[bm],en[tr]+1,qt);
	}
}
int bit[N];
void upd(int id,int v){
	for(int i=id; i<=n ;i+=i&-i) bit[i]+=v;
}
int qry(int id){
	int res=0;
	for(int i=id; i>=1 ;i-=i&-i) res+=bit[i];
	return res;
}
int main(){
	ios::sync_with_stdio(false);cin.tie(0);
	int b;
	cin >> n >> q >> b;
	for(int i=1; i<n ;i++){
		int u,v;cin >> u >> v;
		adj[u].push_back(v);
		adj[v].push_back(u);
		lg[i+1]=lg[(i+1)/2]+1;
	}
	dfs(1,0);
	for(int i=1; i<=n ;i++) sp[i][0]=od[i];
	for(int j=1; j<=lg[n] ;j++){
		for(int i=1; i+(1<<j)<=n+1 ;i++){
			int lc=sp[i][j-1],rc=sp[i+(1<<(j-1))][j-1];
			sp[i][j]=(h[lc]<h[rc])?lc:rc;
		}
	}
	rt[1]=0;
	for(int i=2; i<=n ;i++){
		int x=od[i];
		rt[i]=upd(rt[i-1],1,n,st[par[x]],1);/*
		for(auto c:qs[i]){
			if(c.se>0) ans[c.se]+=(i-1)-qry(c.fi);
			else ans[-c.se]-=(i-1)-qry(c.fi);
		}*/
	}
	for(int i=1; i<=q ;i++){
		cin >> k;
		stack<int>s;
		s.push(0);
		for(int j=1; j<=k ;j++){
			cin >> ql[j] >> qr[j];
			ql[j]^=b*ans[i-1];
			qr[j]^=b*ans[i-1];
			qm[j]=lca(ql[j],qr[j]);
			while(s.top()!=0 && h[qm[s.top()]]>=h[qm[j]]){
				int x=s.top();s.pop();
				deal(x,s.top(),j,i);
			}
			s.push(j);
		}
		while(s.top()!=0){
			int x=s.top();s.pop();
			deal(x,s.top(),0,i);
		}
		cout << ans[i] << '\n';
	}
	//for(int i=1; i<=q ;i++) cout << ans[i] << '\n';
}

I solved this question partially in the contest. I was given points for that in the ranklist but my total marks as well as rank did not get updated I request to do something. My Submission

Why the total complexity of the first solution of Subtask 2 is O(NB log N)? "iterate over all intervals " need O(BlogN) and it has O(NB) (j,q,x)

As mentioned in the beginning of the editorial, N has been assumed to be the upper bound on both the number of vertices and sum over number of intervals across queries.