CHECKPOINT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: utkarsh_25dec
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

2792

PREREQUISITES:

Familiarity with probability, (optional) solving linear recurrences

PROBLEM:

Two players play a competition with N questions, each with M options.
The competition proceeds as follows:

  • On their turn, the contestant will choose an answer to the current question uniformly randomly from among the remaining options.
  • If their answer is correct, they move to the next question and keep their turn.
  • If their answer is wrong, the turn passes to the opponent.

The winner is whoever answers question N.
Find the probability that player 1 wins.

EXPLANATION:

Let’s first look at a slow but correct solution, then see how to speed it up.

Let P_i denote the probability that player 1 is the first to answer question i correctly.
Our aim is to compute P_N.

P_i can be defined in terms of a recurrence.
Either player 1 answered question i-1 correctly, or player 2 did.
Let’s separate these cases and calculate them individually.

Suppose player 1 answered question i-1 (which has a probability of P_{i-1} of happening), and hence goes first on question i.
Then, the probability that he answers question i correctly is

\frac{1}{M} + \left(\frac{M-1}{M}\cdot\frac{M-2}{M-1}\cdot\frac{1}{M-2}\right) + \left(\frac{M-1}{M}\cdot\frac{M-2}{M-1}\cdot\frac{M-3}{M-2}\cdot\frac{M-4}{M-3}\cdot\frac{1}{M-4} \right) + \ldots

because:

  • There’s a \frac{1}{M} chance of getting it right on the first try.
  • If the first try fails \left(\frac{M-1}{M}\text{ chance}\right) , the second player must also get it wrong \left(\frac{M-2}{M-1}\text{ chance, there's one less option now}\right) and then there’s a \frac{1}{M-2} chance of player 1 getting it right on his second guess.
  • If that fails, once again the second player should get his second guess wrong, and so on.

However, note that each term above simply cancels out into \frac{1}{M}.
So, the required probability is simply the sum of \frac{1}{M} several times.
In particular, notice that we obtain one \frac{1}{M} term for every odd number \leq M, of which there are exactly \left \lceil \frac{M}{2} \right\rceil.

Similarly, for the case when player 2 answered question i-1 first, we see that the required probability is

\left(\frac{M-1}{M}\cdot\frac{1}{M-1}\right) + \left(\frac{M-1}{M}\cdot\frac{M-2}{M-1}\cdot\frac{M-3}{M-2}\cdot\frac{1}{M-3}\right) + \ldots

which is once again the sum of \frac{1}{M},but this time \left\lfloor \frac{M}{2} \right\rfloor of them — one for each even number \leq M.

Putting the cases together, we obtain

P_i = P_{i-1} \times \left(\left \lceil \frac{M}{2} \right\rceil \cdot \frac{1}{M} \right) + \left(1 - P_{i-1}\right) \times \left( \left\lfloor \frac{M}{2} \right\rfloor \cdot \frac{1}{M}\right)

This holds for all i \gt 1, and even holds for i = 1 if we define P_0 = 1 (since player 1 always starts on question 1, equivalently we can instead pretend player 1 always answers question 0).

The above formula is a linear recurrence!
That is, it’s of the form P_i = x\cdot P_{i-1} + y for some constants x and y.

Finding the N-th term of such a recurrence can be done in \mathcal{O}(\log N) using matrix exponentiation, as seen in this blogpost for example.

This is already enough to solve the problem, but read on if you’d like to see a solution that doesn’t rely on matrices!

Notice that our formula deals with floor and ceiling division of M by 2.
Let’s look at what happens to even and odd M separately.

Even M

Suppose M = 2k. Then, \lfloor \frac{M}{2} \rfloor = \lceil \frac{M}{2} \rceil = k.
So, our formula becomes

P_i = P_{i-1}\cdot k\cdot \frac{1}{M} + (1 - P_{i-1})\cdot k \cdot \frac{1}{M} = \frac{k}{M}

But M = 2k, so \frac{k}{M} = \frac{1}{2}.

That is, if M is even the answer is simply \frac{1}{2}, irrespective of M.

Odd M

Suppose M = 2k+1, so \lfloor \frac{M}{2} \rfloor = k and \lceil \frac{M}{2} \rceil = k+1.
Plugging into the formula,

P_i = P_{i-1}\cdot (k+1)\cdot \frac{1}{M} + (1-P_{i-1})\cdot k\cdot \frac{1}{M} \\ = \frac{P_{i-1}}{M} + \frac{k}{M}

Repeatedly applying this formula, along with the base case P_0 = 1, gives us

P_i = \frac{1}{M^i} + k\cdot \left(\frac{1}{M} + \frac{1}{M^2} + \ldots + \frac{1}{M^i}\right)

The quantity inside the brackets is the sum of a geometric progression, and can be quickly computed using a formula for the same.
This allows for easy computation of P_N.

We’ve thus found a closed-form formula for P_N, and we’re done!

TIME COMPLEXITY

\mathcal{O}(\log {MOD}) per test case.

CODE:

Setter's code (C++)
//Utkarsh.25dec
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#include <vector>
#include <set>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <queue>
#include <ctime>
#include <cassert>
#include <complex>
#include <string>
#include <cstring>
#include <chrono>
#include <random>
#include <bitset>
#include <array>
#define ll long long int
#define pb push_back
#define mp make_pair
#define mod 998244353
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;
typedef vector<vector<ll>> matrix;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
const int N=500023;
bool vis[N];
vector <int> adj[N];
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
const int K = 2;
// computes A * B
matrix mul(matrix A, matrix B)
{
    matrix C(K+1, vector<ll>(K+1));
    for(int i=1;i<=K;i++) for(int j=1;j<=K;j++) for(int k=1;k<=K;k++)
        C[i][j] = (C[i][j] + A[i][k] * B[k][j]) % mod;
    return C;
}
 
// computes A ^ p
matrix pow(matrix A, ll p)
{
    if (p == 1)
        return A;
    if (p % 2)
        return mul(A, pow(A, p-1));
    matrix X = pow(A, p/2);
    return mul(X, X);
}
//matrix ans(K+1,vl(K+1));
 
matrix ans(K+1, vl(K+1));
void solve()
{
    ll N, M;
    N = readInt(1, 1000000000, ' ');
    M = readInt(2, 100000, '\n');
    ll x;
    if(M%2==0)
        x = modInverse(2);
    else
        x = ((M+1)/2 * modInverse(M))%mod;
    ans[1][1] = (2*x-1+mod)%mod;
    ans[1][2] = (1 - x + mod)%mod;
    ans[2][1] = 0;
    ans[2][2] = 1;
    if (N==1)
    {
        cout<<x<<'\n';
        return;
    }
    ans = pow(ans, N-1);
    ll out = (ans[1][1]*x + ans[1][2])%mod;
    cout << out << '\n';
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL),cout.tie(NULL);
    int T=readInt(1,100000,'\n');
    while(T--)
        solve();
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

template <long long mod>
struct modular {
    long long value;
    modular(long long x = 0) {
        value = x % mod;
        if (value < 0) value += mod;
    }
    modular& operator+=(const modular& other) {
        if ((value += other.value) >= mod) value -= mod;
        return *this;
    }
    modular& operator-=(const modular& other) {
        if ((value -= other.value) < 0) value += mod;
        return *this;
    }
    modular& operator*=(const modular& other) {
        value = value * other.value % mod;
        return *this;
    }
    modular& operator/=(const modular& other) {
        long long a = 0, b = 1, c = other.value, m = mod;
        while (c != 0) {
            long long t = m / c;
            m -= t * c;
            swap(c, m);
            a -= t * b;
            swap(a, b);
        }
        a %= mod;
        if (a < 0) a += mod;
        value = value * a % mod;
        return *this;
    }
    friend modular operator+(const modular& lhs, const modular& rhs) { return modular(lhs) += rhs; }
    friend modular operator-(const modular& lhs, const modular& rhs) { return modular(lhs) -= rhs; }
    friend modular operator*(const modular& lhs, const modular& rhs) { return modular(lhs) *= rhs; }
    friend modular operator/(const modular& lhs, const modular& rhs) { return modular(lhs) /= rhs; }
    modular& operator++() { return *this += 1; }
    modular& operator--() { return *this -= 1; }
    modular operator++(int) {
        modular res(*this);
        *this += 1;
        return res;
    }
    modular operator--(int) {
        modular res(*this);
        *this -= 1;
        return res;
    }
    modular operator-() const { return modular(-value); }
    bool operator==(const modular& rhs) const { return value == rhs.value; }
    bool operator!=(const modular& rhs) const { return value != rhs.value; }
    bool operator<(const modular& rhs) const { return value < rhs.value; }
};
template <long long mod>
string to_string(const modular<mod>& x) {
    return to_string(x.value);
}
template <long long mod>
ostream& operator<<(ostream& stream, const modular<mod>& x) {
    return stream << x.value;
}
template <long long mod>
istream& operator>>(istream& stream, modular<mod>& x) {
    stream >> x.value;
    x.value %= mod;
    if (x.value < 0) x.value += mod;
    return stream;
}

constexpr long long mod = 998244353;
using mint = modular<mod>;

mint power(mint a, long long n) {
    mint res = 1;
    while (n > 0) {
        if (n & 1) {
            res *= a;
        }
        a *= a;
        n >>= 1;
    }
    return res;
}

vector<mint> fact(1, 1);
vector<mint> finv(1, 1);

mint C(int n, int k) {
    if (n < k || k < 0) {
        return mint(0);
    }
    while ((int) fact.size() < n + 1) {
        fact.emplace_back(fact.back() * (int) fact.size());
        finv.emplace_back(mint(1) / fact.back());
    }
    return fact[n] * finv[k] * finv[n - k];
}

template <typename T>
vector<vector<T>> operator*(const vector<vector<T>>& a, const vector<vector<T>>& b) {
    vector<vector<T>> c(a.size(), vector<T>(b[0].size()));
    for (int i = 0; i < (int) c.size(); i++) {
        for (int k = 0; k < (int) b.size(); k++) {
            for (int j = 0; j < (int) c[0].size(); j++) {
                c[i][j] += a[i][k] * b[k][j];
            }
        }
    }
    return c;
}

template <typename T>
vector<vector<T>>& operator*=(vector<vector<T>>& a, const vector<vector<T>>& b) {
    return a = a * b;
}

template <typename T>
vector<vector<T>> power(vector<vector<T>> a, long long n) {
    vector<vector<T>> res(a.size(), vector<T>(a.size()));
    for (int i = 0; i < (int) a.size(); i++) {
        res[i][i] = 1;
    }
    while (n > 0) {
        if (n & 1) {
            res *= a;
        }
        a *= a;
        n >>= 1;
    }
    return res;
}

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e5);
    in.readEoln();
    while (tt--) {
        int n = in.readInt(1, 1e9);
        in.readSpace();
        int m = in.readInt(2, 1e5);
        in.readEoln();
        mint p = mint((m + 1) / 2) / m;
        vector<vector<mint>> a(2, vector<mint>(2));
        a[0][0] = a[1][1] = p;
        a[1][0] = a[0][1] = 1 - p;
        a = power(a, n);
        cout << a[0][0] << '\n';
    }
    in.readEof();
    return 0;
}
Editorialist's code (Python)
mod = 998244353

def gpsum(r, n): # r + ... + r^n
	ret = pow(r, n, mod) - 1
	ret *= pow(r-1, mod-2, mod)
	return r * ret % mod

for _ in range(int(input())):
	n, m = map(int, input().split())
	if m%2 == 0: print((mod+1)//2)
	else:
		ans = (m // 2) * gpsum(pow(m, mod-2, mod), n) + pow(pow(m, mod-2, mod), n, mod)
		print(ans % mod)