PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Authors: lawliet_p and satyam_343
Tester: tabr
Editorialist: iceknight1093
DIFFICULTY:
3096
PREREQUISITES:
Combinatorics — specifically stars and bars, the inclusion-exclusion principle
PROBLEM:
For a fixed parameter D and a multiset S, the cost of a partition of S into several non-empty multisets S_1, S_2, \ldots, S_k equals
Define F(S, D) to be the minimum cost of a partition of S with parameter D.
You’re given N, M, D, and K.
Find the number of distinct multisets A of length N such that:
- 1 \leq A_i \leq M for each i; and
- F(A, D) = K
EXPLANATION:
First, let’s see how to compute F(A, D) for a fixed multiset A.
Computing F(A, D)
Let A_1 \leq A_2 \leq \ldots \leq A_N be the multiset.
We’ll call a partition optimal if it attains F(A, D).
Claim: There exists an optimal partition such that every subset will consist of a contiguous segment of the A_i.
Proof: Consider some partition of A into subsets. Suppose there are indices i \lt j \lt k such that i and k belong to the same subset (say S_1), but j doesn’t (say it’s in S_2).
Then, moving j from S_2 to S_1:
- Doesn’t change the contribution of S_1 to the cost at all, since j is in the middle.
- Doesn’t increase the contribution of S_2 to the cost; since either it was in the middle and didn’t matter, or was an endpoint and its removal brought the endpoints closer, hence lowering cost.
So, moving j to S_1 is not worse.
By repeatedly performing this process on the j with minimal index, we see that in at most N moves we reach a state where each subset is a segment, thus proving the claim.
Now, consider some partition of A into segments; say there are k segments [L_i, R_i].
The cost of this partition is
This means we’re essentially summing up differences between all pairs of adjacent elements that lie in the same segment; and adding D for each adjacent pair that isn’t in the same segment (along with one extra D term).
Thinking of this differently, for each adjacent pair of elements (A_i, A_{i+1}), we can:
- Take them into the same segment, for a cost of A_{i+1} - A_i; or
- Keep them in different segments, for a cost of D.
Together, this tells us that the minimum possible cost is simply
To apply the above condition, we also need a nice enough model of a multiset.
Consider a multiset A of length N with elements between 1 and M, such that A_i \leq A_{i+1}.
Note that this multiset is determined uniquely by:
- The value of A_1;
- The sequence of adjacent differences (A_{i+1} - A_i); and
- The value A_N
This is a useful characterization because as we saw earlier, we’re interested in adjacent differences.
Further, if we define A_0 := 0 and A_{N+1} := M, then A_1 and A_N are also defined by adjacent differences (to 0 and M respectively).
That is, if we let B_i = A_i - A_{i-1}, then A is determined uniquely by the N+1 values [B_1, B_2, \ldots, B_{N+1}].
Note that there are a couple of constraints on the B_i values:
- B_1 \geq 1, because we want A_1 \geq 1.
- B_1 + B_2 + \ldots + B_{N+1} = M, since we start at 0 and end at M.
Considering (B_1 - 1) instead of B_1, we replace the above two constraints by B_1 \geq 0 and sum(B_i) = M-1 instead.
Let’s now see how we can count valid multisets.
B_1 and B_{N+1} don’t contribute to the cost at all.
For each 1 \lt i \leq N, we add \min(D, B_i) to the cost.
There’s always an extra D added to the cost; so let’s just work with K' = K-D as the target cost instead.
We’ll treat elements that are \lt D and \geq D differently.
Suppose x of the B_i are \geq D.
We fix their positions in C(N-1, x) ways.
Note that:
- Each of these x indices contribute D to the cost, and must contain values that are \geq D.
- Each of the other N-1-x indices contribute B_i to the cost, and must contain values that are \lt D.
- Further, the sum of B_i in the second case must equal exactly K' - x\cdot D for the total cost to equal K'.
- B_1 and B_{N+1} are mostly unconstrained; they just need to be \geq 0, and as noted above the sum of all B_i should be M-1.
Without loss of generality, suppose B_2, B_3, \ldots, B_{x+1} are to be \geq D.
Then, for each of them, we can write B_i = D + C_i, where C_i \geq 0.
The C_i are otherwise unconstrained.
Putting everything together, we have:
Here, each of the C_i, and both B_1 and B_{N+1} only need to be non-negative integers.
The number of solutions to this equation can thus be found by stars and bars.
That only leaves the values [B_{x+2}, B_{x+3}, \ldots, B_{N}].
Each of these must be \lt D, and their overall sum must equal K' - x\cdot D.
Counting the number of solutions to this can be done by combining the stars-and-bars method with inclusion-exclusion.
How?
If there were no upper bound D, this is a direct application of stars-and-bars.
Now, let’s define P_i: the number of arrangements such that exactly i of the indices violate the upper bound, i.e, are \geq D.
We want to compute P_0.
As is usual for inclusion-exclusion tasks, when exactness is hard to deal with, we relax the constraints a bit.
Define Q_i to be the number of arrangements such that at least i of the indices violate the upper bound.
Computing Q_i is not hard: fix which of the i positions are being violated, and then everything is free with the target sum being K' - x\cdot D - i\cdot D (using a similar trick as we did earlier when converting B_i to C_i).
Inclusion-exclusion then tells us that
which can be computed in \mathcal{O}(N) time once all the Q_i are known.
The problem is thus solved in \mathcal{O}(N^2) time.
TIME COMPLEXITY
\mathcal{O}(N^2) per testcase.
CODE:
Author's code (C++)
#pragma GCC optimize("O3,unroll-loops")
#include <bits/stdc++.h>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
#define ll long long
#define pb push_back
#define mp make_pair
#define nline "\n"
#define f first
#define s second
#define pll pair<ll,ll>
#define all(x) x.begin(),x.end()
#define vl vector<ll>
#define vvl vector<vector<ll>>
#define vvvl vector<vector<vector<ll>>>
#ifndef ONLINE_JUDGE
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);
#endif
void _print(ll x){cerr<<x;}
void _print(char x){cerr<<x;}
void _print(string x){cerr<<x;}
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";}
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=998244353;
const ll MAX=5000500;
vector<ll> fact(MAX+2,1),inv_fact(MAX+2,1);
ll binpow(ll a,ll b,ll MOD){
ll ans=1;
a%=MOD;
while(b){
if(b&1)
ans=(ans*a)%MOD;
b/=2;
a=(a*a)%MOD;
}
return ans;
}
ll inverse(ll a,ll MOD){
return binpow(a,MOD-2,MOD);
}
void precompute(ll MOD){
for(ll i=2;i<MAX;i++){
fact[i]=(fact[i-1]*i)%MOD;
}
inv_fact[MAX-1]=inverse(fact[MAX-1],MOD);
for(ll i=MAX-2;i>=0;i--){
inv_fact[i]=(inv_fact[i+1]*(i+1))%MOD;
}
}
ll nCr(ll a,ll b,ll MOD){
if(a==b){
return 1;
}
if((a<0)||(a<b)||(b<0))
return 0;
ll denom=(inv_fact[b]*inv_fact[a-b])%MOD;
return (denom*fact[a])%MOD;
}
ll getv(ll n,ll max_val,ll need){
ll ans=0,sgn=1;
for(ll i=0;i<=n;i++){
ans+=sgn*nCr(n,i,MOD)*nCr(need-i*(max_val+1)+n-1,n-1,MOD);
ans%=MOD;
sgn*=-1;
}
ans=(ans+MOD)%MOD;
return ans;
}
ll sum=0;
void solve(){
ll n,m,d,k; cin>>n>>m>>d>>k;
sum+=n;
ll ans=0;
for(ll i=0;i<=n-1;i++){
ll cur_cost=i*d;
if(cur_cost>k){
break;
}
ll rem_cost=k-cur_cost-d;
ll now=getv(n-1-i,d-1,rem_cost);
now=(now*nCr(n-1,i,MOD))%MOD;
ll left_sum=m-k+d;
now=(now*nCr(left_sum+i,i+1,MOD))%MOD;
ans=(ans+now)%MOD;
}
cout<<ans<<nline;
return;
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
freopen("error.txt", "w", stderr);
#endif
ll test_cases=1;
cin>>test_cases;
precompute(MOD);
while(test_cases--){
solve();
}
debug(sum);
cout<<fixed<<setprecision(10);
cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n";
}
Tester's code (C++)
// ignore \r
#include <bits/stdc++.h>
using namespace std;
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
} else if (c == '\r') {
continue;
}
buffer.push_back((char) c);
}
}
string readOne() {
assert(pos < (int) buffer.size());
string res;
while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
res += buffer[pos];
pos++;
}
return res;
}
string readString(int min_len, int max_len, const string& pattern = "") {
assert(min_len <= max_len);
string res = readOne();
assert(min_len <= (int) res.size());
assert((int) res.size() <= max_len);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int min_val, int max_val) {
assert(min_val <= max_val);
int res = stoi(readOne());
assert(min_val <= res);
assert(res <= max_val);
return res;
}
long long readLong(long long min_val, long long max_val) {
assert(min_val <= max_val);
long long res = stoll(readOne());
assert(min_val <= res);
assert(res <= max_val);
return res;
}
vector<int> readInts(int size, int min_val, int max_val) {
assert(min_val <= max_val);
vector<int> res(size);
for (int i = 0; i < size; i++) {
res[i] = readInt(min_val, max_val);
if (i != size - 1) {
readSpace();
}
}
return res;
}
vector<long long> readLongs(int size, long long min_val, long long max_val) {
assert(min_val <= max_val);
vector<long long> res(size);
for (int i = 0; i < size; i++) {
res[i] = readLong(min_val, max_val);
if (i != size - 1) {
readSpace();
}
}
return res;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
};
template <long long mod>
struct modular {
long long value;
modular(long long x = 0) {
value = x % mod;
if (value < 0) value += mod;
}
modular& operator+=(const modular& other) {
if ((value += other.value) >= mod) value -= mod;
return *this;
}
modular& operator-=(const modular& other) {
if ((value -= other.value) < 0) value += mod;
return *this;
}
modular& operator*=(const modular& other) {
value = value * other.value % mod;
return *this;
}
modular& operator/=(const modular& other) {
long long a = 0, b = 1, c = other.value, m = mod;
while (c != 0) {
long long t = m / c;
m -= t * c;
swap(c, m);
a -= t * b;
swap(a, b);
}
a %= mod;
if (a < 0) a += mod;
value = value * a % mod;
return *this;
}
friend modular operator+(const modular& lhs, const modular& rhs) { return modular(lhs) += rhs; }
friend modular operator-(const modular& lhs, const modular& rhs) { return modular(lhs) -= rhs; }
friend modular operator*(const modular& lhs, const modular& rhs) { return modular(lhs) *= rhs; }
friend modular operator/(const modular& lhs, const modular& rhs) { return modular(lhs) /= rhs; }
modular& operator++() { return *this += 1; }
modular& operator--() { return *this -= 1; }
modular operator++(int) {
modular res(*this);
*this += 1;
return res;
}
modular operator--(int) {
modular res(*this);
*this -= 1;
return res;
}
modular operator-() const { return modular(-value); }
bool operator==(const modular& rhs) const { return value == rhs.value; }
bool operator!=(const modular& rhs) const { return value != rhs.value; }
bool operator<(const modular& rhs) const { return value < rhs.value; }
};
template <long long mod>
string to_string(const modular<mod>& x) {
return to_string(x.value);
}
template <long long mod>
ostream& operator<<(ostream& stream, const modular<mod>& x) {
return stream << x.value;
}
template <long long mod>
istream& operator>>(istream& stream, modular<mod>& x) {
stream >> x.value;
x.value %= mod;
if (x.value < 0) x.value += mod;
return stream;
}
constexpr long long mod = 998244353;
using mint = modular<mod>;
mint power(mint a, long long n) {
mint res = 1;
while (n > 0) {
if (n & 1) {
res *= a;
}
a *= a;
n >>= 1;
}
return res;
}
vector<mint> fact(1, 1);
vector<mint> finv(1, 1);
mint C(int n, int k) {
if (n < k || k < 0) {
return mint(0);
}
while ((int) fact.size() < n + 1) {
fact.emplace_back(fact.back() * (int) fact.size());
finv.emplace_back(mint(1) / fact.back());
}
return fact[n] * finv[k] * finv[n - k];
}
mint f(int n, int k) {
if (n == 0 && k == 0) {
return 1;
} else {
return C(n + k - 1, k);
}
}
int main() {
input_checker in;
int tt = in.readInt(1, 1000);
in.readEoln();
int sn = 0, sm = 0;
while (tt--) {
int n = in.readInt(1, 1000);
in.readSpace();
int m = in.readInt(1, 2e6);
in.readSpace();
int d = in.readInt(1, 2e6);
in.readSpace();
int k = in.readInt(1, 2e6);
in.readEoln();
sn += n;
sm += m;
mint ans = 0;
k -= d;
for (int cnt = 0; cnt <= n - 1; cnt++) {
if (k - d * cnt < 0 || m - (k - d * cnt) < 0) {
continue;
}
int x = k - d * cnt;
int y = m - k - 1;
mint t = 0;
// put x balls into (n - cnt - 1) boxes, less than d balls for each box
for (int i = 0; i <= n - cnt - 1; i++) {
if (d * i > x) {
break;
}
int z = x - d * i;
mint s = (i % 2 == 1 ? -1 : +1) * C(n - cnt - 1, i) * f(n - cnt - 1, z);
t += s;
// cerr << i << " " << s << endl;
}
// put y balls into (cnt + 2) boxes
t *= C(y + cnt + 1, cnt + 1);
// choose cnt boxes from n - 1
t *= C(n - 1, cnt);
ans += t;
// cerr << cnt << ": " << t << endl;
}
cout << ans << '\n';
}
assert(sn <= 5000);
assert(sm <= 5e6);
in.readEof();
return 0;
}
Editorialist's code (Python)
mod = 998244353
lim = 2 * 10**6 + 20
fac = [1] + [i for i in range(1, lim)]
for i in range(1, lim):
fac[i] = fac[i-1] * i % mod
invf = [0]*lim
invf[-1] = pow(fac[-1], mod-2, mod)
for i in reversed(range(lim-1)):
invf[i] = invf[i+1] * (i+1) % mod
def C(n, r):
if n < r or r < 0: return 0
return fac[n] * invf[r] % mod * invf[n-r] % mod
def calc(k, n): # sum k non-negative integers to get n
if k == 0: return 1 if n == 0 else 0
return C(n+k-1, n)
for _ in range(int(input())):
n, m, d, k = map(int, input().split())
ans = 0
k -= d
for i in range(n):
# Fix the number of positions that are >= d
if i*d > k: break
ways = C(n-1, i) * calc(i+2, m-1-k) % mod
sm, sign = 0, 1
for j in range(n-i):
if k - i*d - j*d < 0: break
sm += C(n-1-i, j) * calc(n-1-i, k - i*d - j*d) % mod * sign % mod
sign *= -1
ans += ways * sm % mod
print(ans % mod)