SQUIDGAME2 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: yash_daga
Tester: mexomerf
Editorialist: iceknight1093

DIFFICULTY:

easy

PREREQUISITES:

Elementary probability

PROBLEM:

N participants want to cross a glass bridge.
Each step of the bridge has 2 glass pieces, exactly one of which will break when it’s stepped on.

The participants will go in order; and each participant has full knowledge of the steps taken by the preceding ones.
For each length from 1 to M, find the probability that all N participants fall off the bridge.

EXPLANATION:

Unlike the easy version, \mathcal{O}(NM) won’t be fast enough here.
Let’s analyze the structure a bit more.

Suppose the N people fall at steps x_1 \lt x_2 \lt\ldots\lt x_N.
Let’s compute the probability of this.

  • For x_1, there’s a 2^{-x_1} chance, because exactly x_1 50-50 chances have to be satisfied.
  • For x_2 after x_1, there’s a 2^{x_1 - x_2} chance, because every step till x_1 is guaranteed but everything from there to x_2 is a 50-50.
  • For x_3 after x_2, there’s a 2^{x_2 - x_3} chance with similar reasoning.
  • More generally, for x_i after x_{i-1}, there’s a 2^{x_{i-1} - x_i} chance.

So, the overall probability of exactly this sequence of falls, is

2^{-x_1} \cdot 2^{x_1 - x_2} \cdot 2^{x_2 - x_3} \cdot \ldots \cdot 2^{x_{N-1} - x_N} = 2^{x_N}

That is, only the last step matters!

An intuitive way to see why this is true is to note that once the last step has been fixed, for every step exactly one person will have to make a 50-50 choice - once a choice has been made, everyone else will know the result of that choice and will always be safe at that step.


So, let’s fix the last step, say i.
Then, no matter where the first N-1 people fall, the probability is going to be exactly 2^{-i}.

Any sequence of falls of the first N-1 people is valid, which will be some N-1 distinct indices before i - so there are \binom{i-1}{N-1} choices in total.
So, the probability of the N-th person falling at the i-th step is

\binom{i-1}{N-1} 2^{-i}

Now, for a length k bridge, the answer would thus be the above value summed up across all i \leq k, i.e,

\boxed{\sum_{i=1}^k \binom{i-1}{N-1} 2^{-i}}

We want to compute this for each k \leq M.
However, the answer for (k+1) can be obtained by adding a single term to the answer for k (that term being \binom{k}{N-1} 2^{-k-1}, which is doable in \mathcal{O}(1) or \mathcal{O}(\log{MOD}) time, after all it’s a single binomial coefficient and a single power of 2.

This allows us to find the answers for all k = 1, 2, 3, \ldots, M.

TIME COMPLEXITY:

\mathcal{O}(M) per testcase.

CODE:

Author's code (C++)
//clear adj and visited vector declared globally after each test case
//check for long long overflow   
//Mod wale question mein last mein if dalo ie. Ans<0 then ans+=mod;
//Incase of close mle change language to c++17 or c++14  
//Check ans for n=1 
// #pragma GCC target ("avx2")    
// #pragma GCC optimization ("O3")  
// #pragma GCC optimization ("unroll-loops")
#include <bits/stdc++.h>                   
#include <ext/pb_ds/assoc_container.hpp>  
#define int long long     
#define IOS std::ios::sync_with_stdio(false); cin.tie(NULL);cout.tie(NULL);cout.precision(dbl::max_digits10);
#define pb push_back 
#define mod 1000000007ll //998244353ll
#define lld long double
#define mii map<int, int> 
#define pii pair<int, int>
#define ll long long 
#define ff first
#define ss second 
#define all(x) (x).begin(), (x).end()
#define rep(i,x,y) for(int i=x; i<y; i++)    
#define fill(a,b) memset(a, b, sizeof(a))
#define vi vector<int>
#define setbits(x) __builtin_popcountll(x)
#define print2d(dp,n,m) for(int i=0;i<=n;i++){for(int j=0;j<=m;j++)cout<<dp[i][j]<<" ";cout<<"\n";}
typedef std::numeric_limits< double > dbl;
using namespace __gnu_pbds;
using namespace std;
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> indexed_set;
//member functions :
//1. order_of_key(k) : number of elements strictly lesser than k
//2. find_by_order(k) : k-th element in the set
const long long N=500005, INF=2000000000000000000;
const int inf=2e9 + 5;
lld pi=3.1415926535897932;
int lcm(int a, int b)
{
    int g=__gcd(a, b);
    return a/g*b;
}
int power(int a, int b, int p)
    {
        if(b<0)
        return 0;
        if(a==0)
        return 0;
        int res=1;
        a%=p;
        while(b>0)
        {
            if(b&1)
            res=(1ll*res*a)%p;
            b>>=1;
            a=(1ll*a*a)%p;
        }
        return res;
    }

mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

int getRand(int l, int r)
{
    uniform_int_distribution<int> uid(l, r);
    return uid(rng);
}

int fact[N],inv[N];
void pre()
{
    fact[0]=1;
    inv[0]=1;
    for(int i=1;i<N;i++)
    fact[i]=(i*fact[i-1])%mod;
    for(int i=1;i<N;i++)
    inv[i]=power(fact[i], mod-2, mod);
}
int nCr(int n, int r, int p) 
{ 
    if(r>n || r<0)
    return 0;
    if(n==r)
    return 1;
    if (r==0) 
    return 1; 
    return (((fact[n]*inv[r]) % p )*inv[n-r])%p;
} 

int32_t main()
{
    IOS;
    pre();
    int t, inv2=power(2, mod-2, mod);
    cin>>t;
    while(t--)
    {
        int n, m;
        cin>>n>>m;
        // Answer for m=i is probabilty of getting n failures before i-(n-1) successes.
        int req_failures = n, res=0;
        for(int i=1-n;i<=m-n;i++)
        {
            
            // We are calculating probability of getting n-th failure after i successes.
            int prob=(nCr(req_failures+i-1, i, mod)*power(inv2, i+req_failures, mod))%mod;
            res=(res+prob)%mod;
            cout<<res<<" ";
        }
        cout<<"\n";
    }
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define md 1000000007
#define N 1000001
int fac[N];
int pw[N];
long long readInt(long long l, long long r, char endd) {
    long long x = 0;
    int cnt = 0;
    int fi = -1;
    bool is_neg = false;
    while (true) {
        char g = getchar();
        if (g == '-') {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if ('0' <= g && g <= '9') {
            x *= 10;
            x += g - '0';
            if (cnt == 0) {
                fi = g - '0';
            }
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);
 
            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if (g == endd) {
            assert(cnt > 0);
            if (is_neg) {
                x = -x;
            }
            assert(l <= x && x <= r);
            return x;
        }
        else {
            assert(false);
        }
    }
}
 
string readString(int l, int r, char endd) {
    string ret = "";
    int cnt = 0;
    while (true) {
        char g = getchar();
        assert(g != -1);
        if (g == endd) {
            break;
        }
        cnt++;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}
int modex(int a, int b){
    if(b == 0){
        return 1;
    }
    a %= md;
    int temp = modex(a, b / 2);
    temp *= temp;
    temp %= md;
    if(b % 2){
        temp *= a;
        temp %= md;
    }
    return temp;
}
int mod(int a, int b){
    return ((a % md) * modex(b, md - 2)) % md;
}
int ncr(int n, int r){
    if(n < r || r < 0 || n < 0){
        return 0;
    }
    return mod(fac[n], fac[n - r] * fac[r]);
}
int32_t main() {
    fac[0] = 1;
    pw[0] = 1;
    for(int i = 1; i < N; i++){
        fac[i] = fac[i - 1] * i;
        fac[i] %= md;
        pw[i] = pw[i - 1] * 2;
        pw[i] %= md;
    }
	int t;
	t = readInt(1, 500000, '\n');
	int ns = 0;
	int ms = 0;
	while(t--){
	    int n, m;
	    n = readInt(1, 500000, ' ');
	    m = readInt(1, 500000, '\n');
	    ns += n;
	    ms += m;
	    assert(ns <= 500000 && ms <= 500000);
	    int ans = 0;
	    for(int i = 1; i <= m; i++){
	        ans += mod(ncr(i - 1, n - 1), pw[i]);
	        ans %= md;
	        cout<<ans<<" ";
	    }
	    cout<<"\n";
	}
}

Editorialist's code (PyPy3)
mod = 10**9 + 7
half = (mod + 1) // 2

inv = [pow(i, mod-2, mod) for i in range(300005)]

for _ in range(int(input())):
    n, m = map(int, input().split())
    ans = 0
    pw, ch = 1, 0
    answers = []
    for i in range(1, m+1):
        pw = pw * half % mod
        if i == n: ch = 1
        elif i > n:
            ch = ch * (i-1) % mod * inv[i-n] % mod
        ans += pw * ch
        answers.append(ans % mod)
    print(*answers)

While reading the editorial, I am able to understand every step. But when I am trying to solve it myself, I am stuck here.

Let’s try to compare this problem with tradition coin toss problem. Where, we do the M times coin toss.

Let’s write down basic features of coin toss experiment.

=> In total, there are 2^M different outcomes. ( Lets call this U )
=> Each of these 2^M outcome, has equal probability of \frac{1}{2^M}.
=> Let’s define, an event Head as a person falling and Tail has person not falling.
=> Probability of head p_{\text{head}} , and probability of tail be p_{\text{tail}} . Also , p_{\text{head}} + p_{\text{tail}} = 1 ( true for any coin toss ).

=> Among all the outcomes, we are interested in those outcomes, which are having exactly N falls within M coin toss. We actually don’t care about the order of these N events. So, we have \binom{M}{N} different ways of getting N heads, among M coin toss. Lets call this V

=> By definition of probability P(M,N) = \frac{\text{Number of Ways of Occurring an Event}}{\text{Total Number of Outcomes}}

=> Lets define, P(M,N) as probability of getting N heads(falls) among M coin toss.

=> P(M,N) = \frac{V}{U} = \frac{\binom{M}{N}}{2^M}.

This fails when I submit the code. Can you please help me figuire out the error in logic ?

In total, there are 2^M different outcomes. ( Lets call this UUU )

This wrong there are not 2^M outcomes.
Let’s say n = 3, m = 5, here you are overcounting to 32 outcomes including:

XOXXX
XXOXX
OXXXX
XXXOX
XXXXO
XXXXX
X denotes death, O survival.
These are all impossible cases which are counted in 2^M since there are only n=3 people.