B_BRANCH - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Authors: zxy090909
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

3546

PREREQUISITES:

Binary search, finding centroids of a tree

PROBLEM:

You’re given a tree on N vertices.
Two hidden vertices of this tree are marked as special, and the path between them is called the beautiful branch.
Your task is to find any one of the special nodes.

To achieve this, you can ask upto 300 queries of the following type:

  • Provide a set of vertices v_1, v_2, \ldots, v_k to the judge.
  • The judge will return \sum_{i=1}^k d(v_i), where d(v_i) is the distance between v_i and the beautiful branch.

Also, the total sizes of subsets queried cannot exceed 35N.

EXPLANATION:

Let’s root the tree at vertex 1.

Now, suppose we query for the single vertex 1 and obtain d(1).
Also, suppose it has k children, say v_1, v_2, \ldots, v_k.
We have the following cases:

  • If 1 is the unique special vertex, we’ll have d(v_i) = 1 for every v_i.
  • If 1 is one end of the beautiful branch, we’ll have d(v_i) = 0 for exactly one child, and d(v_i) = 1 for all the rest.
  • if 1 lies on the beautiful branch but isn’t an endpoint, we’ll have d(v_i) = 0 for exactly two children and d(v_i) = 1 for the rest.
  • The only remaining case is when the beautiful branch doesn’t contain 1 at all.
    In this case, it’ll be contained in the subtree of one of its children, say v_i.
    Then, we’ll have d(v_i) = d(1)-1, and d(v_j) = d(1)+1 for all other v_j.

If we’re able to quickly distinguish between these cases, we can figure out which subtree to move into and recursively solve for that subtree.

The fourth case is somewhat easy to deal with. We know it’ll only happen when d(1) \gt 0 so it can be treated separately.
We want to find the unique child with a smaller value than the rest.
Let’s binary search on the children from 1 to k, each time querying for the vertices v_1, v_2, \ldots, v_i.
If this sum equals i\cdot (d(1) + 1), the child we’re looking for is after i; otherwise it’s between 1 and i.
This way, we take \mathcal{O}(\log N) queries to find the appropriate child, and we can recurse.

Now we need to distinguish between the other three cases.
Case 1 is easy: it happens if and only if \sum_{i=1}^k d(v_i) = k, so it can be checked for with one query.

Cases 2 and 3 can in fact be combined, and once again solved with binary search.
We just want to find one child that the beautiful branch goes into, so we can for example binary search to find the smallest i such that d(v_1) + d(v_2) + \ldots + d(v_i) \lt i, since all values are either 0 or 1.

Once the appropriate child is found, we can recurse into it and once again use this procedure, till we find an endpoint (which refers to hitting case 1).

This works … almost.
We need to think about the recursion depth. We use upto \log N queries each time, so if we recurse D times, we’ll use \mathcal{O}(D\log N) queries in total. This number needs to be kept within 300.

Here, we make the observation that we don’t quite need to move directly to a child of vertex 1: we move into its subtree, so we’re free to choose our root within this subtree.

This freedom of choice of root is quite nice, since it allows us to bound the recursion depth nicely.
At each step, choose the centroid of the tree you’re in.
That way, each time we move into a subtree we know its size at least halves, so after \log N steps we’ll reach a tree of size 1 which is trivial to solve.
This way, we use about \log^2 N queries, which fits within 300.

However, doing this slightly breaks other parts of our solution!
In particular, cases 1, 2, and 4 from above remain the same.
However, case 3 requires care.

When there are two possible branches to move into, it’s not always correct to move into an arbitrary one.
This is because when we moved into this subtree, it’s possible it contained only one of the two special vertices (for example, if we entered it via cases 2 or 3).
So, while the beautiful branch may be present in both subtrees, only one of them will contain an endpoint of it, and we need to choose this carefully.

One way to resolve this issue is to use distances.
Suppose we solve for centroid c_1, and have decided which subtree to go into.
Let c_2 be this subtree.
If we run into an ambiguity at c_2, it means that the branch passes through c_2 — in particular, it means one endpoint is in the subtree of c_2 (when c_1 is considered as the root).

So, let’s compute the distances of all vertices from c_1.
Then, if there happens to be an ambiguity at c_2, simply choose the child that’s further away from c_1.

Note that this requires us to find both children of c_2 that might satisfy the condition instead of just one, hence needing two binary searches.
However, this still fits within the query limit.

TIME COMPLEXITY

\mathcal{O}(N \log N) per test case.

CODE:

Setter's code (C++)
#define ll long long int
#include<bits/stdc++.h>
#define loop(i,a,b) for(ll i=a;i<b;++i)
#define rloop(i,a,b) for(ll i=a;i>=b;i--)
#define in(a,n) for(ll i=0;i<n;++i) cin>>a[i];
#define pb push_back
#define mk make_pair
#define all(v) v.begin(),v.end()
#define dis(v) for(auto i:v)cout<<i<<" ";cout<<endl;
#define display(arr,n) for(int i=0; i<n; i++)cout<<arr[i]<<" ";cout<<endl;
#define fast ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);srand(time(NULL));
#define l(a) a.length()
#define s(a) (ll)a.size()
#define fr first
#define sc second
#define mod 1000000007
// #define endl '\n'
#define yes cout<<"Yes"<<endl;
#define no cout<<"No"<<endl;
using namespace std;
#define debug(x) cerr << #x<<" "; _print(x); cerr << endl;
void _print(ll t) {cerr << t;}
void _print(int t) {cerr << t;}
void _print(string t) {cerr << t;}
void _print(char t) {cerr << t;}
void _print(double t) {cerr << t;}
template <class T, class V> void _print(pair <T, V> p);
template <class T> void _print(vector <T> v);
template <class T> void _print(set <T> v);
template <class T, class V> void _print(map <T, V> v);
template <class T> void _print(multiset <T> v);
template <class T, class V> void _print(pair <T, V> p) {cerr << "{"; _print(p.fr); cerr << ","; _print(p.sc); cerr << "}";}
template <class T> void _print(vector <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template <class T> void _print(set <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template <class T> void _print(multiset <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template <class T, class V> void _print(map <T, V> v) {cerr << "[ "; for (auto i : v) {_print(i); cerr << " ";} cerr << "]";}

ll add(ll x,ll y)  {ll ans = x+y; return (ans>=mod ? ans - mod : ans);}
ll sub(ll x,ll y)  {ll ans = x-y; return (ans<0 ? ans + mod : ans);}
ll mul(ll x,ll y)  {ll ans = x*y; return (ans>=mod ? ans % mod : ans);}
vector<vector<int>> vec,vec2;
vector<int> mx_sz,tot_sz,dis,far_dist;
int n,a,b,found_that_node;
bool ok = 0,ok1 = 1;
map<vector<int>,ll> queries_found;
ll counter = 0,total_nodes = 0;


/* ---------------------------------------------------------------------------------------- */

pair<int,int> dfs(int i,int par){
    tot_sz[i] = 1;
    for(auto j:vec[i])  {
        if(j!=par)  {
            pair<int,int> p = dfs(j,i);
            mx_sz[i] = max(mx_sz[i],p.fr);
            tot_sz[i]+=p.sc;
        }
    }
    return {tot_sz[i],tot_sz[i]};
}

int dfs2(int i,int par,int vll){
    if(mx_sz[i] <= vll/2 && vll-tot_sz[i] <= vll/2)    return i;
    int ans = -1;
    for(auto j:vec[i]){
        if(j == par)    continue;
        int vl = dfs2(j,i,vll);
        if(vl!=-1)  ans = vl;
    }
    return ans;
}


// finding centroid in O(n), since we have to find it at most log(n) times.
ll find_centroid(){
    mx_sz.assign(n+1,0);
    tot_sz.assign(n+1,0);
    int idx = -1;
    loop(i,1,n+1)   if(vec[i].size())   {idx = i;   break;}
    if(idx == -1)   return 1;
    pair<int,int> p =dfs(idx,-1);
    int n_z = 0;
    loop(i,1,n+1) n_z+=(int)vec[i].size();
    return dfs2(idx,-1,n_z/2+1);
}

/* ---------------------------------------------------------------------------------------- */

// function for the query
ll query(vector<int> &vec,int l,int r){
    vector<int> vv;
    loop(i,l,r+1)   vv.pb(vec[i]);
    // if already asked
    if(queries_found.find(vv)!=queries_found.end()) return queries_found[vv];  
    total_nodes+=(r-l+1);
    cout<<"? "<<r-l+1;
    loop(i,l,r+1)   cout<<' '<<vec[i];
    cout<<endl;
    ll ans;    cin>>ans;
    queries_found[vv] = ans;
    return ans;
}

/* ---------------------------------------------------------------------------------------- */
void fnn3(int i,int par,int dep){
    far_dist[i] = dep;
    for(auto j:vec2[i]) if(j!=par)  fnn3(j,i,dep+1);
}


void dfs3(int i,int par,int not_take,vector<pair<int,int>> &edges){
    if(i == not_take)   return;
    if(par!=-1){
        edges.push_back({i,par});
    }
    for(auto j:vec[i])  if(j!=par)  dfs3(j,i,not_take,edges);
}


// deleting the found edge and reforming the tree
bool clear_tree(int to,int from){
    vector<pair<int,int>> edges;
    dfs3(to,-1,from,edges);
    if(edges.size() == 0)   return 1;
    vec.assign(n+1,{});
    for(auto i:edges)   vec[i.fr].pb(i.sc),vec[i.sc].pb(i.fr);
    return 0;
}
/* ---------------------------------------------------------------------------------------- */


// doing binary search
ll binary_search(ll centroid,ll l,ll r){
    ll found = -1;
    vector<int> qq(1,centroid);
    ll q3 = query(qq,0,0);
    bool okk = 0;
    if(q3 == 0){    // if centroid is on the path
        ll L = 0,R = vec2[centroid].size() - 1;
        ll q = query(vec2[centroid],L,R),qq3 = query(vec[centroid],l,r);
        if(R-L+1 == 1 || (R-L+1 == 2 && !q) || (q+2) == R-L+1)  found_that_node = centroid;
        if((r-l+1 == 2 && !qq3) || (qq3+2 == r-l+1))    okk = 1;
        if(R-L+1 == 1) found = centroid,ok = 1;
        else if(R-L+1 == 2 && q)   found = centroid,ok = 1;
        else if(q+1 == R-L+1)   found = centroid,ok = 1;
        // ok will be one if conditions fall into case-2
        if(ok)  return found;
    }
    // bool okk will be 1 if centroid lies on the special path and two adjacent node is as well on the path.
    while(l<r){
        ll mid = (l+r)/2;
        // if only 3 or 4 elements are left
        if(mid-l+1 == 2){
            set<int> eligi;
            loop(i,l,r+1){
                ll vl = query(vec[centroid],i,i);
                if(vl == 0 || vl == q3-1)   eligi.insert(vec[centroid][i]);
            }
            ll mx = -1,vl = -1;
            for(auto i:eligi)   if(far_dist[i] > mx) mx = far_dist[i],vl = i;
            // no node is present on the path
            if(vl == -1)    return found;
            return vl;
        }
        // if only 1 or 2 elements are left
        if(mid-l+1 == 1){
            ll q1 = query(vec[centroid],l,l);
            ll q2 = query(vec[centroid],r,r);
            if((q2 < q3 || q2 == 0) && (q1 == 0 || q1 < q3)){
                if(far_dist[vec[centroid][r]] > far_dist[vec[centroid][l]])   return vec[centroid][r];
                else return vec[centroid][l];
            }
            else if(q2 < q3 || q2 == 0)    return vec[centroid][r];
            if(q1 < q3 || q1 == 0) r = mid,found = vec[centroid][mid];
            else l = mid+1;
            continue;
        }
        ll q = query(vec[centroid],l,mid);
        if((q+2)%(mid-l+1) == 0)  r = mid,found = vec[centroid][mid];   // if centroid is on the path and two of the child has sum as D(C)-1, Case-1
        else if((q+1)%(mid-l+1) == 0 && okk) {  // if okk is 1 and and one child is present in left to mid and other in mid+1 to r.
            okk = 0;
            ll q1 = binary_search(centroid,l,mid);
            ll q2 = binary_search(centroid,mid+1,r);
            if(q2 == -1)    r = mid,found = vec[centroid][mid];
            else {
                // check which child is farthest
                if(far_dist[q1] < far_dist[q2])   l = mid,found = vec[centroid][mid]; 
                else r = mid,found = vec[centroid][mid];
            }
        }
        else if((q+1)%(mid-l+1) == 0)   r = mid,found = vec[centroid][mid]; // case-3
        else l = mid+1;
    }
    if(l == r){
        ll q = query(vec[centroid],l,l);
        if(q == 0)  found = vec[centroid][l];
    }
    return found;
}

ll tot = 0;

void solve(){
    cin>>n;
    tot+=n;
    assert(n>=1);
    vec.assign(n+1,{});
    loop(i,0,n-1){
        ll a,b; cin>>a>>b;
        vec[a].pb(b);
        vec[b].pb(a);
    }
    vec2 = vec;
    queries_found.clear();
    ok = 0,ok1 = 1,found_that_node = -1;
    int one_node = -1;
    far_dist.assign(n+1,0);
    bool done = 0;
    total_nodes = 0;
    while(1){
        int centroid = find_centroid(); // finding centroid
        ll l = 0,r = (ll)vec[centroid].size() - 1;
        ll found = binary_search(centroid,l,r); // find adjacent node that is on the path
        // if ok is one case-2
        if(ok){
            one_node = found;   break;
        }
        // no adjacent node has value less then D(C)
        if(found == -1){
            one_node = centroid;   break;
        }
        // first time found centroid that is on the path and two child are on the path as well.
        if(found_that_node == -1 || !done){
            fnn3(centroid,-1,0);
            if(found_that_node != -1)   done = 1;
        }
        // deleting the edge between found node and centroid and reforming the tree
        if(clear_tree(found,centroid)){
            one_node = found;   break;
        }
    }
    assert(total_nodes <= 4*n);
    // assert(total_nodes <= 2*n*(ceil(log2(n))));
    cout<<"! "<<one_node<<endl;
}


int main()
{
    fast
    int t; cin>>t;
    assert(t>=1 && t<=4e3);
    while(t--) solve();
    assert(tot<=1e5);
    // assert(total_nodes <= 2e6);
    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct dsu {
    vector<int> p;
    vector<int> sz;
    int n;

    dsu(int _n) : n(_n) {
        p = vector<int>(n);
        iota(p.begin(), p.end(), 0);
        sz = vector<int>(n, 1);
    }

    inline int get(int x) {
        if (p[x] == x) {
            return x;
        } else {
            return p[x] = get(p[x]);
        }
    }

    inline bool unite(int x, int y) {
        x = get(x);
        y = get(y);
        if (x == y) {
            return false;
        }
        p[x] = y;
        sz[y] += sz[x];
        return true;
    }

    inline bool same(int x, int y) {
        return (get(x) == get(y));
    }

    inline int size(int x) {
        return sz[get(x)];
    }

    inline bool root(int x) {
        return (x == get(x));
    }
};

int main() {
    int tt;
    cin >> tt;
    assert(1 <= tt && tt <= (int) 4e3);
    int sn = 0;
    while (tt--) {
        // input
        int n;
        cin >> n;
        assert(1 <= n && n <= (int) 1e5);
        sn += n;
        dsu uf(n);
        vector<vector<int>> g(n);
        for (int i = 0; i < n - 1; i++) {
            int x, y;
            cin >> x >> y;
            x--;
            y--;
            g[x].emplace_back(y);
            g[y].emplace_back(x);
            assert(uf.unite(x, y));
        }

        // alive
        vector<int> a(n, 1);

        // memo when k = 1
        vector<long long> d(n, -1);
        auto Ask = [&](int i) {
            if (d[i] == -1) {
                cout << "? 1 " << i + 1 << endl;
                cin >> d[i];
            }
            return d[i];
        };

        int w = -1;
        while (accumulate(a.begin(), a.end(), 0) > 2) {
            // find centroid
            int r;
            vector<int> c;
            vector<int> pv(n, -1);
            {
                r = (int) (max_element(a.begin(), a.end()) - a.begin());
                vector<int> sz(n, 1);
                vector<int> b(n);
                {
                    function<void(int, int)> Dfs = [&](int v, int p) {
                        for (int to : g[v]) {
                            if (to == p || !a[to]) {
                                continue;
                            }
                            Dfs(to, v);
                            sz[v] += sz[to];
                            b[v] = max(b[v], sz[to]);
                        }
                    };
                    Dfs(r, -1);
                }
                for (int i = 0; i < n; i++) {
                    if (!a[i]) {
                        b[i] = n + 1;
                    } else if (i != r) {
                        b[i] = max(b[i], sz[r] - sz[i]);
                    }
                }
                r = (int) (min_element(b.begin(), b.end()) - b.begin());
                for (int to : g[r]) {
                    if (a[to]) {
                        c.emplace_back(to);
                    }
                }
                {
                    function<void(int, int)> Dfs = [&](int v, int p) {
                        for (int to : g[v]) {
                            if (to == p || !a[to]) {
                                continue;
                            }
                            Dfs(to, v);
                            pv[to] = v;
                        }
                    };
                    Dfs(r, -1);
                }
            }

            // solve
            vector<int> new_a(n);
            function<void(int, int)> Dfs = [&](int v, int p) {
                new_a[v] = 1;
                for (int to : g[v]) {
                    if (to == p || !a[to]) {
                        continue;
                    }
                    Dfs(to, v);
                }
            };

            long long k = Ask(r) + 1;
            int sz = (int) c.size();
            auto Check = [&](int x, int y) {
                long long t;
                if (x == y) {
                    t = Ask(c[x]);
                } else {
                    cout << "? " << y - x + 1;
                    for (int i = x; i <= y; i++) {
                        cout << " " << c[i] + 1;
                    }
                    cout << endl;
                    cin >> t;
                }
                if (t == (y - x + 1) * k) {
                    return false;
                } else {
                    return true;
                }
            };
            int x, y;
            {
                int low = -1, high = sz - 1;
                while (high - low > 1) {
                    int mid = (high + low) >> 1;
                    if (Check(0, mid)) {
                        high = mid;
                    } else {
                        low = mid;
                    }
                }
                x = high;
            }
            {
                int low = 0, high = sz;
                while (high - low > 1) {
                    int mid = (high + low) >> 1;
                    if (Check(mid, sz - 1)) {
                        low = mid;
                    } else {
                        high = mid;
                    }
                }
                y = low;
            }
            if (Ask(r) == 0) {
                if (x == y) {
                    Dfs(c[x], r);
                    new_a[r] = 1;
                } else if (Ask(c[x]) == 0 && Ask(c[y]) == 0) {
                    if (w == -1) {
                        Dfs(c[x], r);
                        w = r;
                        new_a[r] = 1;
                    } else {
                        while (pv[w] != r) {
                            w = pv[w];
                        }
                        debug(w);
                        assert(w == c[x] || w == c[y]);
                        Dfs(w ^ c[x] ^ c[y], r);
                        w = r;
                        new_a[r] = 1;
                    }
                } else {
                    new_a[r] = 1;
                }
            } else {
                assert(x == y);
                Dfs(c[x], r);
            }
            swap(a, new_a);
            debug(a);
        }
        if (accumulate(a.begin(), a.end(), 0) == 2) {
            int x = (int) (max_element(a.begin(), a.end()) - a.begin());
            int y = (int) (max_element(a.begin() + x + 1, a.end()) - a.begin());
            if (w == -1) {
                // assert((Ask(x) == 0) != (Ask(y) == 0));
                if (Ask(x) == 0) {
                    a[y] = 0;
                    cerr << x << " " << y << endl;
                } else {
                    a[x] = 0;
                    cerr << y << " " << x << endl;
                }
            } else {
                assert(w == x || w == y);
                if (w == x) {
                    a[x] = 0;
                    cerr << x << " " << y << endl;
                } else {
                    a[y] = 0;
                    cerr << y << " " << x << endl;
                }
            }
        }
        cout << "! " << (max_element(a.begin(), a.end()) - a.begin()) + 1 << endl;
    }
    assert(sn <= 1e5);
    return 0;
}
1 Like