TREELOOP - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: gunpoint_88
Tester: raysh07
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

DFS, rerooting

PROBLEM:

You’re given a tree on N vertices.
Alice will add one edge to this tree that doesn’t already exist; then Bob will remove one edge from the resulting graph such that the final state is again a tree.

Alice wants to minimize the diameter of the final tree, while Bob wants to maximize it.
Find the result if they choose their moves optimally.

EXPLANATION:

The graph formed after Alice’s edge addition looks like a single cycle with trees hanging off of it.
To make the final graph a tree, Bob must remove an edge from this cycle.

Notice that Bob can always just remove the edge added by Alice, so certainly the answer isn’t smaller than the diameter of the original tree (we’ll call this d_o)

While there are many possible choices for Alice’s move, most of them can actually be discarded.
In particular, we have:
Claim: It’s always optimal for Alice to add an edge between two vertices at distance 2 from each other.

Proof

Suppose Alice adds an edge between vertices u and v, such that dist(u, v) = D \geq 3.
Let the vertices on this path be x_1, x_2, \ldots, x_{D+1}, where x_1 = u and x_{D+1} = v.

Let’s look at Bob’s optimal action here.

  • If Bob chooses to delete the edge between x_i and x_{i+1} for some 1\leq i \leq D, Alice could’ve chosen her edge to be either (x_1, x_D) or (x_2, x_{D+1}) - at least one of them would not increase the diameter of the resulting tree (since Bob’s options would be more limited).
  • If Bob chooses to delete the edge Alice added, it means that deleting any of the intermediate edges was not optimal; so Alice could’ve chosen to connect two of them at a shorter distance instead and the result won’t change.

Performing this repeatedly, we eventually arrive at a state where Alice chooses an edge between two vertices at distance two.


This can be rephrased as “there is some vertex u such that Alice joins two neighbors of u”.

So, let’s fix a vertex u, and look at all its neighbors; say there are k of them labelled v_1, v_2, \ldots, v_k.
Suppose we join v_i and v_j with an edge - let’s analyze what Bob’s move should be.

The cycle created has length 3, so Bob has three choices.
One of them is to remove the edge Alice just added; resulting in a diameter of d_o.
The other two options are symmetric, removing the edge between u and one of v_i or v_j; so let’s look at what happens when the v_i \leftrightarrow u edge is removed.

For convenience, we root the tree at u.
Also, let L_i denote the longest path from v_i that goes into the subtree of v_i; i.e, the longest path from v_i to some leaf in its subtree.

It’s enough to only consider paths passing through the newly added edge.
In particular, the longest of them in the final tree involves:

  • The longest path from v_i into its subtree; which is L_i
  • The longest path from v_j into anywhere other than the subtree of v_i.
    There are two choices here:
    • We can go into the subtree of v_j; with length L_j.
    • We can go up to u, then into the subtree of some child of u that’s not v_i or v_j. This is 2 + L_r edges, if we choose v_r.
  • The total length is thus either L_i + L_j + 1, or L_i + 3 + L_r for some r.
    Notice that the first case is never optimal (the original tree already had a path of length L_i + L_j + 2 so Bob will never choose this).

So, if Bob deletes the edge between u and v_i, the best he can do is L_i + 3 + L_r for some r\neq i, j.
The same applies to L_j as well; so Bob can get L_r + 3 + \max(L_i, L_j) overall.
The choice of r is completely up to Bob, so he’ll always choose the maximum L_r possible.

So, for a fixed u, it’s optimal for Alice to choose the two vertices with smallest L_i values.
In particular, if we have L_1 \geq L_2 \geq\ldots\geq L_k, then the best Alice can do is limit Bob to attaining L_1 + L_{k-1} + 3.


There are a couple of edge cases to consider though:

  • If k = 1, of course u doesn’t even have two neighbors so ignore it.
  • If k = 2, then Bob doesn’t have another L_r to choose; so the best he can do is remove the edge Alice adds.
  • If k = 3, then Alice can force Bob to attain L_1 + L_3 + 3 (instead of L_1 + L_2 + 3 as the general case would give) by choosing to add an edge between v_1 and v_2.
  • The general solution outlined above works for all k \geq 4.

So, all we really need to know is the L_i values for the neighbors of u - in other words, the longest paths from each of them that don’t include u.
These values can be found with the help of rerooting:

  • First, root the tree arbitrarily, and compute using DFS the maximum path length going down for each node.
  • Then, reroot to find the longest path going up for each node.
    An introduction to this technique can be found here.

Once all the L_i values are known, solving the problem is simple, as outlined above.

TIME COMPLEXITY

\mathcal{O}(N\log N) per testcase.

CODE:

Author's code (C++)
#include<bits/stdc++.h>
using namespace std;
using ll=long long;
const ll inf=1e9+7;

#ifdef ANI
#include "D:/DUSTBIN/local_inc.h"
#else
#define dbg(...) 0
#endif

array<ll,2> solution(vector<array<ll,2>> tree) {
	ll n=tree.size()+1;
	vector<vector<ll>> e(n);
	for(auto _:tree) {
		e[_[0]].push_back(_[1]);
		e[_[1]].push_back(_[0]);
	}

	vector<ll> d(n,0); // maximum depth in subtree of node
	function<void(ll,ll)> dfs=[&](ll cur,ll par)->void{
		for(ll node:e[cur]) {
			if(node^par) {
				dfs(node,cur);
				d[cur]=max(d[cur],d[node]+1);
			}
		}
	};
	dfs(0,-1);

	ll diam=0,minbreak=inf;

	function<void(ll,ll,ll)> dfs2=[&](ll cur,ll par,ll up)->void{
		vector<array<ll,2>> dd;
		if(up>=0) dd.push_back({up,-1});
		for(ll node:e[cur])
			if(node^par)
				dd.push_back({d[node],node});
		sort(dd.begin(),dd.end(),greater<array<ll,2>>());
		ll k=dd.size();
		if(dd.size()==2)
			minbreak=min(minbreak,dd[0][0]+dd[1][0]+2);
		else if(dd.size()==3) {
			minbreak=min(minbreak,dd[0][0]+dd[2][0]+3);
		} else if(dd.size()>3) {
			minbreak=min(minbreak,dd[0][0]+dd[k-2][0]+3);
		}
		diam=max(diam,dd[0][0]+1);
		if(dd.size()>=2)
			diam=max(diam,dd[0][0]+dd[1][0]+2);

		for(ll node:e[cur]) {
			if(node^par) {
				ll node_up=0;
				if(dd[0][1]==node) {
					node_up=(dd.size()>1?dd[1][0]+1:0);
				} else {
					node_up=dd[0][0]+1;
				}
				dfs2(node,cur,node_up);
			}
		}
	};
	dfs2(0,-1,-inf);

	return {max(diam,minbreak),diam};
}

void work() {
	ll n; cin>>n;
	vector<array<ll,2>> tree(n-1);
	for(ll i=0;i<n-1;i++) {
		ll u,v; cin>>u>>v;
		tree[i]={u-1,v-1};
	}
	auto res=solution(tree);
	cout<<res[0]<<"\n";
}

int main() {
	ll t; cin>>t;
	while(t--) {
		work();
	}
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
#define f first
#define s second

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

struct input_checker {
    string buffer;
    int pos;
 
    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";
 
    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }
 
    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }
 
    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }
 
    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }
 
    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }
 
    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }
 
    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }
 
    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }
 
    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }
 
    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }
 
    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

input_checker inp;
int sum_n;
int n;
const int N = 2e5 + 69;

int in[N];
int out[N];
bool vis[N];
int par[N];
vector <int> adj[N];
 
void dfs(int u){
    vis[u] = true;
    for (int v: adj[u]){
        if (!vis[v]){
            dfs(v);
            in[u] = max(in[u] , in[v] + 1);
        }
    }
}
 
void dfs2(int u){
    vis[u] = true;
    vector <int> val;
    int sz = 0;
    for (int v: adj[u]){
        if (!vis[v]){
            par[v] = u;
            val.push_back(in[v]);
            sz++;
            out[v] = out[u] + 1;
        }
    }
    int pmax[sz + 1];
    int smax[sz + 1];
    
    pmax[0] = -1;
    smax[sz] = -1;
    for (int i=1; i<=sz; i++){
        pmax[i] = max(pmax[i-1], val[i-1]);
    }
    for (int i=sz-1; i>=0; i--){
        smax[i] = max(smax[i+1], val[i]);
    }
    
    int c = 0;
    for (int v: adj[u]){
        if (!vis[v]){
            out[v] = max(out[v], pmax[c] + 2);
            out[v] = max(out[v], smax[c+1] + 2);
            c++;
        }
    }
    
    for (int v: adj[u]){
        if (!vis[v])
        dfs2(v);
    }
}

void Solve() 
{
    n = inp.readInt(1, N); inp.readEoln();
    sum_n += n;
    
    for (int i = 1; i <= n; i++){
        in[i] = out[i] = 0;
        adj[i].clear();
    }
    
    for (int i = 1; i < n; i++){
        int u = inp.readInt(1, n); inp.readSpace();
        int v = inp.readInt(1, n); inp.readEoln();
        
        adj[u].push_back(v);
        adj[v].push_back(u);
        
        assert(u != v);
    }
    
    for (int i = 1; i <= n; i++){
        vis[i] = false;
    }
    
    dfs(1);
    
    for (int i = 1; i <= n; i++){
        assert(vis[i]);
        vis[i] = false;
    }
    
    dfs2(1);

    int diam = 0;
    for (int i = 1; i <= n; i++){
        diam = max(diam, in[i]);
        diam = max(diam, out[i]);
    }
    
    bool poss = false;
    
    for (int i = 1; i <= n; i++){
        vector <int> ok;
        for (auto x : adj[i]){
            if (x == par[i]){
                ok.push_back(out[i] - 1);
            } else {
                ok.push_back(in[x]);
            }
        }
        
        sort(ok.begin(), ok.end());
        if (ok.size() == 1) continue;
        int lol;
        if (ok.size() == 2) lol = ok[0] + ok[1] + 2;
        else if (ok.size() == 3) lol = ok[0] + ok[2] + 3;
        else lol = ok[1] + ok.back() + 3;
        
        if (lol <= diam) poss = true;
    }
    
    if (poss) cout << diam << "\n";
    else cout << diam + 1 << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    
    t = inp.readInt(1, (int)5e4);
    inp.readEoln();
    
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    
    assert(sum_n <= (int)2e5);
    inp.readEof();
    
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}