Authors: lawliet_p and satyam_343
Tester: tabr
Editorialist: iceknight1093




Combinatorics — specifically stars and bars, the inclusion-exclusion principle


For a fixed parameter D and a multiset S, the cost of a partition of S into several non-empty multisets S_1, S_2, \ldots, S_k equals

\sum_{i=1}^N (D + \max(S_i) - \min(S_i))

Define F(S, D) to be the minimum cost of a partition of S with parameter D.

You’re given N, M, D, and K.
Find the number of distinct multisets A of length N such that:

  • 1 \leq A_i \leq M for each i; and
  • F(A, D) = K


First, let’s see how to compute F(A, D) for a fixed multiset A.

Computing F(A, D)

Let A_1 \leq A_2 \leq \ldots \leq A_N be the multiset.
We’ll call a partition optimal if it attains F(A, D).

Claim: There exists an optimal partition such that every subset will consist of a contiguous segment of the A_i.
Proof: Consider some partition of A into subsets. Suppose there are indices i \lt j \lt k such that i and k belong to the same subset (say S_1), but j doesn’t (say it’s in S_2).
Then, moving j from S_2 to S_1:

  • Doesn’t change the contribution of S_1 to the cost at all, since j is in the middle.
  • Doesn’t increase the contribution of S_2 to the cost; since either it was in the middle and didn’t matter, or was an endpoint and its removal brought the endpoints closer, hence lowering cost.

So, moving j to S_1 is not worse.
By repeatedly performing this process on the j with minimal index, we see that in at most N moves we reach a state where each subset is a segment, thus proving the claim.

Now, consider some partition of A into segments; say there are k segments [L_i, R_i].
The cost of this partition is

\sum_{i=1}^k (D + A_{R_i} - A_{L_i}) = k\cdot D + \sum_{i=1}^k (A_{R_i} - A_{L_i}) \\ = k\cdot D + \sum_{i=1}^k\left(\sum_{j=L_i}^{R_i - 1} (A_{j+1} - A_j) \right)

This means we’re essentially summing up differences between all pairs of adjacent elements that lie in the same segment; and adding D for each adjacent pair that isn’t in the same segment (along with one extra D term).
Thinking of this differently, for each adjacent pair of elements (A_i, A_{i+1}), we can:

  • Take them into the same segment, for a cost of A_{i+1} - A_i; or
  • Keep them in different segments, for a cost of D.

Together, this tells us that the minimum possible cost is simply

F(A, D) = D + \sum_{i=2}^N \min(A_i - A_{i-1}, D)

To apply the above condition, we also need a nice enough model of a multiset.

Consider a multiset A of length N with elements between 1 and M, such that A_i \leq A_{i+1}.
Note that this multiset is determined uniquely by:

  • The value of A_1;
  • The sequence of adjacent differences (A_{i+1} - A_i); and
  • The value A_N

This is a useful characterization because as we saw earlier, we’re interested in adjacent differences.
Further, if we define A_0 := 0 and A_{N+1} := M, then A_1 and A_N are also defined by adjacent differences (to 0 and M respectively).

That is, if we let B_i = A_i - A_{i-1}, then A is determined uniquely by the N+1 values [B_1, B_2, \ldots, B_{N+1}].
Note that there are a couple of constraints on the B_i values:

  • B_1 \geq 1, because we want A_1 \geq 1.
  • B_1 + B_2 + \ldots + B_{N+1} = M, since we start at 0 and end at M.
    Considering (B_1 - 1) instead of B_1, we replace the above two constraints by B_1 \geq 0 and sum(B_i) = M-1 instead.

Let’s now see how we can count valid multisets.
B_1 and B_{N+1} don’t contribute to the cost at all.
For each 1 \lt i \leq N, we add \min(D, B_i) to the cost.
There’s always an extra D added to the cost; so let’s just work with K' = K-D as the target cost instead.

We’ll treat elements that are \lt D and \geq D differently.
Suppose x of the B_i are \geq D.
We fix their positions in C(N-1, x) ways.
Note that:

  • Each of these x indices contribute D to the cost, and must contain values that are \geq D.
  • Each of the other N-1-x indices contribute B_i to the cost, and must contain values that are \lt D.
  • Further, the sum of B_i in the second case must equal exactly K' - x\cdot D for the total cost to equal K'.
  • B_1 and B_{N+1} are mostly unconstrained; they just need to be \geq 0, and as noted above the sum of all B_i should be M-1.

Without loss of generality, suppose B_2, B_3, \ldots, B_{x+1} are to be \geq D.
Then, for each of them, we can write B_i = D + C_i, where C_i \geq 0.
The C_i are otherwise unconstrained.

Putting everything together, we have:

B_1 + B_2 + \ldots + B_{N+1} = M-1 \\ B_1 + B_{N+1} + (B_2 + B_3 + \ldots + B_{x+1}) + (B_{x+2} + \ldots + B_N) = M-1 \\ B_1 + B_{N+1} +x\cdot D + (C_2 + C_3 + \ldots + C_{x+1}) + (K' - x\cdot D) = M-1 \\ B_1 + B_{N+1} + C_2 + C_3 + \ldots + C_{x+1} = M-1-K' \\

Here, each of the C_i, and both B_1 and B_{N+1} only need to be non-negative integers.
The number of solutions to this equation can thus be found by stars and bars.

That only leaves the values [B_{x+2}, B_{x+3}, \ldots, B_{N}].
Each of these must be \lt D, and their overall sum must equal K' - x\cdot D.
Counting the number of solutions to this can be done by combining the stars-and-bars method with inclusion-exclusion.


If there were no upper bound D, this is a direct application of stars-and-bars.

Now, let’s define P_i: the number of arrangements such that exactly i of the indices violate the upper bound, i.e, are \geq D.
We want to compute P_0.
As is usual for inclusion-exclusion tasks, when exactness is hard to deal with, we relax the constraints a bit.
Define Q_i to be the number of arrangements such that at least i of the indices violate the upper bound.

Computing Q_i is not hard: fix which of the i positions are being violated, and then everything is free with the target sum being K' - x\cdot D - i\cdot D (using a similar trick as we did earlier when converting B_i to C_i).

Inclusion-exclusion then tells us that

P_0 = \sum_{i=0}^{N-1-x} (-1)^i \binom{N-1-x}{i} Q_i

which can be computed in \mathcal{O}(N) time once all the Q_i are known.

The problem is thus solved in \mathcal{O}(N^2) time.


\mathcal{O}(N^2) per testcase.


Editorialist's code (Python)
mod = 998244353
lim = 2 * 10**6 + 20
fac = [1] + [i for i in range(1, lim)]
for i in range(1, lim):
    fac[i] = fac[i-1] * i % mod
invf = [0]*lim
invf[-1] = pow(fac[-1], mod-2, mod)
for i in reversed(range(lim-1)):
    invf[i] = invf[i+1] * (i+1) % mod

def C(n, r):
    if n < r or r < 0: return 0
    return fac[n] * invf[r] % mod * invf[n-r] % mod

def calc(k, n): # sum k non-negative integers to get n
    if k == 0: return 1 if n == 0 else 0
    return C(n+k-1, n)

for _ in range(int(input())):
    n, m, d, k = map(int, input().split())
    ans = 0
    k -= d

    for i in range(n):
        # Fix the number of positions that are >= d
        if i*d > k: break
        ways = C(n-1, i) * calc(i+2, m-1-k) % mod
        sm, sign = 0, 1
        for j in range(n-i):
            if k - i*d - j*d < 0: break
            sm += C(n-1-i, j) * calc(n-1-i, k - i*d - j*d) % mod * sign % mod
            sign *= -1
        ans += ways * sm % mod
    print(ans % mod)

Extremely satisfying explanation!! Thanks a lot!