PLANUM - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iiii63027
Tester: jay_1048576
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Basic probability, computing the sum of a geometric progression

PROBLEM:

There are three stacks, initially all empty.
For each 1 \leq i \leq N in order, you place i on top of one of the stacks with equal probability.

Given M, find the probability that the sum of the top values of the first two stacks sums to M.
The top value of an empty stack is considered to be 0.

EXPLANATION:

Note that there are 3^N distinct possible final configurations of stacks — each element has 3 choices, independent of all the other elements.
To compute the required probability, we can instead compute the number of configurations that satisfy our required condition, and divide this by 3^N.

Let’s start with a slow solution.

Consider some 0 \leq x \leq N. Let’s find the probability that x is the topmost element of the first stack, and M-x is the topmost element of the second stack.

Of course if M-x \gt N the probability is 0, so we’ll only deal with those x such that M-x \leq N (equivalently, x \geq M-N).

If M-x = x the probability is again 0, since we only have one of each element.
So, suppose M-x \neq x.
In particular, let L = \min(x, M-x) and R = \max(x, M-x), so L \lt R.
Then,

  • For elements 1, 2, 3, \ldots, L-1, it doesn’t really matter which stack they go onto.
    Each of them has 3 options, for 3^{L-1} in total.
    Note that L = 0 is a special case; instead of 3^{-1} the factor here is just 1.
  • L must go to its appropriate stack.
  • Elements L+1, L+2, \ldots, R-1 can’t go onto L's stack, but can go to either of the other two.
    This is 2^{R-L-1} options.
  • R must go to its appropriate stack.
  • Elements R+1, R+2, \ldots, N can only go to the third stack.

So, there are 3^{L-1} \cdot 2^{R-L-1} options in total.
This is easy to compute in \mathcal{O}(\log N) using binary exponentiation.

Since the total number of valid configurations is obtained by summing this across all valid x, we have a (slow) solution in \mathcal{O}(N\log N).


Now, let’s attempt to optimize this.
Suppose 0 \lt x \lt M-x. Then,

3^{x-1} \cdot 2^{M-x-x-1} = \frac{2^{M-1}}{3} \cdot 3^x \cdot 2^{-2x} = \frac{2^{M-1}}{3}\cdot\left( \frac{3}{4} \right)^x

Here, \frac{2^{M-1}}{3} is a constant, so we only really need to compute the sum of \left( \frac34 \right)^x for all valid x.
In particular, the constraints on x are:

  • x \gt 0
  • x \geq M-N
  • x \lt M-x, meaning 2x \lt M

This gives us lower and upper bounds on the possible values of x, and anything between these bounds is valid.

Computing the sum of \left( \frac34 \right)^x for a contiguous range of x means we really just want the sum of a geometric progression!
This can be computed quickly using a formula, as seen here.
In particular,

a + ar + ar^2 + \ldots + ar^{k-1} = a\cdot \left(\frac{r^k - 1}{r-1}\right)

This allows us to solve the x \lt M-x case in \mathcal{O}(\log {MOD}) time (or even \mathcal{O}(1) if you precompute appropriate powers).
Note that x = 0 (i.e the first stack being empty) is also a valid case, and the number of possibilities for it is 2^{M-1} (assuming M \leq N, of course). This can be handled separately.

The x \gt M-x case is symmetric since swapping x and M-x keeps the L and R values the same, which means the answer can just be multiplied by 2.

As noted at the start, the final answer is obtained by dividing the value we have, by 3^N.

TIME COMPLEXITY

\mathcal{O}(\log {MOD}) per testcase.

CODE:

Author's code (C++)
#include<bits/stdc++.h>
#define int long long
#define mod 998244353
using namespace std;

int binpow(int a,int b){
	int res=1;
	while(b>0){
    	if(b&1)res*=a;
    	a*=a;
    	res%=mod;
    	a%=mod;
    	b>>=1;
	}
	return res;
}

signed main()
{
	ios_base::sync_with_stdio(false);
	cin.tie(NULL);
	int t;cin>>t;
	while(t--){
	int n,m;
	cin>>n>>m;
	if(m>=2*n){
    	cout<<-1<<endl;
    	continue;;
	}
//Formula is 4*(2^m-1/3^n+1)*((3/4)^max(1,m-n+1))*(1-(3/4)^((m-1)/2-max(1,m-n+1)+1)
	int ans=4;
	ans=(ans*binpow(2,m-1))%mod;
	ans=(ans*binpow(binpow(3,n+1),mod-2))%mod;
	int minn=max(1ll,m-n);
	ans=(ans*binpow(3,minn))%mod;
	ans=(ans*binpow(binpow(4,minn),mod-2))%mod;
	int len=(m-1)/2-minn+1;
	if(len<0){
	    cout<<-1<<endl;
	    continue;
	}
	ans=(ans*((1+mod-(binpow(3,len)*binpow(binpow(4,len),mod-2))%mod)%mod))%mod;
	if(m<=n)ans+=(binpow(binpow(3,n),mod-2)*binpow(2,m-1))%mod;
	//This is the case when we place nothing on one of the stacks.
	cout<<(ans*2)%mod<<endl;
	}
}
Tester's code (C++)
/*...................................................................*
 *............___..................___.....____...______......___....*
 *.../|....../...\........./|...../...\...|.............|..../...\...*
 *../.|...../.....\......./.|....|.....|..|.............|.../........*
 *....|....|.......|...../..|....|.....|..|............/...|.........*
 *....|....|.......|..../...|.....\___/...|___......../....|..___....*
 *....|....|.......|.../....|...../...\.......\....../.....|./...\...*
 *....|....|.......|../_____|__..|.....|.......|..../......|/.....\..*
 *....|.....\...../.........|....|.....|.......|.../........\...../..*
 *..__|__....\___/..........|.....\___/...\___/.../..........\___/...*
 *...................................................................*
 */
 
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 1000000000000000000
#define MOD 998244353

int power(int a,int b)
{
    if(b==0)
        return 1;
    else
    {
        int x=power(a,b/2);
        int y=(x*x)%MOD;
        if(b%2)
            y=(y*a)%MOD;
        return y;
    }
}

int inverse(int a)
{
    return power(a,MOD-2);
}

void solve(int tc)
{
    int n,m;
    cin >> n >> m;
    int l = max(1ll,m-n), r = min(n,(m-1)/2);
    if(r<l)
        cout << 0 << '\n';
    int ans = (1-power((3*inverse(4))%MOD,r-l+1)+MOD)%MOD;
    ans = (ans*power(2,m-2*l+2))%MOD;
    ans = (ans*inverse(power(3,n-l+1)))%MOD;
    if(m<=n)
        ans = (ans+power(2,m)*inverse(power(3,n)))%MOD;
    cout << ans << '\n';
}

int32_t main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    int tc=1;
    cin >> tc;
    for(int ttc=1;ttc<=tc;ttc++)
        solve(ttc);
    return 0;
}
Editorialist's code (Python)
mod = 998244353
r = 3 * pow(4, mod-2, mod) % mod

for _ in range(int(input())):
    n, m = map(int, input().split())
    lo, hi = max(1, m-n), (m+1)//2 - 1

    # r^lo + r^(lo+1) + ... + r^hi
    ans = pow(r, lo, mod) * (pow(r, hi-lo+1, mod) - 1) % mod * pow(r-1, mod-2, mod) % mod
    ans = ans * pow(2, m+1, mod) % mod * pow(6, mod-2, mod) % mod
    if m <= n: ans += pow(2, m, mod)
    print(ans * pow(3, n*(mod-2), mod) % mod)
1 Like

Here r =3/4<1 then how summation has r^k-1 ?

The formula has \frac{r^k - 1}{r-1}
When 0 \lt r \lt 1, the numerator and denominator are both negative so their ratio ends up positive.

The link I provided here also goes into a derivation of the formula (which is very short, just some simple algebraic manipulation).
You’ll notice that it doesn’t care about whether r \lt 1 or r \gt 1, only that r \neq 1 so that dividing by 1-r is valid.

1 Like

In the line:

For elements $1, 2, 3, \ldots, L-1$, it doesn’t really matter which stack they go onto.
Each of them has $3$ options, for $3^L$ in total

Shouldn’t the options be 3^(L-1) instead, please explain if I’m missing something.

1 Like

isnt that 3^{L-1} ??

@rpriydarshi @toomatho
You’re right, I had a typo there.
It should be fixed (and formulas throughout the rest of the editorial updated) now.

The only major difference is that x = 0 needs to be special-cased now, which I’ve made a note of - somehow I’d forgotten I did that in my code.