EVADEROBOT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: mathmodel, raysh07
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

DFS, Dynamic Programming

PROBLEM:

There’s an undirected graph with N edges and M vertices.
You start at vertex S.

A robot will start at some vertex A_0 \neq S, and in each second, move to an adjacent vertex; with it being at A_i after i seconds.
You can either stay at your current vertex or move to an adjacent one every second.

A is called avoidable if there exists a way for you to to not meet the robot at all throughout all of its moves; either on a vertex or along an edge.
Count the number of avoidable sequences.

EXPLANATION:

First, we need to characterize avoidable sequences.

To start with, there a couple of simple cases.

  • If S lies on a cycle, any sequence is avoidable.
    This is because, as long as you stay on the cycle, you’ll always be able to avoid the robot.
    Each vertex on the cycle has two other adjacent vertices on the cycle; and the robot can occupy at most one of them.
    If the robot is going to move to your current vertex, move to whichever adjacent cycle vertex doesn’t contain it; otherwise just stay in place.
    Since you always remain on the cycle, this can be repeated indefinitely.
  • Next, if \text{deg}(S) \geq 3, again any sequence is avoidable. \text{deg}(S) refers to the degree of S, i.e. the number of edges incident to it.
    To see why, suppose x, y, z are three vertices adjacent to S.
    Suppose A_i = S is the first time the robot visits S, which means A_{i-1} and A_{i+1} are two among \{x, y, z\}. Without loss of generality, let them be x and y (so z is untouched).
    In the i-th second, move to z, and in the (i+1)-th second, move back to S.
    This saves you for two moves while maintaining the invariant that you remain at S.
    Repeat this over and over to never get caught.

This leaves us with the case where \text{deg}(S) \leq 2 and S is not on a cycle.

Note that this means we have only two directions we can move in - for simplicity, let’s call them “left” and “right”.
The robot cannot start at S, so it must start either on the left or on the right.
The cases are symmetrical, so suppose it starts on the left.

There are now a couple of possibilities.
First, suppose we repeatedly move right, till we first reach a vertex whose degree is not 2.
Let the vertex we reach be v_R. Note that we definitely cannot be caught before reaching v_R.
Then, if v_R lies on a cycle or \text{deg}(v_R) \geq 3, we can safely avoid the robot in the future because of the initial analysis — in fact, if v_R lies on a cycle it will surely satisfy \text{deg}(v_R) \geq 3 anyway, so really all we care about is whether \text{deg}(v_R) \geq 3 or not.

That only leaves the case of \text{deg}(v_R) = 1.
Here, note that if the robot never actually visits v_R, we’ll definitely be safe since we can reach v_R and then just stay there.
If the robot does visit v_R however, we have no hope of avoiding the robot by moving towards the right since we’ll end up forced into a corner. So, our only hope is to move left and do something on that side.

Just like we found v_R, let v_L be the closest vertex on the left with degree \neq 2.
If \text{deg}(v_L) = 1, then we have no chance of avoiding the robot at all - recall that we’re in the case where the robot visits v_R, and if \text{deg}(v_R) = \text{deg}(v_L) = 1 the graph looks like a line.

If \text{deg}(v_L) \geq 3 however, we might be able to escape - as long as we’re able to reach v_L before getting caught.
Specifically, we have the following strategy:

  • If it’s possible to move left without getting caught, move left.
  • Otherwise, if it’s possible to stay in place without getting caught, stay in place.
  • Otherwise, move right.

If we’re able to reach v_L using this strategy, we’ll be free, otherwise it’s impossible to do so.
The question is, when exactly does this strategy fail?

Let d = \text{dist}(v_L, S) denote the distance from S to v_L.
We need at least d seconds to reach v_L, so a necessary condition is that the robot is somewhere to the left of v_L at some time \geq d.
However, there’s the additional constraint that we must have not lost before that: that is, if t denotes the first time \geq d that the robot is to the left of v_L, then the robot must not have visited v_R at a time \lt t.

These conditions are necessary, and it’s not hard to see that they’re sufficient as well - since we always strive to make the leftward-most move we can, we’ll end up exactly one vertex to the right of the robot in no more than d-1 steps; after which if the robot ever goes to the left of v_L we’re free.


Now that the conditions are known, let’s move to counting valid walks.

Case 1: \text{deg}(S) \geq 3 or S lies on a cycle.
Here, all walks of length K that don’t start at S are valid.
Counting the number of such walks can be done using dynamic programming.
Specifically, define dp_{i, j} to be the number of valid walks of length j ending at vertex i.
dp_{i, j} can be computed as the sum of dp_{x, j-1} across all x that are adjacent to i.
The base conditions are dp_{i, 0} = 1 for all i, except dp_{S, 0} = 0.
This runs in \mathcal{O}(K\cdot (N+M)) time - there are \mathcal{O}(NK) states, and each edge is processed at most twice per length.

To check if S lies on a cycle, start a DFS at S and check if there are any back-edges to S.

Case 2: \text{deg}(S) = 1.
Let’s count just avoidable walks that start to the left of S - ones that start on the right can be handled similarly.

We have sub-cases here.
As defined above, let v_L and v_R be the closest vertices to the left/right of v whose degree is \neq 2.


Case 2.1: If \text{deg}(v_R) \geq 3, then any walk starting on the left is valid.
Counting such walks can be done using the same DP as before, just with the initial conditions being dp_{i, 0} = 1 for only those i on the left.


Case 2.2: If \text{deg}(v_L) = \text{deg}(v_R) = 1, valid walks are exactly those that start on the left and do not reach v_R.
This can again be counted using the same DP: the only difference is that you disallow transitions to v_R.


Case 2.3: If \text{deg}(v_L) \geq 3 and \text{deg}(v_R) = 1, valid walks are those that reach some vertex to the left of v_L at a time \geq \text{dist}(S, v_L), while never reaching v_R beforehand (though what they do afterwards doesn’t matter).

The easiest way to count such walks, is to count their complement instead.
That is, start with all walks starting on the left, and subtract only unavoidable walks from them.
This is because here, unavoidable walks are easier to deal with: negating the condition for avoidable walks, we see that unavoidable walks are those that start on the left, reach v_R at some point, and always stay on the path between v_L and v_R at all times \geq \text{dist}(S, v_L).

So, we can write the following DP: dp_{i, j, t} is the number of walks ending at i at time j, where t = 0/1 denotes whether v_R has been visited yet or not.
The transitions are basically the same: the only difference is that when visiting v_R the value of t must be updated, and if j \geq \text{dist}(S, v_L) we set dp_{i, j, 0} = 0 for those vertices outside the (v_L, v_R) path.

In all cases the complexity is \mathcal{O}((N+M)\cdot K) so we’re done.

TIME COMPLEXITY:

\mathcal{O}((N+M)\cdot K) per testcase.

CODE:

Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n, m, k, s; cin >> n >> m >> k >> s;
        --s;
        vector adj(n, vector<int>());
        for (int i = 0; i < m; ++i) {
            int u, v; cin >> u >> v;
            adj[--u].push_back(--v);
            adj[v].push_back(u);
        }

        const int mod = 998244353;
        auto count_paths = [&] (vector<int> starts, vector<int> ban, int when = -1, int target = -1, bool mode = 0) {
            vector dp(n, array<int, 2>());
            for (int i = 0; i < n; ++i) {
                if (starts[i]) dp[i][0] = 1;
            }
            
            for (int i = 1; i <= k; ++i) {
                vector ndp(n, array<int, 2>());
                for (int x = 0; x < 2; ++x) for (int u = 0; u < n; ++u) for (int v : adj[u]) {
                    int y = x | (v == target);
                    ndp[v][y] = (ndp[v][y] + dp[u][x]) % mod;
                }

                for (int u = 0; u < n; ++u) {
                    if (ban[u] and i >= when) ndp[u][0] = 0;
                }
                swap(dp, ndp);
            }

            int res = 0;
            for (int i = 0; i < n; ++i) {
                res += dp[i][1]; res %= mod;
                if (mode == 0) res += dp[i][0];
                res %= mod;
            }
            return res;
        };

        vector par(n, -1);
        auto cycle_check = [&] (const auto &self, int u) -> bool {
            bool res = false;
            for (int v : adj[u]) {
                if (par[u] == v) continue;
                if (par[v] >= 0) continue;
                if (v == s) return true;

                par[v] = u;
                res |= self(self, v);
            }
            return res;
        };

        // If on cycle or deg >= 3, all good
        if (cycle_check(cycle_check, s) or adj[s].size() >= 3) {
            vector starts(n, 1), ban(n, 0);
            starts[s] = 0;
            cout << count_paths(starts, ban) << '\n';
            continue;
        }

        // Not on cycle, degree <= 2
        // Go left and right till you reach a degree >= 3 vertex or die - say v1 and v2
        vector<int> mark(n, -1), vs, dist(n, -1);
        auto dfs = [&] (const auto &self, int u) -> void {
            for (int v : adj[u]) {
                if (mark[v] != -1) continue;
                if (u != s) mark[v] = mark[u];
                else mark[v] = v;
                dist[v] = 1 + dist[u];
                
                if (adj[v].size() != 2) {
                    if (vs.empty() or mark[v] != mark[vs.back()]) vs.push_back(v);
                }
                self(self, v);
            }
        };
        mark[s] = s;
        dist[s] = 0;
        dfs(dfs, s);
        if (adj[s].size() == 1) vs.push_back(s);

        int ans = 0;
        for (int border : vs) {
            int other = vs[0] + vs[1] - border;
            if (other == s) continue;
            vector starts(n, 0), ban(n, 0);
            for (int i = 0; i < n; ++i) {
                if (i != s and mark[i] != -1 and mark[i] != mark[border]) starts[i] = 1;
            }
            
            if (adj[border].size() >= 3) {
                // Everything on other side is good
                ans += count_paths(starts, ban); ans %= mod;
            }
            else if (adj[other].size() == 1) {
                // Everything starting other side, except reaching here
                ban[border] = 1;
                ans += count_paths(starts, ban); ans %= mod;
            }
            else {
                ans += count_paths(starts, ban); ans %= mod;
                int d = dist[other];
                for (int i = 0; i < n; ++i) {
                    if (mark[i] == mark[other] and dist[i] > dist[other]) ban[i] = 1;
                }
                ans += mod - count_paths(starts, ban, d, border, true); ans %= mod;
            }
        }
        cout << ans << '\n';
    }
}
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

const int mod = 998244353;

struct mint{
    int x;

    mint (){ x = 0;}
    mint (int32_t xx){ x = xx % mod; if (x < 0) x += mod;}
    mint (long long xx){ x = xx % mod; if (x < 0) x += mod;}

    int val(){
        return x;
    }
    mint &operator++(){
        x++;
        if (x == mod) x = 0;
        return *this;
    }
    mint &operator--(){
        if (x == 0) x = mod;
        x--;
        return *this;
    }
    mint operator++(int32_t){
        mint result = *this;
        ++*this;
        return result;
    }
    
    mint operator--(int32_t){
        mint result = *this;
        --*this;
        return result;
    }
    mint& operator+=(const mint &b){
        x += b.x;
        if (x >= mod) x -= mod;
        return *this;
    }
    mint& operator-=(const mint &b){
        x -= b.x;
        if (x < 0) x += mod;
        return *this;
    }
    mint& operator*=(const mint &b){
        long long z = x;
        z *= b.x;
        z %= mod;
        x = (int)z;
        return *this;
    }
    mint operator+() const {
        return *this;
    }
    mint operator-() const {
        return mint() - *this;
    }
    mint operator/=(const mint &b){
        return *this = *this * b.inv();
    }
    mint power(long long n) const {
        mint ok = *this, r = 1;
        while (n){
            if (n & 1){
                r *= ok;
            }
            ok *= ok;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        return power(mod - 2);
    }
    friend mint operator+(const mint& a, const mint& b){ return mint(a) += b;}
    friend mint operator-(const mint& a, const mint& b){ return mint(a) -= b;}
    friend mint operator*(const mint& a, const mint& b){ return mint(a) *= b;}
    friend mint operator/(const mint& a, const mint& b){ return mint(a) /= b;}
    friend bool operator==(const mint& a, const mint& b){ return a.x == b.x;}
    friend bool operator!=(const mint& a, const mint& b){ return a.x != b.x;}
    mint power(mint a, long long n){
        return a.power(n);
    }
    friend ostream &operator<<(ostream &os, const mint &m) {
        os << m.x;
        return os;
    }
    explicit operator bool() const {
        return x != 0;
    }
};

// Remember to check MOD

void Solve() 
{
    int n, m, k, s; cin >> n >> m >> k >> s;
    
    vector<vector<int>> adj(n + 1);
    vector <int> deg(n + 1);
    
    for (int i = 1; i <= m; i++){
        int u, v; cin >> u >> v;
        
        deg[u]++;
        deg[v]++;
        
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    
    auto count = [&](){
        vector <mint> dp(n + 1);
        for (int i = 1; i <= n; i++){
            if (i != s)
            dp[i] = 1;
        }
        
        for (int t = 1; t <= k; t++){
            vector <mint> ndp(n + 1);
            for (int i = 1; i <= n; i++){
                for (int v : adj[i]){
                    ndp[v] += dp[i];
                }
            }
            
            dp = ndp;
        }
        
        mint ans = 0;
        for (int i = 1; i <= n; i++){
            ans += dp[i];
        }
        return ans;
    };
    
    mint ans = count();
    
    if (deg[s] >= 3){
        cout << ans << "\n";
        return;
    }
    
    vector <int> comp;
    auto dfs = [&](auto self, int u, int par) -> void{
        comp.push_back(u);
        if (u == s || deg[u] >= 3){
            return;
        }
        
        for (int v : adj[u]) if (v != par){
            self(self, v, u);
        }
    };
    
    vector <int> vec;
    bool both = false;
    
    if (deg[s] == 2){
        dfs(dfs, adj[s][0], s);
        auto l = comp; 
        comp.clear();
        dfs(dfs, adj[s][1], s);
        auto r = comp;
        
        if (l.back() == s || r.back() == s || (deg[l.back()] >= 3 && deg[r.back()] >= 3)){
            cout << ans << "\n";
            return;
        }
        
        if (deg[l.back()] >= 3){
            reverse(r.begin(), r.end());
            for (int x : r) vec.push_back(x);
            vec.push_back(s);
            for (int x : l) vec.push_back(x);
        } else if (deg[r.back()] >= 3){
            reverse(l.begin(), l.end());
            for (int x : l) vec.push_back(x);
            vec.push_back(s);
            for (int x : r) vec.push_back(x);
        } else {
            reverse(l.begin(), l.end());
            for (int x : l) vec.push_back(x);
            vec.push_back(s);
            for (int x : r) vec.push_back(x);
            both = true;
        }
    } else {
        dfs(dfs, adj[s][0], s);
        auto l = comp;
        
        vec.push_back(s);
        for (int x : l){
            vec.push_back(x);
        }
        
        if (deg[l.back()] == 1){
            both = true;
        }
    }
    
    vector <int> pos(n + 1, -1);
    for (int i = 0; i < vec.size(); i++){
        pos[vec[i]] = i;
    }
    
    {
        int tt = vec.size() - pos[s] - 1;
        
        vector<vector<mint>> dp(n + 1, vector<mint>(2, 0));
        for (int i = 1; i <= n; i++){
            if (pos[i] == -1 || pos[i] > pos[s]){
                dp[i][0] = 1;
            }
        }
        
        for (int t = 1; t <= k; t++){
            vector<vector<mint>> ndp(n + 1, vector<mint>(2, 0));
            
            for (int u = 1; u <= n; u++){
                for (int i = 0; i < 2; i++){
                    for (int v : adj[u]){
                        if (pos[v] == -1 && t >= tt && i == 0) continue;
                        ndp[v][i | (pos[v] == 0)] += dp[u][i];
                    }
                }
            }
            
            dp = ndp;
        }
        
        for (int i = 1; i <= n; i++){
            ans -= dp[i][1];
        }
    }
    
    if (both){
        vector<vector<mint>> dp(n + 1, vector<mint>(2, 0));
        for (int i = 1; i <= n; i++){
            if (pos[i] < pos[s]){
                dp[i][0] = 1;
            }
        }
        
        for (int t = 1; t <= k; t++){
            vector<vector<mint>> ndp(n + 1, vector<mint>(2, 0));
            
            for (int u = 1; u <= n; u++){
                for (int i = 0; i < 2; i++){
                    for (int v : adj[u]){
                        ndp[v][i | (pos[v] == vec.size() - 1)] += dp[u][i];
                    }
                }
            }
            
            dp = ndp;
        }
        
        for (int i = 1; i <= n; i++){
            ans -= dp[i][1];
        }
    }
    
    cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}