WATEREVERY - Editorial

Problem Link: Water Everywhere - Problems - CodeChef
Author: anitm

The problem is based on Kosaraju’s algorithm for Strongly Connected Components (SCCs).

For each component, we are interested in finding the SCCs which are unreachable from any vertex. The reachable SCCs will be able to get water by these unreachable ones. Whenever the first DFS is run, the vertex with the highest ending time in that DFS is pushed into the stack the latest. So, upon popping from the stack, if a vertex is unexplored, it is a part of an unreachable SCC in a component, and hence, all nodes reachable from this node must be explored. Additionally, to find the corresponding cost and size of the SCC, another DFS must be run on the reverse graph.

One possible system of water stations with the minimum number of nodes would need to have exactly one node from each of these unreachable SCCs. So, for each unreachable SCC, the expected cost from the SCC would be TotalCostOfNodes/NumberOfNodes. The sum of these expectations for all unreachable SCCs results in the answer.

Time Complexity: O(n+m)

C++ Code
#pragma GCC optimize("Ofast")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,avx2,fma")
#pragma GCC optimize("unroll-loops")
#include <bits/stdc++.h> 
#include <complex>
#include <queue>
#include <set>
#include <unordered_set>
#include <list>
#include <chrono>
#include <random>
#include <iostream>
#include <algorithm>
#include <cmath>
#include <string>
#include <vector>
#include <map>
#include <unordered_map>
#include <stack>
#include <iomanip>
#include <fstream>
 
using namespace std;
 
typedef long long ll;
typedef long double ld;
typedef pair<int,int> p32;
typedef pair<ll,ll> p64;
typedef pair<double,double> pdd;
typedef vector<ll> v64;
typedef vector<int> v32;
typedef vector<vector<int> > vv32;
typedef vector<vector<ll> > vv64;
typedef vector<vector<p64> > vvp64;
typedef vector<p64> vp64;
typedef vector<p32> vp32;
ll MOD = 998244353;
double eps = 1e-12;
#define forn(i,e) for(ll i = 0; i < e; i++)
#define forsn(i,s,e) for(ll i = s; i < e; i++)
#define rforn(i,s) for(ll i = s; i >= 0; i--)
#define rforsn(i,s,e) for(ll i = s; i >= e; i--)
#define ln "\n"
#define dbg(x) cout<<#x<<" = "<<x<<ln
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define INF 2e18
#define fast_cin() ios_base::sync_with_stdio(false); cin.tie(NULL); cout.tie(NULL)
#define all(x) (x).begin(), (x).end()
#define sz(x) ((ll)(x).size())

const int nax = 2e5+5;
vector<int> adj[nax], revadj[nax];
int sccsize[nax];
ll cost[nax], scccost[nax];
bool vis1[nax], vis2[nax], vis3[nax];
stack <int> st;

ll binexpo(ll a, ll b) {
    ll res = 1;
    while (b) {
        if (b&1) res = res*a%MOD;
        a = a*a%MOD;
        b >>= 1;
    }
    return res;
}

ll modinv(ll a) {
    return binexpo(a, MOD-2);
}
// Push into stack when dfs ends
void dfs1(int u) {
    vis1[u] = true;
    for (auto v: adj[u]) if (!vis1[v]) dfs1(v);
    st.push(u);
}
// Mark all nodes in the component after node u
void dfs2(int u) {
    vis2[u] = true;
    for (auto v: adj[u]) if (!vis2[v]) dfs2(v);
}
// Get the cost and size of SCC which is unreachable from other SCCs
void dfs3(int u, int c) {
    vis3[u] = true;
    scccost[c] = (scccost[c]+cost[u])%MOD;
    sccsize[c]++;
    for (auto v: revadj[u]) if (!vis3[v]) dfs3(v, c);
}

void solve() {
    int n, m;
    cin >> n >> m;
    for (int i = 1; i <= n; i++) cin >> cost[i];
    for (int i = 0; i < m; i++) {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        revadj[v].push_back(u);
    }
    for (int i = 1; i <= n; i++) if (!vis1[i]) dfs1(i);
    int c = 0;  // Component number
    ll ans = 0;
    while (!st.empty()) {
        int u = st.top();
        st.pop();
        if (!vis2[u]) {
            dfs2(u);    // Mark all nodes reachable from u
            dfs3(u, c); // Get the cost and size of the SCC of u
            ans = (ans+scccost[c]*modinv(sccsize[c])%MOD)%MOD;      // Add contribution of the SCC to answer
            c++;
        }
    }
    cout << ans << ln;
}

int main() {
    fast_cin();
    ll t = 1;
    for(int it=1;it<=t;it++) {
        solve();
    }
    return 0;
}
1 Like