BINSTRRAND - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3

Authors: Jeevan Jyot Singh
Testers: Abhinav Sharma and Lavish Gupta
Editorialist: Nishank Suresh

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Probabilities, Matrix Exponentiation

PROBLEM:

The BSRand machine takes a binary string A of length N as input and produces another string B of length N as follows:
For each i from 1 to N,

  • Choose a random index 1\leq j \leq N
  • Choose a random index 1\leq k \leq N
  • Set B_i := A_j \oplus A_k

You are given two binary strings A and B. The BSRand is applied T times successively on A. What is the probability that the resulting string is B?

QUICK EXPLANATION

  • Build an (N+1)\times (N+1) matrix M such that for 0\leq i, j \leq N,
M_{i, j} = \binom{N}{j} \cdot p_i^j \cdot (1-p_i)^{N-j}

where p_i = \frac{2i(N-i)}{N^2}.

  • Suppose A contains x ones and B contains y ones. The final answer is then (M^T)_{x, y} divided by \binom{N}{y}.

EXPLANATION:

Let’s analyze what happens when one iteration of the BSRand is run.
Each character of the output string can be treated independently, so let’s focus just on the first one.
What is the probability that this character is a '1'?
For this to happen, there are two possibilities:

  • A_j = 1 and A_k = 0, or
  • A_j = 0 and A_k = 1

If A has x ones, because j and k are chosen uniformly randomly, the probability of the first case happening is x\cdot (N-x) / N^2. The second case similarly has the exact same probability.

So, the probability that the first character is a '1' is p = 2 x\cdot (N-x) / N^2.

This applies to every single character of the output string. Thus, given a string B with y ones, the probability that the output string is B is p^y (1-p)^{N-y}.

So, if T = 1 we have solved the problem. However, this process can’t be repeated for T > 1, since BSRand is applied once, we no longer know the exact number of ones in the string.

But we don’t need to! Note that N is small, so we use that to our advantage.

ENTER MATRICES

Note that above, when we calculated the probability that the output string had y ones, we didn’t really care about what exactly the input string was — we only used the fact that it had exactly x ones.

Also note that we computed the probability that a specific string with y ones was formed. We can relax this condition a little and compute just the probability that the output string has y ones. There are \binom{N}{y} such strings, each being equally likely.

Under this relaxed condition, suppose we are able to calculate the probability that the final string has y ones. Then we can simply divide this by \binom{N}{y} to get the final answer, so we’ll concentrate on how to do the first part.

Notice that we have already, in some sense, completely described the working of the BSRand. If it is given an input string with x ones, the probability that the output string has y ones is \binom{N}{y} \cdot p^y (1-p)^{N-y}, where p = 2x\cdot (N-x) / N^2.

Suppose we create an (N+1)\times (N+1) matrix M, where M_{x, y} is the value described above for x and y.
Suppose we also have a (N+1)\times 1 vector v, where v_i is the probability that the input string to BSRand has exactly i ones.
Let w = Mv. Notice that w is then a (N+1)\times 1 vector, where w_i is the probability that the output string has exactly i ones.

But then we can simply use w as the input probability vector, which means that Mw = M^2 v gives us the probability distribution of output strings when BSRand is applied twice.

Generalizing this, we can see that applying BSRand T times simply corresponds to M^T v.

In our case, the probability vector is such that v_i = 0 when i \neq x, and v_x = 1 (where x is the number of ones in A).
It’s easy to see that multiplying M^T by this vector and taking the y-th entry of the result is the same as (M^T)_{x, y}. (where y is the number of ones in B).

M^T can be computed in \mathcal{O}(N^3 \log T) by using binary exponentiation, and each entry of M can be computed in \mathcal{O}(1) or \mathcal{O}(\log MOD) depending on how modular inverses are dealt with; either way the exponentiation makes the complexity \mathcal{O}(N^3 \log T).

Finally, don’t forget to divide (M^T)_{x, y} by \binom{N}{y} to account for the fact that we want a specific string.

TIME COMPLEXITY:

\mathcal{O}(N^3 \log T) per test case.

SOLUTIONS:

Setter's Solution (C++)
#ifdef WTSH
    #include <wtsh.h>
#else
    #include <bits/stdc++.h>
    using namespace std;
    #define dbg(Z...)
#endif

#define IOS ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define endl "\n"
#define sz(w) (int)(w.size())
using pii = pair<int, int>;

// Modint template: https://github.com/the-hyp0cr1t3/CC/blob/master/%E6%9C%AB%20Snippets/ModInt.cpp
template<const int& MOD>
struct Mint {
    using T = typename decay<decltype(MOD)>::type; T v;
    Mint(int64_t v = 0) { if(v < 0) v = v % MOD + MOD; if(v >= MOD) v %= MOD; this->v = T(v); }
    Mint(uint64_t v) { if (v >= MOD) v %= MOD; this->v = T(v); }
    Mint(int v): Mint(int64_t(v)) {}
    Mint(unsigned v): Mint(uint64_t(v)) {}
    explicit operator int() const { return v; }
    explicit operator int64_t() const { return v; }
    explicit operator uint64_t() const { return v; }
    friend istream& operator>>(istream& in, Mint& m) { int64_t v_; in >> v_; m = Mint(v_); return in; } 
    friend ostream& operator<<(ostream& os, const Mint& m) { return os << m.v; }

    static T inv(T a, T m) {
        T g = m, x = 0, y = 1;
        while(a != 0) {
            T q = g / a;
            g %= a; swap(g, a);
            x -= q * y; swap(x, y);
        } return x < 0? x + m : x;
    }

    static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
    #if !defined(_WIN32) || defined(_WIN64)
        return unsigned(x % m);
    #endif // x must be less than 2^32 * m
        unsigned x_high = unsigned(x >> 32), x_low = unsigned(x), quot, rem;
        asm("divl %4\n" : "=a" (quot), "=d" (rem) : "d" (x_high), "a" (x_low), "r" (m));
        return rem;
    }

    Mint inv() const { return Mint(inv(v, MOD)); }
    Mint operator-() const { return Mint(v? MOD-v : 0); }
    Mint& operator++() { v++; if(v == MOD) v = 0; return *this; }
    Mint& operator--() { if(v == 0) v = MOD; v--; return *this; }
    Mint operator++(int) { Mint a = *this; ++*this; return a; }
    Mint operator--(int) { Mint a = *this; --*this; return a; }
    Mint& operator+=(const Mint& o) { v += o.v; if (v >= MOD) v -= MOD; return *this; }
    Mint& operator-=(const Mint& o) { v -= o.v; if (v < 0) v += MOD; return *this; }
    Mint& operator*=(const Mint& o) { v = fast_mod(uint64_t(v) * o.v); return *this; }
    Mint& operator/=(const Mint& o) { return *this *= o.inv(); }
    friend Mint operator+(const Mint& a, const Mint& b) { return Mint(a) += b; }
    friend Mint operator-(const Mint& a, const Mint& b) { return Mint(a) -= b; }
    friend Mint operator*(const Mint& a, const Mint& b) { return Mint(a) *= b; }
    friend Mint operator/(const Mint& a, const Mint& b) { return Mint(a) /= b; }
    friend bool operator==(const Mint& a, const Mint& b) { return a.v == b.v; }
    friend bool operator!=(const Mint& a, const Mint& b) { return a.v != b.v; }
    friend bool operator<(const Mint& a, const Mint& b) { return a.v < b.v; }
    friend bool operator>(const Mint& a, const Mint& b) { return a.v > b.v; }
    friend bool operator<=(const Mint& a, const Mint& b) { return a.v <= b.v; }
    friend bool operator>=(const Mint& a, const Mint& b) { return a.v >= b.v; }
    Mint operator^(int64_t p) {
        if(p < 0) return inv() ^ -p;
        Mint a = *this, res{1}; while(p > 0) {
            if(p & 1) res *= a;
            p >>= 1; if(p > 0) a *= a;
        } return res;
    }
};

const int MOD = 998244353;
using mint = Mint<MOD>;

// Matrix expo: https://github.com/the-hyp0cr1t3/CC/blob/master/%E6%9C%AB%20Snippets/Matrix%20(std::vector).cpp
template<typename T>
struct Matrix {
    int N, M; vector<vector<T>> a;
    
    Matrix(int n, int m): N(n), M(m) , a(n, vector<T>(m)) {}

    explicit Matrix(int n, int m, T x): Matrix(n, m) {
        for(int i = 0; i < min(N, M); i++) a[i][i] = x;
    }
    
    Matrix(initializer_list<vector<T>> x) {
        N = x.size(); M = x.begin()->size(); a.resize(N);
        for(int i = 0; i < x.size(); i++)
            a[i] = *(x.begin() + i), assert(a[i].size() == M);
    }
    
    vector<T>& operator[](size_t x) { return a[x]; }
    const vector<T>& operator[](size_t x) const { return a[x]; }

    friend ostream& operator<<(ostream& out, const Matrix& x) {
        for(int i = 0; i < x.N; i++) 
            for(int j = 0; j < x.M; j++) 
                out << x.a[i][j] << " \n"[j == x.M-1];
        return out;
    }

    Matrix& operator+=(const Matrix& o) {
        assert(N == o.N && M == o.M);
        for(int i = 0; i < N; i++)
            for(int j = 0; j < M; j++)
                a[i][j] += o[i][j];
        return *this;
    }

    Matrix& operator-=(const Matrix& o) {
        assert(N == o.N && M == o.M);
        for(int i = 0; i < N; i++)
            for(int j = 0; j < M; j++)
                a[i][j] -= o[i][j];
        return *this;
    }

    Matrix& operator*=(T x) {
        for(int i = 0; i < N; i++)
            for(int j = 0; j < M; j++)
                a[i][j] *= x;
        return *this;
    }

    Matrix operator*(const Matrix& o) const {
        assert(M == o.N);
        Matrix<T> res(N, o.M);
        for(int i = 0; i < N; i++)
            for(int j = 0; j < M; j++)
                for(int k = 0; k < o.M; k++)
                    res[i][k] += a[i][j] * o[j][k];
        return res;
    }

    Matrix& operator*=(const Matrix& o) {
        assert(N == M && o.N == o.M && N == o.N);
        return *this = move(*this * o);
    }

    template<typename U, typename = enable_if_t<is_integral<U>::value>>
    Matrix& operator^=(U x) {
        assert(x >= 0 && N == M);
        Matrix res(N, N, 1); while(x) {
            if(x & 1) res *= *this;
            x >>= 1; *this *= *this;
        } return *this = move(res);
    }

    Matrix operator+(const Matrix& o) const { Matrix res = *this; res += o; return res; }
    Matrix operator-(const Matrix& o) const { Matrix res = *this; res -= o; return res; }
    Matrix operator*(T o) const { Matrix res = *this; res *= o; return res; }
    template<typename U, typename = enable_if_t<is_integral<U>::value>>
    Matrix operator^(U x) const { Matrix res = *this; res ^= x; return res; }
};

const int N = 1005;

mint fac[N], invfac[N];

void precompute()
{
    fac[0] = fac[1] = 1;
    for(int i = 2; i < N; i++)
        fac[i] = fac[i-1] * i;
    invfac[N-1] = fac[N-1].inv();
    for(int i = N-2; i >= 0; i--)
        invfac[i] = invfac[i+1] * (i+1);
}

mint nCr(int n, int r)
{
    if(n < 0 or r < 0 or n < r)
        return mint(0);
    return fac[n] * invfac[r] * invfac[n-r];
}

int32_t main()
{
    IOS;
    precompute();
    int n, t; cin >> n >> t;
    string a, b; cin >> a >> b;
    Matrix<mint> trans(n + 1, n + 1);
    for(int i = 0; i <= n; i++)
    {
        for(int j = 0; j <= n; j++)
        {
            // going from i zeros to j zeros
            mint p_zero = (mint(i) * mint(i) / (mint(n) * mint(n))) + (mint(n - i) * mint(n - i) / (mint(n) * mint(n)));
            mint p_one = 1 - p_zero;
            assert(p_zero + p_one == 1);
            trans[i][j] = nCr(n, j) * (p_zero ^ j) * (p_one ^ (n - j));
        }
    }
    trans ^= t;
    int init_0 = count(a.begin(), a.end(), '0');
    int ltr_0 = count(b.begin(), b.end(), '0');
    mint ans = trans[init_0][ltr_0];
    // trans[init_0][ltr_0] is the probabilty of getting a string with ltr_0 number of 0s
    // There exist nCr(n, ltr_0) strings with ltr_0 number of 0s
    ans /= nCr(n, ltr_0);
    cout << ans << endl;
    return 0;
}
Tester's Solution (C++)
#include <bits/stdc++.h>
using namespace std;
#define ll long long
 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/
 
const int MAX_T = 200000;
const int MAX_N = 100000;
const int MAX_SUM_LEN = 1000000;
 
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
 
int sum_len = 0;
int max_n = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;
ll z = 998244353;
 
ll dp[2][101] ;

ll fact[1000] ;
ll power(ll a ,ll b , ll p){if(b == 0) return 1 ; ll c = power(a , b/2 , p) ; if(b%2 == 0) return ((c*c)%p) ; else return ((((c*c)%p)*a)%p) ;}
ll inverse(ll a ,ll n){return power(a , n-2 , n) ;}
ll ncr(ll n , ll r){if(n < r) return 0 ; return ((((fact[n] * inverse(fact[r] , z)) % z) * inverse(fact[n-r] , z)) % z);}

ll get_prob(ll n , ll i , ll val)
{
    ll p = (n-i)*(n-i) + (i*i) ;
    ll den = inverse(n*n , z) ;
    p *= den ;
    p %= z ;
    if(val == 1)
    {
        p = (((1-p)%z)+z)%z ;
    }
    return p ;
}

ll m ; // matrix is of m cross m
ll A[101][101], B[101][101], C[101][101] ; // C will store final answer
#define rep(i , n) for(ll i = 0 ; i < n ; i++)

void matrix_power(ll n)
{
    if(n == 0)
    {
        rep(i , m)
            C[i][i] = 1 ;
        return ;    
    }

    matrix_power(n/2) ;
    rep(i , m)
    {
        rep(j , m)
        {
            B[i][j] = 0 ;
            rep(k , m)
            {
                B[i][j] += (C[i][k] * C[k][j]) ;
                B[i][j] %= z ;
            }
        }
    }

    if(n % 2 == 1)
    {
        rep(i , m)
        {
            rep(j , m)
            {
                C[i][j] = 0 ;
                rep(k , m)
                {
                    C[i][j] += (B[i][k] * A[k][j]) ;
                    C[i][j] %= z ;
                }
            }
        }
    }
    else
    {
        rep(i , m)
        {
            rep(j , m)
            {
                C[i][j] = B[i][j] ;
            }
        }
    }
    return ;
}

void solve()
{   
    int n , t ;
    n = readIntSp(1 , 100);
    t = readIntLn(1 , 1000000);
    m = n+1 ;
    string a , b ;
    a = readStringLn(n , n);
    b = readStringLn(n , n);
    int cnt = 0 ;

    for(int i = 0 ; i < n ; i++)
    {
        assert(a[i] == '0' || a[i] == '1') ;
        assert(b[i] == '0' || b[i] == '1') ;
        cnt += (a[i] == '1') ;
    }
    
    dp[0][cnt] = 1 ;

    for(int i = 0 ; i <= n ; i++)
    {
        for(int j = 0 ; j <= n ; j++)
        {
            A[i][j] = (power(get_prob(n , j , 1) , i , z) * power(get_prob(n , j , 0) , n- i , z))%z ;
            A[i][j] *= ncr(n , i) ;
            A[i][j] %= z ;
        }
    }

    matrix_power(t) ;

    for(int i = 0 ; i <= n ; i++)
    {
        for(int j = 0 ; j <= n ; j++)
        {
            dp[1][i] += ((C[i][j] * dp[0][j])%z) ;
        }
        dp[1][i] %= z ;
    }

    cnt = 0 ;
    for(int i = 0 ; i < n ; i++)
        cnt += (b[i] == '1') ;

    ll ans = dp[1][cnt] * inverse(ncr(n , cnt) , z) ;
    ans %= z ;
    cout << ans << endl ;
    // for(int i = 0 ; i <= n ; i++)
    // {
    //     cout << dp[1][i] << ' ';
    // }
    // cout << endl ;
    // cout << inverse(2 , z) << endl ;
    // cout << inverse(4 , z) << endl ;
    // cout << inverse(8 , z) << endl ;
    // cout << endl ;

    
    // ll p_1 = 0 ;
    // for(int i = 0 ; i <= n ; i++)
    // {
    //     p_1 += (dp[1][i] * get_prob(n , i , 1)) ;
    //     p_1 %= z ;
    // }
    // ll p_0 = (((1 - p_1)%z)+z)%z ;
    // cout << "cnt = " << cnt << " p_1 = " << p_1 << " p_0 = " << p_0 << endl ;

    // ll ans = (power(p_0 , n-cnt , z) * power(p_1 , cnt , z))%z ;
    // cout << ans << endl ;
    // cout << inverse(8 , z) << endl ;

    return ;
}
 
signed main()
{
    //fast;
    #ifndef ONLINE_JUDGE
    freopen("inputf.txt" , "r" , stdin) ;
    freopen("outputf.txt" , "w" , stdout) ;
    freopen("error.txt" , "w" , stderr) ;
    #endif
    
    int t = 1;
    
    //t = readIntLn(1,MAX_T);

    fact[0] = 1 ;
    for(ll i = 1 ; i < 1000 ; i++)
        fact[i] = (fact[i-1] * i)%z ;
    
    for(int i=1;i<=t;i++)
    {    
       solve() ;
    }
    
    assert(getchar() == -1);
 
    cerr<<"SUCCESS\n";
    // cerr<<"Tests : " << t << '\n';
    // cerr<<"Sum of lengths : " << sum_len << '\n';
    // cerr<<"Maximum length : " << max_n << '\n';
    // cerr<<"Total operations : " << total_ops << '\n';
    // cerr<<"Answered yes : " << yess << '\n';
    // cerr<<"Answered no : " << nos << '\n';
}
Editorialist's Solution (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

/**
 * Integers modulo p, where p is a prime
 * Source: Aeren (modified from tourist?)
 *         Modmul for 64-bit mod from kactl:ModMulLL
 * Works with p < 7.2e18 with x87 80-bit long double, and p < 2^52 ~ 4.5e12 with 64-bit
 */
template<typename T>
struct Z_p{
    using Type = typename decay<decltype(T::value)>::type;
    static vector<Type> MOD_INV;
    constexpr Z_p(): value(){ }
    template<typename U> Z_p(const U &x){ value = normalize(x); }
    template<typename U> static Type normalize(const U &x){
        Type v;
        if(-mod() <= x && x < mod()) v = static_cast<Type>(x);
        else v = static_cast<Type>(x % mod());
        if(v < 0) v += mod();
        return v;
    }
    const Type& operator()() const{ return value; }
    template<typename U> explicit operator U() const{ return static_cast<U>(value); }
    constexpr static Type mod(){ return T::value; }
    Z_p &operator+=(const Z_p &otr){ if((value += otr.value) >= mod()) value -= mod(); return *this; }
    Z_p &operator-=(const Z_p &otr){ if((value -= otr.value) < 0) value += mod(); return *this; }
    template<typename U> Z_p &operator+=(const U &otr){ return *this += Z_p(otr); }
    template<typename U> Z_p &operator-=(const U &otr){ return *this -= Z_p(otr); }
    Z_p &operator++(){ return *this += 1; }
    Z_p &operator--(){ return *this -= 1; }
    Z_p operator++(int){ Z_p result(*this); *this += 1; return result; }
    Z_p operator--(int){ Z_p result(*this); *this -= 1; return result; }
    Z_p operator-() const{ return Z_p(-value); }
    template<typename U = T>
    typename enable_if<is_same<typename Z_p<U>::Type, int>::value, Z_p>::type &operator*=(const Z_p& rhs){
        #ifdef _WIN32
        uint64_t x = static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value);
        uint32_t xh = static_cast<uint32_t>(x >> 32), xl = static_cast<uint32_t>(x), d, m;
        asm(
            "divl %4; \n\t"
            : "=a" (d), "=d" (m)
            : "d" (xh), "a" (xl), "r" (mod())
        );
        value = m;
        #else
        value = normalize(static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value));
        #endif
        return *this;
    }
    template<typename U = T>
    typename enable_if<is_same<typename Z_p<U>::Type, int64_t>::value, Z_p>::type &operator*=(const Z_p &rhs){
        uint64_t ret = static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value) - static_cast<uint64_t>(mod()) * static_cast<uint64_t>(1.L / static_cast<uint64_t>(mod()) * static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value));
        value = normalize(static_cast<int64_t>(ret + static_cast<uint64_t>(mod()) * (ret < 0) - static_cast<uint64_t>(mod()) * (ret >= static_cast<uint64_t>(mod()))));
        return *this;
    }
    template<typename U = T>
    typename enable_if<!is_integral<typename Z_p<U>::Type>::value, Z_p>::type &operator*=(const Z_p &rhs){
        value = normalize(value * rhs.value);
        return *this;
    }
    template<typename U>
    Z_p &operator^=(U e){
        if(e < 0) *this = 1 / *this, e = -e;
        Z_p res = 1;
        for(; e; *this *= *this, e >>= 1) if(e & 1) res *= *this;
        return *this = res;
    }
    template<typename U>
    Z_p operator^(U e) const{
        return Z_p(*this) ^= e;
    }
    Z_p &operator/=(const Z_p &otr){
        Type a = otr.value, m = mod(), u = 0, v = 1;
        if(a < (int)MOD_INV.size()) return *this *= MOD_INV[a];
        while(a){
            Type t = m / a;
            m -= t * a; swap(a, m);
            u -= t * v; swap(u, v);
        }
        assert(m == 1);
        return *this *= u;
    }
    template<typename U> friend const Z_p<U> &abs(const Z_p<U> &v){ return v; }
    Type value;
};
template<typename T> bool operator==(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value == rhs.value; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(const Z_p<T>& lhs, U rhs){ return lhs == Z_p<T>(rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) == rhs; }
template<typename T> bool operator!=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(const Z_p<T> &lhs, U rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(U lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T> bool operator<(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value < rhs.value; }
template<typename T> bool operator>(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value > rhs.value; }
template<typename T> bool operator<=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value <= rhs.value; }
template<typename T> bool operator>=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value >= rhs.value; }
template<typename T> Z_p<T> operator+(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(const Z_p<T> &lhs, U rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T> Z_p<T> operator-(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T> Z_p<T> operator*(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T> Z_p<T> operator/(const Z_p<T> &lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(const Z_p<T>& lhs, U rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(U lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T> istream &operator>>(istream &in, Z_p<T> &number){
    typename common_type<typename Z_p<T>::Type, int64_t>::type x;
    in >> x;
    number.value = Z_p<T>::normalize(x);
    return in;
}
template<typename T> ostream &operator<<(ostream &out, const Z_p<T> &number){ return out << number(); }

/*
using ModType = int;
struct VarMod{ static ModType value; };
ModType VarMod::value;
ModType &mod = VarMod::value;
using Zp = Z_p<VarMod>;
*/

// constexpr int mod = 1e9 + 7; // 1000000007
constexpr int mod = (119 << 23) + 1; // 998244353
// constexpr int mod = 1e9 + 9; // 1000000009
using Zp = Z_p<integral_constant<decay<decltype(mod)>::type, mod>>;

template<typename T> vector<typename Z_p<T>::Type> Z_p<T>::MOD_INV;
template<typename T = integral_constant<decay<decltype(mod)>::type, mod>>
void precalc_inverse(int SZ){
    auto &inv = Z_p<T>::MOD_INV;
    if(inv.empty()) inv.assign(2, 1);
    for(; inv.size() <= SZ; ) inv.push_back((mod - 1LL * mod / (int)inv.size() * inv[mod % (int)inv.size()]) % mod);
}

template<typename T>
vector<T> precalc_power(T base, int SZ){
    vector<T> res(SZ + 1, 1);
    for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * base;
    return res;
}

template<typename T>
vector<T> precalc_factorial(int SZ){
    vector<T> res(SZ + 1, 1); res[0] = 1;
    for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * i;
    return res;
}


namespace MatrixMult {

    using u32 = std::uint32_t;
    using u64 = std::uint64_t;

    constexpr u32 MAX_SIZE = 105;
    constexpr u64 MOD = 998244353ULL;
    constexpr u64 MULTIPLIER = (1ULL << 47) % MOD - (1ULL << 47);

    static u64 buf0 alignas(32)[MAX_SIZE];
    static u64 buf1 alignas(32)[MAX_SIZE];
    static u64 buf2 alignas(32)[MAX_SIZE];
    static u64 buf3 alignas(32)[MAX_SIZE];

    void _matrix_mult(const u32* __restrict__ A, const u32* __restrict__ B,
                      u32* __restrict__ C, const u32 Cn, const u32 Am,
                      const u32 Cm) {
        u32 j = 0;
        for (; j + 4 <= Cn; j += 4) {
            std::fill(buf0, buf0 + Cm, 0);
            std::fill(buf1, buf1 + Cm, 0);
            std::fill(buf2, buf2 + Cm, 0);
            std::fill(buf3, buf3 + Cm, 0);
            const auto* A_offset0 = A + (j + 0) * Am;
            const auto* A_offset1 = A + (j + 1) * Am;
            const auto* A_offset2 = A + (j + 2) * Am;
            const auto* A_offset3 = A + (j + 3) * Am;
            u32 k = 0;
            for (; k + 18 <= Am; k += 18) {
                for (u32 _k = k; _k < k + 18; ++_k) {
                    const u64 a0 = A_offset0[_k];
                    const u64 a1 = A_offset1[_k];
                    const u64 a2 = A_offset2[_k];
                    const u64 a3 = A_offset3[_k];
                    const auto* B_offset = B + _k * Cm;
#pragma GCC ivdep
                    for (u32 i = 0; i < Cm; ++i) {
                        u64 x = B_offset[i];
                        buf0[i] += a0 * x;
                        buf1[i] += a1 * x;
                        buf2[i] += a2 * x;
                        buf3[i] += a3 * x;
                    }
                }
#pragma GCC ivdep
                for (u32 i = 0; i < Cm; ++i)
                    buf0[i] += (buf0[i] >> 47) * MULTIPLIER;
#pragma GCC ivdep
                for (u32 i = 0; i < Cm; ++i)
                    buf1[i] += (buf1[i] >> 47) * MULTIPLIER;
#pragma GCC ivdep
                for (u32 i = 0; i < Cm; ++i)
                    buf2[i] += (buf2[i] >> 47) * MULTIPLIER;
#pragma GCC ivdep
                for (u32 i = 0; i < Cm; ++i)
                    buf3[i] += (buf3[i] >> 47) * MULTIPLIER;
            }
            for (; k < Am; ++k) {
                const u64 a0 = A[(j + 0) * Am + k];
                const u64 a1 = A[(j + 1) * Am + k];
                const u64 a2 = A[(j + 2) * Am + k];
                const u64 a3 = A[(j + 3) * Am + k];
                const u32 offset = k * Cm;
                const auto* B_offset = B + offset;
#pragma GCC ivdep
                for (u32 i = 0; i < Cm; ++i) {
                    u64 x = B_offset[i];
                    buf0[i] += a0 * x;
                    buf1[i] += a1 * x;
                    buf2[i] += a2 * x;
                    buf3[i] += a3 * x;
                }
            }
            auto* C_offset = C + j * Cm;
#pragma GCC ivdep
            for (u32 i = 0; i < Cm; ++i) C_offset[i] = buf0[i] % MOD;
            C_offset += Cm;
#pragma GCC ivdep
            for (u32 i = 0; i < Cm; ++i) C_offset[i] = buf1[i] % MOD;
            C_offset += Cm;
#pragma GCC ivdep
            for (u32 i = 0; i < Cm; ++i) C_offset[i] = buf2[i] % MOD;
            C_offset += Cm;
#pragma GCC ivdep
            for (u32 i = 0; i < Cm; ++i) C_offset[i] = buf3[i] % MOD;
        }
        auto* buf = buf0;
        for (; j < Cn; ++j) {
            std::fill(buf, buf + Cm, 0);
            const auto* A_offset = A + j * Am;
            u32 k = 0;
            for (; k + 18 <= Am; k += 18) {
                for (u32 _k = k; _k < k + 18; ++_k) {
                    const auto* B_offset = B + _k * Cm;
                    const u64 a = A_offset[_k];
#pragma GCC ivdep
                    for (u32 i = 0; i < Cm; ++i) buf[i] += a * B_offset[i];
                }
                for (u32 i = 0; i < Cm; ++i)
                    buf[i] += (buf[i] >> 47) * MULTIPLIER;
            }
            for (; k < Am; ++k) {
                const u64 a = A[j * Am + k];
                const auto* B_offset = B + k * Cm;
#pragma GCC ivdep
                for (u32 i = 0; i < Cm; ++i) buf[i] += a * B_offset[i];
            }
            auto* C_offset = C + j * Cm;
            for (u32 i = 0; i < Cm; ++i) C_offset[i] = buf[i] % MOD;
        }
    }

    template <class T>
    struct Matrix {
       private:
        std::vector<T> a;
        int n, m;

       public:
        Matrix(int _n, int _m) : a(_n * _m), n(_n), m(_m) {}
        T* operator[](const int& i) { return a.data() + i * m; }
        Matrix operator*(Matrix& B) {
            assert(this->m == B.n);
            Matrix<T> C(this->n, B.m);
            _matrix_mult((*this)[0], B[0], C[0], this->n, this->m, B.m);
            return C;
        }
    };

}  // namespace MatrixMult

using namespace std;
using MatrixMult::Matrix;
using MatrixMult::u32;

Matrix<u32> matpow(Matrix<u32> M, int pw)
{
    Matrix res = M; --pw;
    while (pw) {
        if (pw&1) res = res * M;
        M = M * M;
        pw /= 2;
    }
    return res;
}

int main()
{
    ios::sync_with_stdio(0); cin.tie(0);

    int n, t; cin >> n >> t;
    string a, b; cin >> a >> b;
    auto fac = precalc_factorial<Zp>(n+1);
    Matrix<u32> M(n+1, n+1);
    for (int i = 0; i <= n; ++i) {
        for (int j = 0; j <= n; ++j) {
            // i 1s -> j 1s
            Zp val = Zp(2)*i*(n-i);
            val /= Zp(1)*n*n;
            val = (val ^ j) * ((1 - val)^(n-j)) * fac[n] / (fac[j] * fac[n-j]);
            M[i][j] = val();
        }
    }
    auto res = matpow(M, t);
    int oc_a = count(begin(a), end(a), '1');
    int oc_b = count(begin(b), end(b), '1');
    Zp ans = int(res[oc_a][oc_b]);
    ans *= fac[oc_b] * fac[n - oc_b] / fac[n];
    cout << ans;
}
1 Like