# PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Contest: Division 3

Contest: Division 4

* Author:* mehrzad_minaei

*Hriday, Utkarsh Gupta*

**Testers:***Nishank Suresh*

**Editorialist:**# DIFFICULTY:

2763

# PREREQUISITES:

Dynamic programming

# PROBLEM:

You are given integers N and M.

An N\times M tables containing 0 and 1 is said to be *special* if, for each (i, j), A_{i, j} = 1 if and only if the sum of row i equals the sum of column j.

Find the minimum sum of values of a special table with N rows and M columns.

# EXPLANATION:

First off, what does a special table even look like?

Suppose the row sums are r_1, r_2, \ldots, r_N and the column sums are c_1, c_2, \ldots, c_M.

Say r_1 = k. What information does that give you?

## Answer

There are exactly k ones in the first row.

This means that there are exactly k columns whose sum is k.

First, let’s assume k \neq 0.

Applying the same argument to one of the columns with sum k tells us that there are also exactly k rows whose sum is k.

So we have k rows and k columns, such that:

- The intersection of any one of these rows and columns contains a 1
- All other cells in these rows/columns are 0

This gives us k^2 ones in the grid.

Notice that without loss of generality, we can assume r_1 = r_2 = \ldots = r_k = c_1 = c_2 = \ldots = c_k = k, then delete these rows and columns and apply the same argument to the rest of the grid.

Now, what happens if k = 0?

Well, we still have exactly k columns with sum k, i.e, there are no columns with sum 0.

This means we can simply delete the first row and consider the remaining part of the grid instead.

The above discussion in fact tells us the following:

- First, the analysis of k = 0 tells us that the grid can contain
*either*empty rows or empty columns, but not both. So, we can safely delete empty rows/columns and solve for the reduced N or M appropriately. - Second, let the
*distinct*non-zero row (or column) sums be x_1 \lt x_2 \lt \ldots \lt x_d.

Then, the grid contains exactly x_1^2 + x_2^2 + \ldots + x_d^2 ones. - Finally, notice that the number of non-zero rows will equal the number of non-zero columns. Combining this with the first point, it’s enough to deal with only square grids, i.e, solving for an N\times M grid is the same as solving for a \min(N, M)\times \min(N, M) grid.

Now let’s use the above observations to formulate a solution.

From now on, I’ll assume N = M since we established that that’s the only case that matters.

As in the second point above, let x_1, x_2, \ldots, x_d be the distinct row-sums of the grid.

Then, note that we must have x_1 + x_2 + \ldots + x_d = N.

So, our objective is to minimize x_1^2 + x_2^2 + \ldots + x_d^2 subject to the above constraint. Note that the x_i are all distinct here.

Since N \leq 5000, solving this in \mathcal{O}(N^2) is still ok: and that is exactly what we will do using dynamic programming.

Let dp(i, j) denote the minimum possible value of the sum of squares, such that:

- The sum of the elements chosen is i; and
- We have picked elements from \{1, 2, \ldots, j\}.

Under this formulation, the answer is simply dp(N, N).

Computing dp(i, j) is not hard: we either choose j or we don’t, so

This allows us to solve a single testcase in \mathcal{O}(N^2).

However, note that there are a large number of testcases, and so running this \mathcal{O}(N^2) algorithm separately for each is not quite enough.

However, the dp states themselves don’t actually change across testcases: the values computed are the exact same.

So, simply precompute the dp table for all possible values before even processing any testcase: then, a testcase can be answered in \mathcal{O}(1).

# TIME COMPLEXITY

\mathcal{O}(N^2) precomputation followed by \mathcal{O}(1) per test case.

# CODE:

## Setter's code (C++)

```
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5000 + 10;
int dp[maxn][maxn];
int main () {
ios_base::sync_with_stdio(false), cin.tie(0);
memset(dp, 63, sizeof(dp));
for (int i = 0; i < maxn; i++) {
dp[0][i] = 0;
}
for (int i = 1; i < maxn; i++) {
for (int j = 1; j < maxn; j++) {
dp[i][j] = dp[i][j-1];
if (j <= i) dp[i][j] = min(dp[i][j], dp[i-j][j-1] + j*j);
}
}
int t;
cin >> t;
while (t--) {
int n, m;
cin >> n >> m;
n = min(n, m);
cout << dp[n][n];
if (t) cout << "\n";
}
return 0;
}
```

## Editorialist's code (Python)

```
inf = 10 ** 18
maxn = 5005
dp = [inf]*(maxn+1)
ans = [0]*(maxn+1)
dp[0] = 0
for i in range(1, maxn+1):
for j in reversed(range(i, maxn+1)):
dp[j] = min(dp[j], dp[j-i] + i*i)
ans[i] = dp[i]
for _ in range(int(input())):
n, m = map(int, input().split())
n = min(n, m)
ans = 0
for i in range(1, n+1):
ans = max(ans, i-1 + (n-i+1)**2)
print(ans)
```