CNTP - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: helloLad
Tester: wasd2401
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Combinatorics

PROBLEM:

There are N bags, each containing M balls numbered 1 to M.
You will choose exactly one ball from each bag, and arrange them randomly in a row.

Find the number of arrangements such that:

  • There exists a ball labelled K appearing in this row, such that the number of balls appearing before it with values \geq K, is \lt K.

EXPLANATION:

There are a couple of different solutions to this task.

Solution 1

Let’s look at how the balls can be chosen, in terms of which have ‘large’ labels and which have ‘small’ ones.

Out of the N balls, suppose x of them have values \geq K.
The positions of these x balls can be chosen in \binom{N}{x} ways.
Then,

  • The other N-x positions can take any values less than K.
    • This gives a total of (K-1)^{N-x} choices.
  • For the x positions with values \geq K, we need to ensure that:
    • At least one of them contains the value K.
    • The leftmost occurrence of K is somewhere within the first K positions; since that will ensure that this leftmost K satisfies the condition of having \lt K elements that are \geq K before it.

The second condition is a bit annoying to deal with directly, so in this case, we can instead count the number of bad arrangements and subtract them from the total number of arrangements.
That is, we’ll count the number of arrangements that either don’t contain K at all, or which contain K but not within the first K indices.
You may verify that:

  • The total number of choices is (M-K+1)^x, freely assigning something between K and M to each index.
  • The number of arrangements in which K isn’t present at all is (M-K)^x.
  • For K to be present but not in the first K indices,
    • The first K indices should have values \gt K, for (M-K)^K choices in total.
    • Then, we again count the total number of choices for all other indices and subtract from that the count of ways where no K is present: namely, (M-K+1)^{x-K} - (M-K)^{x-K}.
    • Note that this case doesn’t need to be considered when x \leq K.

All together, the answer we get is

\sum_{x=0}^K \binom{N}{x} (K-1)^{N-x} \cdot ((M-K+1)^x - (M-K)^x) \\ + \sum_{x=K+1}^N \binom{N}{x} (K-1)^{N-x} \cdot ((M-K+1)^x - (M-K)^x - (M-K)^K \cdot ((M-K+1)^{x-K} - (M-K)^{x-K})) \\

For a fixed x this is just a binomial coefficient and a few powers, so we get a solution in \mathcal{O}(N\log N) time (or even linear, if you precompute powers).

Solution 2

Let’s fix the leftmost occurrence of K, say at index i. Then,

  • Everything after i can be chosen freely, for M^{N-i} choices.
  • K can’t occur before i, and we need to ensure that there are less than K elements before i that are \gt K.

To deal with the latter, let’s fix x, the number of elements \gt K before i.
Then,

  • Their positions can be chosen in \binom{i-1}{x} ways.
  • Each of these x indices has M-K choices for its value, for (M-K)^x ways.
  • The other indices before i have values \lt K, for (K-1)^{i-1-x} ways in total.

So, our answer is

\sum_{i=1}^N \left( M^{N-i} \sum_{x=0}^{K-1} \binom{i-1}{x} (M-K)^x (K-1)^{i-1-x} \right)

This is quadratic, of course, but can be sped up surprisingly easily: use Pascal’s identity!
Specifically, define S_i = \sum_{x=0}^{K-1} \binom{i-1}{x} (M-K)^x (K-1)^{i-1-x}, so that we’re trying to compute
\sum_{i=1}^N M^{N-i} S_i.

Now, apply Pascal’s identity, so that \binom{i-1}{x} = \binom{i-2}{x} + \binom{i-2}{x-1}.
Substituting this into the expression for S_i, you can see that S_i can be computed in terms of S_{i-1} instead.

Specifically, we have

S_i = S_{i-1}\cdot (K-1) + (S_{i-1} - \binom{i-2}{k-1} (M-K)^{K-1} (K-1)^{i-1-K})\cdot (M-K)

This way, all the S_i values can be computed in \mathcal{O}(N) or \mathcal{O}(N\log N), and we’re done.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define IOS ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define ll long long

const ll N=1e6+10;
const ll M=1000000007;

ll fact[N];
ll fact_inverse[N];

ll binExp(ll a,ll b){
    ll ans=1;
    while(b){
        if(b&1){
            ans=(ans*1LL*a)%M;
        }
        a=(a*1LL*a)%M;
        b>>=1;
    }
    return ans;
}

void factorial(){
    fact[0]=1;
    for(ll i=1;i<N;++i){
        fact[i]=(fact[i-1]*1LL*i)%M;
    }
}

void inverse_factorial(){
    fact_inverse[0]=1;
    for(ll i=1;i<N;++i){
        fact_inverse[i]=binExp(fact[i],M-2);
    }
}

ll ncr(ll n,ll r){
    if(r>n){
        return 0;
    }
    return (fact[n]*((fact_inverse[r]*fact_inverse[n-r])%M))%M;
}

int main(){
    IOS
    factorial();
    inverse_factorial();
    int t;
    cin>>t;
    while(t--){
        ll n,m,k;
        cin>>n>>m>>k;
        ll ans=0;
        for(ll i=1;i<=n;++i){
            ll cnt=1;
            // selecting bags from which these i balls are selected
            cnt*=ncr(n,i);
            cnt%=M;
            // selecting balls which are numbered less than k
            cnt*=binExp(k-1,n-i);
            cnt%=M;
            // selecting i positions in the final arrangement
            cnt*=ncr(n,i);
            cnt%=M;
            // arranging the selected n-i balls numbered less than k in remaining n-i positions in the final arrangement
            cnt*=fact[n-i];
            cnt%=M;
            if(i<=k){
                // selecting balls which are numbered greater than or equal to k while ensuring at least one k-numbered ball is selected 
                cnt*=(binExp(m-k+1,i)-binExp(m-k,i)+M);
                cnt%=M;
            }else{
                // selecting balls such that at least one k-numbered ball exists 
                // so that number of balls numbered greater than k before it is less than ks
                cnt*=(binExp(m-k+1,i)-(((binExp(m-k,k)*binExp(m-k+1,i-k))%M))+M);
                cnt%=M;
            }
            // arranging the selected i balls numbered greater than k in selected i positions in the final arrangement
            cnt*=fact[i];
            cnt%=M;
            ans+=cnt;
            ans%=M;
        }
        ans*=binExp(fact[n],M-2);
        ans%=M;
        cout<<ans<<endl;
    }
    return 0;
}




Tester's code (C++)
/*

*       *  *  ***       *       *****
 *     *   *  *  *     * *        *
  *   *    *  ***     *****       *
   * *     *  * *    *     *   *  *
    *      *  *  *  *       *   **

                                 *
                                * *
                               *****
                              *     *
        *****                *       *
      _*     *_
     | * * * * |                ***
     |_*  _  *_|               *   *
       *     *                 *  
        *****                  *  **
       *     *                  ***
  {===*       *===}
      *  IS   *                 ***
      *  IT   *                *   *
      * RATED?*                *  
      *       *                *  **
      *       *                 ***
       *     *
        *****                  *   *
                               *   *
                               *   *
                               *   *
                                ***   

*/

#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace __gnu_pbds;
using namespace std;

#define osl tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update>
#define ll long long
#define ld long double
#define forl(i, a, b) for(ll i = a; i < b; i++)
#define rofl(i, a, b) for(ll i = a; i > b; i--)
#define fors(i, a, b, c) for(ll i = a; i < b; i += c)
#define fora(x, v) for(auto x : v)
#define vl vector<ll>
#define vb vector<bool>
#define pub push_back
#define pob pop_back
#define fbo find_by_order
#define ook order_of_key
#define yesno(x) cout << ((x) ? "YES" : "NO")
#define all(v) v.begin(), v.end()

const ll N = 2e5 + 4;
const ll mod = 1e9 + 7;
// const ll mod = 998244353;

vl fact(N,1);
ll modinverse(ll a) {
	ll m = mod, y = 0, x = 1;
	while (a > 1) {
		ll q = a / m;
		ll t = m;
		m = a % m;
		a = t;
		t = y;
		y = x - q * y;
		x = t;
	}
	if (x < 0) x += mod;
	return x;
}
ll gcd(ll a, ll b) {
	if (b == 0)
		return a;
	return gcd(b, a % b);
}
ll lcm(ll a, ll b) {
	return (a / gcd(a, b)) * b;
}
bool poweroftwo(ll n) {
	return !(n & (n - 1));
}
ll power(ll a, ll b, ll md = mod) {
	if(b<0) return 0;
	ll product = 1;
	a %= md;
	while (b) {
		if (b & 1) product = (product * a) % md;
		a = (a * a) % md;
		b /= 2;
	}
	return product % md;
}
ll barfi(ll n, ll r){
	if(n<r || r<0) return 0;
	ll p=modinverse(fact[r])*modinverse(fact[n-r]);
	p%=mod;
	return (p*fact[n])%mod;
}
void panipuri() {
	ll n, m = 0, k = -1, c = 0, sum = 0, q = 0, ans = 0, p = 1;
	string s;
	bool ch = true;
	cin >> n>>m>>k;
	assert(n<=2e5 && n>=1);
	assert(m<=1e9 && m>=1);
	assert(1<=k && k<=m);
	c=m-k+1;
	forl(i,1,n+1){
		ll k1=min(i,k);
		p=power(c,i)-power(c-1,k1)*power(c,i-k1);
		p%=mod;
		p+=mod;
		p%=mod;
		p*=power(k-1,n-i);
		p%=mod;
		p*=barfi(n,i);
		p%=mod;
		ans+=p;
		ans%=mod;
		// cout<<ans<<' ';
	}
	cout<<ans;
	return;
}
int main() {
	ios::sync_with_stdio(false);
	cin.tie(NULL);
	#ifndef ONLINE_JUDGE
	freopen("input.txt", "r", stdin);
	freopen("output.txt", "w", stdout);
	#endif
	int laddu = 1;
	cin >> laddu;
	forl(i,1,N){
		fact[i]=i*fact[i-1];
		fact[i]%=mod;
	}
	forl(i, 1, laddu + 1) {
		// cout << "Case #" << i << ": ";
		panipuri();
		cout << '\n';
	}
}
Editorialist's code (Python)
N = 10**6 + 10
mod = 10**9 + 7
fac = [1]*N
for i in range(1, N): fac[i] = fac[i-1] * i % mod
inv = fac[:]
inv[-1] = pow(fac[-1], mod-2, mod)
for i in reversed(range(N-1)): inv[i] = inv[i+1] * (i+1) % mod
def C(n, r):
    if n < r or r < 0: return 0
    return fac[n] * inv[r] % mod * inv[n-r] % mod

for _ in range(int(input())):
    n, m, k = map(int, input().split())
    ans, before = 0, 1
    for i in range(1, n+1):
        after = pow(m, n-i, mod)
        if i <= k: before = pow(m-1, i-1, mod)
        else: before = (before * (k-1) + (before - C(i-2, k-1)*pow(m-k, k-1, mod)*pow(k-1, i-1-k, mod)%mod) * (m-k)) % mod

        ans += after * before % mod
    print(ans % mod)
3 Likes

“surprisingly easily”, huh? I spent half an hour rearranging this sum and got nowhere. You really need more testers for a better perception of difficulties.

4 Likes

That line is admittedly a bit biased, because afaik in testing I’m the only one who solved it that way and everyone else had the first solution (which is also the author’s intended solution).

I had an easy time with it because the exact same optimization for basically the same-looking summation was used here quite recently, so it’s one of the first things I tried - if you haven’t seen it before, I suppose it’s not likely to be the first thing one tries.

The first solution is (again, in my opinion) fairly straightforward combinatorial reasoning, and you end up with something that doesn’t need to be optimized from quadratic to linear.

2 Likes

What do you know, I did every Starters in 2024 except that contest :smiling_face_with_tear:

1 Like

hey could you tell me why its m^(n-i) and not (m-1)^(n-i) since we cannot take k right?

In the Solution 2 in transformation from Si-1 to Si,
why are we multiplying second part of equation with (M-k) ? Should not it be (k-1)?

MN-i denotes the number of ways in which you can fill the positions after i. Since there are N-i positions after i and at each position we have M options to fill since there are no restrictions after i because we are placing K-numbered ball at i, therefore the total number of ways is MN-i

@iceknight1093 In the second solution, the relation between Si and Si-1 provided in the editorial works fine when i > k. How does it work for i <= k because in that case i - 2 < k - 1 and C(i - 2, k - 1) will no longer make sense. So how to tackle the case for i <= k ?

You have done some kind of trimming using k1 = min(i, k) in your solution but I am not able to understand how the recurrence will work for i < k.

For i<=k, you can see that the summation simplifies to (M-1)^(i-1)

I get that, but we have to handle it explicitly since the recurrence won’t work, right?

Have you tried writing out the expression and seeing what it reduces to?

\begin{align*} S_i &= \sum_{x=0}^{K-1} \binom{i-1}{x} (M-K)^x (K-1)^{i-1-x} \\ &= \sum_{x=0}^{K-1} \left(\binom{i-2}{x} + \binom{i-2}{x-1}\right) (M-K)^x (K-1)^{i-1-x} \\ &= \sum_{x=0}^{K-1} \binom{i-2}{x}(M-K)^x (K-1)^{i-1-x} + \sum_{x=0}^{K-1} \binom{i-2}{x-1}(M-K)^x (K-1)^{i-1-x} \end{align*}

The first summation is exactly S_{i-1} multiplied by (K-1).
The second summation is almost S_{i-1} multiplied by (M-K), the only problem is that you have \binom{i-2}{x-1} instead of \binom{i-2}{x}; and x-1 goes from -1 to K-2 rather than 0 to K-1.
The exact differences are:

  • You have an extra \binom{i-2}{-1}(M-K)^0(K-1)^{i-1} term.
    This doesn’t matter anyway since \binom{i-2}{-1} = 0 for any i.
  • You’re missing the term \binom{i-2}{K-1}(M-K)^{K}(K-1)^{i-1-K}.
    This can be subtracted out since it’s just a single term (and you’ll notice I indeed subtracted it out when giving the recurrence for S_i).

\binom{i-2}{K-1} is actually perfectly well defined even if i-2 \lt K-1; it’s just 0, so you’re basically not performing the subtraction (or subtracting 0, rather).
So no, mathematically you don’t actually need any special care for “small i”.

My code had that casework purely because otherwise, it would do pow(k-1, i-1-k, mod) where i-1-k can be negative, and that will throw a runtime error (even though the entire product evaluates to 0 anyway because \binom{i-2}{K-1} = 0 in such cases).
See this submission for example where I use a different method to avoid that power becoming negative.

3 Likes
#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define forl(i, a, b) for (ll i = a; i < b; i++)
#define vl vector<ll>

ll mod = 1e9 + 7;
ll N = 2e5 + 4;
vl fact(N, 1);

ll mul_mod(ll a,ll b, ll mod) {
    return ((a%mod)*(b%mod))%mod;
}


ll modinverse(ll a) {
    ll m = mod, y = 0, x = 1;
    while (a > 1) {
        ll q = a / m;
        ll t = m;
        m = a % m;
        a = t;
        t = y;
        y = x - q * y;
        x = t;
    }
    if (x < 0)
        x += mod;
    return x;
}

ll nCr(ll n, ll r) {
    if (n < r || r < 0)
        return 0;
    ll p = mul_mod(modinverse(fact[r]), modinverse(fact[n - r]),mod);
    p %= mod;
    return mul_mod(p,fact[n],mod);
}
ll power(ll a, ll b, ll md) {
    if(a == 0) return 0;
    if (b < 0)
        return 0;
    ll product = 1;
    a %= md;
    while (b) {
        if (b & 1)
            product = (product * a) % md;
        a = (a * a) % md;
        b /= 2;
    }
    return product % md;
}



ll add_mod(ll a,ll b, ll mod) {
    return ((a%mod)+(b%mod))%mod;
}



int main() {
    forl(i, 1, N) {
        fact[i] = mul_mod(i, fact[i - 1],mod);
        fact[i] %= mod;
    }
    int t;
    cin >> t;
    while (t--) {
        ll n;
        cin >> n;
        ll m;
        cin >> m;
        ll k;
        cin >> k;

        // greater = m-k

        // smaller = k-1

        // k-1 elements >= k
        // dp[i] no of selections when K appears at ith position.
        // dp[i] =
        // (dp[i-1]*(k-1)+(dp[i-1]-ncr(i-2,k-1)*(m-k)^(k-1)*(k-1)^(i-k-1))*(m-k))*m^(n-i)

        // sum of all dp[i] * (m^)
        // dp[1] = m^(n-1)
        // dp[2] = m^(n-1)
        vector<ll> dp(n+5, 1);
        ll ans = 0;
        for (int i = 1; i <= k; i++) {
            ll d = power(m, n - i, mod);
            ll c = power(m - 1, i - 1, mod);
            dp[i] = c % mod;
            ans += mul_mod(d,dp[i],mod);
            ans = ans%mod;
        }

        for (int i = k + 1; i <= n; i++) {
            ll t = dp[i - 1] * (k - 1);
            t = t % mod;
            ll t1 = 0;
            t1 = add_mod(t1,dp[i - 1],mod);
            t1 = t1 % mod;
            ll g = nCr(i - 2, k - 1); g=g%mod;
            g = mul_mod(g,power(m - k, k - 1, mod),mod);
            g = g % mod;
            g = mul_mod(g,power(m - k, k - 1, mod),mod);
            g = g % mod;
            t1= add_mod(mod-g,t1,mod);
            t1 = t1 % mod;
            t1 = mul_mod(t,m-k,mod);
            t1 = t1 % mod;
            t =add_mod(t,t1,mod);
            t = t % mod;
            dp[i] = t;
            t = mul_mod(dp[i],power(m, n - i, mod),mod);
            t = t % mod;
            ans = add_mod(ans,t,mod);
        }
        cout << ans << "\n";
    }
}

I have tried to model my solution based on Editorialist’s code (Python) . Could anyone tell me my mistake ? Why is not able to get AC ?

Why >= K condition though, when = case never comes.
I re-read the statement several times due to that, to make sure if i read anything incorrect.

1 Like

Thanks a lot! That answers my question.

I thought n\choose r is undefined for r < 0 and that’s why I thought considering it 0 might not make sense and hence I got confused.