ROADAIR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: dolesh
Preparation: iceknight1093
Tester: mridulahi
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Dynamic programming, finding SCCs

PROBLEM:

You have a directed graph on N vertices with M edges.
It is allowed to move directly from vertex u to vertex v if:

  • The edge u\to v exists, or
  • There exists a path from u to v and a path from v to u.

Find the number of sequences of movements that visit each vertex at most once.

EXPLANATION:

First off, the condition that there exists a path from u to v and a path from v to u in a directed graph, is the same as saying that u and v lie in the same strongly connected component.
All strongly connected components of a graph can be found in \mathcal{O}(N+M) time: the link in the prerequisites section contains a tutorial.

Now, let’s look at how a sequence of visits looks like for us.
We start at some vertex u, then repeat the following:

  • Freely visit vertices in the same SCC as u - as many as we like, and in any order, as long as each vertex is visited at most once.
    Say we end this at vertex v (which is still in the same SCC as u).
  • Then, follow an edge from v to a different SCC, and repeat the process there.

To count the number of ways to perform this multi-stage process, we’ll use dynamic programming.
Let:

  • \text{dp}_1[u] denote the number of paths that end at u, such that we haven’t visited any other vertex in the SCC of u yet.
  • \text{dp}_2[u] denote the number of paths that end at u, such that we have finished visiting vertices in the SCC of u, and are ready to move to a different component now.

Then, we obtain the following transitions.

\text{dp}_1[u] = 1 + \sum_v \text{dp}_2[v] across all v such that the edge v \to u exists, and u and v lie in different components.
This is because we can either start at u (which gives the 1), or reach it from a different component.
By definition, \text{dp}_2[v] means we ended at v and are ready to move to a different component, so any path ending at u which hasn’t visited other vertices within its SCC can be obtained by extending such a path ending at v (where v lies in a different SCC from u).

\text{dp}_2[u] requires a bit of math to compute.
We have a couple of choices:

  • First, we can enter this SCC at u, and then leave it immediately without visiting anything else.
    The number of ways is just \text{dp}_1[u].
  • Otherwise, we want to leave this SCC at u.
    That means we must enter it at some other vertex v, visit a bunch of other vertices in the SCC in some order, then go to u.
    For a fixed v, the number of ways of doing this is exactly
\sum_{i=0}^{k-2} \text{dp}_1[v] \cdot \binom{k-2}{i} \cdot i!

where k is the size of the SCC containing u.
This is because:

  • There are \text{dp}_1[v] ways to enter the SCC via v.
  • The last vertex has to be u, which leaves k-2 of them.
  • Of them, we can choose any i to visit (in \binom{k-2}{i} ways), then one of i! orders to visit them in.
    Summing this up across all i gets us the expression above.

Now, what we want is the sum of this across all v\neq u that are in the SCC of u.
To compute this quickly, note that the quantity \sum_{i=0}^{k-2} \binom{k-2}{i} \cdot i! is a common multiplier independent of v, so it only needs to be computed once.
Then, we just want the sum of \text{dp}_1[v] for all v other than u in this SCC, which is easy: maintain the sum of all \text{dp}_1[v] of the SCC and subtract out \text{dp}_1[u].

When computing the dp, process all the SCCs in topological order (since we need the dp values of “earlier” SCCs to compute the values for later ones), and within a single SCC first compute all the \text{dp}_1 values followed by \text{dp}_2 values; since computing \text{dp}_2[u] requires knowledge of all dp_1 values in that SCC.

TIME COMPLEXITY:

\mathcal{O}(N + M) per testcase.

CODE:

Tester's code (C++)
#include<bits/stdc++.h>

using namespace std;

#define all(x) begin(x), end(x)
#define sz(x) static_cast<int>((x).size())
#define int long long

void d1 (int x, vector<int> adj[], bool vis[], vector<int> &v) {
        vis[x] = true;
        for (auto u : adj[x]) {
                if (vis[u]) continue;
                d1 (u, adj, vis, v);
        }
        v.push_back(x);
}
void d2 (int xp, int x, vector<int> radj[], bool vis[], int p[]) {
        vis[x] = true; p[x] = xp;
        for (auto u : radj[x]) {
                if (vis[u]) continue;
                d2(xp, u, radj, vis, p);
        }
}
void scc (int n, vector<int> adj[], vector<int> radj[], int p[]) {
        vector<int> v; bool vis[n] = {0};
        for (int i = 0; i < n; i++) {
                if (vis[i]) continue;
                d1 (i, adj, vis, v);
        }
        reverse(v.begin(), v.end());
        fill(vis, vis + n, 0);
        for (auto u : v) {
                if (vis[u]) continue;
                d2 (u, u, radj, vis, p);
        }
}

void dfsfortopo (int i, int state[], vector<int> &topo, vector<int> adj[]) {
        state[i] = 1;
        for (auto u : adj[i]) {
                if (state[u] == 0) dfsfortopo(u, state, topo, adj);
        }
        topo.push_back(i);
        state[i] = 2;
}
 
void toposort (vector<int> &topo, int n, vector<int> adj[]) {
        int state[n] = {0};
        for (int i = 0; i < n; i++) {
                if (state[i] == 0) dfsfortopo(i, state, topo, adj);
        }
        reverse(all(topo));
} 

const int mod = 998244353;


signed main() {

        ios::sync_with_stdio(0);
        cin.tie(0);
        cout.tie(0);

        int t;
        cin >> t;

        while (t--) {

                int n, m;
                cin >> n >> m;
                vector<int> adj[n], radj[n];
                for (int i = 0; i < m; i++) {
                        int u, v;
                        cin >> u >> v;
                        u--; v--;
                        adj[u].push_back(v);
                        radj[v].push_back(u);
                }

                int p[n];
                scc(n, adj, radj, p);
                vector<int> a2[n];
                for (int i = 0; i < n; i++) {
                        for (auto u : adj[i]) {
                                if (p[u] == p[i]) continue;
                                a2[p[i]].push_back(p[u]);
                        }
                }
                vector<int> topo;
                toposort(topo, n, a2);
                int dp1[n], dp2[n] = {0};
                for (int i = 0; i < n; i++) dp1[i] = 1;
                vector<int> sub[n];
                for (int i = 0; i < n; i++) sub[p[i]].push_back(i);

                for (auto x : topo) {
                        if (p[x] != x) continue;
                        int z = sz(sub[x]) - 2;
                        int mul = 0;
                        int cur = 1;
                        for (int i = 0; i <= z; i++) {
                                mul += cur;
                                cur = cur * (z - i) % mod;
                        }
                        mul %= mod;
                        int sm = 0;
                        for (auto u : sub[x]) sm += dp1[u];
                        sm %= mod;
                        for (auto u : sub[x]) {
                                dp2[u] = (dp1[u] + (sm - dp1[u]) * mul) % mod;
                                for (auto v : adj[u]) {
                                        if (p[u] == p[v]) continue;
                                        dp1[v] += dp2[u];
                                        if (dp1[v] >= mod) dp1[v] -= mod;
                                }
                        }
                }

                int ans = accumulate(dp2, dp2 + n, 0ll) % mod;

                cout << ans << "\n";

        }
        
        
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

vector<int> val, comp, z, cont;
vector<vector<int>> sccs;
int Time, ncomps;
int dfs(int j, auto& g, auto &f) {
	int low = val[j] = ++Time, x; z.push_back(j);
	for (auto e : g[j]) if (comp[e] < 0)
		low = min(low, val[e] ?: dfs(e,g,f));
	if (low == val[j]) {
		do {
			x = z.back(); z.pop_back();
			comp[x] = ncomps;
			cont.push_back(x);
		} while (x != j);
		f(cont); cont.clear(); // cont contains an SCC
		ncomps++;
	}
	return val[j] = low;
}
void scc(auto& g, auto f) {
	int n = size(g);
	val.assign(n, 0); comp.assign(n, -1);
    sccs.clear();
	Time = ncomps = 0;
	for (int i = 0; i < n; ++i) if (comp[i] < 0) dfs(i, g, f);
}

const int mod = 998'244'353;
int mpow(int a, int n) {
    int res = 1;
    while (n) {
        if (n & 1) res = (1LL * res * a) % mod;
        a = (1LL * a * a) % mod;
        n /= 2;
    }
    return res;
}

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    const int lim = 3e5 + 5;
    vector<int> mul(lim);
    int fac = 1;
    mul[0] = 1;
    for (int i = 1; i < lim; ++i) {
        fac = (1LL * fac * i) % mod;
        mul[i] = (mul[i-1] + mpow(fac, mod-2)) % mod;
    }
    fac = 1;
    for (int i = 1; i < lim; ++i) {
        fac = (1LL * fac * i) % mod;
        mul[i] = (1LL * mul[i] * fac) % mod;
    }

    int t; cin >> t;
    while (t--) {
        int n, m; cin >> n >> m;
        vector adj(n, vector<int>());
        vector radj(n, vector<int>());
        for (int i = 0; i < m; ++i) {
            int u, v; cin >> u >> v;
            adj[--u].push_back(--v);
            radj[v].push_back(u);
        }

        vector<int> dp1(n), dp2(n);
        // dp1[u] -> number of paths that end at u, and haven't touched any other vertex in its scc
        // dp2[u] -> number of paths that end at u, and will never touch any other vertex in its scc
        // ans = sum(dp2[u])
        scc(adj, [&] (const auto &lst) {
            sccs.push_back(lst);
        });
        reverse(begin(sccs), end(sccs));
        for (auto C : sccs) {
            int sumall = 0;
            for (int u : C) {
                dp1[u] = 1;
                for (int v : radj[u]) {
                    if (comp[u] != comp[v]) dp1[u] = (dp1[u] + dp2[v]) % mod;
                }
                sumall = (sumall + dp1[u]) % mod;
            }
            for (int u : C) {
                dp2[u] = dp1[u]; // Enter here, leave immediately

                // Enter elsewhere, leave through u
                // fix entry point v, k vertices in component => dp1[v] * sum((k-2)! / (k-2-x)!) for 0 <= x <= k-2
                // sum(dp1[v]) * C for all v in the scc, other than u
                if (C.size() >= 2) dp2[u] = (dp2[u] + 1LL * (sumall - dp1[u] + mod) * mul[C.size() - 2]) % mod;
            }
        }
        int ans = 0;
        for (int u = 0; u < n; ++u) ans = (ans + dp2[u]) % mod;
        cout << ans << '\n';
    }
}