KNAPSACK_ - Editorial


Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: wuhudsm
Tester: yash_daga
Editorialist: iceknight1093




Dynamic Programming


You have N items, the i-th of them has volume X_i and value Y_i.
It is also known that X_i\cdot Y_i \leq 10^9 for each item.
Find the maximum value of a subset of items whose sum of volumes doesn’t exceed V.


The task of finding the maximum value of items that fit in a certain volume is exactly the well-known 0-1 knapsack problem.
There are several different solutions to it, usually depending on which constraints are small.
The most classical version is when the capacity is small, as seen here.
It is also possible to solve when all the values are small, as in this task.
Solutions to both the linked problems can be found in this post.

Unfortunately for us, our problem has neither small capacity nor small values.

Instead, let’s use the additional constraint we have: X_i\cdot Y_i \leq 10^9.
This means that either X_i or Y_i must be ‘small’, i.e, \leq \sqrt{10^9}.

Let’s process the pairs with small X_i separately from those with large X_i (and hence, small Y_i), and combine them later.

When X_i is small, the ‘standard’ dp solution can be applied.
That is, define dp_1[c] to be the maximum value of items chosen, with volume exactly c.
Then, when processing the pair (X_i, Y_i) we have:

dp_1[c] = \max(dp_1[c], dp_1[c-X_i] + Y_i)

depending on whether we choose to take the item or not.
In order for this to make sense, remember to process dp_1 in descending order of c.

Since we’re dealing with X_i \leq \sqrt{10^9} here, the total volume of all the items is bounded by N\cdot \sqrt{10^9}.
So, we have N^2\cdot \sqrt{10^9} states in our dp, with \mathcal{O}(1) transitions from each, which is good enough.

Next, let’s look at the pairs where X_i is large but Y_i is small.
Here, we have Y_i \leq \sqrt{10^9}.

This gives us the idea to define our dp states based on the value instead of the volume.
That is, define dp_2[p] to be the minimum volume required to obtain a value of exactly p.
Then, we have

dp_2[p] = \min(dp_2[p], dp_2[p-Y_i] + X_i)

with the same reasoning as before.

Once again, this has \mathcal{O}(N^2 \cdot \sqrt{10^9}) complexity.

Now that we have dp_1 and dp_2, let’s combine them!

Let’s fix p, the value of the items from the second set.
By definition, the minimum weight we need to do this is dp_2[p].
Next, we’d like to pick some items from the first set to fill the remaining volume.
The remaining volume is exactly V - dp_2[p].

If we fix the volume c of items from the first set, again by definition the largest value we can get is dp_1[c].
We’re free to take any c \leq V-dp_2[p].
So, the value we’re looking for is the maximum of dp_1[0], dp_1[1], dp_1[2], \ldots, dp_1[c].
Since dp_1 is already computed, we can find this in \mathcal{O}(1) time by just taking its prefix sums!

We consider all N\cdot \sqrt{10^9} values of p, each of which is processed in \mathcal{O}(1); meaning this part is fast enough too.


\mathcal{O}(N^2 \sqrt{10^9}) per testcase.


Author's code (C++)
#include <map>
#include <set>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
typedef double db; 
typedef long long ll;
typedef unsigned long long ull;
const int N=10000010;
const int LOGN=28;
const ll  TMD=0;
const ll  INF=2147483647LL*2147483647LL;
const int SQR=(int)sqrt(1e9) + 1;
int T,v,n;
ll  ans;
int x[N],y[N];
ll  dp1[N],dp2[N],premx[N];

int main()
		for(int i=1;i<=n;i++) scanf("%d%d",&x[i],&y[i]);
		for(int i=1;i<=n*40000;i++) dp1[i]=-INF,dp2[i]=INF;
		for(int i=1,c=0;i<=n;i++)
				for(int j=c*SQR;j>=x[i];j--)
					if(dp1[j-x[i]]!=-INF) dp1[j]=max(dp1[j],dp1[j-x[i]]+y[i]);
		for(int i=1,c=0;i<=n;i++)
				for(int j=c*SQR;j>=y[i];j--)
					if(dp2[j-y[i]]!=INF) dp2[j]=min(dp2[j],dp2[j-y[i]]+x[i]);
		vector<pair<ll,ll> > pr;
		for(int i=0;i<=n*SQR;i++) 
			if(dp2[i]!=INF) pr.push_back(make_pair(dp2[i],i));
		for(int i=1;i<pr.size();i++) premx[i]=max(premx[i-1],pr[i].second); 
		for(int i=0;i<=min(v,n*SQR);i++)
			if(dp1[i]==-INF) continue;
			int L=0,R=pr.size(),M;
				if(i+pr[M].first<=v) L=M;
				else R=M;
	return 0;
Tester's code (C++)
//clear adj and visited vector declared globally after each test case
//check for long long overflow   
//Mod wale question mein last mein if dalo ie. Ans<0 then ans+=mod;
//Incase of close mle change language to c++17 or c++14  
//Check ans for n=1 
#pragma GCC target ("avx2")    
#pragma GCC optimize ("O3")  
#pragma GCC optimize ("unroll-loops")
#include <bits/stdc++.h>                   
#include <ext/pb_ds/assoc_container.hpp>  
#define int long long      
#define IOS std::ios::sync_with_stdio(false); cin.tie(NULL);cout.tie(NULL);cout.precision(dbl::max_digits10);
#define pb push_back 
#define mod 1000000007ll
#define lld long double
#define mii map<int, int> 
#define pii pair<int, int>
#define ll long long 
#define ff first
#define ss second 
#define all(x) (x).begin(), (x).end()
#define rep(i,x,y) for(int i=x; i<y; i++)    
#define fill(a,b) memset(a, b, sizeof(a))
#define vi vector<int>
#define setbits(x) __builtin_popcountll(x)
#define print2d(dp,n,m) for(int i=0;i<=n;i++){for(int j=0;j<=m;j++)cout<<dp[i][j]<<" ";cout<<"\n";}
typedef std::numeric_limits< double > dbl;
using namespace __gnu_pbds;
using namespace std;
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> indexed_set;
//member functions :
//1. order_of_key(k) : number of elements strictly lesser than k
//2. find_by_order(k) : k-th element in the set
const long long N=20000005, INF=2000000000000000000;
const int inf=2e9 + 5;
lld pi=3.1415926535897932;
int lcm(int a, int b)
    int g=__gcd(a, b);
    return a/g*b;
int power(int a, int b, int p)
    return 0;
    int res=1;
    return res;
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

int getRand(int l, int r)
    uniform_int_distribution<int> uid(l, r);
    return uid(rng);
struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
            buffer.push_back((char) c);

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
        return now;

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
        return res;

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        return res;

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;

    auto readIntVec(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
            else readEoln();
        return v;

    auto readLongVec(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
            else readEoln();
        return v;

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');

    void readEof() {
        assert((int) buffer.size() == pos);
int dp1[N], dp2[N];
int32_t main()
    input_checker inp;
    int t = inp.readInt(1, 100), sum_n=0, lim=4e4;
        int n = inp.readInt(1, 100); inp.readSpace();
        sum_n += n;
        assert(sum_n <= 100);
        int cap = inp.readInt(1, 1e9); inp.readEoln();
        int x[n], v[n];
            x[i] = inp.readInt(1, 1e9); inp.readSpace();
            v[i] = inp.readInt(1, 1e9); inp.readEoln();
            assert(x[i]*v[i] <= 1e9);
        for(int i=0;i<=(n*lim)+2;i++)
                for(int j=(i+1)*lim;j>=x[i];j--)
                dp1[j]=max(dp1[j], dp1[j-x[i]]+v[i]);
                for(int j=(i+1)*lim;j>=v[i];j--)
                dp2[j]=min(dp2[j], dp2[j-v[i]]+x[i]);
        int z=0;
        for(int i=0;i<=(n*lim)+2;i++)
            dp1[i]=max(dp1[i], dp1[i-1]);
        for(int i=(n*lim);i>=0;i--)
        dp2[i]=min(dp2[i], dp2[i+1]);
        int pt=(n*lim), ans=0;
        for(int i=0;i<=min(cap, lim*n);i++)
            ans=max(ans, pt+dp1[i]);
Editorialist's code (Python)
for _ in range(int(input())):
    n, cap = map(int, input().split())
    smallx, bigx = [], []
    lim1, lim2 = 10, 10
    for i in range(n):
        x, y = map(int, input().split())
        if x <= y:
            smallx.append((x, y))
            lim1 += x
            bigx.append((x, y))
            lim2 += y
    dp1 = [-1]*(lim1)
    dp2 = [10**13]*(lim2)
    dp1[0] = 0
    for x, y in smallx:
        for v in reversed(range(lim1)):
            if v - x < 0: break
            if dp1[v-x] >= 0: dp1[v] = max(dp1[v], dp1[v-x] + y)
    dp2[0] = 0
    for x, y in bigx:
        for v in reversed(range(lim2)):
            if v - y < 0: break
            dp2[v] = min(dp2[v], dp2[v-y] + x)
    for v in range(lim1): dp1[v] = max(dp1[v] ,dp1[v-1])
    ans = 0
    for v in range(lim2):
        if dp2[v] > cap: continue
        rem = min(lim1-1, cap - dp2[v])
        ans = max(ans, v + dp1[rem])

Sir, should this be ascending order of c… Like dp1[10] should be found first than dp1[20]

No, descending order is correct.
Think about it: suppose you have an item with weight 1 and volume 1, and you process in ascending order. What happens?
You’ll get dp_1[x] = x for all x, which is obviously wrong since you only have one item — what you want is dp_1[1] = 1, and dp_1[x] = 0 for all other x.

If you go in ascending order, when computing dp_1[x] you use dp_1[y] for some y \lt x but dp_1[y] has already been computed using this item, so you implicitly use this item multiple times!
If you go in descending order, dp_1[y] for y\lt x hasn’t been updated yet, so it uses only the previous items, not the current one. This ensures that each item is only used once.

Ascending order would be correct if you could use each item as many times as you like.

1 Like

Thanks a lot, i understood