PROBLEM LINK:
Author: Jishnu Roychoudhury
Tester: Jatin Yadav
Editorialist: Jishnu Roychoudhury
Please note that the problem author listed on CodeChef is incorrect; it’s due to a quirk in the setting procedure. The author is Jishnu Roychoudhury, as listed here.
DIFFICULTY:
Medium
PREREQUISITES:
Trees, Randomisation, Sorting
PROBLEM:
You are not given a tree with N nodes. Using at most Q queries to the LCA of two nodes, recover the tree. The degree of each node is at most 3, and the grader is not adaptive.
Subtask 1 [15 points] : 1 \leq N \leq 100, Q = 5000
Subtask 2 [25 points]: Q = 25000. Each junction has at most two roads connected to it.
Subtask 3 [60 points]: Q = 25000.
QUICK EXPLANATION:
- Consider querying an arbitrary node x against all other nodes in the tree.
- The returned LCAs form the path from the root of the tree to x.
- We can sort the path using std::sort.
- Divide the remaining nodes into groups. If lca(x,y) = v, then place y in the group of v. Each group is also a tree, and can be solved recursively.
EXPLANATION:
Subtask 1:
In this subtask you can query all pairs of nodes. So, given all LCA pairs, you want to recover the tree. In addition, the required time complexity is O(N^3).
We can solve the subtask by building the tree down from the root. First find the root (all its LCAs with other nodes are itself). We then do N-1 steps of adding a node to the found tree. For each node not in the found tree, check if all of its LCAs are in the found tree. Then, it must be the child of a node in the found tree. The parent of the node v is x, where x is the deepest node in the found tree that satisfies lca(v,x) = x. After we have done this, we can go to the next step.
Trivia
This solution can be easily optimised to O(N^2). The previous solution is O(N^3) because every time we add a node to the tree, we have to iterate through possibly N nodes to find a node that adds to the tree (in average case, it is still N/2 in the case of line). To fix this, we can find any iterating order that guarantees that when we add a node it must be the child of a node in the found tree. There are multiple ways to do this; one is to sort by number of non-self LCAs.
Subtask 2:
In this subtask, the tree is a line.
Let’s consider a line where the head is also the root. Then, the LCA query compares which node is higher up in the line. Therefore, we can use an O(NlogN) sorting algorithm such as std::sort in STL to solve the line.
However, it is not necessary that the head is also the root - the root may be anywhere in the line. We can split the line into two separate lines with their heads as the root (based on which side of the root they are on) and solve the lines separately.
Take an arbitrary node x, and query it against all other nodes. If lca(x,y) = x or lca(x,y) = y, then x and y are on the same side of the root. Otherwise, they must be on different sides of the root. In this way, the line is split into two and each line can be solved by O(NlogN) sort.
Subtask 3:
Choose an arbitrary node x. Now, consider querying x against all other nodes.
If we look at the list of LCAs that we get, we can observe that this list is actually the path from the root to the node x. From the solution in subtask 2, we can simply sort this path.
Now, consider the remaining nodes. Let’s try grouping the nodes based on what lca(v,x) is. That is, we put all nodes with the same lca(v,x) into a single group. Then, what we find is that the group also constitutes a tree (the root of the tree being y, where y = lca(v,x)), and that all the groups are disjoint (obviously). So, we can divide into groups, and recursively solve the groups the same way.
It’s fairly intuitive that this should work fast, but we do have a proof.
Proof:
Let’s treat the chain as one of the groups. (So, the groups are the chain from root to x and all the trees hanging from this chain). Since we have to prove the bound of 2 \cdot N \log N, we can treat the chain as a group of size L / 2. (as sorting the chain needs \approx 2 \cdot (L / 2) \cdot \log(L / 2)). This is a little incorrect as sorting takes 2 \cdot (L/2) \cdot \log(L) (not \log(L / 2)) queries, but the inequalities later would be loose and it’ll get adjusted.
The chain-group will clearly have a size \leq N / 2. If x lies in the centroid’s subtree, all other groups also have size \leq N / 2. This clearly happens with a probability \geq 1 / 2.
Let the group sizes be L / 2 = x_1, x_2, ..., x_k, and f(n) be the expected number of queries used to solve n nodes.
f(N) \le \sum_{i=1}^{k} E[2 \cdot x_i \log(x_i)] + N
Since all sizes \le n / 2 with probability \ge 1 / 2,
f(N) \le \sum_{i=1}^{k} E[2 \cdot x_i \log(x_i)] \le 1 / 2 \cdot (2N/2 \cdot \log(N / 2) + 2N/2 \cdot \log(N / 2)) + 1 / 2 \cdot (2N \log(N)) + N = 2 N \cdot \log(N).
Therefore, overall query count is 2N\log(N). In fact, in real tests, the solution performs with around 12,000 queries in the worst case, which is line.
There are also two optimisations we can use to shave queries. The first is to find a leaf for our random node. The second is to use merge sort instead of std::sort as std::sort may use more than NlogN queries, while merge sort uses exactly NlogN queries. However, the query limit is very generous on the problem as we did not want to make it harder than it already is.
By the way, for those curious why there is a “2”, “City Mapping” was Singapore NOI 2018 Q4.
Edit:
There is also a deterministic solution to the problem, which can be seen in the comment by errorgorn on CodeForces: Invitation to UWCOI 2021 on CodeChef — Rated for All - Codeforces
SOLUTIONS:
Setter's Solution
#include "bits/stdc++.h"
using namespace std;
//exit codes used
//504 = LCA bad
//403 = wrong answer
//302 = didn't solve all
const int N=1005;
int n,q;
vector<int> adj[N];
int p[N];
int lca(int x, int y){
cout<<"? "<<x<<' '<<y<<endl;
int v; cin>>v;
if(v==-1) exit(504);
return v;
}
bool comp (int a, int b){
if (lca(a,b)==a) return true;
else return false;
//if a is closer to root, a first. if b is closer to root, b first.
}
void sol(){
int n,q; cin>>n>>q;
int rt = 1;
for(int i=2; i<=n; i++) rt = lca(rt,i);
p[rt] = -1;
deque<pair<vector<int>, int> > allgrps; //<group's elements, group's root>
vector<int> initgrp;
for (int i=1; i<=n; i++){
if(i==rt) continue;
initgrp.push_back(i);
}
allgrps.push_back({initgrp,rt});
while(!allgrps.empty()){ //terminate when all nodes solved
vector<int> currgrp = allgrps.front().first; //set of nodes in grp
//cout<<"CGSZ"<<currgrp.size()<<endl;
int curr_root = allgrps.front().second; //root of set
int pos_of_rando = rand()%currgrp.size();
int arbnode = currgrp[pos_of_rando]; //arbitrary node
//cout<<"ARB"<<arbnode<<endl;
vector<pair<int,int> > vals; //<node queried, lca>
for (auto i : currgrp){ //go through group
if (i == arbnode) continue; //don't query lca(u,u)
vals.push_back(make_pair(i,lca(i,arbnode))); //query against arbnode
}
unordered_set<int> unique_vals; //all the nodes on the path | must be unique
for (auto i : vals) unique_vals.insert(i.second);
unique_vals.insert(arbnode); //must have "arbnode" added
unique_vals.erase(curr_root); //root has already been solved for
vector<int> path_vals; //transfer to a vector for sorting
for (auto i : unique_vals){ path_vals.push_back(i); } //transfer
//path_vals = mergeSort(path_vals,0,path_vals.size()-1); //sorting nodes by proximity to root
sort(path_vals.begin(),path_vals.end(),comp);
//cout<<"PV0"<<path_vals[0]<<endl;
p[path_vals[0]] = curr_root; //connected to root of group
for (int i=1; i<path_vals.size(); i++) p[path_vals[i]] = path_vals[i-1];
//adding nodes to tree
unordered_map<int,vector<int> > nonon; //jakuzure
for (auto i : vals){
if (unique_vals.find(i.first) != unique_vals.end()) continue;
if (nonon.find(i.second) == nonon.end()){
vector<int> tmp; tmp.push_back(i.first);
nonon[i.second] = tmp;
}
else nonon[i.second].push_back(i.first);
}
for (auto mp : nonon){ //adding new groups to "allgrps"
allgrps.push_back(make_pair(mp.second, mp.first));
}
allgrps.pop_front(); //remove the group we just solved.
}
cout<<"! ";
for(int i=1; i<=n; i++){ if(p[i]==-5) exit(302); cout<<p[i]<<' ';}
cout<<endl;
int correct; cin>>correct;
if(correct == -1) exit(403);
}
int main(){
int t;
cin>>t;
srand(423); //can be anything
while(t--){
memset(p,-5,sizeof(p));
sol();
//do some re-initialisation stuff
for(int i=1; i<=n; i++) adj[i].clear();
}
}
Tester's Solution
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define pii pair<int, int>
#define all(c) ((c).begin()), ((c).end())
#define sz(x) ((int)(x).size())
template<class T,class U>
ostream& operator<<(ostream& os,const pair<T,U>& p){
os<<"("<<p.first<<", "<<p.second<<")";
return os;
}
template<class T>
ostream& operator <<(ostream& os,const vector<T>& v){
os<<"{";
for(int i = 0;i < (int)v.size(); i++){
if(i)os<<", ";
os<<v[i];
}
os<<"}";
return os;
}
template<class T>
ostream& operator <<(ostream& os,const set<T>& v){
os<<"{";
for(auto it = v.begin(); it != v.end(); it++){
if(it != v.begin())os<<", ";
os<<*it;
}
os<<"}";
return os;
}
template<class T, class V>
ostream& operator <<(ostream& os,const map<T, V>& v){
os<<"{";
for(auto it = v.begin(); it != v.end(); it++){
if(it != v.begin())os<<", ";
os<<*it;
}
os<<"}";
return os;
}
#ifdef LOCAL
#define cerr cout
#endif
#define TRACE
#ifdef TRACE
#define trace(...) __f(#__VA_ARGS__, __VA_ARGS__)
template <typename Arg1>
void __f(const char* name, Arg1&& arg1){
cerr << name << " : " << arg1 << std::endl;
}
template <typename Arg1, typename... Args>
void __f(const char* names, Arg1&& arg1, Args&&... args){
const char* comma = strchr(names + 1, ',');cerr.write(names, comma - names) << " : " << arg1<<" | ";__f(comma+1, args...);
}
#else
#define trace(...)
#endif
const int N = 1 << 10;
int par[N];
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
inline int getRand(int x, int y){
return uniform_int_distribution<int>(x, y)(rng);
}
map<pair<int, int>, int> asked;
int getLCA(int u, int v){
if(u == v) return u;
if(u > v) swap(u, v);
if(asked.find({u, v}) != asked.end()) return asked[{u, v}];
cout << "? " << u + 1 << " " << v + 1 << endl;
int x; cin >> x;
return asked[{u, v}] = x - 1;
}
void method_1(int root, vector<int> nodes){
if(nodes.empty()) return;
if((int) nodes.size() == 1){
par[nodes[0]] = root;
return;
}
int x = nodes[getRand(0, nodes.size() - 1)];
map<int, vector<int>> groups;
vector<int> roots;
for(int j : nodes){
int l = getLCA(j, x);
if(l == j) roots.push_back(j);
else groups[l].push_back(j);
}
sort(roots.begin(), roots.end(), [&](int a, int b){return getLCA(a, b) == a;});
for(int i = 1; i < (int) roots.size(); i++) par[roots[i]] = roots[i - 1];
par[roots[0]] = root;
for(auto it : groups) method_1(it.first, it.second);
}
void method_2(int root, vector<int> nodes){
if(nodes.empty()) return;
if((int) nodes.size() == 1){
par[nodes[0]] = root;
return;
}
for(int i = 1; i < nodes.size(); i++) swap(nodes[i], nodes[getRand(0, i)]);
// random_shuffle(nodes.begin(), nodes.end());
int curr = -1;
set<int> root_set;
map<int, vector<int>> groups;
for(int x : nodes){
if(curr == -1){
curr = x;
root_set.insert(x);
continue;
}
int l = getLCA(curr, x);
if(l == x){
root_set.insert(x);
} else if(l == curr){
root_set.insert(x);
curr = x;
} else{
groups[l].push_back(x);
root_set.insert(l);
}
}
vector<int> roots(root_set.begin(), root_set.end());
sort(roots.begin(), roots.end(), [&](int a, int b){return getLCA(a, b) == a;});
for(int i = 1; i < (int) roots.size(); i++) par[roots[i]] = roots[i - 1];
par[roots[0]] = root;
for(auto it : groups) method_2(it.first, it.second);
}
struct dsu{
int n;
vector<int> par;
dsu(int n) : n(n), par(n){
iota(par.begin(), par.end(), 0);
}
int root(int x){
return x == par[x] ? x : (par[x] = root(par[x]));
}
bool merge(int x, int y){
x = root(x); y = root(y);
if(x == y) return 0;
par[x] = y;
return 1;
}
};
int main(){
int t; cin >> t;
while(t--){
asked.clear();
memset(par, -1, sizeof par);
int n, Q; cin >> n >> Q;
// cerr << n << " " << Q << endl;
int rt = 0;
for(int i = 1; i < n; i++) rt = getLCA(rt, i);
vector<int>nodes;
for(int i = 0; i < n; i++) if(i != rt) nodes.push_back(i);
method_2(rt, nodes);
// cerr << "answer ready" << endl;
cout << "! ";
vector<int> deg(n);
vector<vector<int>> adj(n);
dsu D(n);
int T = 0;
for(int i = 0; i < n; i++){
if(i == rt){
assert(!T);
T = 1;
cout << -1 << " ";
}
else{
deg[i]++;
deg[par[i]]++;
assert(D.merge(i, par[i]));
cout << par[i] + 1 << " ";
}
}
assert(*max_element(all(deg)) <= 3);
assert(T);
cout << endl;
int z;
cin >> z;
assert(z == 1);
}
}