PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: iiii63027
Tester: jay_1048576
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Basic probability, computing the sum of a geometric progression
PROBLEM:
There are three stacks, initially all empty.
For each 1 \leq i \leq N in order, you place i on top of one of the stacks with equal probability.
Given M, find the probability that the sum of the top values of the first two stacks sums to M.
The top value of an empty stack is considered to be 0.
EXPLANATION:
Note that there are 3^N distinct possible final configurations of stacks — each element has 3 choices, independent of all the other elements.
To compute the required probability, we can instead compute the number of configurations that satisfy our required condition, and divide this by 3^N.
Let’s start with a slow solution.
Consider some 0 \leq x \leq N. Let’s find the probability that x is the topmost element of the first stack, and M-x is the topmost element of the second stack.
Of course if M-x \gt N the probability is 0, so we’ll only deal with those x such that M-x \leq N (equivalently, x \geq M-N).
If M-x = x the probability is again 0, since we only have one of each element.
So, suppose M-x \neq x.
In particular, let L = \min(x, M-x) and R = \max(x, M-x), so L \lt R.
Then,
- For elements 1, 2, 3, \ldots, L-1, it doesn’t really matter which stack they go onto.
Each of them has 3 options, for 3^{L-1} in total.
Note that L = 0 is a special case; instead of 3^{-1} the factor here is just 1. - L must go to its appropriate stack.
- Elements L+1, L+2, \ldots, R-1 can’t go onto L's stack, but can go to either of the other two.
This is 2^{R-L-1} options. - R must go to its appropriate stack.
- Elements R+1, R+2, \ldots, N can only go to the third stack.
So, there are 3^{L-1} \cdot 2^{R-L-1} options in total.
This is easy to compute in \mathcal{O}(\log N) using binary exponentiation.
Since the total number of valid configurations is obtained by summing this across all valid x, we have a (slow) solution in \mathcal{O}(N\log N).
Now, let’s attempt to optimize this.
Suppose 0 \lt x \lt M-x. Then,
Here, \frac{2^{M-1}}{3} is a constant, so we only really need to compute the sum of \left( \frac34 \right)^x for all valid x.
In particular, the constraints on x are:
- x \gt 0
- x \geq M-N
- x \lt M-x, meaning 2x \lt M
This gives us lower and upper bounds on the possible values of x, and anything between these bounds is valid.
Computing the sum of \left( \frac34 \right)^x for a contiguous range of x means we really just want the sum of a geometric progression!
This can be computed quickly using a formula, as seen here.
In particular,
This allows us to solve the x \lt M-x case in \mathcal{O}(\log {MOD}) time (or even \mathcal{O}(1) if you precompute appropriate powers).
Note that x = 0 (i.e the first stack being empty) is also a valid case, and the number of possibilities for it is 2^{M-1} (assuming M \leq N, of course). This can be handled separately.
The x \gt M-x case is symmetric since swapping x and M-x keeps the L and R values the same, which means the answer can just be multiplied by 2.
As noted at the start, the final answer is obtained by dividing the value we have, by 3^N.
TIME COMPLEXITY
\mathcal{O}(\log {MOD}) per testcase.
CODE:
Author's code (C++)
#include<bits/stdc++.h>
#define int long long
#define mod 998244353
using namespace std;
int binpow(int a,int b){
int res=1;
while(b>0){
if(b&1)res*=a;
a*=a;
res%=mod;
a%=mod;
b>>=1;
}
return res;
}
signed main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int t;cin>>t;
while(t--){
int n,m;
cin>>n>>m;
if(m>=2*n){
cout<<-1<<endl;
continue;;
}
//Formula is 4*(2^m-1/3^n+1)*((3/4)^max(1,m-n+1))*(1-(3/4)^((m-1)/2-max(1,m-n+1)+1)
int ans=4;
ans=(ans*binpow(2,m-1))%mod;
ans=(ans*binpow(binpow(3,n+1),mod-2))%mod;
int minn=max(1ll,m-n);
ans=(ans*binpow(3,minn))%mod;
ans=(ans*binpow(binpow(4,minn),mod-2))%mod;
int len=(m-1)/2-minn+1;
if(len<0){
cout<<-1<<endl;
continue;
}
ans=(ans*((1+mod-(binpow(3,len)*binpow(binpow(4,len),mod-2))%mod)%mod))%mod;
if(m<=n)ans+=(binpow(binpow(3,n),mod-2)*binpow(2,m-1))%mod;
//This is the case when we place nothing on one of the stacks.
cout<<(ans*2)%mod<<endl;
}
}
Tester's code (C++)
/*...................................................................*
*............___..................___.....____...______......___....*
*.../|....../...\........./|...../...\...|.............|..../...\...*
*../.|...../.....\......./.|....|.....|..|.............|.../........*
*....|....|.......|...../..|....|.....|..|............/...|.........*
*....|....|.......|..../...|.....\___/...|___......../....|..___....*
*....|....|.......|.../....|...../...\.......\....../.....|./...\...*
*....|....|.......|../_____|__..|.....|.......|..../......|/.....\..*
*....|.....\...../.........|....|.....|.......|.../........\...../..*
*..__|__....\___/..........|.....\___/...\___/.../..........\___/...*
*...................................................................*
*/
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 1000000000000000000
#define MOD 998244353
int power(int a,int b)
{
if(b==0)
return 1;
else
{
int x=power(a,b/2);
int y=(x*x)%MOD;
if(b%2)
y=(y*a)%MOD;
return y;
}
}
int inverse(int a)
{
return power(a,MOD-2);
}
void solve(int tc)
{
int n,m;
cin >> n >> m;
int l = max(1ll,m-n), r = min(n,(m-1)/2);
if(r<l)
cout << 0 << '\n';
int ans = (1-power((3*inverse(4))%MOD,r-l+1)+MOD)%MOD;
ans = (ans*power(2,m-2*l+2))%MOD;
ans = (ans*inverse(power(3,n-l+1)))%MOD;
if(m<=n)
ans = (ans+power(2,m)*inverse(power(3,n)))%MOD;
cout << ans << '\n';
}
int32_t main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
int tc=1;
cin >> tc;
for(int ttc=1;ttc<=tc;ttc++)
solve(ttc);
return 0;
}
Editorialist's code (Python)
mod = 998244353
r = 3 * pow(4, mod-2, mod) % mod
for _ in range(int(input())):
n, m = map(int, input().split())
lo, hi = max(1, m-n), (m+1)//2 - 1
# r^lo + r^(lo+1) + ... + r^hi
ans = pow(r, lo, mod) * (pow(r, hi-lo+1, mod) - 1) % mod * pow(r-1, mod-2, mod) % mod
ans = ans * pow(2, m+1, mod) % mod * pow(6, mod-2, mod) % mod
if m <= n: ans += pow(2, m, mod)
print(ans * pow(3, n*(mod-2), mod) % mod)