TREEQR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: isheoran
Preparer: mexomerf
Tester & Editorialist: iceknight1093

DIFFICULTY:

2967

PREREQUISITES:

DFS, familiarity with bitwise AND

PROBLEM:

You’re given a tree with weighted edges, and Q queries on it.
For each query:

  • You’re given a vertex u. Find two vertices x and y such that:
    • u lies on the path between x and y
    • The bitwise AND of the path between them is \gt 0
    • The distance between x and y is maximum among all such (x, y) pairs

EXPLANATION:

Let’s solve for a single vertex u first.

Consider some (x, y) such that the path between them has bitwise AND \gt 0.
Then, there must exist some bit b such that b is set in the weight of every edge on this path.

So, suppose we fix a bit b and only consider those edges whose weight has b set in it.
These edges will form a forest, so consider the component (i.e tree) that contains u.
In this tree, we now just want to find the longest path that includes u.
The answer for u is then the longest such path across all b.

To restate the problem: we’re given an unweighted tree and a vertex u of it. We’d like to find the longest path in this tree passing through u.

There are several ways to find this in linear time, here’s a neat one.

Let \Gamma be the tree we’re considering, and (x, y) be the endpoints of an optimal path.
Then, there exists a solution where at least one of x and y will be the endpoint of a diameter of the tree.

Proof

Let d_1 and d_2 be two endpoints of a diameter of \Gamma.
Recall that one property of the diameter is that the longest path from a vertex will have its other endpoint be a diameter endpoint.
In particular, the longest path from u will have either d_1 or d_2 as an endpoint. Without loss of generality, let it be d_1.

Let’s root the tree at u and see what happens.

  • If the path from d_1 to d_2 passes through u, clearly it’s the longest path passing through u and the claim is true.
  • Otherwise, there’s a child c of u whose subtree contains both d_1 and d_2.
  • Since (x, y) as a path passes through u, we can break the path up into disjoint x \to u and y \to u paths.
    • If x lies in the subtree of c, then choosing x = d_1 gives us a not shorter path, so we can do so.
    • If y lies in the subtree of c, similarly we can replace it with d_1 and not be worse.
    • Otherwise, x and y will lie in different subtrees of children of u. In this case, either one of them can be replaced by d_1 for a not worse answer.

We can always ensure that d_1 is one of the endpoints of an optimal path, hence proving our claim.

Now that we have this, a solution in \mathcal{O}(N) is fairly simple:

  • Find a diameter of \Gamma in \mathcal{O}(N).
    There are a few different ways to do this: for example, you can use two runs of dfs/bfs as mentioned in this article; or you can use dp as mentioned here.
  • Let d_1 and d_2 be the endpoints of the diameter.
  • Root \Gamma at d_1 and compute distances to all nodes. Let this array be \text{dist}_1.
    Note that the longest path from d_1 passing through u simply equals the maximum distance of some vertex in the subtree of u, which can be computed easily using dynamic programming once \text{dist}_1 is known.
  • Similarly, root \Gamma at d_2 to compute \text{dist}_2 and mx_2.
  • The answer for u is then just \max(mx_1[u], mx_2[u]).

Repeat this linear algorithm for each bit to obtain a solution in \mathcal{O}(30\cdot N), for a fixed u.


Obviously, doing the above for each query is too slow.
However, it’s not really that hard to optimize this.

Notice that for a fixed bit b, the forest we end up with is independent of u.
This means that for a certain tree in this forest, if we compute the diameter and the mx_1 and mx_2 arrays mentioned above, then the answer for every u in this tree is \max(mx_1[u], mx_2[u]) — we don’t need to recompute them!

This means we can fix a bit b and compute the answer for all vertices in \mathcal{O}(N) time, which in turn means we can just precompute the answer for all vertices in \mathcal{O}(30\cdot N) time.

After this precomputation, answering queries is \mathcal{O}(1) each: just look up the answer.


There are other ways to solve this problem, for example dp with rerooting.
However, the general idea remains the same: precompute the answer for all vertices in \mathcal{O}(30\cdot N) or similar, and use that to answer queries.

TIME COMPLEXITY

\mathcal{O}(30\cdot N + Q) per test case.

CODE:

Preparer's code (C++)
// library link: https://github.com/manan-grover/My-CP-Library/blob/main/library.cpp
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
#define asc(i,a,n) for(I i=a;i<n;i++)
#define dsc(i,a,n) for(I i=n-1;i>=a;i--)
#define forw(it,x) for(A it=(x).begin();it!=(x).end();it++)
#define bacw(it,x) for(A it=(x).rbegin();it!=(x).rend();it++)
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define lb(x) lower_bound(x)
#define ub(x) upper_bound(x)
#define fbo(x) find_by_order(x)
#define ook(x) order_of_key(x)
#define all(x) (x).begin(),(x).end()
#define sz(x) (I)((x).size())
#define clr(x) (x).clear()
#define U unsigned
#define I int
#define S string
#define C char
#define D long double
#define A auto
#define B bool
#define CM(x) complex<x>
#define V(x) vector<x>
#define P(x,y) pair<x,y>
#define OS(x) set<x>
#define US(x) unordered_set<x>
#define OMS(x) multiset<x>
#define UMS(x) unordered_multiset<x>
#define OM(x,y) map<x,y>
#define UM(x,y) unordered_map<x,y>
#define OMM(x,y) multimap<x,y>
#define UMM(x,y) unordered_multimap<x,y>
#define BS(x) bitset<x>
#define L(x) list<x>
#define Q(x) queue<x>
#define PBS(x) tree<x,null_type,less<I>,rb_tree_tag,tree_order_statistics_node_update>
#define PBM(x,y) tree<x,y,less<I>,rb_tree_tag,tree_order_statistics_node_update>
#define pi (D)acos(-1)
#define md 1000000007
#define rnd randGen(rng)
inline void comp(P(P(I,I),I) &vv,P(P(I,I),I) v[2],I y){
  if(vv.fi.fi>v[0].fi.fi){
    v[1]=v[0];
    v[0]=vv;
    v[0].se=y;
  }else if(vv.fi.fi>v[1].fi.fi){
    v[1]=vv;
    v[1].se=y;
  }
}
void dfs0(I x,I pr,V(P(I,I)) tr[],P(P(I,I),I) v[][2],I temp){
  P(P(I,I),I) vv;
  asc(i,0,sz(tr[x])){
    I y=tr[x][i].fi;
    I w=tr[x][i].se;
    if(y!=pr){
      dfs0(y,x,tr,v,temp);
      if(w&temp){
        if(v[y][0].fi.fi!=-1){
          vv=v[y][0];
          vv.fi.fi++;
        }else{
          vv={{1,y},y};
        }
      }else{
        vv={{0,x},y};
      }
      comp(vv,v[x],y);
    }
  }
}
void dfs1(I x,I pr,V(P(I,I)) tr[],P(P(I,I),I) v[][2],I temp){
  P(P(I,I),I) vv;
  if(x==1){
    vv={{0,1},0};
    comp(vv,v[1],0);
  }
  asc(i,0,sz(tr[x])){
    I y=tr[x][i].fi;
    I w=tr[x][i].se;
    if(y!=pr){
      if(w&temp){
        if(v[x][0].se!=y){
          vv=v[x][0];
        }else{
          vv=v[x][1];
        }
        vv.fi.fi++;
      }else{
        vv={{0,y},x};
      }
      comp(vv,v[y],x);
      dfs1(y,x,tr,v,temp);
    }
  }
}
void cal(V(I) dp[],V(P(I,I)) tr[],I n,I temp){
  P(P(I,I),I) v[n+1][2]; //{len, dest, dir};
  asc(i,1,n+1){
    v[i][0]=v[i][1]={{-1,-1},-1};
  }
  dfs0(1,0,tr,v,temp);
  dfs1(1,0,tr,v,temp);
  asc(i,1,n+1){
    /*if(sz(v[i][0])){
      cout<<i<<" "<<v[i][0][0]<<"\n";
    }*/
    if(v[i][1].fi.fi==-1){
      dp[i]={v[i][0].fi.fi,i,v[i][0].fi.se}; //{len, start, end};
    }else{
      dp[i]={v[i][0].fi.fi+v[i][1].fi.fi,v[i][0].fi.se,v[i][1].fi.se};
    }
  }
}
int main(){
  mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
  uniform_int_distribution<I> randGen;
  ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
  #ifndef ONLINE_JUDGE
  freopen("input.txt", "r", stdin);
  freopen("output.txt", "w", stdout);
  #endif
  I t;
  cin>>t;
  while(t--){
    I n;
    cin>>n;
    V(P(I,I)) tr[n+1];
    asc(i,0,n-1){
      I u,v,w;
      cin>>u>>v>>w;
      tr[u].pb({v,w});
      tr[v].pb({u,w});
    }
    V(I) dp[30][n+1];
    I temp=1;
    asc(i,0,30){
      cal(dp[i],tr,n,temp);
      temp*=2;
    }
    I q;
    cin>>q;
    while(q--){
      I x;
      cin>>x;
      I res=-1;
      I u,v;
      asc(i,0,30){
        if(dp[i][x][0]>res){
          res=dp[i][x][0];
          u=dp[i][x][1];
          v=dp[i][x][2];
        }
      }
      //cout<<res<<"\n";
      if(res){
        cout<<u<<" "<<v<<"\n";
      }else{
        cout<<-1<<" "<<-1<<"\n";
      }
    }
  }
  return 0;
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	int t; cin >> t;
	while (t--) {
		int n; cin >> n;

		vector adj(n, basic_string<array<int, 2>>());
		for (int i = 0; i < n-1; ++i) {
			int u, v, w; cin >> u >> v >> w;
			--u, --v;
			adj[u].push_back({v, w});
			adj[v].push_back({u, w});
		}

		vector<array<int, 3>> ans(n, array{0, -2, -2});

		vector<int> dist1(n, n+5), dist2(n, n+5), dist3(n, n+5), mark(n);
		vector<int> sub1(n), sub2(n), sub3(n);
		vector<int> root1(n), root2(n);
		
		for (int bit = 0; bit < 30; ++bit) {
		    dist1.assign(n, n+5); dist2.assign(n, n+5); dist3.assign(n, n+5);
		    mark.assign(n, 0);
		    sub1.assign(n, 0); sub2.assign(n, 0); sub3.assign(n, 0);
		    root1.assign(n, 0); root2.assign(n, 0);

			auto dfs = [&] (int src, auto &dist, auto &sub, int flag) {
				stack<int> st; st.push(src);
				dist[src] = 0;

				while (!st.empty()) {
					int u = st.top();
					if (flag == 2) root1[u] = src;
					if (flag == 3) root2[u] = src;
					for (auto &[v, w] : adj[u]) {
						if (dist[u] > dist[v]) continue;
						if (~w & (1 << bit)) continue;

						
						if (mark[u] < flag) st.push(v), dist[v] = 1 + dist[u];
						else if (dist[sub[v]] > dist[sub[u]]) sub[u] = sub[v];
					}
					if (mark[u] < flag) mark[u] = flag, sub[u] = u;
					else st.pop();
				}
				return sub[src];
			};
			
			for (int i = 0; i < n; ++i) {
				if (dist1[i] < n+5) continue;
				int u = dfs(i, dist1, sub1, 1);
				int v = dfs(u, dist2, sub2, 2);
				dfs(v, dist3, sub3, 3);
			}

			for (int i = 0; i < n; ++i) {
				if (dist2[sub2[i]] > ans[i][0]) ans[i] = {dist2[sub2[i]], sub2[i], root1[i]};
				if (dist3[sub3[i]] > ans[i][0]) ans[i] = {dist3[sub3[i]], sub3[i], root2[i]};
			}
		}

		int q; cin >> q;

		while (q--) {
			int x; cin >> x;
			--x;
			cout << ans[x][1]+1 << ' ' << ans[x][2]+1 << '\n';
		}
	}
}