QRYLAND - EDITORIAL

PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Setter: Mladen Puzic

Tester: Michael Nematollahi

Editorialist: Taranpreet Singh

DIFFICULTY:

PREREQUISITES:

Mo’s algorithm on tree or Heavy-Light Decomposition, Randomization.

PROBLEM:

You are given a tree with N nodes, where a value is assigned to each node, given in array A. You have to process following updates and queries.

  • Update value assigned to node p to X.
  • Consider all values assigned nodes on the path from L to R on the path from u to v. Say the number of nodes is L. Print Yes, if the values under consideration form a permutation of natural numbers in the range [1, L].

EXPLANATION

First of all, assign random numbers > 0 to each value from 1 to N and for each value in the array A, replace x by the random value assigned to x.

Let us consider we need to answer query (u, v) such that there are L nodes on the path from u to v. Consider the xor of all values on the path from u to v. If the values on path form a permutation of first L natural numbers, then this xor value shall be equal to the xor of values assigned to each i from 1 to L. Xor of values assigned from 1 to i can be easily precomputed using prefix xor for each i from 1 to N. So, the answer to our query becomes Yes, if the xor of numbers on the path from u to v is the prefix xor up to position L.

This gives us the means to check if values of nodes on a path form the permutation or not, if we can somehow implement a data structure capable of updating value for a position, and finding xor of values on a path.

We can either use Heavy-Light Decomposition but there’s a simpler and efficient solution using Euler tour and Mo’s algorithm on trees.

Since the inverse of xor operation is xor operation itself, we can see, that xoring with a number twice doesn’t affect the initial value.

Consider the sample tree as given below.
qryland

The Euler tour for this Tree shall be

1 2 4 10 10 4 5 6 6 7 7 8 8 5 2 3 9 9 3 1

We can see, that for query (u, v) (Assuming ST_u \leq ST_v where ST denote start times and EN denote end times.), if LCA of u and v is u, then interval [ST_u, ST_v] contains all nodes on path from u to v exactly once, while all other nodes either twice or not at all (ST being . Since xor operation cancels itself, presence of any node twice negates its presence, and we are only left with nodes on the path from u to v. This way, we can get xor of value on the path from u to v as xor of an interval.

Similarly, if LCA of u and v is not u, we need to consider LCA node separately, and remaining nodes on the path from u to v appear exactly once in the interval [EN_u, ST_v].

For updates, we just need to update the start and end position of a node with the random value assigned to value given in update in the segment tree.

TIME COMPLEXITY

Time complexity is O((N+Q)*log(N)) per test case.

SOLUTIONS:

Setter's Solution
#include<bits/stdc++.h>
#define STIZE(x) fprintf(stderr, "STIZE%d\n", x);
#define PRINT(x) cerr << #x << ' ' << x << endl;
#define NL(x) printf("%c", " \n"[(x)]);
#define lld long long
#define pll pair<lld,lld>
#define pb push_back
#define fi first
#define se second
#define mid (l+r)/2
#define endl '\n'
#define all(a) begin(a),end(a)
#define sz(a) int((a).size())
#define LINF 2000000000000000000LL
#define INF 1000000000
#define EPS 1e-9
using namespace std;
#define MAXN 500010
#define MAXL 20
mt19937 rng(48201);
vector<int> adj[MAXN];
int N, Q, in[MAXN], out[MAXN], dub[MAXN], anc[MAXN][MAXL], timer;
unsigned lld bit[2*MAXN], prefix[MAXN], val[MAXN];
map<int, unsigned lld> hsh;
map<unsigned lld, bool> used;
void update(int idx, unsigned lld val) {
	while(idx < 2*MAXN) {
	    bit[idx] ^= val;
	    idx += idx&-idx;
	}
}
unsigned lld query(int idx) {
	unsigned lld xorr = 0;
	while(idx) {
	    xorr ^= bit[idx];
	    idx -= idx&-idx;
	}
	return xorr;
}
unsigned lld query(int l, int r) {
	return query(r)^query(l-1);
}
void dfs(int node, int prev, int dubb) {
	dub[node] = dubb;
	in[node] = ++timer;
	anc[node][0] = prev;
	for(auto x : adj[node]) {
	    if(x != prev) dfs(x, node, dubb+1);
	}
	out[node] = ++timer;
}
void initLCA(int node) {
	dfs(1, 1, 0);
	for(int i = 1; i <= N; i++) update(in[i], val[i]), update(out[i], val[i]);
	for(int d = 1; d < MAXL; d++) {
	    for(int i = 1; i <= N; i++) {
	        anc[i][d] = anc[anc[i][d-1]][d-1];
	    }
	}
}
bool inSubtree(int X, int Y) { ///Y in subtree of X
	return (in[X] <= in[Y] && out[Y] <= out[X]);
}
int LCA(int X, int Y) {
	if(inSubtree(X, Y)) return X;
	if(inSubtree(Y, X)) return Y;
	for(int d = MAXL-1; d >= 0; d--) {
	    if(!inSubtree(anc[X][d], Y)) X = anc[X][d];
	}
	return anc[X][0];
}
unsigned long long getRand() {
	unsigned lld x;
	while(1) {
	    x = uniform_int_distribution<unsigned lld> (1, ULLONG_MAX)(rng);
	    if(!used[x]) {
	        used[x] = true;
	        return x;
	    }
	}

}
int main() {
	ios::sync_with_stdio(false); cin.tie(0); cout.tie(0); cerr.tie(0);
	int T; cin >> T;
	while(T--) {
	    hsh.clear(); used.clear(); timer = 0;
	    cin >> N >> Q;
	    for(int i = 1; i <= N; i++) prefix[i] = 0, adj[i].clear();
	    for(int i = 1; i <= 2*N; i++) bit[i] = 0;
	    for(int i = 1; i <= N; i++) {
	        hsh[i] = getRand();
	        prefix[i] = prefix[i-1] ^ hsh[i];
	    }
	    for(int i = 1; i <= N; i++) {
	        cin >> val[i];
	        if(hsh[val[i]] == 0) hsh[val[i]] = getRand();
	        val[i] = hsh[val[i]];
	    }
	    for(int i = 1; i < N; i++) {
	        int x, y; cin >> x >> y;
	        adj[x].pb(y);
	        adj[y].pb(x);
	    }
	    initLCA(1);
	    while(Q--) {
	        int type, X, Y; cin >> type >> X >> Y;
	        if(type == 1) {
	            if(in[Y] < in[X]) swap(X, Y);
	            int lca = LCA(X, Y);
	            int L = dub[X] + dub[Y] - 2*dub[lca] + 1;
	            unsigned lld rez = 0;
	            if(X == lca) rez = query(in[X], in[Y]);
	            else rez = query(out[X], in[Y]) ^ val[lca];
	            if(rez == prefix[L]) cout << "Yes\n";
	            else cout << "No\n";
	        } else {
	            if(hsh[Y] == 0) hsh[Y] = getRand();
	            unsigned lld y = hsh[Y];
	            update(in[X], y^val[X]);
	            update(out[X], y^val[X]);
	            val[X] = y;
	        }
	    }
	}
}
Tester's Solution
#include<bits/stdc++.h>

using namespace std;

typedef unsigned long long ull;
typedef pair<int, int> pii;

#define F first
#define S second
#define tm kljasdf

const int MAXN = 5e5 + 10;
const int B[2] = {690397, 692141};

int n, q, a[MAXN];
vector<int> adj[MAXN];

ull pw(ull a, int b){
	ull ret = 1;
	while (b){
		if (b & 1)
			ret = ret*a;
		b >>= 1;
		a = a*a;
	}
	return ret;
}

int sub[MAXN], depth[MAXN], par[MAXN];
bool cmp(int u, int v){return sub[u] > sub[v];}

void plant(int v, int p = -1, int de = 0){
	if (~p)
		adj[v].erase(find(adj[v].begin(), adj[v].end(), p));
	sub[v] = 1;
	depth[v] = de;
	par[v] = p;
	for (int u:adj[v]) {
		plant(u, v, de+1);
		sub[v] += sub[u];
	}

	sort(adj[v].begin(), adj[v].end(), cmp);
}

int curRt = -1, root[MAXN], st[MAXN], tm, ord[MAXN];
void hld(int v){
	if (curRt == -1)
		curRt = v;
	root[v] = curRt;
	ord[tm] = v;
	st[v] = tm++;

	for (int u:adj[v]){
		hld(u);
		curRt = -1;
	}
}

ull seg[MAXN<<2][2];
void merge(int v){
	for (int w = 0; w < 2; w++)
		seg[v][w] = seg[v<<1][w] + seg[v<<1^1][w];
}

void reCalc(int v, int val){
	for (int w = 0; w < 2; w++)
		seg[v][w] = pw(B[w], val);
}

void plantSeg(int v, int b, int e){
	if (e - b == 1){
		reCalc(v, a[ord[b]]);
		return;
	}

	int mid = b + e >> 1;
	plantSeg(v<<1, b, mid);
	plantSeg(v<<1^1, mid, e);
	merge(v);
}

void upd(int v, int b, int e, int pos){
	if (e - b == 1){
		reCalc(v, a[ord[pos]]);
		return;
	}

	int mid = b + e >> 1;
	if (pos < mid)
		upd(v<<1, b, mid, pos);
	else
		upd(v<<1^1, mid, e, pos);
	merge(v);
}

pair<ull, ull> getSeg(int v, int b, int e, int l, int r){
	if (l <= b && e <= r) return {seg[v][0], seg[v][1]};
	if (r <= b || e <= l) return {0, 0};

	int mid = b + e >> 1;
	auto x = getSeg(v<<1, b, mid, l, r);
	auto y = getSeg(v<<1^1, mid, e, l, r);
	return {x.F+y.F, x.S+y.S};
}

pair<pair<ull, ull>, int> get(int u, int v){
	pair<pair<ull, ull>, int> ret = {{0, 0}, 0};
	while (root[u] ^ root[v]){
		if (depth[root[u]] < depth[root[v]])
			swap(u, v);

		ret.S += depth[u] - depth[root[u]] + 1;
		auto x = getSeg(1, 0, n, st[root[u]], st[u]+1);
		ret.F.F += x.F, ret.F.S += x.S;
		u = par[root[u]];
	}

	if (depth[u] < depth[v])
		swap(u, v);
	ret.S += depth[u] - depth[v] + 1;
	auto x = getSeg(1, 0, n, st[v], st[u]+1);
	ret.F.F += x.F, ret.F.S += x.S;
	return ret;
}

ull sv[MAXN][2];

int main(){
	ios::sync_with_stdio(false);
	cin.tie(0);
	for (int i = 1; i < MAXN; i++)
		for (int w = 0; w < 2; w++){
			sv[i][w] = sv[i-1][w] + pw(B[w], i);
		}
	int te; cin >> te;
	while (te--){
		cin >> n >> q;
		for (int i = 0; i < n; i++) adj[i].clear();
		tm = 0;
		curRt = -1;

		for (int i = 0; i < n; i++) cin >> a[i];
		for (int i = 0; i < n-1; i++){
			int a, b; cin >> a >> b, a--, b--;
			adj[a].push_back(b);
			adj[b].push_back(a);
		}
		plant(0);
		hld(0);
		plantSeg(1, 0, n);
		while (q--){
			int type; cin >> type;
			if (type == 1){
				int u, v; cin >> u >> v, u--, v--;
				auto x = get(u, v);

				if (sv[x.S][0] != x.F.F || sv[x.S][1] != x.F.S)
					cout << "No\n";
				else
					cout << "Yes\n";
			}
			else{
				int v, val; cin >> v >> val, v--;
				a[v] = val;
				upd(1, 0, n, st[v]);
			}
		}
	}
	return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class QRYLAND{
	//SOLUTION BEGIN
	int B = 20;
	void pre() throws Exception{}
	void solve(int TC) throws Exception{
	    int n = ni(), q = ni();
	    int[] a = new int[n];
	    for(int i = 0; i< n; i++){
	        a[i] = ni();
	        if(a[i] > n)a[i] = 0;
	    }
	    int[] rand = new int[1+n];
	    Random r = new Random();
	    for(int i = 1; i<= n; i++)rand[i] = 1+r.nextInt((1<<30)-1);
	    int[] pre = new int[1+n];
	    for(int i = 1; i<= n; i++)pre[i] = pre[i-1]^rand[i];
	    int[][] e = new int[n-1][];
	    for(int i = 0; i< n-1; i++)e[i] = new int[]{ni()-1, ni()-1};
	    int[][] g = makeU(n, e);
	    time = -1;
	    int[] depth = new int[n];
	    int[][] par = new int[B][n];
	    for(int b = 0; b < B; b++)Arrays.fill(par[b], -1);
	    
	    int[] eu = new int[2*n];
	    int[][] ti = new int[n][2];
	    dfs(g, ti, eu, par, depth, 0, -1);
	    
	    for(int i = 0; i< 2*n; i++)eu[i] = rand[a[eu[i]]];
	    SegTree t = new SegTree(eu);
	    while(q-->0){
	        int ty = ni();
	        if(ty == 1){
	            int x = ni()-1, y = ni()-1;
	            int lca = lca(par, depth, x, y);
	            if(ti[x][0] > ti[y][0]){
	                int tt = x;
	                x = y;
	                y = tt;
	            }
	            if(lca == x){
	                int xor = t.q(ti[x][0], ti[y][0]);
	                int length = depth[x]+depth[y]-2*depth[lca]+1;
	                pn(xor == pre[length]?"Yes":"No");
	            }else{
	                int xor = t.q(ti[x][1], ti[y][0])^t.q(ti[lca][0], ti[lca][0]);
	                int length = depth[x]+depth[y]-2*depth[lca]+1;
	                pn(xor == pre[length]?"Yes":"No");
	            }
	        }else{
	            int x = ni()-1, y = ni();
	            int rnd = 0;
	            if(1<= y && y<= n)rnd = rand[y];
	            t.u(ti[x][0], rnd);
	            t.u(ti[x][1], rnd);
	        }
	    }
	}
	int lca(int[][] par, int[] d, int u, int v){
	    if(d[u] > d[v]){
	        int t = u;
	        u = v;
	        v = t;
	    }
	    for(int b = B-1; b>= 0; b--)if((((d[v]-d[u])>>b)&1) == 1)v = par[b][v];
	    if(u == v)return u;
	    for(int b = B-1; b>= 0; b--)
	        if(par[b][u] != par[b][v]){
	            u = par[b][u];
	            v = par[b][v];
	        }
	    return par[0][u];
	}
	int time;
	void dfs(int[][] g, int[][] ti, int[] eu, int[][] par, int[] d, int u, int p){
	    par[0][u] = p;
	    for(int b = 1; b< B; b++)
	        if(par[b-1][u] != -1)
	            par[b][u] = par[b-1][par[b-1][u]];
	    eu[++time] = u;
	    ti[u][0] = time;
	    for(int v:g[u])if(v!= p){
	        d[v] = d[u]+1;
	        dfs(g, ti, eu, par, d, v, u);
	    }
	    eu[++time] = u;
	    ti[u][1] = time;
	}
	class SegTree{
	    int m= 1;
	    int[] t;
	    public SegTree(int[] a){
	        while(m<a.length)m<<=1;
	        t = new int[m<<1];
	        for(int i = 0; i< a.length; i++)t[i+m] = a[i];
	        for(int i = m-1; i>0; i--)t[i] = t[i<<1]^t[i<<1|1];
	    }
	    void u(int p, int value){
	        t[p+=m] = value;
	        for(p>>=1;p>0;p>>=1)t[p] = t[p<<1]^t[p<<1|1];
	    }
	    int q(int l, int r){
	        int ans = 0;
	        for(l+=m,r+=m+1;l<r;l>>=1,r>>=1){
	            if((l&1)==1)ans^=t[l++];
	            if((r&1)==1)ans^=t[--r];
	        }
	        return ans;
	    }
	}
	
	int[][] makeU(int n, int[][] edge){
	    int[][] g = new int[n][];int[] cnt = new int[n];
	    for(int i = 0; i< edge.length; i++){cnt[edge[i][0]]++;cnt[edge[i][1]]++;}
	    for(int i = 0; i< n; i++)g[i] = new int[cnt[i]];
	    for(int i = 0; i< edge.length; i++){
	        g[edge[i][0]][--cnt[edge[i][0]]] = edge[i][1];
	        g[edge[i][1]][--cnt[edge[i][1]]] = edge[i][0];
	    }
	    return g;
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	DecimalFormat df = new DecimalFormat("0.00000000000");
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new QRYLAND().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to Share your approach, if you want to. (even if its same :stuck_out_tongue: ) . Suggestions are welcomed as always had been. :slight_smile:

Is there any probabilistic proof to why this would always work? On first look, XOR seems like a terrible hash function here :confused:

2 Likes

Here is how I was thinking: if you take two random bits, their xor will also be random (with .5 probability). If we take random 64 bit integers, their xor will also be a random 64 bit integer. Someone else can hopefully post a proof.

Still, using xor is not necessary to solve the problem. Any set hash will work.

Can we solve this using 2 segment tree and HLD, one containing minimum value on the path and another one containing count of distinct values.
Ans=(min=1 and cnt_distinct=path_length)?“YES”:“NO”;
is it possible?

Mmm, i don’t think so. How would you handle the count of distinct values? Remember that you have to do segment tree queries over the hld chains, so you will need to merge the answer of several chains when the nodes of a query belong to different chains.

We could use MOs algorithm as queries are offline for calculation of distinct values.

Can we do it by checking if the segment(x - y) has sum = (1 + 2 + 3 + … +L) and xor of segment = ( 1 ^ 2 ^ 3 ^ … L) and min of segment is 1 and max of segment is L.
Will it be ok?? Or are there any counter example of this. ? please provide any feedback.
[Edit: This method works and my solution get passed. But i don’t know why it get passed.]

Why are we using XOR but not any other operator or function?