IMAT - Editorial


Contest Division 1
Contest Division 2
Contest Division 3

Setter: Chahat Agrawal
Tester & Editorialist: Taranpreet Singh




Meet in the middle, Knapsack DP.


Given a matrix with N rows and M columns containing positive integers, you are allowed to choose a subset of rows and a subset of columns. Determine the number of ways we can choose subset of rows and columns such that sum of elements whose row and column are included in selected subset, equals B, given in input.


  • We try all subset of rows (after taking transpose of matrix if number of rows exceed number of columns). For each column, we can now calculate the sum it’ll contribute if that column is selected.
  • Now it becomes an instance of subset sum problem. A combination of meet in the middle for some cases and knapsack DP for others is required in order to ace this problem.


First of all, let’s read constraints properly. N*M \leq C implies min(N, M) \leq \sqrt C.
Assuming N \leq M (take transpose of matrix and swap N and M if N > M), We now have N \leq 14, and N*M \leq 200.

Naive approach

In this approach, we try all subsets of rows, and for each subset of rows, we try all subsets of columns. This approach takes O(2^{N+M}) time.

Optimizing it a bit

Let’s try all subsets of rows.

For a subset of rows, we now know that either a column is selected or not selected. Hence, for each column, Compute the sum of elements of selected rows lying in that column, Say A_c denote the sum of values of elements in selected rows lying in column c. Hence, The problem reduces to the number of ways to select subset of values in list of length M whose sum is B.

There are two well known approaches for that

Optimization approach 1: Meet in the middle

In this approach, We divide the list of length M into two halves, and trying all subsets of left half and store the number of subsets leading to each sum of selected values in a map.

Then we iterate over the subsets of elements of right half, and for each subset with sum s2, we increment the total ways with the number of subsets of left half with sum B-s2 which is already computed.

This approach has time complexity O(2^{N+M/2}) per test case.

This approach times out when M > 40

Optimization approach 2: Knapsack DP

Once again, we have a list of M values, and we want to find the number of ways to find the number of ways to achieve sum B.

Knapsack DP approach means to consider elements one by one and compute ways_x denoting the number of ways to make x using elements considered till now.

You can read more on subset sum using DP here

This method takes time O(M*B), would fail for cases like N = 8, M = 25 and B = 10^5

Final approach

We have two approaches, both struggling on different type of cases. First approach struggles when M gets too high, as 2^N*2^{M/2} goes too high. Second approach struggles when 2^N*M*B gets too high.

We can use first approach when M is small, and second approach when M is large.

We can either choose some dividing criteria manually, or (as i prefer) estimate the number of operations in both cases and use the method using less operations.

For this problem, estimated number of operations for first approach is 2^{M/2} and for second approach is M*B. Use the approach having lower number of estimated operations.


The time complexity is O(2^N*min(M*B, 2^{M/2})) per test case.


Setter's Solution
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#include <vector>
#include <set>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <queue>
#include <ctime>
#include <cassert>
#include <complex>
#include <string>
#include <cstring>
#include <chrono>
#include <random>
#include <bitset>
using namespace std;
#ifdef LOCAL
    #define eprintf(...) fprintf(stderr, __VA_ARGS__);fflush(stderr);
    #define eprintf(...) 42
using ll = long long;
using ld = long double;
using uint = unsigned int;
using ull = unsigned long long;
template<typename T>
using pair2 = pair<T, T>;
using pii = pair<int, int>;
using pli = pair<ll, int>;
using pll = pair<ll, ll>;
mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
ll myRand(ll B) {
    return (ull)rng() % B;
#define pb push_back
#define mp make_pair
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second
clock_t startTime;
double getCurrentTime() {
    return (double)(clock() - startTime) / CLOCKS_PER_SEC;
const int MOD = (int)1e9 + 7;
int add(int x, int y) {
    x += y;
    if (x >= MOD) x -= MOD;
    return x;
const int N = 202;
const int M = (int)1e5 + 3;
const int S = (1 << 19) + 5;
int n, m;
int a[N][N];
int b[N];
int B;
int ans;
int dp[M];
int c[2][S];
int sz[2];
void brute2(int p, int s) {
    if (p == n) {
	    if (s == B) ans = add(ans, 1);
    brute2(p + 1, s);
    brute2(p + 1, s + b[p]);
void solveBrute() {
    brute2(0, 0);
void solveKnapsack() {
    for (int i = 0; i <= B; i++)
	    dp[i] = 0;
    dp[B] = 1;
    for (int i = 0; i < n; i++) {
	    for (int x = b[i]; x <= B; x++)
		    dp[x - b[i]] = add(dp[x - b[i]], dp[x]);
    ans = add(ans, dp[0]);
void brute3(int p, int en, int k, int s) {
    if (p == en) {
	    c[k][sz[k]++] = s;
    brute3(p + 1, en, k, s);
    brute3(p + 1, en, k, s + b[p]);
void solveMITM() {
    sz[0] = sz[1] = 0;
    brute3(0, n / 2, 0, 0);
    brute3(n / 2, n, 1, 0);
    sort(c[0], c[0] + sz[0]);
    sort(c[1], c[1] + sz[1]);
    int l = sz[1], r = sz[1];
    for (int i = 0; i < sz[0]; i++) {
	    while(l > 0 && c[0][i] + c[1][l - 1] >= B) l--;
	    while(r > 0 && c[0][i] + c[1][r - 1] > B) r--;
	    ans = add(ans, r - l);
void solveLin() {
    if (n <= 38 && ((1 << ((n + 1) / 2)) < B))
    if (n < 30 && (1 << n) < n * B)
void brute(int p) {
    if (p == m) {
    brute(p + 1);
    for (int i = 0; i < n; i++)
	    b[i] += a[i][p];
    brute(p + 1);
    for (int i = 0; i < n; i++)
	    b[i] -= a[i][p];
void solve() {
    ans = 0;
    scanf("%d%d%d", &n, &m, &B);
    assert(1 <= B && B <= 100000);
    assert(1 <= n * m <= 156);
    for (int i = 0; i < n; i++)
	    for (int j = 0; j < m; j++)
		    scanf("%d", &a[i][j]);
    if (n < m) {
	    for (int i = 0; i < m; i++)
		    for (int j = 0; j < i; j++)
			    swap(a[i][j], a[j][i]);
	    swap(n, m);
    for (int i = 0; i < n; i++)
	    b[i] = 0;
    printf("%d\n", ans);
int main()
    startTime = clock();
//	freopen("input.txt", "r", stdin);
//	freopen("output.txt", "w", stdout);
    int t;
    scanf("%d", &t);
    while(t--) solve();
    return 0;
Tester's Solution
import java.util.*;
class IMAT{
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        int N = ni(), M = ni(), B = ni();
        int[][] A = new int[N][M];
        for(int i = 0; i< N; i++){
            for(int j = 0; j< M; j++){
                A[i][j] = ni();
        if(N > M){
            int[][] tmp = new int[M][N];
            for(int i = 0; i< N; i++)
                for(int j = 0; j< M; j++)
                    tmp[j][i] = A[i][j];
            A = tmp;
            int t = N;
            N = M;
            M = t;
        int MOD = (int)1e9+7, ans = 0;
        int[] ways = new int[1+B];
        for(int mask = 0; mask < 1<<N; mask++){
            Arrays.fill(ways, 0);
            int[] val = new int[M];
            for(int i = 0; i< N; i++){
                for(int j = 0; j< M; j++){
                    val[j] += ((mask>>i)&1)*A[i][j];
            long w1 = B*M, w2 = M*(1L<<((M+1)/2));
            if(M > 40 || w1 < w2){
                ways[0] = 1;
                for(int x:val){
                    for(int v = B; v >= x; v--){
                        ways[v] += ways[v-x];
                        if(ways[v] >= MOD)ways[v] -= MOD;
                ans += ways[B];
                if(ans >= MOD)ans -= MOD;
                int left = M/2, right = M-left;
                for(int i = 0; i< 1<<left; i++){
                    long sum = 0;
                    for(int j = 0; j< left; j++)sum += val[j] * ((i>>j)&1);
                    if(sum <= B)ways[(int)sum]++;
                for(int i = 0; i< 1<<right; i++){
                    long sum = 0;
                    for(int j = 0; j< right; j++)sum += val[j+left] * ((i>>j)&1);
                    if(sum <= B)ans += ways[B-(int)sum];
                    if(ans >= MOD)ans -= MOD;
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = true;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
    public static void main(String[] args) throws Exception{
        new IMAT().run();
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return;}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(;}
    long nl()throws Exception{return Long.parseLong(;}
    double nd()throws Exception{return Double.parseDouble(;}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(;

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
            return st.nextToken();

        String nextLine() throws Exception{
            String str = "";
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            return str;

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:


NIce problem on dp optimization.


I knew this was a higher dimension of subset-sum, but I was thinking in the lines of 3D DP, completely missed meet in the middle approach. Nice question!

1 Like

If for each of the testcase M=200 and B=10^5 then min(M * B,2^(M/2)) will always be M * B. Can we not then say that it is O(2^N * M*B) because that is what that can happen in worst case?

Well in worst case, I think it should be N = 4, M = 50, B = 10^5.

But as N*M<=200 as given in the constraint we can take N=1 and M=200.

I said the worst case. The case I mentioned should give maximum value of 2^N*min(M*B, 2^{M/2}).

Let lg=max(n,m) and sm=min(n,m)

My sol is 2^{sm}*min(lg*B*log(sm), 2^{lg}*sm*lg).

Should it pass subtask 3?
Intersection Matrix - Meet In Middle Problem | CodeChef

Solution: 45716132 | CodeChef

Functions used in program
f1 → all subset of sm
f2 → all subset of lg
sum → for calculating sum of elements for given subset of sm and lg
f3 → for storing elements in vector for dp
fun → dp

Edit - How to optimize this? After finding subsets of sm and lg how to find sum of elements?

Firstly, you can kick out the log factor, since we can use array of length 1+B and elements are positive.

Also, if you used meet in the middle, shouldn’t the complexity have 2^{lg/2} term?

“How to optimize this? After finding subsets of sm and lg how to find sum of elements?”
Iterate over subset of right, say the sum of right subset is S, so we want the number of subset of left part, which have sum B-S which can be retrieved using array or map.

Hope that clarifies.

I don’t know why I use map. I haven’t read the editorial when I was trying this problem.

My complexity will be 2^{sm} ∗ min(lg∗B + lg*sm, 2^{lg} ∗sm ∗ lg)

After using meet in the middle then comp will be 2^{sm} ∗ min(lg∗B + lg*sm, 2^{lg/2} ∗log(2^{lg/2})+2^{lg/2})

But if we have some subset of row and some subset of columns then how to find sum of elements found on intersection of those rows and columns. For that in my sol extra sm*lg is there.

for(int k=0;k<r.size();k++)
        for(int l=0;l<c.size();l++)

Done some optimizations and applied meet in the middle. Now only 2TC are giving TLE.

My code comp is 2^{sm} * [sm *lg + min ( H *lg , (2^{lg/2} *log(2^{lg/2}) + 2^{lg/2} )* lg/2) ]

Solution: 45744169 | CodeChef

Functions used in program are
f1 → finding subsets of rows
f3 → storing sum of values present in row in vector
fun → for dp
f2 → meet in the middle for first half and store in map
f4 → meet in the middle for second half

How to remove lg/2


Edit - solved, removed map and used array for keeping count
Solution: 45744899 | CodeChef

Thanks a lot