PROBLEM LINK:
Contest Division 1
Contest Division 2
Contest Division 3
Practice
Setter: Souradeep and Daanish Mahajan
Tester: Aryan
Editorialist: Taranpreet Singh
DIFFICULTY
Easy-Medium
PREREQUISITES
LCA, Euler tour of tree
PROBLEM
Given a tree with N nodes, we have to answer Q queries. Each query consists of K vertices, and we need to determine whether there exists a simple path in the tree, containing all these vertices.
QUICK EXPLANATION
- Let us sort the vertices in the query in decreasing order of depth. Let u = V_1 and find v to be the first vertex in V such that v is not an ancestor of u.
- If no such v exists, then all vertices can lie on a single path.
- Otherwise, the vertices can lie on a simple path if and only if all vertices in query lie on a simple path from u to v, which can be checked easily with some preprocessing
EXPLANATION
This problem has a variety of approaches to solve it. You can solve it by just using Euler tour and some basic observations, by building an auxiliary tree and checking if it’s a simple chain, or even using an overkill idea, maybe applying TALCA problem as a subproblem (just for fun)
I’d explain the idea I think is the simplest in terms of implementation.
The core idea is that we find a pair (u, v) such that we can claim that all vertices lie on a simple path if and only if they lie on the path from u to v.
Let us root the tree at any node. Now, consider a query, where set V denotes the vertices in the current query. Let’s pick the vertex u in V such that u is the deepest node present in set V.
Now, let’s pick another vertex v present in set V such that v is not an ancestor of u.
If there doesn’t exist such v, we can see that all nodes form a chain already, so the answer is YES.
Now, let’s assume we found such v.
Claim: All nodes in V can lie on the simple path if and only if all nodes lie on a path from u to v.
Proof: Assume tree is rooted at node R. Both u and v must lie on the path since they are in query. So, (u, v) is a candidate path. Let’s suppose L is the LCA of u and v. We have L \neq u and L \neq v. So, if we try to extend path (u, v), we would include nodes either in the subtree of node u or in the subtree of node v.
- No node in the subtree of u is included in the queried set, as u is the deepest node in the query set
- No node in the subtree of v is included in the queried set as if some node w in the subtree of node v was present in set V, then the depth of w is greater than the depth of v, so node w would have got chosen instead of node v.
Apart from node u and node v, all nodes on the path from node u to node v are connected to two vertices each, so we cannot include any other vertices, which are not already on the path from node u to node v.
Hence, the problem is reduced to checking that given path (u, v) and a set of vertices V, whether all vertices present in V lie on this path or not.
Unfortunately, the length of path (u, v) can be O(N), so we cannot just generate the whole path to answer each query. So we need to find a better way to check if some node w lie on path (u, v) or not.
Observation: If and only if node w lie on path from node u to node v, then dist(u, v) = dist(u, w) + dist(w, v).
This way, all we need to do is to optimize dist(a, b) queries which can be easily done by building euler tour of tree and computing LCA of nodes a and b, as we have dist(a, b) = depth_a + depth_b - 2*depth_L if L is LCA of a and b.
Implementation
Hence, we need to build an Euler tour to both answer LCA queries, and to check if one node is the ancestor of the other or not. We also need to compute the depth of each vertex. Either use RMQ, or binary lifting to answer LCA queries.
Bonus
Solve this problem using TALCA problem as a subproblem
Solving this problem using TALCA problem
Suppose the tree is rooted at node R.
For the query, find node v where node v is the deepest node present in queried nodes. Now, sort the vertices in V by their distance from v, and check if V_i is an ancestor of V_{i+1} when the tree is rooted at v.
TIME COMPLEXITY
The time complexity is O(N*log(N) + \sum K*log(K)) or O((N + \sum K)*log(N)) or even (N*log(N) + \sum K) per test case depending upon implementation.
SOLUTIONS
Setter's Solution
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
#define int long long int
#define ordered_set tree<int, nuint_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>
mt19937 rng(std::chrono::duration_cast<std::chrono::nanoseconds>(chrono::high_resolution_clock::now().time_since_epoch()).count());
#define mp make_pair
#define pb push_back
#define F first
#define S second
const int N=200005;
#define M 1000000007
#define BINF 1e9
#define init(arr,val) memset(arr,val,sizeof(arr))
#define MAXN 501
#define deb(xx) cout << #xx << " " << xx << "\n";
const int LG = 22;
vector<int> adj[N];
int timer = 0, st[N], en[N], lvl[N], P[N][LG];
void dfs(int node, int parent) {
lvl[node] = 1 + lvl[parent];
P[node][0] = parent;
st[node] = timer++;
for (int i : adj[node]) {
if (i != parent) {
dfs(i, node);
}
}
en[node] = timer++;
}
void pre(int u, int p){
P[u][0] = p;
for(int i = 1; i < LG; i++)
P[u][i] = P[P[u][i - 1]][i - 1];
for(auto i: adj[u])
if(i != p)
pre(i, u);
}
int lca(int u, int v){
int i, lg;
if (lvl[u] < lvl[v]) swap(u, v);
for(lg = 0; (1<<lg) <= lvl[u]; lg++);
lg--;
for(i=lg; i>=0; i--){
if (lvl[u] - (1<<i) >= lvl[v])
u = P[u][i];
}
if (u == v)
return u;
for(i=lg; i>=0; i--){
if (P[u][i] != -1 and P[u][i] != P[v][i])
u = P[u][i], v = P[v][i];
}
return P[u][0];
}
void add_edge(int x, int y){
adj[x].push_back(y);
adj[y].push_back(x);
}
void solve() {
int n;
cin >> n;
for(int i = 1; i <= n; i++){
adj[i].clear();
st[i] = en[i] = lvl[i] = 0;
for(int j = 0; j < LG; j++){
P[i][j] = -1;
}
}
timer = 0;
for(int i = 0; i < n - 1; i++){
int x, y;
cin >> x >> y;
add_edge(x, y);
}
dfs(1, 0);
pre(1, 0);
int q;
cin >> q;
while(q--){
int k;
cin >> k;
assert(k >= 1 and k <= n);
vector<int> path(k);
for(int i = 0; i < k; i++){
cin >> path[i];
}
vector<pair<int, int>> v;
for(auto i : path){
v.push_back(make_pair(st[i], i));
}
sort(v.begin(), v.end());
vector<int>node;
for(int i = 0; i < k; i++){
bool got = false;
if(i + 1 < k and en[v[i + 1].S] < en[v[i].S]){
got = true;
}
if(!got){
node.push_back(v[i].S);
}
}
if(node.size() > 2){
cout << "NO\n";
continue ;
}
if(node.size() == 1){
cout << "YES\n";
continue ;
}
assert(node.size() == 2);
int lca_node = lca(node[0], node[1]);
int ok = 1;
for(auto i : path){
if(i != lca_node and lvl[lca_node] >= lvl[i]){
ok = 0;
break;
}
}
if(ok){
cout << "YES\n";
}else{
cout << "NO\n";
}
}
}
#undef int
int main() {
#define int long long int
ios_base::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("optput.txt", "w", stdout);
#endif
int t;
cin >> t;
while(t--){
solve();
}
return 0;
}
Tester's Solution
/* in the name of Anton */
/*
Compete against Yourself.
Author - Aryan (@aryanc403)
Atcoder library - https://atcoder.github.io/ac-library/production/document_en/
*/
#ifdef ARYANC403
#include <header.h>
#else
#pragma GCC optimize ("Ofast")
#pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx")
//#pragma GCC optimize ("-ffloat-store")
#include<bits/stdc++.h>
#define dbg(args...) 42;
#endif
using namespace std;
#define fo(i,n) for(i=0;i<(n);++i)
#define repA(i,j,n) for(i=(j);i<=(n);++i)
#define repD(i,j,n) for(i=(j);i>=(n);--i)
#define all(x) begin(x), end(x)
#define sz(x) ((lli)(x).size())
#define pb push_back
#define mp make_pair
#define X first
#define Y second
#define endl "\n"
typedef long long int lli;
typedef long double mytype;
typedef pair<lli,lli> ii;
typedef vector<ii> vii;
typedef vector<lli> vi;
const auto start_time = std::chrono::high_resolution_clock::now();
void aryanc403()
{
#ifdef ARYANC403
auto end_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end_time-start_time;
cerr<<"Time Taken : "<<diff.count()<<"\n";
#endif
}
long long readInt(long long l, long long r, char endd) {
long long x=0;
cin>>x;return x;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true) {
char g=getchar();
if(g=='-') {
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g&&g<='9') {
x*=10;
x+=g-'0';
if(cnt==0) {
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd) {
if(is_neg) {
x=-x;
}
assert(l<=x&&x<=r);
return x;
} else {
assert(false);
}
}
}
string readString(int l, int r, char endd) {
string ret="";
int cnt=0;
while(true) {
char g=getchar();
assert(g!=-1);
if(g==endd) {
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt&&cnt<=r);
return ret;
}
long long readIntSp(long long l, long long r) {
return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
return readString(l,r,' ');
}
void readEOF(){
assert(getchar()==EOF);
}
vi readVectorInt(int n,lli l,lli r){
vi a(n);
for(int i=0;i<n-1;++i)
a[i]=readIntSp(l,r);
a[n-1]=readIntLn(l,r);
return a;
}
const lli INF = 0xFFFFFFFFFFFFFFFL;
lli seed;
mt19937 rng(seed=chrono::steady_clock::now().time_since_epoch().count());
inline lli rnd(lli l=0,lli r=INF)
{return uniform_int_distribution<lli>(l,r)(rng);}
class CMP
{public:
bool operator()(ii a , ii b) //For min priority_queue .
{ return ! ( a.X < b.X || ( a.X==b.X && a.Y <= b.Y )); }};
void add( map<lli,lli> &m, lli x,lli cnt=1)
{
auto jt=m.find(x);
if(jt==m.end()) m.insert({x,cnt});
else jt->Y+=cnt;
}
void del( map<lli,lli> &m, lli x,lli cnt=1)
{
auto jt=m.find(x);
if(jt->Y<=cnt) m.erase(jt);
else jt->Y-=cnt;
}
bool cmp(const ii &a,const ii &b)
{
return a.X<b.X||(a.X==b.X&&a.Y<b.Y);
}
const lli mod = 1000000007L;
// const lli maxN = 1000000007L;
#include <algorithm>
#include <cassert>
#include <vector>
namespace atcoder {
struct dsu {
public:
dsu() : _n(0) {}
explicit dsu(int n) : _n(n), parent_or_size(n, -1) {}
int merge(int a, int b) {
assert(0 <= a && a < _n);
assert(0 <= b && b < _n);
int x = leader(a), y = leader(b);
if (x == y) return x;
if (-parent_or_size[x] < -parent_or_size[y]) std::swap(x, y);
parent_or_size[x] += parent_or_size[y];
parent_or_size[y] = x;
return x;
}
bool same(int a, int b) {
assert(0 <= a && a < _n);
assert(0 <= b && b < _n);
return leader(a) == leader(b);
}
int leader(int a) {
assert(0 <= a && a < _n);
if (parent_or_size[a] < 0) return a;
return parent_or_size[a] = leader(parent_or_size[a]);
}
int size(int a) {
assert(0 <= a && a < _n);
return -parent_or_size[leader(a)];
}
std::vector<std::vector<int>> groups() {
std::vector<int> leader_buf(_n), group_size(_n);
for (int i = 0; i < _n; i++) {
leader_buf[i] = leader(i);
group_size[leader_buf[i]]++;
}
std::vector<std::vector<int>> result(_n);
for (int i = 0; i < _n; i++) {
result[i].reserve(group_size[i]);
}
for (int i = 0; i < _n; i++) {
result[leader_buf[i]].push_back(i);
}
result.erase(
std::remove_if(result.begin(), result.end(),
[&](const std::vector<int>& v) { return v.empty(); }),
result.end());
return result;
}
private:
int _n;
std::vector<int> parent_or_size;
};
} // namespace atcoder
#define rep(i, a, b) for(int i = a; i < (b); ++i)
#define trav(a, x) for(auto& a : x)
// #define all(x) begin(x), end(x)
// #define sz(x) (int)(x).size()
typedef long long ll;
typedef pair<int, int> pii;
// typedef vector<int> vi;
typedef vector<pii> vpi;
typedef vector<vpi> graph;
graph readTree(lli n){
graph e(n);
atcoder::dsu d(n);
for(lli i=1;i<n;++i){
const lli u=readIntSp(1,n)-1;
const lli v=readIntLn(1,n)-1;
e[u].pb({v,1});
e[v].pb({u,1});
d.merge(u,v);
}
assert(d.size(0)==n);
return e;
}
template<class T>
struct RMQ {
vector<vector<T>> jmp;
RMQ(const vector<T>& V) {
int N = sz(V), on = 1, depth = 1;
while (on < N) on *= 2, depth++;
jmp.assign(depth, V);
rep(i,0,depth-1) rep(j,0,N)
jmp[i+1][j] = min(jmp[i][j],
jmp[i][min(N - 1, j + (1 << i))]);
}
T query(int a, int b) {
assert(a < b); // or return inf if a == b
int dep = 31 - __builtin_clz(b - a);
return min(jmp[dep][a], jmp[dep][b - (1 << dep)]);
}
};
struct LCA {
vi time;
vector<ll> dist;
RMQ<pii> rmq;
LCA(graph& C) : time(sz(C), -99), dist(sz(C)), rmq(dfs(C)) {}
vpi dfs(graph& C) {
vector<tuple<int, int, int, ll>> q(1);
vpi ret;
int T = 0, v, p, d; ll di;
while (!q.empty()) {
tie(v, p, d, di) = q.back();
q.pop_back();
if (d) ret.emplace_back(d, p);
time[v] = T++;
dist[v] = di;
trav(e, C[v]) if (e.first != p)
q.emplace_back(e.first, v, d+1, di + e.second);
}
return ret;
}
int query(int a, int b) {
if (a == b) return a;
a = time[a], b = time[b];
return rmq.query(min(a, b), max(a, b)).second;
}
ll distance(int a, int b) {
int lca = query(a, b);
return dist[a] + dist[b] - 2 * dist[lca];
}
};
vii buildEulerTour(const graph &e){
const lli n=sz(e);
lli tim=0;
vii tinout(n,mp(-1,-1));
function<void(lli,lli)> dfs2 = [&](lli u,lli p){
tinout[u].X=tim++;
for(auto x:e[u]){
if(p==x.X)
continue;
dfs2(x.X,u);
}
tinout[u].Y=tim++;
};
dfs2(0,-1);
return tinout;
}
int main(void) {
ios_base::sync_with_stdio(false);cin.tie(NULL);
// freopen("txt.in", "r", stdin);
// freopen("txt.out", "w", stdout);
// cout<<std::fixed<<std::setprecision(35);
lli T=readIntLn(1,10);
lli sumN = 0;
while(T--)
{
const lli n=readIntLn(1,1e5);
sumN += n;
dbg(n);
auto g = readTree(n);
dbg(sz(g));
lli q=readIntLn(1,1e5);
dbg(q);
lli sumK = 0;
vector<bool> vis(n);
LCA lca(g);
auto tinout = buildEulerTour(g);
auto solveQuery=[&](vi &a){
if(sz(a)<=2)
return true;
lli lc=a[0];
for(auto x:a)
lc=lca.query(lc,x);
sort(all(a),[&](const lli x,const lli y){
return tinout[x].X<tinout[y].X;
});
reverse(all(a));
if(a.back()==lc)
a.pop_back();
const lli k=sz(a);
dbg(lc,a);
lli cnt=0;
for(lli i=0;i+1<k;++i){
const lli u=a[i],v=a[i+1];
const lli l=lca.query(u,v);
dbg(u,v,l);
if(l==lc){
if(cnt)
return false;
cnt++;
continue;
}
if(l!=v)
return false;
}
return true;
};
while(q--){
const lli k = readIntSp(1,n);
dbg(k);
sumK += k;
auto a = readVectorInt(k,1,n);
dbg(sz(a));
for(auto &x:a)
{
x--;
assert(!vis[x]);
vis[x]=true;
}
for(auto &x:a)
vis[x]=false;
cout<<(solveQuery(a)?"YES":"NO")<<endl;
dbg("query over");
}
// assert(sumK<=1e6);
}
// assert(sumN<=2e5);
aryanc403();
// readEOF();
return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class KPATHQRY{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
int N = ni();
int[] from = new int[N-1], to = new int[N-1];
for(int i = 0; i< N-1; i++){
from[i] = ni()-1;
to[i] = ni()-1;
}
int[][] tree = tree(N, from, to);
LCA lca = new LCA(tree);
int[] dep = new int[N], st = new int[N], en = new int[N];
count = -1;
dfs(tree, dep, st, en, 0, -1);
for(int Q = ni(); Q> 0; Q--){
int K = ni();
Integer[] V = new Integer[K];
for(int i = 0; i< K; i++)V[i] = ni()-1;
Arrays.sort(V, (Integer i1, Integer i2) -> Integer.compare(dep[i1], dep[i2]));
int u = V[K-1], v = -1;
for(int i = K-2; i>= 0; i--){
if(st[V[i]] <= st[V[K-1]] && en[V[K-1]] <= en[V[i]])continue;
v = V[i];
break;
}
if(v == -1){
pn("YES");
continue;
}
//Only path from u to v can be good, so we check if all vertices lie on this path or not
boolean good = true;
int dist = lca.dist(u, v);
for(int x:V)good &= lca.dist(u, x)+lca.dist(x, v) == dist;
pn(good?"YES":"NO");
}
}
int count;
void dfs(int[][] tree, int[] dep, int[] st, int[] en, int u, int p){
st[u] = ++count;
for(int v:tree[u])if(v != p){
dep[v] = dep[u]+1;
dfs(tree, dep, st, en, v, u);
}
en[u] = count;
}
int[][] tree(int N, int[] from, int[] to){
int[] cnt = new int[N];
for(int x:from)cnt[x]++;
for(int x:to)cnt[x]++;
int[][] g = new int[N][];
for(int i = 0; i< N; i++)g[i] = new int[cnt[i]];
for(int i = 0; i< N-1; i++){
g[from[i]][--cnt[from[i]]] = to[i];
g[to[i]][--cnt[to[i]]] = from[i];
}
return g;
}
class LCA{
int n = 0, ti= -1;
int[] eu, fi, d;
RMQ rmq;
public LCA(int[][] g){
n = g.length;
eu = new int[2*n-1];fi = new int[n];d = new int[n];
Arrays.fill(fi, -1);Arrays.fill(eu, -1);
dfs(g, 0, -1);
rmq = new RMQ(eu, d);
}
public LCA(int[] eu, int[] fi, int[] d){
this.n = eu.length;
this.eu = eu;
this.fi = fi;
this.d = d;
rmq = new RMQ(eu, d);
}
void dfs(int[][] g, int u, int p){
eu[++ti] = u;fi[u] = ti;
for(int v:g[u])if(v!=p){
d[v] = d[u]+1;
dfs(g, v, u);eu[++ti] = u;
}
}
int lca(int u, int v){return rmq.query(Math.min(fi[u], fi[v]), Math.max(fi[u], fi[v]));}
int dist(int u, int v){return d[u]+d[v]-2*d[lca(u,v)];}
class RMQ{
int[] len, d;
int[][] rmq;
public RMQ(int[] ar, int[] weight){
len = new int[ar.length+1];
this.d = weight;
for(int i = 2; i<= ar.length; i++)len[i] = len[i>>1]+1;
rmq = new int[len[ar.length]+1][ar.length];
for(int i = 0; i< rmq.length; i++)
for(int j = 0; j< rmq[i].length; j++)
rmq[i][j] = -1;
for(int i = 0; i< ar.length; i++)rmq[0][i] = ar[i];
for(int b = 1; b<= len[ar.length]; b++)
for(int i = 0; i + (1<<b)-1< ar.length; i++)
if(weight[rmq[b-1][i]]<weight[rmq[b-1][i+(1<<(b-1))]])rmq[b][i] =rmq[b-1][i];
else rmq[b][i] = rmq[b-1][i+(1<<(b-1))];
}
int query(int l, int r){
if(l==r)return rmq[0][l];
int b = len[r-l];
if(d[rmq[b][l]]<d[rmq[b][r-(1<<b)]])return rmq[b][l];
return rmq[b][r-(1<<b)];
}
}
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new KPATHQRY().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Feel free to share your approach. Suggestions are welcomed as always.