ARRDEL7 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh_07
Tester: mridulahi
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Combinatorics

PROBLEM:

Given an array A of length N and a fixed parameter K, you can do the following:

  • Choose any integer X, and delete every element of A that lies between X and X+K-1.

Let f(A) denote the minimum number of such moves needed to delete every element of A.
Given N, M, K, compute the value of f(A) across all strictly increasing arrays of length N with elements from 1 to K.

EXPLANATION:

First, let’s compute f(A) for a single array.
This is quite simple: it can be seen that the optimal strategy is to just keep repeatedly choosing X to be the minimum element of A till everything’s gone.

In particular, we’ll always need \leq \left\lceil \frac{M}{K} \right\rceil moves to delete any array A with elements between 1 and M.

Let’s fix f(A) = c, and try to find the number of arrays that require exactly c moves to delete.
As observed earlier, this means we pick the minimum remaining element of A exactly c times.

Suppose these minimums are m_1 \lt m_2 \lt \ldots\lt m_c.
Then, the following conditions must hold:

  • 1 \leq m_i \leq M for each 1 \leq i \leq c.
  • m_i+K \leq m_{i+1} for each 1 \leq i \lt c.
    This is to ensure that when picking m_i, m_{i+1} doesn’t get deleted.

It’s not too hard to count the number of ways of choosing these m_i.
If we let d_i = m_i - m_{i-1} (with m_0 = 0 and m_{c+1} = M), choosing the m_i is equivalent to assigning values to the d_i such that:

  • The sum of all d_i is exactly M (we start at 0 and continually add differences till we reach M)
  • d_1 \gt 0
  • d_i \geq K for each 2 \leq i \leq c
  • d_{c+1} \geq 0

Counting the number of solutions to such a system is a classical stars-and-bars task: letting d_1' = d_1-1 and d_i' = d_i - K for 2 \leq i \leq c, we just want the number of solutions to c+1 non-negative integers summing up to M - 1 - (c-1)K.

What about the other elements, i.e, the ones that aren’t chosen as the minimums?
There are N-c of them, but notice that they must be some specific elements.
In particular, these N-c elements should be of the form m_i + x, where 1 \leq x \lt K - that’s how we ensure that they are indeed deleted when the m_i are picked.

Suppose there are a total of L elements of this form.
We can choose any N-c of them, for a total of \binom{L}{N-c} ways. All we need to do, is actually find L.
That isn’t too hard:

  • For each of m_1, m_2, \ldots, m_{c-1} there are (K-1) such choices.
  • For m_c, there are K-1 choices if m_c + K-1 \leq M, and M-m_c choices otherwise.
    • So, if m_c \leq M+1-K, we have L =c\cdot (K-1)
    • Otherwise, we have L = (c-1)\cdot (K-1) + M-m_c

Note that the stars-and-bars part also needs to be modified slightly to account for these cases.

  • If m_c \leq M+1-K, we’re essentially asking for d_{c+1} \geq K-1, rather than d_{c+1} \geq 0
  • Otherwise, we force d_{c+1} = M - m_c.
    However, observe that this happens only for K values of m_c, and so can be bruteforced: since the number of moves is \leq \left\lceil \frac{M}{K} \right\rceil, the total complexity remains \mathcal{O}(M).

TIME COMPLEXITY:

\mathcal{O}(M) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
#define f first
#define s second

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

const int facN = 2e6 + 5;
const int mod = 1e9 + 7; // 998244353
int ff[facN], iff[facN];
bool facinit = false;

int power(int x, int y){
    if (y == 0) return 1;

    int v = power(x, y / 2);
    v = 1LL * v * v % mod;

    if (y & 1) return 1LL * v * x % mod;
    else return v;
}

void factorialinit(){
    facinit = true;
    ff[0] = iff[0] = 1;

    for (int i = 1; i < facN; i++){
        ff[i] = 1LL * ff[i - 1] * i % mod;
    }

    iff[facN - 1] = power(ff[facN - 1], mod - 2);
    for (int i = facN - 2; i >= 1; i--){
        iff[i] = 1LL * iff[i + 1] * (i + 1) % mod;
    }
}

int C(int n, int r){
    if (!facinit) factorialinit();

    if (n == r) return 1;

    if (r < 0 || r > n) return 0;
    return 1LL * ff[n] * iff[r] % mod * iff[n - r] % mod;
}

int P(int n, int r){
    if (!facinit) factorialinit();

    assert(0 <= r && r <= n);
    return 1LL * ff[n] * iff[n - r] % mod;
}

int Solutions(int n, int r){
    //solutions to x1 + ... + xn = r 
    //xi >= 0

    return C(n + r - 1, n - 1);
}

void Solve() 
{
    int n, m, k; cin >> n >> m >> k;
    int ans = 0;

    for (int ops = 1; ops <= (m + k - 1) / k; ops++){
        // two cases : 
        int ways = 0;
        int left = n - ops;
        
        // Case 1 : last element + k - 1 > M 
        // last > M + 1 - k 
        for (int last = m + 2 - k; last <= m; last++){
            int ok = Solutions(ops, last - 1 - (ops - 1) * k);
            if (ops == 1) ok = 1;
            int av = (k - 1) * (ops - 1) + m - last;
            int ok1 = C(av, left);
            
            ways += ok1 * ok % mod;
        }
        
        // Case 2 : last element + k - 1 <= M 
        // last <= M + 1 - k 
        // d1 + d2 + ..... + d(ops + 1) = M + 2 - k 
        // d1 >= 1, d2 >= k, ...., d(ops + 1) >= 1 
        int ok = Solutions(ops + 1, m + 2 - k - 2 - (ops - 1) * k);
        int av = (k - 1) * ops;
        int ok1 = C(av, left);
        
        ways += ok1 * ok % mod;
        
        ways %= mod;
        assert(ways >= 0);
        
      //  cout << ways << "\n";
        ans += ways * ops;
        ans %= mod;
    }

    ans %= mod;

    cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++)
#include<bits/stdc++.h>

using namespace std;

#define all(x) begin(x), end(x)
#define sz(x) static_cast<int>((x).size())
#define int long long

const int mod = 1e9 + 7;

int norm (int x) {
        if (x < 0) {
                x += mod;
        }
        if (x >= mod) {
                x -= mod;
        }
        return x;
}
template<class T>
T power(T a, int b) {
        T res = 1;
        for (; b; b /= 2, a *= a) {
                if (b % 2) {
                res *= a;
                }
        }
        return res;
}
struct Z {
        int x;
        Z(int x = 0) : x(norm(x)) {}
        int val() const {
                return x;
        }
        Z operator-() const {
                return Z(norm(mod - x));
        }
        Z inv() const {
                assert(x != 0);
                return power(*this, mod - 2);
        }
        Z &operator*=(const Z &rhs) {
                x = x * rhs.x % mod;
                return *this;
        }
        Z &operator+=(const Z &rhs) {
                x = norm(x + rhs.x);
                return *this;
        }
        Z &operator-=(const Z &rhs) {
                x = norm(x - rhs.x);
                return *this;
        }
        Z &operator/=(const Z &rhs) {
                return *this *= rhs.inv();
        }
        friend Z operator*(const Z &lhs, const Z &rhs) {
                Z res = lhs;
                res *= rhs;
                return res;
        }
        friend Z operator+(const Z &lhs, const Z &rhs) {
                Z res = lhs;
                res += rhs;
                return res;
        }
        friend Z operator-(const Z &lhs, const Z &rhs) {
                Z res = lhs;
                res -= rhs;
                return res;
        }
        friend Z operator/(const Z &lhs, const Z &rhs) {
                Z res = lhs;
                res /= rhs;
                return res;
        }
        friend std::istream &operator>>(std::istream &is, Z &a) {
                int v;
                is >> v;
                a = Z(v);
                return is;
        }
        friend std::ostream &operator<<(std::ostream &os, const Z &a) {
                return os << a.val();
        }
};

const int maxn = 3e6;
Z fact[maxn];
Z ifact[maxn];
Z C (int n, int r) {
        if (n < r) return Z(0);
        return fact[n] * ifact[r] * ifact[n - r];
}




signed main() {

        ios::sync_with_stdio(0);
        cin.tie(0);
        cout.tie(0);

        fact[0] = 1;
        for (int i = 1; i < maxn; i++) fact[i] = fact[i - 1] * i;
        ifact[maxn - 1] = Z(1) / fact[maxn - 1]; 
        for (int i = maxn - 2; i >= 0; i--) ifact[i] = ifact[i + 1] * (i + 1); 

        int t;
        cin >> t;

        while (t--) {

                int n, m, k;
                cin >> n >> m >> k;

                Z ans = 0;
                for (int i = 1; i < k; i++) {
                        for (int j = 0; (1 + j) <= n && j * k <= m - i; j++) {
                                ans += C(m - i - j * k + j, j) * C(j * k + i - j - 1, n - j - 1) * (j + 1);
                        }
                }
                
                for (int j = 1; j <= n && j * k <= m; j++) {
                        ans += C(m - j * k + j, j) * C(j * k - j, n - j) * j;
                }

                cout << ans << "\n";


        }
        
}