PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: gunpoint_88
Tester: raysh07
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
PROBLEM:
Given an N\times M binary grid, find the maximum weight of a rectangular subgrid whose borders are all ones.
The weight of a rectangular subgrid equals the number of ones strictly inside it.
EXPLANATION:
Without loss of generality, let N\leq M (if not, transpose the grid).
Let’s fix rows i and j (1 \leq i \lt j \leq N) and try to consider all bags with these rows as their border.
Once the rows are fixed, note that the only ‘useful’ columns are those which are filled with ones between these rows — other columns cannot be the borders of a bag at all.
So, for each column from 1 to M, check whether it’s useful or not - this can be done in \mathcal{O}(1) using prefix sums, by finding the sum of the column between those rows and checking if it equals the required length.
Finally, we need to find the largest bag with two of these useful columns as their borders.
Suppose we fix column R to be the right border of the bag, and look to find the optimal left column L.
Clearly, it’s best to choose L to be as small as possible - however, we also need to ensure that the top and bottom borders are filled with 1's.
To do this quickly, note that the top and bottom borders both being 1 defines some segments where this condition holds - so the best we can do is choose the leftmost L belonging to the same segment as R.
Finding these segments is quite easy, by just iterating left to right and checking if both rows i and j contain ones — if they do, extend the segment; otherwise end it.
Once L and R are fixed, we have all the borders of the bag.
All that remains is to find its weight, i.e, the number of ones inside it.
This is just the sum of a rectangle, and can be done in \mathcal{O}(1) using 2D prefix sums.
The final answer is the maximum weight across all the bags we considered.
The time complexity of this solution is \mathcal{O}(N^2 M), since we fixed two rows and then iterated all columns.
However, recall that we initially assumed N \leq M, which along with N\cdot M \leq 10^5 means that N \leq\sqrt{10^5}.
So, the complexity is bounded by \mathcal{O}(NM\sqrt{10^5}), which is good enough.
In particular, we’ll always have \min(N, M) \leq \sqrt{10^5}, so transposing the grid if necessary for a complexity of \mathcal{O}(NM\min(N, M)) is always fast.
TIME COMPLEXITY
\mathcal{O}(N M \min(N, M)) per testcase.
CODE:
Author's code (C++)
#include<bits/stdc++.h>
using namespace std;
const int inf=1e9;
#ifdef ANI
#include "D:/DUSTBIN/local_inc.h"
#else
#define dbg(...) 0
#endif
vector<string> rotate(vector<string> g) {
int n=g.size(),m=g[0].size();
vector<string> res(m,string(n,'0'));
for(int i=0;i<n;i++)
for(int j=0;j<m;j++)
res[j][n-1-i]=g[i][j];
return res;
}
int solution(vector<string> g) {
if(g[0].size()>g.size()) g=rotate(g);
int n=g.size(),m=g[0].size();
/*
dp[i][j] : max weight inside inverted cap b/w i..j
*/
const int inf=1e9;
vector<vector<int>> dp(m,vector<int>(m,-inf));
int ans=0;
for(int i=0;i<n;i++) {
vector<int> p(m+1,0);
for(int j=0;j<m;j++)
p[j+1]=p[j]+(g[i][j]=='1');
vector<vector<int>> ndp(m,vector<int>(m,-inf));
for(int j=0;j<m;j++) {
for(int k=j+1;k<m;k++) {
if(g[i][j]=='1' && g[i][k]=='1') {
ndp[j][k]=dp[j][k]+p[k+1]-p[j]-2;
}
if(p[k+1]-p[j]==(k-j+1)) {
ndp[j][k]=max(ndp[j][k],0);
ans=max(ans,dp[j][k]);
}
}
}
dp=ndp;
}
return ans;
}
void solve() {
int n,m;
cin>>n>>m;
vector<string> g(n);
for(int i=0;i<n;i++) {
cin>>g[i];
}
cout<<solution(g)<<'\n';
}
int main() {
int t=1;
cin>>t;
while(t--) {
solve();
}
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
//#define int long long
#define INF (int)1e18
#define f first
#define s second
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}
int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
now++;
}
return now;
}
string readOne() {
assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
return res;
}
string readString(int minl, int maxl, const string &pattern = "") {
assert(minl <= maxl);
string res = readOne();
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int minv, int maxv) {
assert(minv <= maxv);
int res = stoi(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
long long res = stoll(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
auto readInts(int n, int minv, int maxv) {
assert(n >= 0);
vector<int> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readInt(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
auto readLongs(int n, long long minv, long long maxv) {
assert(n >= 0);
vector<long long> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readLong(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
};
input_checker inp;
const int N = 1e5;
int sum_nm = 0;
void Solve(int n, int m, vector<string> a)
{
sum_nm += n * m;
if (m > 700){
vector <string> b(m);
for (int i = 0; i < n; i++){
for (int j = 0; j < m; j++){
b[j] += a[i][j];
}
}
swap(n, m);
swap(a, b);
}
vector<vector<int>> p1(n + 1, vector<int>(m + 1, 0));
vector<vector<int>> pr(n + 1, vector<int>(m + 1, 0));
for (int i = 1; i <= n; i++){
for (int j = 1; j <= m; j++){
p1[i][j] = p1[i][j - 1] + (a[i - 1][j - 1] == '1');
pr[i][j] = pr[i][j - 1] + pr[i - 1][j] - pr[i - 1][j - 1] + (a[i - 1][j - 1] == '1');
}
}
int ans = 0;
vector <bool> good(n + 1);
for (int j1 = 1; j1 <= m; j1++){
for (int j2 = j1; j2 <= m; j2++){
for (int i = 1; i <= n; i++){
int x = p1[i][j2] - p1[i][j1 - 1];
int y = j2 - j1 + 1;
good[i] = x == y;
}
int mx = 0;
int p1 = 0, p2 = 0, p = 0;
for (int i = 1; i <= n; i++){
p1 = max(p1, i - 1);
p2 = max(p2, i - 1);
p = min(p1, p2);
while (p1 != n && a[p1][j1 - 1] == '1') p1++;
while (p2 != n && a[p2][j2 - 1] == '1') p2++;
while (p < min(p1, p2)){
p++;
if (good[p]) mx = p;
}
auto query = [&](int r1, int r2, int c1, int c2){
if (r1 > r2 || c1 > c2) return 0;
return pr[r2][c2] - pr[r1 - 1][c2] - pr[r2][c1 - 1] + pr[r1 - 1][c1 - 1];
};
if (good[i])
ans = max(ans, query(i + 1, mx - 1, j1 + 1, j2 - 1));
}
}
}
cout << ans << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
t = inp.readInt(1, 100000);
inp.readEoln();
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
int n, m;
n = inp.readInt(1, N);
inp.readSpace();
m = inp.readInt(1, N);
inp.readEoln();
vector<string> a(n);
for (int j = 0; j < n; j++){
a[j] = inp.readString(m, m, "01");
inp.readEoln();
}
Solve(n, m, a);
}
inp.readEof();
assert(sum_nm <= N);
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
n, m = map(int, input().split())
a = ['0'*(m+1)] + [list('0' + input()) for _ in range(n)]
if n > m: a = list(map(list, zip(*a)))
n, m = len(a), len(a[0])
ans = 0
pref = [[int(a[i][j]) for j in range(m)] for i in range(n)]
for i in range(1, n):
for j in range(1, m):
pref[i][j] = pref[i][j] + pref[i-1][j] + pref[i][j-1] - pref[i-1][j-1]
def get(x1, y1, x2, y2):
if x1 > x2 or y1 > y2: return 0
return pref[x2][y2] - pref[x1-1][y2] - pref[x2][y1-1] + pref[x1-1][y1-1]
for i in range(n):
for j in range(i+1, n):
L = -1
for k in range(m):
if a[i][k] == '0' or a[j][k] == '0':
L = -1
continue
if get(i, k, j, k) == j-i+1:
if L == -1: L = k
else: ans = max(ans, get(i+1, L+1, j-1, k-1))
print(ans)