PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author:
Tester: sushil2006
Editorialist: iceknight1093
DIFFICULTY:
Medium
PREREQUISITES:
Sorting, DSU
PROBLEM:
There’s a graph with N vertices and M edges. Vertex i has pass value of A_i and a monster with health B_i.
Suppose your health is H, and you’re standing on vertex u. You can do the following:
- If H \geq B_u, kill the monster at vertex u. This changes H \to H - B_u; or
- For any vertex v adjacent to u such that H \geq \max(A_u, A_v), move to vertex v.
You’re allowed to start at any vertex of your choice. Find the minimum possible initial value of your health H such that all monsters can be killed.
EXPLANATION:
We can pretend that the edges have ‘weights’ of \max(A_u, A_v).
Suppose we fix our starting health, H.
As and when monsters are killed, H will only decrease.
Since we can only cross an edge if H is at least its weight, the process of killing monsters will iteratively make more and more edges impassable - once H falls below the weight of an edge, that edge will never be usable again.
So, as monsters are killed, we essentially delete edges from the graph.
Clearly, edges will become impassable in descending order of their weights.
Let’s analyze what changes when edge (u, v) is deleted.
There are two possibilities:
- After deleting (u, v), it’s still possible to reach u from v.
- In this case, nothing changes at all - every monster that was previously reachable remains reachable.
- After deleting (u, v), it’s no longer possible to reach u from v, i.e. this edge is a bridge at the time of its deletion.
- This splits the component containing (u, v) into exactly two components. Since our aim is to defeat every monster, this means by the time this has happened we must’ve defeated every monster inside one of these components, and already be in the other one - otherwise it’s definitely not possible to defeat every monster.
A common trick when dealing with only edge deletions in a graph, is to reverse the process so that we only have to deal with edge insertions instead.
This is useful because we have a powerful data structure that can help deal with edge insertions - the DSU.
So, let’s analyze what happens when the process is reversed.
Initially, there are no edges, and they’ll be added in ascending order of their weights (recall that they were deleted in descending order).
When edge (u, v) is added, if u and v are in the same component nothing really changes, so we only need to analyze what happens when they’re in different components which get merged.
Since these components are being merged now, observe that the only possible way all monsters in the merged component can be defeated is if all monsters in the component of u are defeated and we’re able to cross over to the component of v before the edge disappears (or vice versa).
With this in mind, let’s define \text{ans}_u to be the minimum starting health required to defeat all monsters in the component containing u.
Also define h_u to be the sum of health of all monsters in the component of u.
Then, observe that:
- After defeating all the monsters in the component of u, our health will be \text{ans}_u - h_u.
- We need this to be \geq \max(A_u, A_v) to be able to cross to the other component.
- We also need this to be \geq \text{ans}_v, which is the minimum health needed to defeat everything in the component of v.
So, the minimum health needed to defeat every monster in the merged component, assuming everything in the component of u is defeated first, equals
A similar computation can be done for defeating everything in v and then moving to u (just swap u and v in the above expression since they’re symmetric).
The answer for the merged component is then the minimum of these two values.
A DSU handles merging components and finding the component of a vertex quickly.
In the end we’ll have a single component so the final answer is just the answer corresponding to that component.
TIME COMPLEXITY:
\mathcal{O}(M\log M + N) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
struct ufds{
vector <int> root, sz;
int n;
void init(int nn){
n = nn;
root.resize(n + 1);
sz.resize(n + 1, 1);
for (int i = 1; i <= n; i++) root[i] = i;
}
int find(int x){
if (root[x] == x) return x;
return root[x] = find(root[x]);
}
bool unite(int x, int y){
x = find(x); y = find(y);
if (x == y) return false;
if (sz[y] > sz[x]) swap(x, y);
sz[x] += sz[y];
root[y] = x;
return true;
}
};
void Solve()
{
int n, m; cin >> n >> m;
vector <int> a(n + 1);
for (int i = 1; i <= n; i++){
cin >> a[i];
}
vector <int> b(n + 1);
for (int i = 1; i <= n; i++){
cin >> b[i];
}
vector<vector<int>> adj(n + 1);
for (int i = 1; i <= m; i++){
int u, v; cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
vector <int> ord(n);
iota(ord.begin(), ord.end(), 1);
sort(ord.begin(), ord.end(), [&](int x, int y){
return a[x] < a[y];
});
ufds uf;
uf.init(n);
ufds uf1;
uf1.init(n);
vector <int> dp(n + 1);
vector <bool> alive(n + 1, false);
vector <int> sum(n + 1, 0);
for (int i : ord){
vector <int> vec;
alive[i] = true;
for (int j : adj[i]){
if (!alive[j]) continue;
if (uf1.find(i) != uf1.find(j)){
vec.push_back(j);
uf1.unite(i, j);
}
}
if (vec.size() == 0){
dp[i] = max(a[i], b[i]);
sum[i] = b[i];
continue;
}
int S = 0;
for (int j : vec){
S += sum[uf.find(j)];
}
int res = INF;
for (int j : vec){
// kill everyone else
int need = S - sum[uf.find(j)] + b[i];
// must have >= a[i]
// and >= dp[uf.find(j)]
need += max(dp[uf.find(j)], a[i]);
res = min(res, need);
}
res = min(res, S + max(a[i], b[i]));
for (int j : vec){
uf.unite(i, j);
}
sum[uf.find(i)] = S + b[i];
dp[uf.find(i)] = res;
}
int ans = dp[uf.find(1)];
cout << ans << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl
#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)
template<typename T>
void amin(T &a, T b) {
a = min(a,b);
}
template<typename T>
void amax(T &a, T b) {
a = max(a,b);
}
#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif
/*
*/
const int MOD = 1e9 + 7;
const int N = 1e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;
struct DSU {
vector<int> par, rankk, siz;
vector<ll> sum, dp;
DSU() {
}
DSU(int n) {
init(n);
}
void init(int n) {
par = vector<int>(n + 1);
rankk = vector<int>(n + 1);
siz = vector<int>(n + 1);
sum = vector<ll>(n + 1);
dp = vector<ll>(n + 1);
rep(i, n + 1) create(i);
}
void create(int u) {
par[u] = u;
rankk[u] = 0;
siz[u] = 1;
}
int find(int u) {
if (u == par[u]) return u;
else return par[u] = find(par[u]);
}
bool same(int u, int v) {
return find(u) == find(v);
}
void merge(int u, int v, int w) {
u = find(u), v = find(v);
if (u == v) return;
if (rankk[u] == rankk[v]) rankk[u]++;
if (rankk[u] < rankk[v]) swap(u, v);
ll val1 = max({dp[u],sum[u]+w,dp[v]+sum[u]});
ll val2 = max({dp[v],sum[v]+w,dp[u]+sum[v]});
dp[u] = min(val1,val2);
par[v] = u;
siz[u] += siz[v];
sum[u] += sum[v];
}
};
void solve(int test_case){
ll n,m; cin >> n >> m;
vector<ll> a(n+5), b(n+5);
rep1(i,n) cin >> a[i];
rep1(i,n) cin >> b[i];
vector<array<ll,3>> edges;
rep1(i,m){
ll u,v; cin >> u >> v;
ll w = max(a[u],a[v]);
edges.pb({w,u,v});
}
sort(all(edges));
DSU dsu(n);
rep1(i,n) dsu.sum[i] = dsu.dp[i] = b[i];
for(auto [w,u,v] : edges){
if(dsu.same(u,v)) conts;
dsu.merge(u,v,w);
}
ll ans = dsu.dp[dsu.find(1)];
cout << ans << endl;
}
int main()
{
fastio;
int t = 1;
cin >> t;
rep1(i, t) {
solve(i);
}
return 0;
}
Editorialist's code (PyPy3)
class UnionFind:
def __init__(self, n):
self.parent = list(range(n))
def find(self, a):
acopy = a
while a != self.parent[a]:
a = self.parent[a]
while acopy != a:
self.parent[acopy], acopy = a, self.parent[acopy]
return a
def union(self, a, b):
self.parent[self.find(b)] = self.find(a)
for _ in range(int(input())):
n, m = map(int, input().split())
a = list(map(int, input().split()))
b = list(map(int, input().split()))
edges = []
for i in range(m):
u, v = map(int, input().split())
u, v = u-1, v-1
edges.append((max(a[u], a[v]), u, v))
ord = list(range(m))
ord.sort(key=lambda x: edges[x][0])
ans = b[:]
sm = b[:]
dsu = UnionFind(n)
for i in ord:
w, u, v = edges[i]
u, v = dsu.find(u), dsu.find(v)
if u == v: continue
req = 10**18
# u -> v
# ans[u] - sm[u] >= w should hold
# ans[u] - sm[u] >= ans[v] should hold
inc = max(0, max(w, ans[v]) - (ans[u] - sm[u]))
req = min(req, ans[u] + inc)
# v -> u
inc = max(0, max(w, ans[u]) - (ans[v] - sm[v]))
req = min(req, ans[v] + inc)
tot = sm[u] + sm[v]
dsu.union(u, v)
u = dsu.find(u)
ans[u] = req
sm[u] = tot
print(ans[dsu.find(0)])