MIN_OR_ST - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: wuhudsm
Tester: jay_1048576
Editorialist: iceknight1093

DIFFICULTY:

2992

PREREQUISITES:

DFS/DSU, Observation

PROBLEM:

You’re given a connected undirected weighted graph with N vertices and M edges.
Answer M independent queries on it:

  • Add edge (u, v) to the graph with weight 0. Then, compute the weight of the minimum spanning tree of the resulting graph, where the total weight is the bitwise OR of chosen edges.

EXPLANATION:

Let Y be the value of the initial OR-MST weight.
Let’s first see how to compute Y.

We’d like to minimize the bitwise OR of the chosen weights, so it’s optimal to iterate across bits in decreasing order and see if we can avoid taking the current bit, even at the cost of all lower bits.

That leads us to the following algorithm:

  • We’ll iterate across bits in descending order from 29 down to 0.
    Let \text{ans} be the current answer; initially zero.
    Let E be the current active subset of edges; initially this equals all the edges.
  • For a fixed bit b, remove any edge of E that has the b'th bit set. Let E' be the remaining set of edges.
    • if E' is a spanning set of edges for all N vertices, \text{ans} needn’t have bit b set.
      In this case, replace E with E' and continue on.
    • Otherwise, \text{ans} must have the b'th bit set. So, increase \text{ans} by 2^b, but continue on to the next iteration without deleting any edges from E.

Checking whether a set of edges spans all vertices can be done in linear time using DFS/BFS; and almost linear time using DSU which.
We run this linear method once for each bit, which is good enough.
At the end of it all, \text{ans} is the answer we’re looking for.

Now we need to deal with updates.
Of course, running the above algorithm afresh for each new edge would be way too slow.

Instead, let’s analyze how a new edge (u, v, 0) would change the answer.
For each bit b from 29 down to 0:

  • If b isn’t set in \text{ans}, then we don’t need to use the new edge anyway.
  • If b is set, then we’d like to check if it can improve our answer.
    • Recall that b is set in the answer only when we couldn’t connect the graph without using an edge where b was set; i.e, the edges present divided the graph into \gt 1 connected components.
    • If there are \geq 3 connected components, a single edge can’t help us anyway so there’s no change.
    • If there are exactly two connected components, and u and v are in different components, then this edge does help us reduce the total weight!
      Clearly, when we can do this it’s optimal to do so.

This gives us an algorithm:

  • Find the highest bit b such that the graph at this step of our initial algorithm had exactly two connected components, and u and v are in different components.
  • At this stage, we’re definitely using the edge (u, v, 0) to connect these components. So, add it to the current edges (recall that we know which edges have been discarded already) and recompute the OR-MST.

Of course, running the entire OR-MST algorithm for each query is still too slow.
However, here is where we’ll use the fact that queries are independent.

Notice that if we fix the bit b for which we’re using the edge (u, v, 0), the answer is in fact independent of what the actual values of u and v are.
The only thing that matters is that our initial OR-MST algorithm gave us exactly two components, and we compute the OR-MSTs of the union of these two components, along with some 0-weight edge joining them.

So, we only need to run the OR-MST algorithm (at most) once for each bit, making 31 runs in total including the initial one.
The results of these runs can be cached so that each query can be quickly answered later.


Putting everything together, our final solution is as follows:

  • Run the OR-MST algorithm described initially.
  • For each bit such that the resulting graph had exactly two components, run the OR-MST algorithm on the union of these two components (along with some 0-weight edge joining them) as well and store their answers.
  • Then, for each query (u, v):
    • Find the highest bit b such that the graph for this bit has two components; and u and v are in different components.
    • The answer for this query is then simply the precomputed answer for this bit.
    • If no such bit exists, the answer is simply \text{ans}, the initial OR-MST value.

This requires us to quickly check if two vertices lie in the same component, but we can precompute components for each vertex corresponding to each bit, so this can be answered in \mathcal{O}(1) per bit as well.

TIME COMPLEXITY

\mathcal{O}(B^2 \cdot (N + M) + B\cdot Q) per test case, where B = 30 for this problem.

CODE:

Author's code (C++)
#include <map>
#include <set>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
typedef double db; 
typedef long long ll;
typedef unsigned long long ull;
const int N=200010;
const int LOGN=28;
const ll  TMD=0;
const ll  INF=2147483647;
int T,n,m,q;
int ty[N],tag[N],ans[N];

struct Edge
{
	int u,v,w;
	
	Edge() {}
	
	Edge(int u,int v,int w):u(u),v(v),w(w) {}
};
vector<Edge> E,Q;

int f[N];
int find(int x)
{
	return x==f[x]?x:f[x]=find(f[x]);
}
	
void uni(int x,int y)
{
	f[find(x)]=find(y);
}

int cal_component(vector<Edge> &e)
{
	int cnt=0;
	for(int i=1;i<=n;i++) tag[i]=0;
	for(int i=1;i<=n;i++) f[i]=i;
	for(int i=0;i<e.size();i++) uni(e[i].u,e[i].v);
	for(int i=1;i<=n;i++) 
	    if(!tag[find(i)]) tag[find(i)]=1,cnt++;
    return cnt;
}

int cal_answer(vector<Edge> e)
{
	int ans=0;
	for(int i=29;i>=0;i--)
	{
		vector<Edge> tmp;
		for(int j=0;j<e.size();j++)
		    if(!(e[j].w&(1<<i))) tmp.push_back(e[j]);
	    if(cal_component(tmp)!=1) ans^=(1<<i);
		else e=tmp;	    
	}
	return ans;
}

int main()
{
	scanf("%d%d%d",&n,&m,&q);
	for(int i=1;i<=m;i++)
	{
		int u,v,w;
		scanf("%d%d%d",&u,&v,&w);
		E.push_back(Edge(u,v,w));
	}
	for(int i=1;i<=q;i++)
	{
		int u,v;
		scanf("%d%d",&u,&v);
		Q.push_back(Edge(u,v,0));
	}
	for(int i=29;i>=0;i--)
	{
		vector<Edge> tmp;
		for(int j=0;j<E.size();j++)
		    if(!(E[j].w&(1<<i))) tmp.push_back(E[j]);
	    if(cal_component(tmp)==1) E=tmp;	    
	}
	for(int i=29;i>=0;i--)
	{
		int c;
		vector<Edge> tmp;
		for(int j=0;j<E.size();j++)
		    if(!(E[j].w&(1<<i))) tmp.push_back(E[j]);
        c=cal_component(tmp);
	    if(c==1||c>2) continue;
		else
		{
			for(int j=0;j<Q.size();j++)
				if((!ty[j])&&find(Q[j].u)!=find(Q[j].v)) ty[j]=i+1;
		} 	    
	}
	ans[0]=cal_answer(E);
	for(int i=1;i<=30;i++) ans[i]=-1;
	for(int i=0;i<Q.size();i++)
	{
	    if(ans[ty[i]]==-1)
	    {
			E.push_back(Q[i]);
			ans[ty[i]]=cal_answer(E);
			E.pop_back();
		}
		printf("%d\n",ans[ty[i]]);
	}
		
	return 0;
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

struct DSU {
private:
	std::vector<int> parent_or_size;
public:
	DSU(int n = 1): parent_or_size(n, -1) {}
	int get_root(int u) {
		if (parent_or_size[u] < 0) return u;
		return parent_or_size[u] = get_root(parent_or_size[u]);
	}
	int size(int u) { return -parent_or_size[get_root(u)]; }
	bool same_set(int u, int v) {return get_root(u) == get_root(v); }
	bool merge(int u, int v) {
		u = get_root(u), v = get_root(v);
		if (u == v) return false;
		if (parent_or_size[u] > parent_or_size[v]) std::swap(u, v);
		parent_or_size[u] += parent_or_size[v];
		parent_or_size[v] = u;
		return true;
	}
	std::vector<std::vector<int>> group_up() {
		int n = parent_or_size.size();
		std::vector<std::vector<int>> groups(n);
		for (int i = 0; i < n; ++i) {
			groups[get_root(i)].push_back(i);
		}
		groups.erase(std::remove_if(groups.begin(), groups.end(), [&](auto &s) { return s.empty(); }), groups.end());
		return groups;
	}
};

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	int n, m, q; cin >> n >> m >> q;
	vector<array<int, 3>> edges(m);
	for (auto &[u, v, w] : edges) {
		cin >> u >> v >> w;
		--u, --v;
	}

	const int bits = 30;
	vector comp(bits, vector(n, -1));
	vector<int> mst(bits);

	auto or_mst = [&] (bool type = 1, int curb = -1) {
		int ans = 0, bad = 0;
		DSU dsu(n);
		for (int bit = bits-1; bit >= 0; --bit) {
			dsu = DSU(n);
			bad += 1 << bit;
			int comps = n;
			
			for (auto &[u, v, w] : edges) {
				if (type == 0) {
					if (w != 0 and comp[curb][u] != comp[curb][v]) continue;
				}
				if (w & bad) continue;
				comps -= dsu.merge(u, v);
			}
			if (comps == 1) continue;
			bad -= 1 << bit;
			ans += 1 << bit;

			if (type == 1 and comps == 2) {
				int id = 0;
				for (const auto &conn : dsu.group_up()) {
					for (int v : conn) comp[bit][v] = id;
					++id;
				}
			}
		}
		return ans;
	};
	vector<int> all_vertices(n); iota(begin(all_vertices), end(all_vertices), 0);

	int initial_ans = or_mst();
	for (int i = 0; i < bits; ++i) {
		if (comp[i][0] == -1) continue;
		int u = 0, v = 1;
		while (comp[i][u] == comp[i][v]) ++v;
		edges.push_back({u, v, 0});
		mst[i] = or_mst(0, i);
		edges.pop_back();
	}
	
	while (q--) {
		int u, v; cin >> u >> v; --u, --v;
		int ans = initial_ans;
		for (int bit = bits-1; bit >= 0; --bit) {
			if (comp[bit][u] == comp[bit][v]) continue;
			ans = mst[bit];
			break;
		}
		cout << ans << '\n';
	}
}