RANDCOLORING - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: sushil2006
Tester: iceknight1093
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Dynamic programming, elementary combinatorics, DFS/DSU

PROBLEM:

For a graph with N vertices, consider coloring its vertices as follows:

  • For each i from 1 to N, let there be B blue vertices and R red vertices currently.
    Color vertex i red with probability \frac{R+1}{R+B+2} and blue otherwise.

Let f(G) be the probability that the coloring obtained this way is proper.

You’re given N and M, and M edges.
Starting with an empty graph on N vertices, find f(G) after each new edge addition.

EXPLANATION:

Let’s first try to find f(G) for a fixed graph G.

As an initial observation, note that when a vertex i is being colored, the denominator of the probability (which is (R+B+2)) will always equal exactly i+1, because (R+B) just equals the total number of vertices colored so far (which is i-1).

Consider a specific coloring of its vertices. What’s the probability that we attain exactly this coloring?
In particular, let’s look at only the vertices colored red: suppose they are x_1, x_2, \ldots, x_k.
Then,

  • When x_1 was colored red, there were no red vertices.
    So, the probability of this happening would’ve been exactly \frac{1}{x_1 + 1}.
  • When x_2 was colored red, the only other red vertex was x_1.
    So, the probability of it being red would’ve been \frac{2}{x_2 + 1}.
  • More generally, for any x_i, the probability that it’s colored red is exactly \frac{i}{x_i + 1}.

Notice that when you multiply all these together, the numerator is exactly k!.
In fact, applying the same logic to the vertices colored blue, we see that the numerator there is (N-k)! (after all every vertex not colored red will be colored blue).
Further, the product of the denominators of all the probabilities is exactly (N+1)!.

So, the probability of getting this specific configuration is

\frac{k! \cdot (N-k)!}{(N+1)!}

In particular, observe that this probability doesn’t really depend on which vertices are red and which are blue: it only depends on how many are red and how many are blue.

So, suppose we’re able to count the number of valid colorings with exactly k vertices being red; say c_k.
Then, the required probability would just be

\sum_{k=0}^N c_k \frac{k! (N-k)!}{(N+1)!}

Now, we turn our attention to computing c_k.

First, observe that we’re essentially looking for a 2-coloring of our graph.
For such a coloring to exist at all, the graph has to be bipartite.
So, if the graph is not bipartite, we have c_k = 0 for all k.

From now on, we’ll only deal with bipartite graphs.
Consider some connected component of a bipartite graph; say with a vertices on one side and b on the other.
Then, from this component, either a vertices will be red, or b will.

More generally, if there are m components, with the i-th component having a_i vertices on one side and b_i on the other, we have m pairs (a_i, b_i).
c_k then equals the number of ways of picking exactly one element from each of these pairs such their total sum equals k.

If this setup feels familiar, it should - after all, it’s just a variant on the classical knapsack problem!

Let dp(i, x) denote the number of ways of getting a sum of exactly x from the first i pairs.
Then, dp(i, x) = dp(i-1, x-a_i) + dp(i-1, x - b_i), depending on which element of the i-th pair we choose.

This dynamic programming runs in \mathcal{O}(N^2) time; so for a fixed graph on N vertices and M edges, we now know how to solve the problem in \mathcal{O}(N^2 + M) time:

  • Check if the graph is bipartite (say, with BFS/DFS); if not, the answer is 0.
  • If it is bipartite, find all connected components and their respective pairs; and run the above dp.

Of course, this is way too slow since we’d need to rerun it after each edge addition.
Let’s see how it can be optimized.
Consider what happens when a new edge is added:

  • If the graph was already not bipartite, the answer is 0 anyway; so nothing needs to be done.
  • If the graph is bipartite and the new edge is between two different components, it remains bipartite.
    However, this can only happen at most N-1 times; since after that the graph will be connected.
  • If the new edge is between two vertices of the same component:
    • If these two vertices have the same color, the graph stops being bipartite, and the answer is 0.
    • Otherwise, nothing changes at all, so we don’t need to rerun the dp!

So, we need to run the dynamic programming only when components change; which as noted above happens at most N-1 times.
Further, since edges within components can essentially be ignored, what we’re really maintaining is a spanning forest of the graph, which allows the bipartite check/coloring to be found in linear time whenever it changes.

However, our dp is still quadratic, so the overall complexity remains \mathcal{O}(N^3 + M), which is too slow.


To further optimize this, we use one final trick: removing an element from a knapsack (technique 5 here).

The idea here is that the order of adding elements to the knapsack doesn’t really matter; so to remove any element you simply pretend it was the last one added and just reverse the moves you made.

However, this doesn’t apply directly to our situation with pairs - the classical knapsack problem is “you either take this item, or you don’t”, not “take one element from each pair”.
However, it’s fairly easy to transform our version to the pair version.

Let’s first take a_i from every pair, for a total of \sum a_i.
Now, for the i-th pair,

  • If we take b_i instead, the change in sum is (b_i - a_i).
  • If we remain with a_i, there’s no change in sum.

This is now analogous to the original knapsack (rather, subset-sum) problem: the i-th element has value (b_i - a_i), and we can either choose to take it (meaning we go with b_i), or not take it (meaning we stay with a_i).

This now allows us to use the undo trick to remove one pair!
But wait, how is that useful here again?

Well, recall that we want to recompute the dp only when two components are merged into one.
However, when doing this, the only change is that the pairs corresponding to those components are removed, and the pair corresponding to the new component is added in - the other components don’t change at all.
So, instead of recomputing the entire dp, we undo those two components; and then add in the new one.

Each undo and insert takes linear time, and we only do it thrice.
Since this happens at most N-1 times in total, the overall complexity is \mathcal{O}(N^2), which is now fast enough!

TIME COMPLEXITY:

\mathcal{O}(N^2 + M) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;

template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;

#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl

#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)

template<typename T>
void amin(T &a, T b) {
    a = min(a,b);
}

template<typename T>
void amax(T &a, T b) {
    a = max(a,b);
}

#ifdef LOCAL
#include "debug.h"
#else
#define debug(x) 42
#endif

/*



*/

const int MOD = 998244353;
const int N = 5e3 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;

ll fact[N], ifact[N], pow2[N], ipow2[N];

ll bexp(ll a, ll b) {
    a %= MOD;
    if (a == 0) return 0;

    ll res = 1;

    while (b) {
        if (b & 1) res = res * a % MOD;
        a = a * a % MOD;
        b >>= 1;
    }

    return res;
}

ll invmod(ll a) {
    return bexp(a, MOD - 2);
}

ll ncr(ll n, ll r) {
    if (n < 0 or r < 0 or n < r) return 0;
    return fact[n] * ifact[r] % MOD * ifact[n - r] % MOD;
}

ll npr(ll n, ll r) {
    if (n < 0 or r < 0 or n < r) return 0;
    return fact[n] * ifact[n - r] % MOD;
}

void precalc(ll n) {
    fact[0] = 1;
    rep1(i, n) fact[i] = fact[i - 1] * i % MOD;

    ifact[n] = invmod(fact[n]);
    rev(i, n - 1, 0) ifact[i] = ifact[i + 1] * (i + 1) % MOD;

    pow2[0] = 1;
    rep1(i,n) pow2[i] = pow2[i-1]*2%MOD;

    ipow2[n] = invmod(pow2[n]);
    rev(i,n-1,0) ipow2[i] = ipow2[i+1]*2%MOD;
}

vector<ll> adj[N];
vector<ll> comp(N);
vector<ll> col(N);
vector<pll> colp(N);
pll col_cnt;
vector<ll> nodes;

void dfs(ll u, ll p, ll c, ll r){
    nodes.pb(u);
    comp[u] = r;
    col[u] = c;

    if(c == 0){
        col_cnt.ff++;
    }
    else{
        col_cnt.ss++;
    }

    trav(v,adj[u]){
        if(v == p) conts;
        dfs(v,u,c^1,r);
    }
}

void solve(int test_case)
{
    ll n,m; cin >> n >> m;
    rep1(i,n){
        adj[i].clear();
        comp[i] = i;
        col[i] = 0;
        colp[i] = {0,1};
    }

    vector<ll> dp(n+5);
    rep(i,n+1) dp[i] = ncr(n,i);

    ll forced = 0, zero_cnt = 0;
    ll ans = 0;

    auto upd_ans = [&](){
        ans = 0;

        rep(i,n+1){
            ll w = i+forced;
            if(w > n) break;
            ll add = fact[w]*fact[n-w]%MOD*dp[i]%MOD;
            ans = (ans+add)%MOD;
        }

        ans = ans*pow2[zero_cnt]%MOD;
        ans = ans*ifact[n+1]%MOD;
    };

    auto rem = [&](ll x){
        if(x == 0){
            zero_cnt--;
        }
        else{
            for(int i = x; i <= n; ++i){
                dp[i] = (dp[i]-dp[i-x]+MOD)%MOD;
            }
        }
    };

    auto add = [&](ll x){
        if(x == 0){
            zero_cnt++;
        }
        else{
            rev(i,n,x){
                dp[i] = (dp[i]+dp[i-x])%MOD;
            }
        }
    };

    bool bad = false;

    rep1(i,m){
        ll u,v; cin >> u >> v;
        if(bad){
            cout << ans << endl;
            conts;
        }

        if(comp[u] == comp[v]){
            if(col[u] == col[v]){
                bad = true;
                ans = 0;
            }
            cout << ans << endl;
            conts;
        }

        pll p1 = colp[u], p2 = colp[v];
        forced -= p1.ff;
        rem(p1.ss-p1.ff);
        forced -= p2.ff;
        rem(p2.ss-p2.ff);

        adj[u].pb(v), adj[v].pb(u);
        nodes.clear();
        col_cnt = {0,0};
        dfs(u,-1,0,u);
        if(col_cnt.ff > col_cnt.ss){
            swap(col_cnt.ff,col_cnt.ss);
        }

        trav(u,nodes){
            colp[u] = col_cnt;
        }

        forced += col_cnt.ff;
        add(col_cnt.ss-col_cnt.ff);
        upd_ans();

        cout << ans << endl;
    }
}

int main()
{
    precalc(N-1);

    fastio;

    int t = 1;
    cin >> t;

    rep1(i, t) {
        solve(i);
    }

    return 0;
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

/**
 * Integers modulo p, where p is a prime
 * Source: Aeren (modified from tourist?)
 *         Modmul for 64-bit mod from kactl:ModMulLL
 * Works with p < 7.2e18 with x87 80-bit long double, and p < 2^52 ~ 4.5e12 with 64-bit
 */
template<typename T>
struct Z_p{
    using Type = typename decay<decltype(T::value)>::type;
    static vector<Type> MOD_INV;
    constexpr Z_p(): value(){ }
    template<typename U> Z_p(const U &x){ value = normalize(x); }
    template<typename U> static Type normalize(const U &x){
        Type v;
        if(-mod() <= x && x < mod()) v = static_cast<Type>(x);
        else v = static_cast<Type>(x % mod());
        if(v < 0) v += mod();
        return v;
    }
    const Type& operator()() const{ return value; }
    template<typename U> explicit operator U() const{ return static_cast<U>(value); }
    constexpr static Type mod(){ return T::value; }
    Z_p &operator+=(const Z_p &otr){ if((value += otr.value) >= mod()) value -= mod(); return *this; }
    Z_p &operator-=(const Z_p &otr){ if((value -= otr.value) < 0) value += mod(); return *this; }
    template<typename U> Z_p &operator+=(const U &otr){ return *this += Z_p(otr); }
    template<typename U> Z_p &operator-=(const U &otr){ return *this -= Z_p(otr); }
    Z_p &operator++(){ return *this += 1; }
    Z_p &operator--(){ return *this -= 1; }
    Z_p operator++(int){ Z_p result(*this); *this += 1; return result; }
    Z_p operator--(int){ Z_p result(*this); *this -= 1; return result; }
    Z_p operator-() const{ return Z_p(-value); }
    template<typename U = T>
    typename enable_if<is_same<typename Z_p<U>::Type, int>::value, Z_p>::type &operator*=(const Z_p& rhs){
        #ifdef _WIN32
        uint64_t x = static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value);
        uint32_t xh = static_cast<uint32_t>(x >> 32), xl = static_cast<uint32_t>(x), d, m;
        asm(
            "divl %4; \n\t"
            : "=a" (d), "=d" (m)
            : "d" (xh), "a" (xl), "r" (mod())
        );
        value = m;
        #else
        value = normalize(static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value));
        #endif
        return *this;
    }
    template<typename U = T>
    typename enable_if<is_same<typename Z_p<U>::Type, int64_t>::value, Z_p>::type &operator*=(const Z_p &rhs){
        uint64_t ret = static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value) - static_cast<uint64_t>(mod()) * static_cast<uint64_t>(1.L / static_cast<uint64_t>(mod()) * static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value));
        value = normalize(static_cast<int64_t>(ret + static_cast<uint64_t>(mod()) * (ret < 0) - static_cast<uint64_t>(mod()) * (ret >= static_cast<uint64_t>(mod()))));
        return *this;
    }
    template<typename U = T>
    typename enable_if<!is_integral<typename Z_p<U>::Type>::value, Z_p>::type &operator*=(const Z_p &rhs){
        value = normalize(value * rhs.value);
        return *this;
    }
    template<typename U>
    Z_p &operator^=(U e){
        if(e < 0) *this = 1 / *this, e = -e;
        Z_p res = 1;
        for(; e; *this *= *this, e >>= 1) if(e & 1) res *= *this;
        return *this = res;
    }
    template<typename U>
    Z_p operator^(U e) const{
        return Z_p(*this) ^= e;
    }
    Z_p &operator/=(const Z_p &otr){
        Type a = otr.value, m = mod(), u = 0, v = 1;
        if(a < (int)MOD_INV.size()) return *this *= MOD_INV[a];
        while(a){
            Type t = m / a;
            m -= t * a; swap(a, m);
            u -= t * v; swap(u, v);
        }
        assert(m == 1);
        return *this *= u;
    }
    template<typename U> friend const Z_p<U> &abs(const Z_p<U> &v){ return v; }
    Type value;
};
template<typename T> bool operator==(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value == rhs.value; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(const Z_p<T>& lhs, U rhs){ return lhs == Z_p<T>(rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) == rhs; }
template<typename T> bool operator!=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(const Z_p<T> &lhs, U rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(U lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T> bool operator<(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value < rhs.value; }
template<typename T> bool operator>(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value > rhs.value; }
template<typename T> bool operator<=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value <= rhs.value; }
template<typename T> bool operator>=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value >= rhs.value; }
template<typename T> Z_p<T> operator+(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(const Z_p<T> &lhs, U rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T> Z_p<T> operator-(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T> Z_p<T> operator*(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T> Z_p<T> operator/(const Z_p<T> &lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(const Z_p<T>& lhs, U rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(U lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T> istream &operator>>(istream &in, Z_p<T> &number){
    typename common_type<typename Z_p<T>::Type, int64_t>::type x;
    in >> x;
    number.value = Z_p<T>::normalize(x);
    return in;
}
template<typename T> ostream &operator<<(ostream &out, const Z_p<T> &number){ return out << number(); }

/*
using ModType = int;
struct VarMod{ static ModType value; };
ModType VarMod::value;
ModType &mod = VarMod::value;
using Zp = Z_p<VarMod>;
*/

// constexpr int mod = 1e9 + 7; // 1000000007
constexpr int mod = (119 << 23) + 1; // 998244353
// constexpr int mod = 1e9 + 9; // 1000000009
using Zp = Z_p<integral_constant<decay<decltype(mod)>::type, mod>>;

template<typename T> vector<typename Z_p<T>::Type> Z_p<T>::MOD_INV;
template<typename T = integral_constant<decay<decltype(mod)>::type, mod>>
void precalc_inverse(int SZ){
    auto &inv = Z_p<T>::MOD_INV;
    if(inv.empty()) inv.assign(2, 1);
    for(; inv.size() <= SZ; ) inv.push_back((mod - 1LL * mod / (int)inv.size() * inv[mod % (int)inv.size()]) % mod);
}

template<typename T>
vector<T> precalc_power(T base, int SZ){
    vector<T> res(SZ + 1, 1);
    for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * base;
    return res;
}

template<typename T>
vector<T> precalc_factorial(int SZ){
    vector<T> res(SZ + 1, 1); res[0] = 1;
    for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * i;
    return res;
}

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    auto fac = precalc_factorial<Zp>(5005);
    auto inv = fac;
    for (auto &x : inv) x = 1/x;
    auto C = [&] (int n, int r) {
        if (n < 0 or n < r) return Zp(0);
        return fac[n] * inv[r] * inv[n-r];
    };

    int t; cin >> t;
    while (t--) {
        int n, m; cin >> n >> m;

        vector<int> comp(n), col(n), white(n), black(n);
        vector<vector<int>> adj(n);
        for (int i = 0; i < n; ++i) {
            comp[i] = i;
            white[i] = 1;
            col[i] = 1;
        }
        
        Zp ans = 0;
        vector<Zp> dp(n+1);
        for (int i = 0; i <= n; ++i) dp[i] = C(n, i);
        for (int i = 0; i <= n; ++i) ans += dp[i]*fac[i]*fac[n-i];
        ans /= fac[n+1];

        bool bad = false;
        int tot = n;
        auto upd = [&] (int u, int v) {
            for (int c : {comp[u], comp[v]}) {
                int d = white[c] - black[c];
                for (int i = 0; i <= n; ++i) {
                    if (d == 0) dp[i] /= 2;
                    else if (i >= d) dp[i] -= dp[i-d];
                }
            }

            adj[u].push_back(v);
            adj[v].push_back(u);
            int c = comp[u];
            tot -= white[comp[u]] + white[comp[v]];
            white[c] = black[c] = 0;
            vector<int> stk = {u};
            for (int i = 0; i < n; ++i) if (comp[i] == comp[u] or comp[i] == comp[v])
                col[i] = 0;
            
            col[u] = 1;
            while (!stk.empty()) {
                int x = stk.back();
                stk.pop_back();
                if (col[x] == 1) ++white[c];
                else ++black[c];
                comp[x] = c;
                for (auto y : adj[x]) {
                    if (col[y]) continue;
                    col[y] = 3 - col[x];
                    stk.push_back(y);
                }
            }
            if (white[c] < black[c]) swap(white[c], black[c]);

            int d = white[c] - black[c];
            for (int i = n; i >= 0; --i) {
                if (i >= d) dp[i] += dp[i-d];
            }
            tot += white[c];
            
            ans = 0;
            for (int i = 0; i <= n; ++i) {
                int ct = tot - i;
                ans += dp[i] * fac[ct] * fac[n-ct];
            }
            ans /= fac[n+1];
        };

        while (m--) {
            int u, v; cin >> u >> v;
            --u, --v;
            if (comp[u] == comp[v]) {
                if (col[u] == col[v]) bad = true;
            }
            else upd(u, v);
            if (bad) ans = 0;
            cout << ans << '\n';
        }
    }
}