PSUM - Editorial

PROBLEM LINKS :

Contest : Division 1

Contest : Division 2

Practice

Setter : Ayush Ranjan

Tester : Istvan Nagy

Editorialist : Anand Jaisingh

DIFFICULTY :

Medium Hard

PREREQUISITES :

Exponential Generating Functions, NTT , Basic Dynamic Programming

PROBLEM :

Given a set of N ingredients with their respective cost and tastiness, for each dish that you can make using a subset of ingredients , let its total tastiness be V. Then, over all dishes you can make with a total cost \le S , you need to find the sum of their V^K

QUICK EXPLANATION :

We, most fundamentally use the fact that for a given sequence of numbers a_1,a_2,...a_z, the number (a_1+a_2+....+a_z)^k equals the coefficient of x^k mutlipled by k! in the power series expansion of the function e^{a_1 \cdot x + a_2 \cdot x +....+a_z \cdot x} , that also equals the function e^{a_1 \cdot x} \cdot e^{a_2 \cdot x} \cdot ... \cdot e^{a_z \cdot x} . We can combine this with a dp[prefix][sum] type knapsack DP, where each entry of this table is a power series, and not a number.

EXPLANATION :

Hello,

This problem is not too hard. First, letā€™s consider a simpler version of the given problem :

Letā€™s consider, K can only be equal to 1. In that case, this problem converts to a modified version of the Knapsack Problem. Instead of finding the maximum number of items, we just find for each i, the sum of V_i of all subsets having cost equal to i.

So, the dynamic programming becomes :

dp[0][0]=0 ,cnt[0][0]=1

dp[i][j]=cnt[i-1][j-C_i] \cdot V_i+dp[i-1][j] , \hspace{0.2cm} i \ge 1

cnt[i][j] = cnt[i-1][j] + cnt[i-1][j-C_i]

Here, dp[i][j] indicated the sum V_i of all subsets having sum of costs equal to j, and cnt[i][j] indicates the number of subsets of ingredients having cost equal to j.

Now, letā€™s go through some formulae before proceeding further :

(a_1+a_2+...+a_z)^{k} = \sum_{x_1+x_2+...+x_z=k} \binom{k}{ x_1,x_2...x_z } \cdot a_1^{x_1} \cdot a_2^{x_2} \cdot .... \cdot a_z^{x_z}

This is the multinomial theorem. This can be rewritten as :

(a_1+a_2+...+a_z)^{k} = \sum_{x_1+x_2+...+x_z=k} \frac{ k ! }{x_1 ! \cdot x_2 ! \cdot ... \cdot x_z !} \cdot a_1^{x_1} \cdot a_2^{x_2} \cdot .... \cdot a_z^{x_z}

(a_1+a_2+...+a_z)^{k} =k! \cdot ( \sum_{x_1+x_2+...+x_z=k} \frac{a_1^{x_1}}{x_1!} \cdot \frac{a_2^{x_2}}{x_2!} \cdot ...\cdot \frac{a_z^{x_z}}{x_z!})

Another one :

In the ring of formal power series :

e^{ax} = \sum_{n \ge 0} \frac{a^n \cdot x^n}{n!}

So, we can easily see that if we multiply e^{a_1 \cdot x} \cdot e^{a_2 \cdot x} \cdot .... \cdot e^{a_z \cdot x } , then the coefficient of x^k is \frac{(a_1+a_2+...+a_z)^k}{k!} since it equals the coefficient of x^k in the expansion of e^{(a_1+a_2+....+a_z) \cdot x} .

This also obviously equals the multinomial expansion we saw above.

So, at each step of the above dynamic programming, if instead of maintain the sum of V_i of the subsets, if we can maintain the sum of the first k+1 coefficients of a power series of the form of e^{a_1+a_2+...+a_z} , then weā€™ve got exactly what we wanted !

Now, letā€™s assume dp[i][j] is a power series and not a number. Then ,

dp[0][0]=1

dp[i][j] = dp[i-1][j] + dp[i-1][j-C_i] \cdot e^{V_i \cdot x}

See so easily, how we will get as the coefficient of x^k, the summation of the k^{th} powers but with an extra dividing factor of k! . Itā€™s not difficult really.

For further simplicity , I simulate for a modified version of the sample test for you :

Test

3 3 2
1 2
2 3
1 4

Now, initially dp[0][0]=1

We process the first dish ,

dp[1][1]=e^{ 2\cdot x}

We process the second dish :

dp[2][1]=e^{2 \cdot x}

dp[2][2]=e^{3 \cdot x}

dp[2][3] =e^{2 \cdot x} \cdot e^{3 \cdot x} = e^{5 \cdot x}

We process the 3^{rd} dish :

dp[3][1] = e^{2 \cdot x} + e^{4 \cdot x}

dp[3][2] = e^{2 \cdot x} \cdot e^{4 \cdot x} + e^{3 \cdot x}= e^{6 \cdot x} +e^{3 \cdot x}

dp[3][3] =e^{2 \cdot x} \cdot e^{3 \cdot x} + e^{3 \cdot x} \cdot e^{4 \cdot x} = e^{5 \cdot x} + e^{7 \cdot x}

dp[3][4]= e^{2 \cdot x} \cdot e^{3 \cdot 4} \cdot e^{4 \cdot x} = e^{9 \cdot x}

Note that for our purposes, we only need to maintain the first K+1 coefficients and not the entire polynomial

To multiply these polynomials, we can use NTT.

In case your interested, I set a similar problem back in February, here.

Thatā€™s it ! Thank you !

Your comments are welcome !

COMPLEXITY ANALYSIS :

Time Complexity : O( N \cdot S \cdot K \cdot \log K )

Space Complexity: O(N \cdot S \cdot K )

SOLUTION LINKS :

Setter
#include<bits/stdc++.h>
using namespace std;
#define ll long long
 
const int mod=998244353,N=101,M=2001;
inline int mul(int a,int b){return (a*1ll*b)%mod;}
inline int add(int a,int b){a+=b;if(a>=mod)a-=mod;return a;}
inline int sub(int a,int b){a-=b;if(a<0)a+=mod;return a;}
inline int power(int a,int b){int rt=1;while(b>0){if(b&1)rt=mul(rt,a);a=mul(a,a);b>>=1;}return rt;}
inline int inv(int a){return power(a,mod-2);}
inline void modadd(int &a,int &b){a+=b;if(a>=mod)a-=mod;}
 
int base = 1;
vector<int> roots = {0, 1};
vector<int> rev = {0, 1};
const int max_base=14;  //x such that 2^x|(mod-1) and 2^x>max answer size(=2*n)
const int root=666702199;       //primitive root^((mod-1)/(2^max_base))
void ensure_base(int nbase) {
    if (nbase <= base) {
      return;
    }
    assert(nbase <= max_base);
    rev.resize(1 << nbase);
    for (int i = 0; i < (1 << nbase); i++) {
      rev[i] = (rev[i >> 1] >> 1) + ((i & 1) << (nbase - 1));
    }
    roots.resize(1 << nbase);
    while (base < nbase) {
      int z = power(root, 1 << (max_base - 1 - base));
      for (int i = 1 << (base - 1); i < (1 << base); i++) {
        roots[i << 1] = roots[i];
        roots[(i << 1) + 1] = mul(roots[i], z);
      }
      base++;
    }
}
void fft(vector<int> &a) {
    int n = (int) a.size();
    assert((n & (n - 1)) == 0);
    int zeros = __builtin_ctz(n);
    ensure_base(zeros);
    int shift = base - zeros;
    for (int i = 0; i < n; i++) {
      if (i < (rev[i] >> shift)) {
        swap(a[i], a[rev[i] >> shift]);
      }
    }
    for (int k = 1; k < n; k <<= 1) {
      for (int i = 0; i < n; i += 2 * k) {
        for (int j = 0; j < k; j++) {
          int x = a[i + j];
          int y = mul(a[i + j + k], roots[j + k]);
          a[i + j] = x + y - mod;
          if (a[i + j] < 0) a[i + j] += mod;
          a[i + j + k] = x - y + mod;
          if (a[i + j + k] >= mod) a[i + j + k] -= mod;
        }
      }
    }
}
vector<int> multiply(vector<int> a, vector<int> b, int eq = 0) {
    int need = (int) (a.size() + b.size() - 1);
    int nbase = 0;
    while ((1 << nbase) < need) nbase++;
    ensure_base(nbase);
    int sz = 1 << nbase;
    a.resize(sz);
    b.resize(sz);
    fft(a);
    if (eq) b = a; else fft(b);
    int inv_sz = inv(sz);
    for (int i = 0; i < sz; i++) {
      a[i] = mul(mul(a[i], b[i]), inv_sz);
    }
    reverse(a.begin() + 1, a.end());
    fft(a);
    a.resize(need);
    return a;
}
vector<int> square(vector<int> a) {
    return multiply(a, a, 1);
}
vector<int> cost(N),val(N),fac(M),invfac(M);
vector<vector<int>> dp(N,vector<int>(M)),res;
int main(){
    fac[0]=invfac[0]=1;
    for(int i=1;i<M;i++)
        fac[i]=mul(fac[i-1],i),invfac[i]=inv(fac[i]);
    int n,s,k;
    cin>>n>>s>>k;
    for(int i=0;i<n;i++)
        cin>>cost[i]>>val[i],val[i]%=mod,assert(val[i]>0);
    dp[0][0]=1;
    res=dp;
    for(int i=0;i<n;i++){
        vector<int> P2(k+1);
        int cpow=1;
        for(int l=0;l<=k;l++){
            P2[l]=mul(cpow,invfac[l]);
            cpow=mul(cpow,val[i]);
        }
        for(int j=0;j<=s;j++){
            if(!dp[j][0]||j+cost[i]>s)continue;
            vector<int> P1(k+1);
            for(int l=0;l<=k;l++)
                P1[l]=mul(dp[j][l],invfac[l]);
            vector<int> P=multiply(P1,P2);
            for(int l=0;l<=k;l++)
                res[j+cost[i]][l]=add(res[j+cost[i]][l],mul(fac[l],P[l]));
        }
        dp=res;
    }
    int ans=0;
    for(int i=0;i<=s;i++)
        ans=add(ans,dp[i][k]);
    cout<<ans<<endl;
}
Tester
#include <bits/stdc++.h>

#define all(x) (x).begin(), (x).end()
#define rall(x) (x).rbegin(), (x).rend()
#define forn(i, n) for (int i = 0; i < (int)(n); ++i)
#define for1(i, n) for (int i = 1; i <= (int)(n); ++i)
#define ford(i, n) for (int i = (int)(n) - 1; i >= 0; --i)
#define fore(i, a, b) for (int i = (int)(a); i <= (int)(b); ++i)

template<class T> bool umin(T &a, T b) { return a > b ? (a = b, true) : false; }
template<class T> bool umax(T &a, T b) { return a < b ? (a = b, true) : false; }

using namespace std;

long long readInt(long long l, long long r, char endd) {
    long long x = 0;
    int cnt = 0;
    int fi = -1;
    bool is_neg = false;
    while (true) {
        char g = getchar();
        if (g == '-') {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if ('0' <= g && g <= '9') {
            x *= 10;
            x += g - '0';
            if (cnt == 0) {
                fi = g - '0';
            }
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);

            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if (g == endd) {
            assert(cnt > 0);
            if (is_neg) {
                x = -x;
            }
            assert(l <= x && x <= r);
            return x;
        }
        else {
            //assert(false);
        }
    }
}

string readString(int l, int r, char endd) {
    string ret = "";
    int cnt = 0;
    while (true) {
        char g = getchar();
        assert(g != -1);
        if (g == endd) {
            break;
        }
        cnt++;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}
long long readIntSp(long long l, long long r) {
    return readInt(l, r, ' ');
}
long long readIntLn(long long l, long long r) {
    return readInt(l, r, '\n');
}
string readStringLn(int l, int r) {
    return readString(l, r, '\n');
}
string readStringSp(int l, int r) {
    return readString(l, r, ' ');
}

const uint32_t mod = 998244353;// (2 ^ k)*c + 1 = 2^23 * 7 * 17

uint32_t powMod(uint32_t a, uint32_t pw)
{
    uint32_t res = 1;
    while (pw)
    {
        if (pw & 1)
        {
            res = (res * static_cast<uint64_t>(a)) % mod;
        }
        pw >>= 1;
        a = (static_cast<uint64_t>(a) * a) % mod;
    }
    return res;
}

uint32_t inverse(uint32_t a)
{
    return powMod(a, mod - 2);
}

bool isPrimitiveRoot(uint32_t a)
{
    return powMod(a, (mod - 1) / 2) != 1;
}

struct NTT
{
    const uint32_t k = 23;
    const uint32_t c = 7 * 17;
    // 
    uint32_t primitiveRoot;
    uint32_t prc;// (primitiveRoot^c)%mod
    uint64_t inv2;

    NTT()
    {
        //find primitive root
        primitiveRoot = 2;
        while (!isPrimitiveRoot(primitiveRoot))
            ++primitiveRoot;
        //set prc
        prc = powMod(primitiveRoot, c);
        inv2 = inverse(2);
    }

    vector<uint32_t> transform(const vector<uint32_t>& a, bool inv)
    {
        size_t len = a.size();
        if (len == 1) return a;
        vector<uint32_t> f(len / 2), g(len / 2);
        for (uint32_t i = 0; i < len; i += 2)
        {
            f[i / 2] = a[i];
            g[i / 2] = a[i + 1];
        }

        vector<uint32_t> F = transform(f, inv), G = transform(g, inv);
        vector<uint32_t> ret(len);

        uint32_t pw = static_cast<uint32_t>((1ULL << k) / len);

        uint32_t w = powMod(prc, pw), wk = 1;
        if (inv) w = inverse(w);

        for (size_t i = 0; i < len / 2; ++i)
        {
            uint32_t u = F[i], v = (G[i] * static_cast<uint64_t>(wk)) % mod;
            ret[i] = (u + v) % mod;
            ret[i + len / 2] = (mod + u - v) % mod;
            if (inv)
            {
                ret[i] = (ret[i] * inv2) % mod;
                ret[i + len / 2] = (ret[i + len / 2] * inv2) % mod;
            }
            wk = (static_cast<uint64_t>(wk) * w) % mod;
        }

        return ret;
    }

    //without recursion
    void transform2(vector<uint32_t>& a, bool inv)
    {
        size_t n = a.size();

        for (size_t i = 1, j = 0; i < n; ++i)
        {
            size_t bit = n >> 1;
            while (j >= bit)
            {
                j -= bit;
                bit >>= 1;
            }
            j += bit;
            if (i < j)
                swap(a[i], a[j]);
        }

        for (size_t len = 2; len <= n; len <<= 1)
        {
            uint32_t pw = (1ULL << k) / len;
            uint32_t wlen = powMod(prc, pw);
            if (inv) wlen = inverse(wlen);
            for (size_t i = 0; i < n; i += len)
            {
                uint32_t w = 1;
                for (size_t j = 0; j < len / 2; ++j)
                {
                    uint32_t u = a[i + j], v = (a[i + j + len / 2] * static_cast<uint64_t>(w)) % mod;
                    a[i + j] = u + v < mod ? u + v : u + v - mod;
                    a[i + j + len / 2] = u >= v ? u - v : u - v + mod;
                    w = (static_cast<uint64_t>(w) * wlen) % mod;
                }
            }
        }
        if (inv)
        {
            uint32_t nrev = inverse(n);
            for (int i = 0; i < n; ++i)
                a[i] = (a[i] * static_cast<uint64_t>(nrev)) % mod;
        }
    }
};

int main(int argc, char** argv)
{
#ifdef HOME
    //if (IsDebuggerPresent())
	{
		freopen("../build/in.txt", "rb", stdin);
		freopen("../build/out.txt", "wb", stdout);
	}
#endif

    const int MAXB = 2002;

    int N, S, K, KK = 1;

    scanf("%d %d %d", &N, &S, &K);

    while (KK <= 2 * K)
        KK <<= 1;
    //KK <<= 1;

    vector<uint32_t> invf(KK, 1), fact(KK, 1);

    for (uint32_t i = 1; i < KK; ++i)
    {
        fact[i] = (static_cast<uint64_t>(i) * fact[i - 1]) % mod;
        invf[i] = inverse(fact[i]);
    }


    vector<int> C(N);
    vector<vector<uint32_t> > V(N, vector<uint32_t>(KK));

    for (int i = 0; i < N; ++i)
    {
        V[i][0] = 1;
        scanf("%d %d", &C[i], &V[i][1]);
        for (int j = 2; j < K + 1; ++j)
        {
            V[i][j] = (static_cast<uint64_t>(V[i][j - 1]) * V[i][1]) % mod;
        }
    }

    vector<vector<uint32_t> > vR(S + 1, vector<uint32_t>(KK));
    NTT ntt;
    vR[0][0] = 1;

    for (int i = 0; i < N; ++i)
    {
        vector<uint32_t> vI = V[i];
        for (uint32_t o = 0; o <= K; ++o)
        {
            vI[o] = (static_cast<uint64_t>(vI[o]) * invf[o]) % mod;
        }
        ntt.transform2(vI, false);

        for (int j = S - C[i]; j >= 0; --j)
        {
            int actj = j + C[i];

            //convolution vR[j] , V[i]

            vector<uint32_t> vRJ = vR[j];

            for (uint32_t o = 0; o <= K; ++o)
            {
                vRJ[o] = (static_cast<uint64_t>(vRJ[o]) * invf[o]) % mod;
            }

            ntt.transform2(vRJ, false);

            vector<uint32_t> fm(KK);
            for (uint32_t o = 0; o < KK; ++o)
            {
                fm[o] = (static_cast<uint64_t>(vRJ[o]) * vI[o]) % mod;
            }
            ntt.transform2(fm, true);
            for (uint32_t o = 0; o <= K; ++o)
            {
                vR[actj][o] = (vR[actj][o] + fm[o] * static_cast<uint64_t>(fact[o])) % mod;
            }
        }
    }
    int64_t res = 0;
    for (int i = 0; i <= S; ++i)
    {
        res += vR[i][K];
    }
    printf("%lld\n", res%mod);
    return 0;
}
6 Likes

Iā€™m sorry, but that was incorrect, the coefficients get messed up as when we multiply 2 polynomials of size K+1, we get the evaluations of their product, not the evaluations of the first K+1 coefficients of their product.

The intended complexity is O( N S K \log K )

Time limits for Python and other languages seriously need to be reconsidered! My implementation of the NTT solution in Python gets TLE:

https://www.codechef.com/viewsolution/26628357

Apparently Java and C# programmers had similar issues. Not fair!

1 Like

A couple of questions:

  • Why have you multiplied V_i to dp[i-1][j-C_i]? Since dp[i][j] represents the summation of all subsets of first i elements with cost j, shouldnā€™t we be adding V_i to all the applicable subsets? I donā€™t understand how multiplying V_i is right. (sorry for the stupid question, but this would help me understand the solution fully)
  • What does B_i represent in dp[i][j]=dp[iāˆ’1][j]+dp[iāˆ’1][jāˆ’Ciā€‹]ā‹…e^{B_iā€‹ā‹…x}? Is there a specific reason why you are using a different notation here?
1 Like

Even I donā€™t see how multiplying by V_i is correct. However, what you can do instead is maintain another array cnt[i][j] which represents the number of subsets that have sum of C_i = j. Then you can add V_i*cnt[i-1][j-C_i] + dp[i-1][j-C_i] to dp[i][j].

The B_i is probably a typo. It should be V_i

4 Likes

Thanks for the clarification, this makes sense.

Yes please reconsider the solutions. I faced the same issue with java.

Hi Anand, thanks for the editorial!

Itā€™s just that, there are a lot of variables and terms that are not defined before being used. Also, why did you begin with the expansion of (a_1 + a_2 + a_3 +... +a_z)^k . What is the intuition behind it?

Are there some easy questions that need such polynomial representation and basic algebra(like multi-point evaluation or same basic question that can bring the feel of FFT)? Directly playing with polynomials on this level is not easy. Some easy problems, if available, are requested.

Thanks.

Ofcourse, it was incorrect, I have made changes in both places, thank you.

Some people may not understand power series expansions of functions, and it was to show, that how power series and multinomial expansion lead to the same thing. Iā€™m sure, you can find a lot of basic fft problems, by just using the fft tag on cf or codechef.

1 Like

Hello,

For all people that are beginning with FFT, the link to learn it is in the editorial, and for people looking to dive deeper into this theory, here is a wonderful document by Adamant : fft_eng.pdf - Google Drive

1 Like

I found this during the contest,not much related but still very good concept.
link for pdf
also thanks @anand20 for such simple editorial, if you could provide other maths problems links that are like this or not like this,but really cool like this , that would be really great.

2 Likes

Can u please elaborate this, atleast you should have explained the test case completely(with output 65) instead of just writing this.
This is just a local work, you didnā€™t even explain about how to approach such problems, and just wrote the setterā€™s solution

Read the link attached to the work NTT. It is a well known algorithm, and there is nothing more to elaborate.

2 polynomials of degree K can be multiplied in O(Klogā”(K)), that is obviously in the link.

What I learnt plain and simple over time, is that itā€™s just better to use c++. I quit programming in java, and what I learnt is that actually c++ is better in every possible way.

2 Likes

I guess you didnā€™t understand my question. My question is which Polynomials do I have to multiply. Like in the test case you calculated various dp[i][j], do I need to multiply all these dp[i][j] and what would be the final result??

@anand20 thanks for the editorial. Can you also share/post the editorial/approach for the DOOFST problem?

1 Like