PROBLEM LINK:
Contest Division 1
Contest Division 2
Contest Division 3
Practice
Setter: Chahat Agrawal
Tester & Editorialist: Taranpreet Singh
DIFFICULTY
Easy-Medium
PREREQUISITES
Meet in the middle, Knapsack DP.
PROBLEM
Given a matrix with N rows and M columns containing positive integers, you are allowed to choose a subset of rows and a subset of columns. Determine the number of ways we can choose subset of rows and columns such that sum of elements whose row and column are included in selected subset, equals B, given in input.
QUICK EXPLANATION
- We try all subset of rows (after taking transpose of matrix if number of rows exceed number of columns). For each column, we can now calculate the sum it’ll contribute if that column is selected.
- Now it becomes an instance of subset sum problem. A combination of meet in the middle for some cases and knapsack DP for others is required in order to ace this problem.
EXPLANATION
First of all, let’s read constraints properly. N*M \leq C implies min(N, M) \leq \sqrt C.
Assuming N \leq M (take transpose of matrix and swap N and M if N > M), We now have N \leq 14, and N*M \leq 200.
Naive approach
In this approach, we try all subsets of rows, and for each subset of rows, we try all subsets of columns. This approach takes O(2^{N+M}) time.
Optimizing it a bit
Let’s try all subsets of rows.
For a subset of rows, we now know that either a column is selected or not selected. Hence, for each column, Compute the sum of elements of selected rows lying in that column, Say A_c denote the sum of values of elements in selected rows lying in column c. Hence, The problem reduces to the number of ways to select subset of values in list of length M whose sum is B.
There are two well known approaches for that
Optimization approach 1: Meet in the middle
In this approach, We divide the list of length M into two halves, and trying all subsets of left half and store the number of subsets leading to each sum of selected values in a map.
Then we iterate over the subsets of elements of right half, and for each subset with sum s2, we increment the total ways with the number of subsets of left half with sum B-s2 which is already computed.
This approach has time complexity O(2^{N+M/2}) per test case.
This approach times out when M > 40
Optimization approach 2: Knapsack DP
Once again, we have a list of M values, and we want to find the number of ways to find the number of ways to achieve sum B.
Knapsack DP approach means to consider elements one by one and compute ways_x denoting the number of ways to make x using elements considered till now.
You can read more on subset sum using DP here
This method takes time O(M*B), would fail for cases like N = 8, M = 25 and B = 10^5
Final approach
We have two approaches, both struggling on different type of cases. First approach struggles when M gets too high, as 2^N*2^{M/2} goes too high. Second approach struggles when 2^N*M*B gets too high.
We can use first approach when M is small, and second approach when M is large.
We can either choose some dividing criteria manually, or (as i prefer) estimate the number of operations in both cases and use the method using less operations.
For this problem, estimated number of operations for first approach is 2^{M/2} and for second approach is M*B. Use the approach having lower number of estimated operations.
TIME COMPLEXITY
The time complexity is O(2^N*min(M*B, 2^{M/2})) per test case.
SOLUTIONS
Setter's Solution
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#include <vector>
#include <set>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <queue>
#include <ctime>
#include <cassert>
#include <complex>
#include <string>
#include <cstring>
#include <chrono>
#include <random>
#include <bitset>
using namespace std;
#ifdef LOCAL
#define eprintf(...) fprintf(stderr, __VA_ARGS__);fflush(stderr);
#else
#define eprintf(...) 42
#endif
using ll = long long;
using ld = long double;
using uint = unsigned int;
using ull = unsigned long long;
template<typename T>
using pair2 = pair<T, T>;
using pii = pair<int, int>;
using pli = pair<ll, int>;
using pll = pair<ll, ll>;
mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
ll myRand(ll B) {
return (ull)rng() % B;
}
#define pb push_back
#define mp make_pair
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second
clock_t startTime;
double getCurrentTime() {
return (double)(clock() - startTime) / CLOCKS_PER_SEC;
}
const int MOD = (int)1e9 + 7;
int add(int x, int y) {
x += y;
if (x >= MOD) x -= MOD;
return x;
}
const int N = 202;
const int M = (int)1e5 + 3;
const int S = (1 << 19) + 5;
int n, m;
int a[N][N];
int b[N];
int B;
int ans;
int dp[M];
int c[2][S];
int sz[2];
void brute2(int p, int s) {
if (p == n) {
if (s == B) ans = add(ans, 1);
return;
}
brute2(p + 1, s);
brute2(p + 1, s + b[p]);
}
void solveBrute() {
brute2(0, 0);
}
void solveKnapsack() {
for (int i = 0; i <= B; i++)
dp[i] = 0;
dp[B] = 1;
for (int i = 0; i < n; i++) {
for (int x = b[i]; x <= B; x++)
dp[x - b[i]] = add(dp[x - b[i]], dp[x]);
}
ans = add(ans, dp[0]);
}
void brute3(int p, int en, int k, int s) {
if (p == en) {
c[k][sz[k]++] = s;
return;
}
brute3(p + 1, en, k, s);
brute3(p + 1, en, k, s + b[p]);
}
void solveMITM() {
sz[0] = sz[1] = 0;
brute3(0, n / 2, 0, 0);
brute3(n / 2, n, 1, 0);
sort(c[0], c[0] + sz[0]);
sort(c[1], c[1] + sz[1]);
int l = sz[1], r = sz[1];
for (int i = 0; i < sz[0]; i++) {
while(l > 0 && c[0][i] + c[1][l - 1] >= B) l--;
while(r > 0 && c[0][i] + c[1][r - 1] > B) r--;
ans = add(ans, r - l);
}
}
void solveLin() {
if (n <= 38 && ((1 << ((n + 1) / 2)) < B))
solveMITM();
else
solveKnapsack();
/*
if (n < 30 && (1 << n) < n * B)
solveBrute();
else
solveKnapsack();
*/
}
void brute(int p) {
if (p == m) {
solveLin();
return;
}
brute(p + 1);
for (int i = 0; i < n; i++)
b[i] += a[i][p];
brute(p + 1);
for (int i = 0; i < n; i++)
b[i] -= a[i][p];
}
void solve() {
ans = 0;
scanf("%d%d%d", &n, &m, &B);
assert(1 <= B && B <= 100000);
assert(1 <= n * m <= 156);
for (int i = 0; i < n; i++)
for (int j = 0; j < m; j++)
scanf("%d", &a[i][j]);
if (n < m) {
for (int i = 0; i < m; i++)
for (int j = 0; j < i; j++)
swap(a[i][j], a[j][i]);
swap(n, m);
}
for (int i = 0; i < n; i++)
b[i] = 0;
brute(0);
printf("%d\n", ans);
}
int main()
{
startTime = clock();
// freopen("input.txt", "r", stdin);
// freopen("output.txt", "w", stdout);
int t;
scanf("%d", &t);
while(t--) solve();
return 0;
}
Tester's Solution
import java.util.*;
import java.io.*;
class IMAT{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
int N = ni(), M = ni(), B = ni();
int[][] A = new int[N][M];
for(int i = 0; i< N; i++){
for(int j = 0; j< M; j++){
A[i][j] = ni();
}
}
if(N > M){
int[][] tmp = new int[M][N];
for(int i = 0; i< N; i++)
for(int j = 0; j< M; j++)
tmp[j][i] = A[i][j];
A = tmp;
int t = N;
N = M;
M = t;
}
int MOD = (int)1e9+7, ans = 0;
int[] ways = new int[1+B];
for(int mask = 0; mask < 1<<N; mask++){
Arrays.fill(ways, 0);
int[] val = new int[M];
for(int i = 0; i< N; i++){
for(int j = 0; j< M; j++){
val[j] += ((mask>>i)&1)*A[i][j];
}
}
long w1 = B*M, w2 = M*(1L<<((M+1)/2));
if(M > 40 || w1 < w2){
ways[0] = 1;
for(int x:val){
for(int v = B; v >= x; v--){
ways[v] += ways[v-x];
if(ways[v] >= MOD)ways[v] -= MOD;
}
}
ans += ways[B];
if(ans >= MOD)ans -= MOD;
}else{
int left = M/2, right = M-left;
for(int i = 0; i< 1<<left; i++){
long sum = 0;
for(int j = 0; j< left; j++)sum += val[j] * ((i>>j)&1);
if(sum <= B)ways[(int)sum]++;
}
for(int i = 0; i< 1<<right; i++){
long sum = 0;
for(int j = 0; j< right; j++)sum += val[j+left] * ((i>>j)&1);
if(sum <= B)ans += ways[B-(int)sum];
if(ans >= MOD)ans -= MOD;
}
}
}
pn(ans);
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new IMAT().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Feel free to share your approach. Suggestions are welcomed as always.