Contest : Division 1
Contest : Division 2
Setter : Ayush Ranjan
Tester : Istvan Nagy
Editorialist : Anand Jaisingh
Medium Hard
Exponential Generating Functions, NTT , Basic Dynamic Programming
Given a set of N ingredients with their respective cost and tastiness, for each dish that you can make using a subset of ingredients , let its total tastiness be V. Then, over all dishes you can make with a total cost \le S , you need to find the sum of their V^K
We, most fundamentally use the fact that for a given sequence of numbers a_1,a_2,...a_z, the number (a_1+a_2+....+a_z)^k equals the coefficient of x^k mutlipled by k! in the power series expansion of the function e^{a_1 \cdot x + a_2 \cdot x +....+a_z \cdot x} , that also equals the function e^{a_1 \cdot x} \cdot e^{a_2 \cdot x} \cdot ... \cdot e^{a_z \cdot x} . We can combine this with a dp[prefix][sum] type knapsack DP, where each entry of this table is a power series, and not a number.
This problem is not too hard. First, letās consider a simpler version of the given problem :
Letās consider, K can only be equal to 1. In that case, this problem converts to a modified version of the Knapsack Problem. Instead of finding the maximum number of items, we just find for each i, the sum of V_i of all subsets having cost equal to i.
So, the dynamic programming becomes :
dp[0][0]=0 ,cnt[0][0]=1
dp[i][j]=cnt[i-1][j-C_i] \cdot V_i+dp[i-1][j] , \hspace{0.2cm} i \ge 1
cnt[i][j] = cnt[i-1][j] + cnt[i-1][j-C_i]
Here, dp[i][j] indicated the sum V_i of all subsets having sum of costs equal to j, and cnt[i][j] indicates the number of subsets of ingredients having cost equal to j.
Now, letās go through some formulae before proceeding further :
(a_1+a_2+...+a_z)^{k} = \sum_{x_1+x_2+...+x_z=k} \binom{k}{ x_1,x_2...x_z } \cdot a_1^{x_1} \cdot a_2^{x_2} \cdot .... \cdot a_z^{x_z}
This is the multinomial theorem. This can be rewritten as :
(a_1+a_2+...+a_z)^{k} = \sum_{x_1+x_2+...+x_z=k} \frac{ k ! }{x_1 ! \cdot x_2 ! \cdot ... \cdot x_z !} \cdot a_1^{x_1} \cdot a_2^{x_2} \cdot .... \cdot a_z^{x_z}
(a_1+a_2+...+a_z)^{k} =k! \cdot ( \sum_{x_1+x_2+...+x_z=k} \frac{a_1^{x_1}}{x_1!} \cdot \frac{a_2^{x_2}}{x_2!} \cdot ...\cdot \frac{a_z^{x_z}}{x_z!})
Another one :
In the ring of formal power series :
e^{ax} = \sum_{n \ge 0} \frac{a^n \cdot x^n}{n!}
So, we can easily see that if we multiply e^{a_1 \cdot x} \cdot e^{a_2 \cdot x} \cdot .... \cdot e^{a_z \cdot x } , then the coefficient of x^k is \frac{(a_1+a_2+...+a_z)^k}{k!} since it equals the coefficient of x^k in the expansion of e^{(a_1+a_2+....+a_z) \cdot x} .
This also obviously equals the multinomial expansion we saw above.
So, at each step of the above dynamic programming, if instead of maintain the sum of V_i of the subsets, if we can maintain the sum of the first k+1 coefficients of a power series of the form of e^{a_1+a_2+...+a_z} , then weāve got exactly what we wanted !
Now, letās assume dp[i][j] is a power series and not a number. Then ,
dp[i][j] = dp[i-1][j] + dp[i-1][j-C_i] \cdot e^{V_i \cdot x}
See so easily, how we will get as the coefficient of x^k, the summation of the k^{th} powers but with an extra dividing factor of k! . Itās not difficult really.
For further simplicity , I simulate for a modified version of the sample test for you :
3 3 2
1 2
2 3
1 4
Now, initially dp[0][0]=1
We process the first dish ,
dp[1][1]=e^{ 2\cdot x}
We process the second dish :
dp[2][1]=e^{2 \cdot x}
dp[2][2]=e^{3 \cdot x}
dp[2][3] =e^{2 \cdot x} \cdot e^{3 \cdot x} = e^{5 \cdot x}
We process the 3^{rd} dish :
dp[3][1] = e^{2 \cdot x} + e^{4 \cdot x}
dp[3][2] = e^{2 \cdot x} \cdot e^{4 \cdot x} + e^{3 \cdot x}= e^{6 \cdot x} +e^{3 \cdot x}
dp[3][3] =e^{2 \cdot x} \cdot e^{3 \cdot x} + e^{3 \cdot x} \cdot e^{4 \cdot x} = e^{5 \cdot x} + e^{7 \cdot x}
dp[3][4]= e^{2 \cdot x} \cdot e^{3 \cdot 4} \cdot e^{4 \cdot x} = e^{9 \cdot x}
Note that for our purposes, we only need to maintain the first K+1 coefficients and not the entire polynomial
To multiply these polynomials, we can use NTT.
In case your interested, I set a similar problem back in February, here.
Thatās it ! Thank you !
Your comments are welcome !
Time Complexity : O( N \cdot S \cdot K \cdot \log K )
Space Complexity: O(N \cdot S \cdot K )
using namespace std;
#define ll long long
const int mod=998244353,N=101,M=2001;
inline int mul(int a,int b){return (a*1ll*b)%mod;}
inline int add(int a,int b){a+=b;if(a>=mod)a-=mod;return a;}
inline int sub(int a,int b){a-=b;if(a<0)a+=mod;return a;}
inline int power(int a,int b){int rt=1;while(b>0){if(b&1)rt=mul(rt,a);a=mul(a,a);b>>=1;}return rt;}
inline int inv(int a){return power(a,mod-2);}
inline void modadd(int &a,int &b){a+=b;if(a>=mod)a-=mod;}
int base = 1;
vector<int> roots = {0, 1};
vector<int> rev = {0, 1};
const int max_base=14; //x such that 2^x|(mod-1) and 2^x>max answer size(=2*n)
const int root=666702199; //primitive root^((mod-1)/(2^max_base))
void ensure_base(int nbase) {
if (nbase <= base) {
assert(nbase <= max_base);
rev.resize(1 << nbase);
for (int i = 0; i < (1 << nbase); i++) {
rev[i] = (rev[i >> 1] >> 1) + ((i & 1) << (nbase - 1));
roots.resize(1 << nbase);
while (base < nbase) {
int z = power(root, 1 << (max_base - 1 - base));
for (int i = 1 << (base - 1); i < (1 << base); i++) {
roots[i << 1] = roots[i];
roots[(i << 1) + 1] = mul(roots[i], z);
void fft(vector<int> &a) {
int n = (int) a.size();
assert((n & (n - 1)) == 0);
int zeros = __builtin_ctz(n);
int shift = base - zeros;
for (int i = 0; i < n; i++) {
if (i < (rev[i] >> shift)) {
swap(a[i], a[rev[i] >> shift]);
for (int k = 1; k < n; k <<= 1) {
for (int i = 0; i < n; i += 2 * k) {
for (int j = 0; j < k; j++) {
int x = a[i + j];
int y = mul(a[i + j + k], roots[j + k]);
a[i + j] = x + y - mod;
if (a[i + j] < 0) a[i + j] += mod;
a[i + j + k] = x - y + mod;
if (a[i + j + k] >= mod) a[i + j + k] -= mod;
vector<int> multiply(vector<int> a, vector<int> b, int eq = 0) {
int need = (int) (a.size() + b.size() - 1);
int nbase = 0;
while ((1 << nbase) < need) nbase++;
int sz = 1 << nbase;
if (eq) b = a; else fft(b);
int inv_sz = inv(sz);
for (int i = 0; i < sz; i++) {
a[i] = mul(mul(a[i], b[i]), inv_sz);
reverse(a.begin() + 1, a.end());
return a;
vector<int> square(vector<int> a) {
return multiply(a, a, 1);
vector<int> cost(N),val(N),fac(M),invfac(M);
vector<vector<int>> dp(N,vector<int>(M)),res;
int main(){
for(int i=1;i<M;i++)
int n,s,k;
for(int i=0;i<n;i++)
for(int i=0;i<n;i++){
vector<int> P2(k+1);
int cpow=1;
for(int l=0;l<=k;l++){
for(int j=0;j<=s;j++){
vector<int> P1(k+1);
for(int l=0;l<=k;l++)
vector<int> P=multiply(P1,P2);
for(int l=0;l<=k;l++)
int ans=0;
for(int i=0;i<=s;i++)
#include <bits/stdc++.h>
#define all(x) (x).begin(), (x).end()
#define rall(x) (x).rbegin(), (x).rend()
#define forn(i, n) for (int i = 0; i < (int)(n); ++i)
#define for1(i, n) for (int i = 1; i <= (int)(n); ++i)
#define ford(i, n) for (int i = (int)(n) - 1; i >= 0; --i)
#define fore(i, a, b) for (int i = (int)(a); i <= (int)(b); ++i)
template<class T> bool umin(T &a, T b) { return a > b ? (a = b, true) : false; }
template<class T> bool umax(T &a, T b) { return a < b ? (a = b, true) : false; }
using namespace std;
long long readInt(long long l, long long r, char endd) {
long long x = 0;
int cnt = 0;
int fi = -1;
bool is_neg = false;
while (true) {
char g = getchar();
if (g == '-') {
assert(fi == -1);
is_neg = true;
if ('0' <= g && g <= '9') {
x *= 10;
x += g - '0';
if (cnt == 0) {
fi = g - '0';
assert(fi != 0 || cnt == 1);
assert(fi != 0 || is_neg == false);
assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
else if (g == endd) {
assert(cnt > 0);
if (is_neg) {
x = -x;
assert(l <= x && x <= r);
return x;
else {
string readString(int l, int r, char endd) {
string ret = "";
int cnt = 0;
while (true) {
char g = getchar();
assert(g != -1);
if (g == endd) {
ret += g;
assert(l <= cnt && cnt <= r);
return ret;
long long readIntSp(long long l, long long r) {
return readInt(l, r, ' ');
long long readIntLn(long long l, long long r) {
return readInt(l, r, '\n');
string readStringLn(int l, int r) {
return readString(l, r, '\n');
string readStringSp(int l, int r) {
return readString(l, r, ' ');
const uint32_t mod = 998244353;// (2 ^ k)*c + 1 = 2^23 * 7 * 17
uint32_t powMod(uint32_t a, uint32_t pw)
uint32_t res = 1;
while (pw)
if (pw & 1)
res = (res * static_cast<uint64_t>(a)) % mod;
pw >>= 1;
a = (static_cast<uint64_t>(a) * a) % mod;
return res;
uint32_t inverse(uint32_t a)
return powMod(a, mod - 2);
bool isPrimitiveRoot(uint32_t a)
return powMod(a, (mod - 1) / 2) != 1;
struct NTT
const uint32_t k = 23;
const uint32_t c = 7 * 17;
uint32_t primitiveRoot;
uint32_t prc;// (primitiveRoot^c)%mod
uint64_t inv2;
//find primitive root
primitiveRoot = 2;
while (!isPrimitiveRoot(primitiveRoot))
//set prc
prc = powMod(primitiveRoot, c);
inv2 = inverse(2);
vector<uint32_t> transform(const vector<uint32_t>& a, bool inv)
size_t len = a.size();
if (len == 1) return a;
vector<uint32_t> f(len / 2), g(len / 2);
for (uint32_t i = 0; i < len; i += 2)
f[i / 2] = a[i];
g[i / 2] = a[i + 1];
vector<uint32_t> F = transform(f, inv), G = transform(g, inv);
vector<uint32_t> ret(len);
uint32_t pw = static_cast<uint32_t>((1ULL << k) / len);
uint32_t w = powMod(prc, pw), wk = 1;
if (inv) w = inverse(w);
for (size_t i = 0; i < len / 2; ++i)
uint32_t u = F[i], v = (G[i] * static_cast<uint64_t>(wk)) % mod;
ret[i] = (u + v) % mod;
ret[i + len / 2] = (mod + u - v) % mod;
if (inv)
ret[i] = (ret[i] * inv2) % mod;
ret[i + len / 2] = (ret[i + len / 2] * inv2) % mod;
wk = (static_cast<uint64_t>(wk) * w) % mod;
return ret;
//without recursion
void transform2(vector<uint32_t>& a, bool inv)
size_t n = a.size();
for (size_t i = 1, j = 0; i < n; ++i)
size_t bit = n >> 1;
while (j >= bit)
j -= bit;
bit >>= 1;
j += bit;
if (i < j)
swap(a[i], a[j]);
for (size_t len = 2; len <= n; len <<= 1)
uint32_t pw = (1ULL << k) / len;
uint32_t wlen = powMod(prc, pw);
if (inv) wlen = inverse(wlen);
for (size_t i = 0; i < n; i += len)
uint32_t w = 1;
for (size_t j = 0; j < len / 2; ++j)
uint32_t u = a[i + j], v = (a[i + j + len / 2] * static_cast<uint64_t>(w)) % mod;
a[i + j] = u + v < mod ? u + v : u + v - mod;
a[i + j + len / 2] = u >= v ? u - v : u - v + mod;
w = (static_cast<uint64_t>(w) * wlen) % mod;
if (inv)
uint32_t nrev = inverse(n);
for (int i = 0; i < n; ++i)
a[i] = (a[i] * static_cast<uint64_t>(nrev)) % mod;
int main(int argc, char** argv)
#ifdef HOME
//if (IsDebuggerPresent())
freopen("../build/in.txt", "rb", stdin);
freopen("../build/out.txt", "wb", stdout);
const int MAXB = 2002;
int N, S, K, KK = 1;
scanf("%d %d %d", &N, &S, &K);
while (KK <= 2 * K)
KK <<= 1;
//KK <<= 1;
vector<uint32_t> invf(KK, 1), fact(KK, 1);
for (uint32_t i = 1; i < KK; ++i)
fact[i] = (static_cast<uint64_t>(i) * fact[i - 1]) % mod;
invf[i] = inverse(fact[i]);
vector<int> C(N);
vector<vector<uint32_t> > V(N, vector<uint32_t>(KK));
for (int i = 0; i < N; ++i)
V[i][0] = 1;
scanf("%d %d", &C[i], &V[i][1]);
for (int j = 2; j < K + 1; ++j)
V[i][j] = (static_cast<uint64_t>(V[i][j - 1]) * V[i][1]) % mod;
vector<vector<uint32_t> > vR(S + 1, vector<uint32_t>(KK));
NTT ntt;
vR[0][0] = 1;
for (int i = 0; i < N; ++i)
vector<uint32_t> vI = V[i];
for (uint32_t o = 0; o <= K; ++o)
vI[o] = (static_cast<uint64_t>(vI[o]) * invf[o]) % mod;
ntt.transform2(vI, false);
for (int j = S - C[i]; j >= 0; --j)
int actj = j + C[i];
//convolution vR[j] , V[i]
vector<uint32_t> vRJ = vR[j];
for (uint32_t o = 0; o <= K; ++o)
vRJ[o] = (static_cast<uint64_t>(vRJ[o]) * invf[o]) % mod;
ntt.transform2(vRJ, false);
vector<uint32_t> fm(KK);
for (uint32_t o = 0; o < KK; ++o)
fm[o] = (static_cast<uint64_t>(vRJ[o]) * vI[o]) % mod;
ntt.transform2(fm, true);
for (uint32_t o = 0; o <= K; ++o)
vR[actj][o] = (vR[actj][o] + fm[o] * static_cast<uint64_t>(fact[o])) % mod;
int64_t res = 0;
for (int i = 0; i <= S; ++i)
res += vR[i][K];
printf("%lld\n", res%mod);
return 0;