AJ - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Danny Boy
Testers: Utkarsh Gupta, Hriday
Editorialist: Nishank Suresh

DIFFICULTY:

3242

PREREQUISITES:

Diameter of a tree

PROBLEM:

You are given a tree on N vertices and Q queries, each of the form (u, v). For each query, find the largest integer k such that there exists a set of vertices \{x_1, x_2, \ldots, x_k\} satisfying:

  • x_1 = u
  • x_k = v
  • For each 1 \leq i \lt k, dist(x_i, x_{i+1}) = i where the distance between two vertices is the number of edges on the unique shortest path between them.

A further constraint is that u is never a leaf.

EXPLANATION:

As it turns out, the actual solution to a query is rather simple — proving that it works is the hard part.

First, consider things from the perspective of v.

  • We need to end at v with a path of length k-1, so of course a path of this length from v must exist. Let M be the maximum length of a path that has v as one of its endpoints. Then, M+1 is an upper bound for the answer.
  • Second, we make a total of 1 + 2 + \ldots + (k-1) steps from u to v, which equals k(k-1)/2. This number must definitely have the same parity as the length of the (unique) u-v path in the tree.

This is in fact sufficient. Let M be as described above, and p be the parity of the u-v path in the tree. The answer is then the largest integer k such that k \leq M+1 and k(k-1)/2 has the same parity as p.
If we are able to compute M, then finding k is trivial: since only parity matters, the answer is going to be one of \{M+1, M, M-1\} so just find the largest of them that satisfies the parity condition.

Computing M is also fairly easy. It is well-known that in a tree, the longest path from a vertex has its other endpoint as one of the endpoints of the tree’s diameter. So, use any standard algorithm (for example, use dfs/bfs twice) to compute the endpoints of a diameter of the tree, then compute the distance from these two endpoints to all the other vertices. This precomputation can be done in \mathcal{O}(N), after which each query is answered in \mathcal{O}(1).

Now comes the harder part: proving that the above solution is indeed correct. Note that the below proof depends specifically on the fact that u is not a leaf.

Proof

The below proof is from the author, and will be edited a bit to add more detail later.

Let k be the value we computed above, i.e, the largest integer that is \leq M and k(k-1)/2 has the same parity as that of the u-v path.

This k is an obvious upper bound on the answer since those two conditions are necessary. To prove that they are sufficient, we can explicitly construct an appropriate sequence.

Let’s make the moves backwards, i.e, start from v then make steps of length k, k-1, \ldots 1 to reach u.

Define a function f, where f(i, j, x) is true if it is possible to reach j from i with the first step being of length exactly x, and false otherwise.

Claim: If the following three conditions hold, f(i, j, x) is true:

  • x(x+1)/2 \geq dist(i, j)
  • x \leq dist(i, j)
  • dist(i, j) and x(x+1)/2 have the same parity.
Proof

Let u denote the current vertex. Initially, u = i. Repeat the following strategy:

  • If dist(u, j) \leq x, jump directly towards j
  • Otherwise, jump directly towards i (and hence, away from j)
  • Then, decrease x by 1 and continue the process

This process does the following:

  • Since x \leq dist(i, j), the first move is guaranteed to be a jump towards j. We then make some more jumps toward j.
  • Consider the first moment we make a jump away from j. After this point, any jump we make of some length L will ensure that we are within a distance of 2L - 1 from j.
  • In particular, when we make the jump with L = 1, we are within a distance of 1 from j. However, the parity condition mentioned above guarantees that we reach exactly j.

The second point above that we stay within distance 2L - 1 has some counterexamples if the current distance is \leq 3 and j is a leaf. However, the input guarantees that u is not a leaf (and u is what we treat as j in this case), and so these counterexamples don’t work.

Now for the actual construction: given u and v, find a vertex g such that dist(g, v) = k.
Then, from v, alternate jumping towards g and towards v till you reach a point where the distance from the current vertex to u is larger than the step size that needs to be made (recall that we are making moves in decreasing order of length, starting from v).

Such a moment always exists. Let the length of the current jump be L. It further holds that the distance from the current vertex to u is \leq L(L+1)/2 (again, there are some counterexamples when u is a leaf but they don’t matter in this case).

The parity condition is also satisfied (since we ensured k was chosen to satisfy the parity condition initially). Thus, all three conditions from our above claim are now true, so there exists a way to reach u with the final step. This completes the construction.

TIME COMPLEXITY

\mathcal{O}(N + Q) per test case.

CODE:

Setter's code (C++)
#include <bits/stdc++.h>
#define ll long long
#define int long long
#define fi first
#define se second
#define mat vector<vector<ll>> 
using namespace std;
void db() {cout << endl;}
template <typename T, typename ...U> void db(T a, U ...b) {cout << a << ' ', db(b...);}
#ifdef Cloud
#define file freopen("input.txt", "r", stdin), freopen("output.txt", "w", stdout)
#else
#define file ios::sync_with_stdio(false); cin.tie(0)
#endif
const int N = 2e5 + 1, mod = 1e9 + 7;// inf = 1e9;
const int inf = 1e9;
int d[N], anc[N][20];
vector<int> g[N];
pair<int, int> dfs(int u, int p){
    pair<int, int> ans = {0, u};
    for (int i : g[u]){
        if (i == p) continue;
        d[i] = d[u] + 1;
        anc[i][0] = u;
        for (int j = 1; j < 20; j++) anc[i][j] = anc[anc[i][j - 1]][j - 1];
        pair<int, int> tmp = dfs(i, u);
        if (tmp.fi + 1 > ans.fi) ans = {tmp.fi + 1, tmp.se};
    }
    return ans;
}
int lca(int u, int v){
    if (d[u] < d[v]) swap(u, v);
    for (int i = 19; i >= 0; i--) if (d[anc[u][i]] >= d[v]) u = anc[u][i];
    if (u == v) return v;
    for (int i = 19; i >= 0; i--) if (anc[u][i] != anc[v][i]) u = anc[u][i], v = anc[v][i];
    return anc[u][0];
}
int dis(int u, int v){
    return d[u] + d[v] - 2 * d[lca(u, v)];
}
signed main(){
    file;
    int n, q;
    cin >> n >> q;
    for (int i = 1; i <= n; i++) g[i].clear();
    for (int i = 0; i < n - 1; i++){
        int u, v;
        cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    auto x = dfs(1, -1);
    d[x.se] = 0;
    auto y = dfs(x.se, -1);
    int ans[n + 1]{};
    for (int i = 1; i <= n; i++) ans[i] = max(ans[i], d[i]);
    d[y.se] = 1;
    dfs(y.se, -1);
    for (int i = 1; i <= n; i++) ans[i] = max(ans[i], d[i] - 1);
    while (q--){
        int u, v;
        cin >> u >> v;
        int res = ans[v], D = dis(u, v);
        while ((res * (res + 1) / 2) % 2 != D % 2) res--;
        cout << res + 1 << '\n';
    }
}
Tester's code (C++)
/**
 * the_hyp0cr1t3
 * 07.09.2022 14:37:05
**/
#ifdef W
    #include <k_II.h>
#else
    #include <bits/stdc++.h>
    using namespace std;
#endif

template<class T> class Y {
    T f;
public:
    template<class U> explicit Y(U&& f): f(forward<U>(f)) {}
    template<class... Args> decltype(auto) operator()(Args&&... args) {
        return f(ref(*this), forward<Args>(args)...);
    }
}; template<class T> Y(T) -> Y<T>;

// -------------------- Input Checker Start --------------------

long long readInt(long long l, long long r, char endd)
{
    long long x = 0;
    int cnt = 0, fi = -1;
    bool is_neg = false;
    while(true)
    {
        char g = getchar();
        if(g == '-')
        {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if('0' <= g && g <= '9')
        {
            x *= 10;
            x += g - '0';
            if(cnt == 0)
                fi = g - '0';
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);
            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if(g == endd)
        {
            if(is_neg)
                x = -x;
            if(!(l <= x && x <= r))
            {
                cerr << "L: " << l << ", R: " << r << ", Value Found: " << x << '\n';
                assert(false);
            }
            return x;
        }
        else
        {
            assert(false);
        }
    }
}

string readString(int l, int r, char endd)
{
    string ret = "";
    int cnt = 0;
    while(true)
    {
        char g = getchar();
        assert(g != -1);
        if(g == endd)
            break;
        cnt++;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}

long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
void readEOF() { assert(getchar() == EOF); }

vector<int> readVectorInt(int n, long long l, long long r)
{
    vector<int> a(n);
    for(int i = 0; i < n - 1; i++)
        a[i] = readIntSp(l, r);
    a[n - 1] = readIntLn(l, r);
    return a;
}

// -------------------- Input Checker End --------------------

int main() {
#if __cplusplus > 201703L
    namespace R = ranges;
#endif
    ios_base::sync_with_stdio(false), cin.tie(nullptr);
    const int LG = 18;

    int n = readIntSp(3, 2e5);
    int q = readIntLn(1, 2e5);
    vector<vector<int>> g(n);
    vector<pair<int, int>> edges;
    for(int i = 0; i < n - 1; i++) {
        int u = readIntSp(1, n) - 1;
        int v = readIntLn(1, n) - 1;
        g[u].push_back(v);
        g[v].push_back(u);

        assert(u != v);
        if(u > v) swap(u, v);
        edges.emplace_back(u, v);
    }

    sort(edges.begin(), edges.end());
    for(int i = 1; i < edges.size(); i++)
        assert(edges[i] != edges[i - 1]);

    vector<int> depth(n);
    array<vector<int>, LG> anc;
    for(auto& x: anc) x.assign(n, -1);

    vector<array<array<int, 2>, 2>> best(n);
    Y([&](auto dfs, int v, int p) -> void {
        anc[0][v] = p;
        for(int k = 1; k < LG; k++)
            if(~anc[k - 1][v]) anc[k][v] = anc[k - 1][anc[k - 1][v]];

        best[v].fill({0, v});
        for(auto& x: g[v]) if(x ^ p) {
            depth[x] = depth[v] + 1;
            dfs(x, v);
            array<int, 2> cand = { best[x][0][0] + 1, x };
            best[v][1] = cand >= best[v][0]?
                            exchange(best[v][0], cand)
                                : max(best[v][1], cand);
        }
    })(0, -1);

    Y([&](auto dfs, int v, int p) -> void {
        for(auto& x: g[v]) if(x ^ p) {
            array<int, 2> cand = { best[v][best[v][0][1] == x][0] + 1, v };
            best[x][1] = cand >= best[x][0]?
                            exchange(best[x][0], cand)
                                : max(best[x][1], cand);
            dfs(x, v);
        }
    })(0, -1);

    auto LCA = [&](int u, int v) {
        if(depth[u] < depth[v]) swap(u, v);
        for(int z = 0; z < LG; z++)
            if(depth[u] - depth[v] >> z & 1) u = anc[z][u];

        if(u == v) return u;

        for(int z = LG - 1; ~z; z--)
            if(anc[z][u] ^ anc[z][v])
                u = anc[z][u], v = anc[z][v];

        return anc[0][u];
    };

    auto dist = [&](int u, int v) {
        return depth[u] + depth[v] - 2 * depth[LCA(u, v)];
    };

    while(q--) {
        int u = readIntSp(1, n) - 1;
        int v = readIntLn(1, n) - 1;
        assert(g[u].size() > 1);

        int ans = best[v][0][0];
        while(1LL * ans * (ans + 1) / 2 - dist(u, v) & 1)
            ans--;

        cout << ans + 1 << '\n';
    }

} // ~W
Editorialist's code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	int n, m; cin >> n >> m;
	vector<vector<int>> g(n);
	for (int i = 0; i < n-1; ++i) {
		int u, v; cin >> u >> v;
		g[--u].push_back(--v);
		g[v].push_back(u);
	}
	auto bfs = [&] (int src) {
		vector<int> dist(n, -1);
		dist[src] = 0;
		queue<int> q; q.push(src);
		while (!q.empty()) {
			int u = q.front(); q.pop();
			for (int v : g[u]) {
				if (dist[v] == -1) {
					dist[v] = 1 + dist[u];
					q.push(v);
				}
			}
		}
		return dist;
	};
	auto tmp = bfs(0);
	int u1 = max_element(begin(tmp), end(tmp)) - begin(tmp);
	auto d1 = bfs(u1);
	int u2 = max_element(begin(d1), end(d1)) - begin(d1);
	auto d2 = bfs(u2);

	while (m--) {
		int u, v; cin >> u >> v; --u, --v;
		int mx = max(d1[v], d2[v]);
		int par = (d1[u]%2)^(d1[v]%2);
		while (1) {
			int par2 = ((1LL*mx*(mx+1))/2)%2;
			if (par == par2) break;
			--mx;
		}
		cout << mx+1 << '\n';
	}
}