# PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Contest: Division 3

Contest: Division 4

* Author:* helloLad

*wasd2401*

**Tester:***iceknight1093*

**Editorialist:**# DIFFICULTY:

TBD

# PREREQUISITES:

Combinatorics

# PROBLEM:

There are N bags, each containing M balls numbered 1 to M.

You will choose exactly one ball from each bag, and arrange them randomly in a row.

Find the number of arrangements such that:

- There exists a ball labelled K appearing in this row, such that the number of balls appearing before it with values \geq K, is \lt K.

# EXPLANATION:

There are a couple of different solutions to this task.

## Solution 1

Let’s look at how the balls can be chosen, in terms of which have ‘large’ labels and which have ‘small’ ones.

Out of the N balls, suppose x of them have values \geq K.

The positions of these x balls can be chosen in \binom{N}{x} ways.

Then,

- The other N-x positions can take
*any*values less than K.- This gives a total of (K-1)^{N-x} choices.

- For the x positions with values \geq K, we need to ensure that:
- At least one of them contains the value K.
- The leftmost occurrence of K is somewhere within the first K positions; since that will ensure that this leftmost K satisfies the condition of having \lt K elements that are \geq K before it.

The second condition is a bit annoying to deal with directly, so in this case, we can instead count the number of *bad* arrangements and subtract them from the total number of arrangements.

That is, we’ll count the number of arrangements that either don’t contain K at all, or which contain K but not within the first K indices.

You may verify that:

- The total number of choices is (M-K+1)^x, freely assigning something between K and M to each index.
- The number of arrangements in which K isn’t present at all is (M-K)^x.
- For K to be present but not in the first K indices,
- The first K indices should have values \gt K, for (M-K)^K choices in total.
- Then, we again count the total number of choices for all other indices and subtract from that the count of ways where no K is present: namely, (M-K+1)^{x-K} - (M-K)^{x-K}.
- Note that this case doesn’t need to be considered when x \leq K.

All together, the answer we get is

For a fixed x this is just a binomial coefficient and a few powers, so we get a solution in \mathcal{O}(N\log N) time (or even linear, if you precompute powers).

## Solution 2

Let’s fix the leftmost occurrence of K, say at index i. Then,

- Everything after i can be chosen freely, for M^{N-i} choices.
- K can’t occur before i, and we need to ensure that there are less than K elements before i that are \gt K.

To deal with the latter, let’s fix x, the number of elements \gt K before i.

Then,

- Their positions can be chosen in \binom{i-1}{x} ways.
- Each of these x indices has M-K choices for its value, for (M-K)^x ways.
- The other indices before i have values \lt K, for (K-1)^{i-1-x} ways in total.

So, our answer is

This is quadratic, of course, but can be sped up surprisingly easily: use Pascal’s identity!

Specifically, define S_i = \sum_{x=0}^{K-1} \binom{i-1}{x} (M-K)^x (K-1)^{i-1-x}, so that we’re trying to compute

\sum_{i=1}^N M^{N-i} S_i.

Now, apply Pascal’s identity, so that \binom{i-1}{x} = \binom{i-2}{x} + \binom{i-2}{x-1}.

Substituting this into the expression for S_i, you can see that S_i can be computed in terms of S_{i-1} instead.

Specifically, we have

This way, all the S_i values can be computed in \mathcal{O}(N) or \mathcal{O}(N\log N), and we’re done.

# TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

# CODE:

## Author's code (C++)

```
#include <bits/stdc++.h>
using namespace std;
#define IOS ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define ll long long
const ll N=1e6+10;
const ll M=1000000007;
ll fact[N];
ll fact_inverse[N];
ll binExp(ll a,ll b){
ll ans=1;
while(b){
if(b&1){
ans=(ans*1LL*a)%M;
}
a=(a*1LL*a)%M;
b>>=1;
}
return ans;
}
void factorial(){
fact[0]=1;
for(ll i=1;i<N;++i){
fact[i]=(fact[i-1]*1LL*i)%M;
}
}
void inverse_factorial(){
fact_inverse[0]=1;
for(ll i=1;i<N;++i){
fact_inverse[i]=binExp(fact[i],M-2);
}
}
ll ncr(ll n,ll r){
if(r>n){
return 0;
}
return (fact[n]*((fact_inverse[r]*fact_inverse[n-r])%M))%M;
}
int main(){
IOS
factorial();
inverse_factorial();
int t;
cin>>t;
while(t--){
ll n,m,k;
cin>>n>>m>>k;
ll ans=0;
for(ll i=1;i<=n;++i){
ll cnt=1;
// selecting bags from which these i balls are selected
cnt*=ncr(n,i);
cnt%=M;
// selecting balls which are numbered less than k
cnt*=binExp(k-1,n-i);
cnt%=M;
// selecting i positions in the final arrangement
cnt*=ncr(n,i);
cnt%=M;
// arranging the selected n-i balls numbered less than k in remaining n-i positions in the final arrangement
cnt*=fact[n-i];
cnt%=M;
if(i<=k){
// selecting balls which are numbered greater than or equal to k while ensuring at least one k-numbered ball is selected
cnt*=(binExp(m-k+1,i)-binExp(m-k,i)+M);
cnt%=M;
}else{
// selecting balls such that at least one k-numbered ball exists
// so that number of balls numbered greater than k before it is less than ks
cnt*=(binExp(m-k+1,i)-(((binExp(m-k,k)*binExp(m-k+1,i-k))%M))+M);
cnt%=M;
}
// arranging the selected i balls numbered greater than k in selected i positions in the final arrangement
cnt*=fact[i];
cnt%=M;
ans+=cnt;
ans%=M;
}
ans*=binExp(fact[n],M-2);
ans%=M;
cout<<ans<<endl;
}
return 0;
}
```

## Tester's code (C++)

```
/*
* * * *** * *****
* * * * * * * *
* * * *** ***** *
* * * * * * * * *
* * * * * * **
*
* *
*****
* *
***** * *
_* *_
| * * * * | ***
|_* _ *_| * *
* * *
***** * **
* * ***
{===* *===}
* IS * ***
* IT * * *
* RATED?* *
* * * **
* * ***
* *
***** * *
* *
* *
* *
***
*/
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
using namespace std;
#define osl tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update>
#define ll long long
#define ld long double
#define forl(i, a, b) for(ll i = a; i < b; i++)
#define rofl(i, a, b) for(ll i = a; i > b; i--)
#define fors(i, a, b, c) for(ll i = a; i < b; i += c)
#define fora(x, v) for(auto x : v)
#define vl vector<ll>
#define vb vector<bool>
#define pub push_back
#define pob pop_back
#define fbo find_by_order
#define ook order_of_key
#define yesno(x) cout << ((x) ? "YES" : "NO")
#define all(v) v.begin(), v.end()
const ll N = 2e5 + 4;
const ll mod = 1e9 + 7;
// const ll mod = 998244353;
vl fact(N,1);
ll modinverse(ll a) {
ll m = mod, y = 0, x = 1;
while (a > 1) {
ll q = a / m;
ll t = m;
m = a % m;
a = t;
t = y;
y = x - q * y;
x = t;
}
if (x < 0) x += mod;
return x;
}
ll gcd(ll a, ll b) {
if (b == 0)
return a;
return gcd(b, a % b);
}
ll lcm(ll a, ll b) {
return (a / gcd(a, b)) * b;
}
bool poweroftwo(ll n) {
return !(n & (n - 1));
}
ll power(ll a, ll b, ll md = mod) {
if(b<0) return 0;
ll product = 1;
a %= md;
while (b) {
if (b & 1) product = (product * a) % md;
a = (a * a) % md;
b /= 2;
}
return product % md;
}
ll barfi(ll n, ll r){
if(n<r || r<0) return 0;
ll p=modinverse(fact[r])*modinverse(fact[n-r]);
p%=mod;
return (p*fact[n])%mod;
}
void panipuri() {
ll n, m = 0, k = -1, c = 0, sum = 0, q = 0, ans = 0, p = 1;
string s;
bool ch = true;
cin >> n>>m>>k;
assert(n<=2e5 && n>=1);
assert(m<=1e9 && m>=1);
assert(1<=k && k<=m);
c=m-k+1;
forl(i,1,n+1){
ll k1=min(i,k);
p=power(c,i)-power(c-1,k1)*power(c,i-k1);
p%=mod;
p+=mod;
p%=mod;
p*=power(k-1,n-i);
p%=mod;
p*=barfi(n,i);
p%=mod;
ans+=p;
ans%=mod;
// cout<<ans<<' ';
}
cout<<ans;
return;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
#endif
int laddu = 1;
cin >> laddu;
forl(i,1,N){
fact[i]=i*fact[i-1];
fact[i]%=mod;
}
forl(i, 1, laddu + 1) {
// cout << "Case #" << i << ": ";
panipuri();
cout << '\n';
}
}
```

## Editorialist's code (Python)

```
N = 10**6 + 10
mod = 10**9 + 7
fac = [1]*N
for i in range(1, N): fac[i] = fac[i-1] * i % mod
inv = fac[:]
inv[-1] = pow(fac[-1], mod-2, mod)
for i in reversed(range(N-1)): inv[i] = inv[i+1] * (i+1) % mod
def C(n, r):
if n < r or r < 0: return 0
return fac[n] * inv[r] % mod * inv[n-r] % mod
for _ in range(int(input())):
n, m, k = map(int, input().split())
ans, before = 0, 1
for i in range(1, n+1):
after = pow(m, n-i, mod)
if i <= k: before = pow(m-1, i-1, mod)
else: before = (before * (k-1) + (before - C(i-2, k-1)*pow(m-k, k-1, mod)*pow(k-1, i-1-k, mod)%mod) * (m-k)) % mod
ans += after * before % mod
print(ans % mod)
```