# MYPROB2 - Editorial

PIZZA LAND!

Author: sastaa_tourist,alpha_1205
Tester: sastaa_tourist, alpha_1205,chef_hamster
Editorialist: sastaa_tourist

HARD

# PROBLEM:

You are given a tree and in each query you are given a set of vertices and you have to find if there is any path to reach all the vertices such that no edge is visited more than once. And if there is a path you have to find how many extra vertices you have to visit to complete that path.

# Prerequisites:

knowledge of binary lifting algorithm and several pre-computations on tree.

# QUICK EXPLANATION:

for each query find the two end nodes of the path by the level of that nodes then find LCA of that two nodes and check if every node from query lies on the path from end node to LCA or not.

# EXPLANATION:

First of all, we have to do some pre-computations on given tree. So run a DFS and calculate the level, in-out time for every node. Then do pre-computations for binary-lifting algorithm (generate 2D array of binary-lifting).

Now take input of query and store it in vector of pair in which First element is level of node and Second element is node itself. Now we will sort this vector in descending order according to First element of pair.

Now, In this vector first node (let say X) will be the one end of the given path and then we will check for every other node from the vector if that node is ancestor of the X or not in O(1) time by in-out time. First node(let say Y) that we will get which is not a ancestor of the X will be a second end of the path and if we wonâ€™t get any node from vector which is not a ancestor of X then path always exists.

So we will get two end of the path which is X and Y. Now, we will find LCA of that two node (let say L) in O(log n) time using binary-lifting and we will check for every node of query if that node exists on the path of X to L OR Y to L (For this that node should be ancestor of X or Y and L should be ancestor of that node) . If there is any node which is not on this path then path through given nodes wonâ€™t exist and print â€śNOâ€ť for this case otherwise print â€śYESâ€ť.

To calculate fine find total number of nodes between X and Y by :- Level[X] - Level[L] + Level[Y] - Level[L] + 1 and subtract number of given node in query.

# SOLUTIONS:

Setterâ€™s Solution

#include <bits/stdc++.h>

using namespace std;
#define int long long int
#define mp make_pair
#define pb push_back
#define F first
#define S second
const int N = 200005;
#define M 1000000007

int timer = 0, st[N], en[N], lvl[N], P[N][22];

bool is_ancestor(int u, int v)
{
return st[u] <= st[v] && en[u] >= en[v];
}

void dfs(int node, int parent) {
lvl[node] = 1 + lvl[parent];
P[node][0] = parent;

st[node] = timer++;
for (int i : adj[node]) {
if (i != parent) {
dfs(i, node);
}
}
en[node] = timer++;
}

void pre(int u, int p) {
P[u][0] = p;
for (int i = 1; i < 22; i++)
P[u][i] = P[P[u][i - 1]][i - 1];

if (i != p)
pre(i, u);
}

int lca(int u, int v) {
int i, lg;
if (lvl[u] < lvl[v]) swap(u, v);

for (lg = 0; (1 << lg) <= lvl[u]; lg++);
lg--;

for (i = lg; i >= 0; i--) {
if (lvl[u] - (1 << i) >= lvl[v])
u = P[u][i];
}

if (u == v)
return u;

for (i = lg; i >= 0; i--) {
if (P[u][i] != -1 and P[u][i] != P[v][i])
u = P[u][i], v = P[v][i];
}

return P[u][0];
}

void solve() {

int n;
cin >> n;

for (int i = 1; i <= n; i++) {
st[i] = en[i] = lvl[i] = 0;
for (int j = 0; j < 22; j++) {
P[i][j] = -1;
}
}
timer = 0;

for (int i = 0; i < n - 1; i++) {
int x, y;
cin >> x >> y;
}

dfs(1, 0);
pre(1, 0);

int q;
cin >> q;

while (q--) {

int k;
cin >> k;

vector<int> path(k);

for (int i = 0; i < k; i++) {
cin >> path[i];
}

vector<pair<int, int> > v;

for (auto i : path) {
v.push_back(make_pair(lvl[i], i));
}

sort(v.rbegin(), v.rend());

vector<int>node;

node.push_back(v[0].S);

for (int i = 1; i < k; i++) {

bool got = false;

if (is_ancestor(v[i].S, v[i - 1].S)) {
got = true;

}

if (!got) {

node.push_back(v[i].S);

break;
}
}

if (node.size() == 1) {

int r = lvl[node[0]] - v[v.size() - 1].first + 1;

cout << "YES" << " " << r - k << endl;
continue ;
}

int lca_node = lca(node[0], node[1]);

int ok = 1;

for (auto i : path) {
if (i != lca_node and i != node[0] and i != node[1] and  (is_ancestor(i, node[0]) || is_ancestor(i, node[1])) and is_ancestor(lca_node, i)) {
ok = 1;

}
else if (i != lca_node and i != node[0] and i != node[1]) {
ok = 0; break;

}
}

if (ok) {

int r = lvl[node[0]] - lvl[lca_node] + lvl[node[1]] - lvl[lca_node] + 1;
cout << "YES" << " " << r - k << endl;

}
else {
cout << "NO\n";
}

}

}

#undef int
int main() {

#define int long long int
ios_base::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
#ifndef ONLINE_JUDGE
freopen("Error.txt", "w", stderr);
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
#endif

int t;
cin >> t;
while (t--) {
solve();
}

return 0;

}

Testerâ€™s Solution

import java.util.Scanner;
import java.util.Arrays;
import java.util.Comparator;
import java.util.*;
import java.io.IOException;
import java.util.Scanner;
import java.util.StringTokenizer;

public class Main{

StringTokenizer st;

{
}

String next()
{
while (st == null || !st.hasMoreElements()) {
try {
}
catch (IOException e) {
e.printStackTrace();
}
}
return st.nextToken();
}

int nextInt() { return Integer.parseInt(next()); }

long nextLong() { return Long.parseLong(next()); }

double nextDouble()
{
return Double.parseDouble(next());
}

String nextLine()
{
String str = "";
try {
}
catch (IOException e) {
e.printStackTrace();
}
return str;
}
}
// static Scanner at = new Scanner(System.in);
static int N = 200005;
static int timer = 0;
static int [] st = new int[N];
static int [] en = new int[N];
static int [] lvl = new int[N];
static int [][] P = new int[N][22];
//

public static boolean is_ancestor(int u, int v){
return (st[u] <= st[v] && en[u] >= en[v]);
}

//
public static void dfs(int node,int parent){
lvl[node] = 1+lvl[parent];
P[node][0] = parent;
st[node] = timer++;
if( i != parent){
dfs(i,node);
}
}
en[node] = timer++;
}
//
public static void pre(int u , int p){
P[u][0] = p;
for(int i = 1;i<22;i++){
P[u][i] = P[P[u][i - 1]][i - 1];
}
if(i != p){
pre(i,u);
}
}
}
//
public static int lca(int u,int v){
int i = 0;
int lg = 0;
if(lvl[u] < lvl[v]){
int temp  = u;
u = v;
v = temp;
}

for(lg = 0;(1<<lg) <= lvl[u] ; lg++);
lg--;

for(i = lg;i >= 0;i--){
// int x = ( ((1 << i) >= lvl[v]) == true) ? 1 :  0;
if(lvl[u] - (1<<i) >= lvl[v]){
u = P[u][i];
}
}

if(u == v)return u;

for(i = lg;i>=0 ;i--){
if (P[u][i] != -1 && P[u][i] != P[v][i]){
u = P[u][i]; v = P[v][i];
}
}

return P[u][0];

}
//
public static class Pair{
int first;
int second;
Pair(int x,int y){
this.first = x;
this.second = y;
}
}

//
public static class comp implements Comparator<Pair>{
public int compare(Pair a , Pair b){
if(a.first != b.first){
return -1*(a.first-b.first);
}
else{
return (-1*(a.second-b.second));
}
}
}

public static void solve(){
int n = at.nextInt();
for(int i = 0;i<=n;i++){
}
for(int i = 1;i<=n;i++){
st[i] = 0;en[i] = 0; lvl[i] = 0;
for(int j = 0;j<22;j++){
P[i][j] = -1;
}
}
timer = 0;

for(int i = 0;i<n-1;i++){
int x = at.nextInt();
int y = at.nextInt();
}

dfs(1,0);
pre(1,0);

int q = at.nextInt();

while(q > 0){
q--;
//
int k = at.nextInt();
ArrayList<Integer> path = new ArrayList();

for(int i = 0;i<k;i++){
int zeta = at.nextInt();
}

ArrayList<Pair> v = new ArrayList<Pair>();

for(int i : path){
}

Collections.sort(v,new comp());

ArrayList<Integer> node = new ArrayList<Integer>();

for(int i = 1;i<k;i++){
boolean got = false;

if(is_ancestor(v.get(i).second , v.get(i-1).second)){
got = true;
}

if(got == false){
break;
}
}
// v[v.size() - 1].first + 1;

if(node.size() == 1){
int r = lvl[node.get(0)] - v.get(v.size()-1).first + 1;

System.out.println("YES "+(r - k));
continue;
}

int lca_node = lca(node.get(0), node.get(1));

int ok = 1;

for(int i : path){

if(i != lca_node && i != node.get(0) && i != node.get(1) && (is_ancestor(i,node.get(0)) || is_ancestor(i,node.get(1)) ) && is_ancestor(lca_node,i)){
ok = 1;
}
else if(i != lca_node && i != node.get(0) && i != node.get(1)){
ok = 0;
break;
}
}

if(ok == 1){
int r = lvl[node.get(0)] - lvl[lca_node] + lvl[node.get(1)] - lvl[lca_node] + 1;
System.out.println("YES "+(r-k));
}
else{
System.out.println("NO");
}

// System.out.println();
// System.out.println();
//
}
//
}
public static void main(String[]args){
int T = at.nextInt();
while(T>0){
solve();
T--;
}
}
}

import sys
sys.setrecursionlimit(10000)

N = 200005
adj = [[] for i in range(N)]
timer = 0
st = [0 for i in range(N)]
en = [0 for i in range(N)]
lvl = [0 for i in range(N)]
P = [[0 for j in range(22)] for i in range(N)]

def is_ancestor(u,v):
return st[u] <= st[v] and en[u]>=en[v]

def dfs(node,parent):
global timer
lvl[node] = 1+lvl[parent]
P[node][0] = parent
st[node] = timer
timer+=1
if(i!=parent):
dfs(i,node)
en[node] = timer
timer+=1

def pre(u,p):
P[u][0]=p
for i in range(1,22):
P[u][i] = P[P[u][i-1]][i-1]
if i!=p:
pre(i,u)

def lca(u,v):
i=0
lg=0
if(lvl[u]<lvl[v]):
u,v=v,u
while((1<<lg)<=lvl[u]): lg+=1
lg-=1
for i in range(lg,-1,-1):
if lvl[u] - (1<<i) >= lvl[v]:
u = P[u][i]

if(u==v): return u

for i in range(lg,-1,-1):
if P[u][i]!=-1 and P[u][i] != P[v][i]:
u = P[u][i]
v = P[v][i]
return P[u][0]

class Pair:
first=0
second=0
def __init__(self,first,second):
self.first=first
self.second = second
def __lt__(self,other):
if(self.first==other.first):
return self.second<other.second
return self.first<other.first

def solve():
n = int(input())
timer=0
for i in range(n-1):
x,y = [int(j) for j in input().split()]

dfs(1,0)
pre(1,0)

q = int(input())
while q:
q-=1
k=int(input())
path = [int(j) for j in input().split()]
v = []
for i in path:
v.append(Pair(lvl[i],i))
v.sort(reverse=True)
node = []
node.append(v[0].second)
for i in range(1,k):
got = False
if is_ancestor(v[i].second, v[i-1].second):
got = True
if got == False:
node.append(v[i].second)
break

if len(node) == 1:
r = lvl[node[0]]-v[len(v)-1].first+1
print("YES "+str(r-k))
continue
lca_node = lca(node[0],node[1])
ok=1
for i in path:
if i!=lca_node and i!=node[0] and i!= node[1] and (is_ancestor(i, node[0]) or is_ancestor(i, node[1])) and is_ancestor(lca_node, i):
ok = 1
elif i!=lca_node and i!=node[0] and i!=node[1]:
ok=0
break

if ok==1:
r = lvl[node[0]] - lvl[lca_node] + lvl[node[1]] - lvl[lca_node]+1
print("YES "+str(r-k))
else:
print("NO")

t = int(input())
while t:
t-=1
solve()

â€‹

3 Likes