PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: alpha_ashwin, shanmukh29
Tester: raysh07
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Elementary probability
PROBLEM:
There are B boxes, each containing M cheese balls and N chicken bones.
For each i from 1 to B-1, i random items are moved from box i to box i+1.
At the end of the process, find the expected number of cheese balls in the K-th box.
EXPLANATION:
The process seems like it gets complicated very quickly, but it really doesn’t.
The key here is to look at things from a global perspective, and the final answer turns out to be quite simple!
The main observation to be made here is that, since all the boxes start with an equal ratio of cheese balls to chicken bones, moving around a few items randomly doesn’t really change that ratio (or rather, the expected ratio) at all.
Intuitively, you can see that:
- The expected number of cheese balls moved from box 1 to 2 is \frac{M}{M+N}.
Note that this will keep the cheese-chicken ratio in both boxes 1 and 2 at \frac{M}{N} (in expectation), which can be worked out with some simple algebra. - Next, the expected number of cheese balls moved from box 2 to 3 is 2\cdot \frac{M}{M+N} because the ratio hasn’t changed.
Once again, the expected cheese-chicken ratio of both boxes remains the same: \frac{M}{N} - This process continues till the end, and the expected cheese-chicken ratio of all the boxes is still \frac{M}{N} in the end.
Note that this is only true because M, N\geq B, so we can never ‘run out’ of chicken and/or cheese.
This already gives us the answer!
If the expected cheese-chicken ratio is \frac{M}{N}, then the expected number of cheese balls is simply \frac{M}{M+N} multiplied by the number of items.
So,
- If K = B, there are N+M+B-1 items in the last box, and the answer is
- If 1 \leq K \lt B, there are N+M-1 items in the box (K-1 items were added to it from the previous box, and K were removed from it and given to the next one, for a net change of -1), and the answer is
TIME COMPLEXITY:
\mathcal{O}(1) per testcase.
CODE:
Author's code (Python)
t = int(input())
for i in range(t):
inputline = list(map(int,input().split()))
x,m,n,k = inputline[0],inputline[1],inputline[2],inputline[3]
if x ==1:
print(str(m))
elif k == x:
print(str((m+n+x-1)*m/(m+n)))
else:
print(str((m+n-1)*m/(m+n)))
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
#define f first
#define s second
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}
int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && !isspace(buffer[now])) {
now++;
}
return now;
}
string readOne() {
assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
return res;
}
string readString(int minl, int maxl, const string &pattern = "") {
assert(minl <= maxl);
string res = readOne();
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int minv, int maxv) {
assert(minv <= maxv);
int res = stoi(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
long long res = stoll(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
auto readInts(int n, int minv, int maxv) {
assert(n >= 0);
vector<int> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readInt(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
auto readLongs(int n, long long minv, long long maxv) {
assert(n >= 0);
vector<long long> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readLong(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
};
input_checker inp;
const int T = 1e5;
const int N = 1e6;
void Solve()
{
int x, m, n, k;
x = inp.readInt(1, N); inp.readSpace();
m = inp.readInt(x, N); inp.readSpace();
n = inp.readInt(x, N); inp.readSpace();
k = inp.readInt(1, x); inp.readEoln();
double prob = (double) m / (m + n);
if (k != x) prob *= (m + n - 1);
else prob *= (m + n + x - 1);
cout << fixed << setprecision(10) << prob << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
t = inp.readInt(1, T);
inp.readEoln();
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
inp.readEof();
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
b, m, n, k = map(int, input().split())
if k == b: print((m+n+b-1)*m/(m+n))
else: print((m+n-1)*m/(m+n))