MYPROB2 - Editorial

PROBLEM LINK:

PIZZA LAND!

Author: sastaa_tourist,alpha_1205
Tester: sastaa_tourist, alpha_1205,chef_hamster
Editorialist: sastaa_tourist

DIFFICULTY:

HARD

PROBLEM:

You are given a tree and in each query you are given a set of vertices and you have to find if there is any path to reach all the vertices such that no edge is visited more than once. And if there is a path you have to find how many extra vertices you have to visit to complete that path.

Prerequisites:

knowledge of binary lifting algorithm and several pre-computations on tree.

QUICK EXPLANATION:

for each query find the two end nodes of the path by the level of that nodes then find LCA of that two nodes and check if every node from query lies on the path from end node to LCA or not.

EXPLANATION:

First of all, we have to do some pre-computations on given tree. So run a DFS and calculate the level, in-out time for every node. Then do pre-computations for binary-lifting algorithm (generate 2D array of binary-lifting).

Now take input of query and store it in vector of pair in which First element is level of node and Second element is node itself. Now we will sort this vector in descending order according to First element of pair.

Now, In this vector first node (let say X) will be the one end of the given path and then we will check for every other node from the vector if that node is ancestor of the X or not in O(1) time by in-out time. First node(let say Y) that we will get which is not a ancestor of the X will be a second end of the path and if we won’t get any node from vector which is not a ancestor of X then path always exists.

So we will get two end of the path which is X and Y. Now, we will find LCA of that two node (let say L) in O(log n) time using binary-lifting and we will check for every node of query if that node exists on the path of X to L OR Y to L (For this that node should be ancestor of X or Y and L should be ancestor of that node) . If there is any node which is not on this path then path through given nodes won’t exist and print “NO” for this case otherwise print “YES”.

To calculate fine find total number of nodes between X and Y by :- Level[X] - Level[L] + Level[Y] - Level[L] + 1 and subtract number of given node in query.

SOLUTIONS:

Setter’s Solution

#include <bits/stdc++.h>

using namespace std;
#define int long long int
#define mp make_pair
#define pb push_back
#define F first
#define S second
const int N = 200005;
#define M 1000000007

vector<int> adj[N];

int timer = 0, st[N], en[N], lvl[N], P[N][22];

bool is_ancestor(int u, int v)
{
	return st[u] <= st[v] && en[u] >= en[v];
}


void dfs(int node, int parent) {
	lvl[node] = 1 + lvl[parent];
	P[node][0] = parent;

	st[node] = timer++;
	for (int i : adj[node]) {
		if (i != parent) {
			dfs(i, node);
		}
	}
	en[node] = timer++;
}

void pre(int u, int p) {
	P[u][0] = p;
	for (int i = 1; i < 22; i++)
		P[u][i] = P[P[u][i - 1]][i - 1];

	for (auto i : adj[u])
		if (i != p)
			pre(i, u);
}

int lca(int u, int v) {
	int i, lg;
	if (lvl[u] < lvl[v]) swap(u, v);

	for (lg = 0; (1 << lg) <= lvl[u]; lg++);
	lg--;

	for (i = lg; i >= 0; i--) {
		if (lvl[u] - (1 << i) >= lvl[v])
			u = P[u][i];
	}

	if (u == v)
		return u;

	for (i = lg; i >= 0; i--) {
		if (P[u][i] != -1 and P[u][i] != P[v][i])
			u = P[u][i], v = P[v][i];
	}

	return P[u][0];
}



void solve() {

	int n;
	cin >> n;


	for (int i = 1; i <= n; i++) {
		adj[i].clear();
		st[i] = en[i] = lvl[i] = 0;
		for (int j = 0; j < 22; j++) {
			P[i][j] = -1;
		}
	}
	timer = 0;

	for (int i = 0; i < n - 1; i++) {
		int x, y;
		cin >> x >> y;
		adj[x].push_back(y);
		adj[y].push_back(x);
	}

	dfs(1, 0);
	pre(1, 0);

	int q;
	cin >> q;

	while (q--) {

		int k;
		cin >> k;

		vector<int> path(k);

		for (int i = 0; i < k; i++) {
			cin >> path[i];
		}

		vector<pair<int, int> > v;

		for (auto i : path) {
			v.push_back(make_pair(lvl[i], i));
		}

		sort(v.rbegin(), v.rend());

		vector<int>node;

		node.push_back(v[0].S);

		for (int i = 1; i < k; i++) {

			bool got = false;

			if (is_ancestor(v[i].S, v[i - 1].S)) {
				got = true;

			}

			if (!got) {

				node.push_back(v[i].S);

				break;
			}
		}


		if (node.size() == 1) {

			int r = lvl[node[0]] - v[v.size() - 1].first + 1;

			cout << "YES" << " " << r - k << endl;
			continue ;
		}



		int lca_node = lca(node[0], node[1]);

		int ok = 1;

		for (auto i : path) {
			if (i != lca_node and i != node[0] and i != node[1] and  (is_ancestor(i, node[0]) || is_ancestor(i, node[1])) and is_ancestor(lca_node, i)) {
				ok = 1;

			}
			else if (i != lca_node and i != node[0] and i != node[1]) {
				ok = 0; break;

			}
		}

		if (ok) {

			int r = lvl[node[0]] - lvl[lca_node] + lvl[node[1]] - lvl[lca_node] + 1;
			cout << "YES" << " " << r - k << endl;

		}
		else {
			cout << "NO\n";
		}

	}

}

#undef int
int main() {

#define int long long int
	ios_base::sync_with_stdio(false);
	cin.tie(0);
	cout.tie(0);
#ifndef ONLINE_JUDGE
	freopen("Error.txt", "w", stderr);
	freopen("input.txt", "r", stdin);
	freopen("output.txt", "w", stdout);
#endif

	int t;
	cin >> t;
	while (t--) {
		solve();
	}

	return 0;

}

Tester’s Solution

import java.util.Scanner;
import java.util.Arrays;
import java.util.Comparator;
import java.util.*;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Scanner;
import java.util.StringTokenizer;

public class Main{

	static class FastReader {
        BufferedReader br;
        StringTokenizer st;
 
        public FastReader()
        {
            br = new BufferedReader(
                new InputStreamReader(System.in));
        }
 
        String next()
        {
            while (st == null || !st.hasMoreElements()) {
                try {
                    st = new StringTokenizer(br.readLine());
                }
                catch (IOException e) {
                    e.printStackTrace();
                }
            }
            return st.nextToken();
        }
 
        int nextInt() { return Integer.parseInt(next()); }
 
        long nextLong() { return Long.parseLong(next()); }
 
        double nextDouble()
        {
            return Double.parseDouble(next());
        }
 
        String nextLine()
        {
            String str = "";
            try {
                str = br.readLine();
            }
            catch (IOException e) {
                e.printStackTrace();
            }
            return str;
        }
    }
    static FastReader at  = new FastReader();
	// static Scanner at = new Scanner(System.in);
	static int N = 200005;
	static ArrayList<Integer> []adj= new ArrayList[N];
	static int timer = 0;
	static int [] st = new int[N];
	static int [] en = new int[N];
	static int [] lvl = new int[N];
	static int [][] P = new int[N][22];
	//

	public static boolean is_ancestor(int u, int v){
		return (st[u] <= st[v] && en[u] >= en[v]);
	}

	//
	public static void dfs(int node,int parent){
		lvl[node] = 1+lvl[parent];
		P[node][0] = parent;
		st[node] = timer++;
		for(int i : adj[node]){
			if( i != parent){
				dfs(i,node);
			}
		}
		en[node] = timer++;
	}
	//
	public static void pre(int u , int p){
		P[u][0] = p;
		for(int i = 1;i<22;i++){
			P[u][i] = P[P[u][i - 1]][i - 1];
		}
		for(int i : adj[u]){
			if(i != p){
				pre(i,u);
			}
		}
	}
	//
	public static int lca(int u,int v){
		int i = 0;
		int lg = 0;
		if(lvl[u] < lvl[v]){
			int temp  = u;
			u = v;
			v = temp;
		}

		for(lg = 0;(1<<lg) <= lvl[u] ; lg++);
		lg--;

		for(i = lg;i >= 0;i--){
			// int x = ( ((1 << i) >= lvl[v]) == true) ? 1 :  0;
			if(lvl[u] - (1<<i) >= lvl[v]){
				u = P[u][i];
			}
		}

		if(u == v)return u;

		for(i = lg;i>=0 ;i--){
			if (P[u][i] != -1 && P[u][i] != P[v][i]){
				u = P[u][i]; v = P[v][i];
			}
		}

		return P[u][0];

	}
	//
	public static class Pair{
		int first;
		int second;
		Pair(int x,int y){
			this.first = x;
			this.second = y;
		}
	}

	//
	public static class comp implements Comparator<Pair>{
		public int compare(Pair a , Pair b){
			if(a.first != b.first){
				return -1*(a.first-b.first);
			}
			else{
				return (-1*(a.second-b.second)); 
			}
		}
	}
	
	public static void solve(){
		int n = at.nextInt();
		for(int i = 0;i<=n;i++){
			adj[i] = new ArrayList<Integer>();
		}
		for(int i = 1;i<=n;i++){
			adj[i].clear();
			st[i] = 0;en[i] = 0; lvl[i] = 0;
			for(int j = 0;j<22;j++){
				P[i][j] = -1;
			}
		}
		timer = 0;
		


		for(int i = 0;i<n-1;i++){
			int x = at.nextInt();
			int y = at.nextInt();
			adj[x].add(y);
			adj[y].add(x);
		}

		dfs(1,0);
		pre(1,0);

		int q = at.nextInt();
		
		while(q > 0){
			q--;
			//
			int k = at.nextInt();
			ArrayList<Integer> path = new ArrayList();

			for(int i = 0;i<k;i++){
				int zeta = at.nextInt();
				path.add(zeta);
			}

			ArrayList<Pair> v = new ArrayList<Pair>();

			for(int i : path){
				v.add(new Pair(lvl[i],i));
			}

			Collections.sort(v,new comp());
			
			ArrayList<Integer> node = new ArrayList<Integer>();

			node.add(v.get(0).second);

			for(int i = 1;i<k;i++){
				boolean got = false;

				if(is_ancestor(v.get(i).second , v.get(i-1).second)){
					got = true;
				}

				if(got == false){
					node.add(v.get(i).second);
					break;
				}
			}
			// v[v.size() - 1].first + 1;
			
			if(node.size() == 1){
				int r = lvl[node.get(0)] - v.get(v.size()-1).first + 1;

				System.out.println("YES "+(r - k));
				continue;
			}

		
			int lca_node = lca(node.get(0), node.get(1));

			int ok = 1;

			for(int i : path){
				
				if(i != lca_node && i != node.get(0) && i != node.get(1) && (is_ancestor(i,node.get(0)) || is_ancestor(i,node.get(1)) ) && is_ancestor(lca_node,i)){
					ok = 1;
				}
				else if(i != lca_node && i != node.get(0) && i != node.get(1)){
					ok = 0;
					break;
				}
			}
			


			if(ok == 1){
				int r = lvl[node.get(0)] - lvl[lca_node] + lvl[node.get(1)] - lvl[lca_node] + 1;
				System.out.println("YES "+(r-k));
			}	
			else{
				System.out.println("NO");
			}

			// System.out.println();
			// System.out.println();
			//
		}
		//
	}
	public static void main(String[]args){
		int T = at.nextInt();
		while(T>0){
			solve();
			T--;
		}
	}
}




import sys
sys.setrecursionlimit(10000)

N = 200005
adj = [[] for i in range(N)]
timer = 0
st = [0 for i in range(N)]
en = [0 for i in range(N)]
lvl = [0 for i in range(N)]
P = [[0 for j in range(22)] for i in range(N)]


def is_ancestor(u,v):
    return st[u] <= st[v] and en[u]>=en[v]


def dfs(node,parent):
    global timer
    lvl[node] = 1+lvl[parent]
    P[node][0] = parent
    st[node] = timer
    timer+=1
    for i in adj[node]:
        if(i!=parent):
            dfs(i,node)
    en[node] = timer
    timer+=1


def pre(u,p):
    P[u][0]=p
    for i in range(1,22):
        P[u][i] = P[P[u][i-1]][i-1]
    for i in adj[u]:
        if i!=p:
            pre(i,u)

def lca(u,v):
    i=0
    lg=0
    if(lvl[u]<lvl[v]):
        u,v=v,u
    while((1<<lg)<=lvl[u]): lg+=1
    lg-=1
    for i in range(lg,-1,-1):
        if lvl[u] - (1<<i) >= lvl[v]:
            u = P[u][i]
    
    if(u==v): return u
    
    for i in range(lg,-1,-1):
        if P[u][i]!=-1 and P[u][i] != P[v][i]:
            u = P[u][i]
            v = P[v][i]
    return P[u][0]

class Pair:
    first=0
    second=0
    def __init__(self,first,second):
        self.first=first
        self.second = second
    def __lt__(self,other):
        if(self.first==other.first):
            return self.second<other.second
        return self.first<other.first



def solve():
    n = int(input())
    timer=0
    for i in range(n-1):
        x,y = [int(j) for j in input().split()]
        adj[x].append(y)
        adj[y].append(x)
    
    dfs(1,0)
    pre(1,0)

    q = int(input())
    while q:
        q-=1
        k=int(input())
        path = [int(j) for j in input().split()]
        v = []
        for i in path:
            v.append(Pair(lvl[i],i))
        v.sort(reverse=True)
        node = []
        node.append(v[0].second)
        for i in range(1,k):
            got = False
            if is_ancestor(v[i].second, v[i-1].second):
                got = True
            if got == False:
                node.append(v[i].second)
                break
        
        if len(node) == 1:
            r = lvl[node[0]]-v[len(v)-1].first+1
            print("YES "+str(r-k))
            continue
        lca_node = lca(node[0],node[1])
        ok=1
        for i in path:
            if i!=lca_node and i!=node[0] and i!= node[1] and (is_ancestor(i, node[0]) or is_ancestor(i, node[1])) and is_ancestor(lca_node, i):
                ok = 1
            elif i!=lca_node and i!=node[0] and i!=node[1]:
                ok=0
                break

        if ok==1:
            r = lvl[node[0]] - lvl[lca_node] + lvl[node[1]] - lvl[lca_node]+1
            print("YES "+str(r-k))
        else:
            print("NO")



t = int(input())
while t:
    t-=1
    solve()
    

​

3 Likes