PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: helloLad
Tester: wasd2401
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Combinatorics
PROBLEM:
There are N bags, each containing M balls numbered 1 to M.
You will choose exactly one ball from each bag, and arrange them randomly in a row.
Find the number of arrangements such that:
- There exists a ball labelled K appearing in this row, such that the number of balls appearing before it with values \geq K, is \lt K.
EXPLANATION:
There are a couple of different solutions to this task.
Solution 1
Letâs look at how the balls can be chosen, in terms of which have âlargeâ labels and which have âsmallâ ones.
Out of the N balls, suppose x of them have values \geq K.
The positions of these x balls can be chosen in \binom{N}{x} ways.
Then,
- The other N-x positions can take any values less than K.
- This gives a total of (K-1)^{N-x} choices.
- For the x positions with values \geq K, we need to ensure that:
- At least one of them contains the value K.
- The leftmost occurrence of K is somewhere within the first K positions; since that will ensure that this leftmost K satisfies the condition of having \lt K elements that are \geq K before it.
The second condition is a bit annoying to deal with directly, so in this case, we can instead count the number of bad arrangements and subtract them from the total number of arrangements.
That is, weâll count the number of arrangements that either donât contain K at all, or which contain K but not within the first K indices.
You may verify that:
- The total number of choices is (M-K+1)^x, freely assigning something between K and M to each index.
- The number of arrangements in which K isnât present at all is (M-K)^x.
- For K to be present but not in the first K indices,
- The first K indices should have values \gt K, for (M-K)^K choices in total.
- Then, we again count the total number of choices for all other indices and subtract from that the count of ways where no K is present: namely, (M-K+1)^{x-K} - (M-K)^{x-K}.
- Note that this case doesnât need to be considered when x \leq K.
All together, the answer we get is
For a fixed x this is just a binomial coefficient and a few powers, so we get a solution in \mathcal{O}(N\log N) time (or even linear, if you precompute powers).
Solution 2
Letâs fix the leftmost occurrence of K, say at index i. Then,
- Everything after i can be chosen freely, for M^{N-i} choices.
- K canât occur before i, and we need to ensure that there are less than K elements before i that are \gt K.
To deal with the latter, letâs fix x, the number of elements \gt K before i.
Then,
- Their positions can be chosen in \binom{i-1}{x} ways.
- Each of these x indices has M-K choices for its value, for (M-K)^x ways.
- The other indices before i have values \lt K, for (K-1)^{i-1-x} ways in total.
So, our answer is
This is quadratic, of course, but can be sped up surprisingly easily: use Pascalâs identity!
Specifically, define S_i = \sum_{x=0}^{K-1} \binom{i-1}{x} (M-K)^x (K-1)^{i-1-x}, so that weâre trying to compute
\sum_{i=1}^N M^{N-i} S_i.
Now, apply Pascalâs identity, so that \binom{i-1}{x} = \binom{i-2}{x} + \binom{i-2}{x-1}.
Substituting this into the expression for S_i, you can see that S_i can be computed in terms of S_{i-1} instead.
Specifically, we have
This way, all the S_i values can be computed in \mathcal{O}(N) or \mathcal{O}(N\log N), and weâre done.
TIME COMPLEXITY:
\mathcal{O}(N) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define IOS ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define ll long long
const ll N=1e6+10;
const ll M=1000000007;
ll fact[N];
ll fact_inverse[N];
ll binExp(ll a,ll b){
ll ans=1;
while(b){
if(b&1){
ans=(ans*1LL*a)%M;
}
a=(a*1LL*a)%M;
b>>=1;
}
return ans;
}
void factorial(){
fact[0]=1;
for(ll i=1;i<N;++i){
fact[i]=(fact[i-1]*1LL*i)%M;
}
}
void inverse_factorial(){
fact_inverse[0]=1;
for(ll i=1;i<N;++i){
fact_inverse[i]=binExp(fact[i],M-2);
}
}
ll ncr(ll n,ll r){
if(r>n){
return 0;
}
return (fact[n]*((fact_inverse[r]*fact_inverse[n-r])%M))%M;
}
int main(){
IOS
factorial();
inverse_factorial();
int t;
cin>>t;
while(t--){
ll n,m,k;
cin>>n>>m>>k;
ll ans=0;
for(ll i=1;i<=n;++i){
ll cnt=1;
// selecting bags from which these i balls are selected
cnt*=ncr(n,i);
cnt%=M;
// selecting balls which are numbered less than k
cnt*=binExp(k-1,n-i);
cnt%=M;
// selecting i positions in the final arrangement
cnt*=ncr(n,i);
cnt%=M;
// arranging the selected n-i balls numbered less than k in remaining n-i positions in the final arrangement
cnt*=fact[n-i];
cnt%=M;
if(i<=k){
// selecting balls which are numbered greater than or equal to k while ensuring at least one k-numbered ball is selected
cnt*=(binExp(m-k+1,i)-binExp(m-k,i)+M);
cnt%=M;
}else{
// selecting balls such that at least one k-numbered ball exists
// so that number of balls numbered greater than k before it is less than ks
cnt*=(binExp(m-k+1,i)-(((binExp(m-k,k)*binExp(m-k+1,i-k))%M))+M);
cnt%=M;
}
// arranging the selected i balls numbered greater than k in selected i positions in the final arrangement
cnt*=fact[i];
cnt%=M;
ans+=cnt;
ans%=M;
}
ans*=binExp(fact[n],M-2);
ans%=M;
cout<<ans<<endl;
}
return 0;
}
Tester's code (C++)
/*
* * * *** * *****
* * * * * * * *
* * * *** ***** *
* * * * * * * * *
* * * * * * **
*
* *
*****
* *
***** * *
_* *_
| * * * * | ***
|_* _ *_| * *
* * *
***** * **
* * ***
{===* *===}
* IS * ***
* IT * * *
* RATED?* *
* * * **
* * ***
* *
***** * *
* *
* *
* *
***
*/
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
using namespace std;
#define osl tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update>
#define ll long long
#define ld long double
#define forl(i, a, b) for(ll i = a; i < b; i++)
#define rofl(i, a, b) for(ll i = a; i > b; i--)
#define fors(i, a, b, c) for(ll i = a; i < b; i += c)
#define fora(x, v) for(auto x : v)
#define vl vector<ll>
#define vb vector<bool>
#define pub push_back
#define pob pop_back
#define fbo find_by_order
#define ook order_of_key
#define yesno(x) cout << ((x) ? "YES" : "NO")
#define all(v) v.begin(), v.end()
const ll N = 2e5 + 4;
const ll mod = 1e9 + 7;
// const ll mod = 998244353;
vl fact(N,1);
ll modinverse(ll a) {
ll m = mod, y = 0, x = 1;
while (a > 1) {
ll q = a / m;
ll t = m;
m = a % m;
a = t;
t = y;
y = x - q * y;
x = t;
}
if (x < 0) x += mod;
return x;
}
ll gcd(ll a, ll b) {
if (b == 0)
return a;
return gcd(b, a % b);
}
ll lcm(ll a, ll b) {
return (a / gcd(a, b)) * b;
}
bool poweroftwo(ll n) {
return !(n & (n - 1));
}
ll power(ll a, ll b, ll md = mod) {
if(b<0) return 0;
ll product = 1;
a %= md;
while (b) {
if (b & 1) product = (product * a) % md;
a = (a * a) % md;
b /= 2;
}
return product % md;
}
ll barfi(ll n, ll r){
if(n<r || r<0) return 0;
ll p=modinverse(fact[r])*modinverse(fact[n-r]);
p%=mod;
return (p*fact[n])%mod;
}
void panipuri() {
ll n, m = 0, k = -1, c = 0, sum = 0, q = 0, ans = 0, p = 1;
string s;
bool ch = true;
cin >> n>>m>>k;
assert(n<=2e5 && n>=1);
assert(m<=1e9 && m>=1);
assert(1<=k && k<=m);
c=m-k+1;
forl(i,1,n+1){
ll k1=min(i,k);
p=power(c,i)-power(c-1,k1)*power(c,i-k1);
p%=mod;
p+=mod;
p%=mod;
p*=power(k-1,n-i);
p%=mod;
p*=barfi(n,i);
p%=mod;
ans+=p;
ans%=mod;
// cout<<ans<<' ';
}
cout<<ans;
return;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
#endif
int laddu = 1;
cin >> laddu;
forl(i,1,N){
fact[i]=i*fact[i-1];
fact[i]%=mod;
}
forl(i, 1, laddu + 1) {
// cout << "Case #" << i << ": ";
panipuri();
cout << '\n';
}
}
Editorialist's code (Python)
N = 10**6 + 10
mod = 10**9 + 7
fac = [1]*N
for i in range(1, N): fac[i] = fac[i-1] * i % mod
inv = fac[:]
inv[-1] = pow(fac[-1], mod-2, mod)
for i in reversed(range(N-1)): inv[i] = inv[i+1] * (i+1) % mod
def C(n, r):
if n < r or r < 0: return 0
return fac[n] * inv[r] % mod * inv[n-r] % mod
for _ in range(int(input())):
n, m, k = map(int, input().split())
ans, before = 0, 1
for i in range(1, n+1):
after = pow(m, n-i, mod)
if i <= k: before = pow(m-1, i-1, mod)
else: before = (before * (k-1) + (before - C(i-2, k-1)*pow(m-k, k-1, mod)*pow(k-1, i-1-k, mod)%mod) * (m-k)) % mod
ans += after * before % mod
print(ans % mod)