HOLLOW - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Practice

Setter:
Tester and Editorialist: Taranpreet Singh

DIFFICULTY

Medium

PREREQUISITES

Binary Search, two-dimensional prefix sums

PROBLEM

Given a grid with N rows and M columns, filled with zeros and ones, find the largest size of square filled completely with zeros you can make by at most K operations, each operation swapping values at any two different cells.

QUICK EXPLANATION

  • We need to find largest square containing atmost K ones.
  • Counting the number of ones in any sub-square can be speed up using two dimensional prefix sums.
  • So we can binary search over size of largest square.
  • Only edge case is when there are not enough zeros in grid to make a square with side S. It can be easily handled by fixing upper bound of \lfloor \sqrt{Z} \rfloor where Z denote number of zeros in grid.

EXPLANATION

In order to make a square with side S filled with zeros, we can move at most K ones from chosen square to outside. Hence, we are looking for largest square with at most K ones.

Let’s discuss the solutions one by one, incrementally improving till we hit optimal solution.

Approach 1

One solution would be to fix top left and bottom right corner and then counting the number of ones. This solution takes O((N*M)^3) as there are (N*M)^2 pairs of cells and for each pair of cells, we asymptotically iterate over matrix. This approach will time out.

Optimization 1

For each pair of cells, counting the number of ones takes a lot of time. So we are looking for a faster way to count the number of ones in a sub-grid, preferentially in O(1).

So we have a grid with no updates (no cell value changes), and we want to optimize querying the number of cells in any sub-grid.

Let’s assume we compute f(r, c) which stores the number of ones in sub-grid with top left corner (1, 1) and bottom right corner (r, c). Defining f(0, c) and f(r, 0) to be zero. Assume this function computes the number of ones in O(1) time.

For computing the number of ones in sub-grid denoted by (r1, c1) as top-left corner and (r2, c2) as bottom right corner can be represented as follows:

We want to compute the number of ones in region 4.

We can see that

  • f(r1-1, c1-1) is the number of ones in region 1
  • f(r1-1, c2) is the number of ones in region 1 and region 2.
  • f(r2, c1-1) is the number of ones in region 1 and region 3.
  • f(r2, c2) is the number of ones in region 1, 2, 3 and region 4.

Note: The one is subtracted from r1 and c1 since we want to include cell (r1, c1) in region cell.
So we can try an express number of ones in region 4 by combination of these four. With basic knowledge of set theory, you can convince yourself that

Number of ones in region 4 is f(r2, c2) + f(r1-1, c1-1) - f(r1-1, c2) - f(r2, c1-1) It’s a good exercize to prove why.

Hence, if we have f pre-computed for each cell, we can count the number of ones in each sub-grid fast, optimizing approach one to Time complexity O((N*M)^2).

Computing f is easy, as f(r, c) = f(r-1, c)+f(r, c-1) - f(r-1, c-1) + M_{r, c} where M_{r, c} denote the value in cell (r, c). This recurrence can be proved by similar equations, taking r1 = r2 and c1 = c2

Above data-structure is known as Two-dimensional prefix sums, about which you can read more here and here

Approach 2

Since we only care about square size, Let’s try all square sizes from 1 to min(N, M). For square size S, we need to check if there’s some sub-grid with at most K ones. For this, let’s iterate over all candidates of top-left cell, and count the number of ones in square of side S with current cell as the top-left corner. For all S having atleast one such square, the maximum is the required answer.

Using prefix sums, the time complexity of this approach is O(min(N, M)*N*M)

Optimization 2

In above approach, we try all square sizes. Let’s assume for some s, there’s at least one square with top left cell (r, c) and side s containing up to K ones. Then for all s' \leq s, the square with top left cell (r, c) and side s' shall have up to K ones only.

Hence, there exists a side length S such that for all sides 1 \leq s \leq S, there’s at least one square with up to K ones, and for all S < s \leq min(N, M), there doesn’t exist any square of size s with up to K ones.

Let’s denote g(s) = 1 if a sub-square with side s and up to K ones exist, otherwise g(s) = 0

Hence, function of g is monotonic, and this allows us to binary search on the largest value of s for which g(s) = 1. This largest S is the required side length we are looking for.

Time complexity of this approach is O(log_2(min(N, M)) * N*M)

Edge case

You thought it was over, wasn’t it?

What if there aren’t enough zeros in grid to make a square of side S. Consider following case

3 3 1
101
001
111

Even though there’s a square with side 2 with at most K ones, we cannot make it filled with zeros, since there aren’t four zeros in grid.

Hence, the final square size is bounded by \lfloor \sqrt Z \rfloor.

Follow up

Can you solve this problem in three dimensions? Think about prefix sums in cuboid.

Hint

For n dimensions, there’d be 2^n terms. Have fun figuring out terms.

TIME COMPLEXITY

The time complexity is O(N*M*log_2(min(N, M))) per test case.
The space complexity is O(N*M) per test case.

SOLUTIONS

Setter's Solution
#include <bits/stdc++.h> 

using namespace std;

const int maxn = 1e3, maxm = 1e3;
const int minv = 0, maxv = 1;

int main()
{   
    int t; cin >> t;

    while(t--){
        int n, m, k; cin >> n >> m >> k;
        
        int dp[n + 1][m + 1];
        for(int i = 0; i <= n; i++){
            for(int j = 0; j <= m; j++)dp[i][j] = 0;
        }
        for(int i = 1; i <= n; i++){
            string s; cin >> s;
            for(int j = 1; j <= m; j++){
                int val = s[j - 1] - '0';
                dp[i][j] = dp[i - 1][j] + dp[i][j - 1] - dp[i - 1][j - 1] + val;
            }
        }
        int t0 = n * m - dp[n][m];
        int ans = 0;
        for(int i = 1; i <= n; i++){
            for(int j = 1; j <= m; j++){
                int l = 0, r = min(n - i, m - j);
                while(l <= r){
                    int mid = (l + r) >> 1;
                    int one = dp[i + mid][j + mid] - dp[i - 1][j + mid] - dp[i + mid][j - 1] + dp[i - 1][j - 1];
                    int zero = (mid + 1) * (mid + 1) - one;
                    if(zero + k >= (mid + 1) * (mid + 1) && t0 - zero >= (mid + 1) * (mid + 1) - zero){
                        ans = max(ans, mid + 1); l = mid + 1;
                    }else{
                        r = mid - 1;
                    }
                }
            }
        }
        cout << ans << endl;
    }
} 
Tester/Editorialist's Solution
import java.util.*;
import java.io.*;
class HOLLOW{
    //SOLUTION BEGIN
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        int N = ni(), M = ni(), K = ni();
        int[][] grid = new int[1+N][1+M], pref = new int[1+N][1+M];
        int count = 0;
        for(int i = 1; i<= N; i++){
            String S = " "+n();
            for(int j = 1; j<= M; j++){
                grid[i][j] = S.charAt(j)-'0';
                pref[i][j] = grid[i][j] + pref[i-1][j]+pref[i][j-1]-pref[i-1][j-1];
                count += 1-grid[i][j];
            }
        }
        int lo = 0, hi = Math.min(N, M);
        while(hi*hi > count)hi--;//Taking care of edge case
        while(lo < hi){
            int mid = lo+(hi-lo+1)/2;
            int ones = minOnes(N, M, pref, mid);
            if(ones <= K)lo = mid;
            else hi = mid-1;
        }
        pn(lo);
    }
    int minOnes(int N, int M, int[][] pref, int size){
        int count = Integer.MAX_VALUE;
        for(int i = size; i<= N; i++){
            for(int j = size; j <= M; j++){
                count = Math.min(count, pref[i][j] + pref[i-size][j-size] - pref[i-size][j] - pref[i][j-size]);
            }
        }
        return count;
    }
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = true;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
        new HOLLOW().run();
    }
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
                }
            }
            return st.nextToken();
        }

        String nextLine() throws Exception{
            String str = "";
            try{   
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            }  
            return str;
        }
    }
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

10 Likes

I always screw up at implementing binary search, can anyone please tell me what did I do wrong this time, this one doesn’t even pass the samples. What changes should I do to make this work? And how do you guys fix the off-by-one issues in binary search problems, even though I had 1 hour left I couldn’t fix this.

INCORRECT_CODE
// #include<bits/stdc++.h>
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <vector>
#include <set>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <queue>
#include <string>
#include <cstring>
#include <random>
#include <bitset>
using namespace std;
#define vt vector 
#define ar array 
#define sz(a) (int)a.size() 
#define ll long long 
// debugger credits: https://codeforces.com/blog/entry/68809 
//#pragma GCC optimize "trapv"
#define F_OR(i, a, b, s) for (int i = (a); ((s) > 0 ? i < (b) : i > (b)); i += (s))
#define F_OR1(e) F_OR(i, 0, e, 1)
#define F_OR2(i, e) F_OR(i, 0, e, 1)
#define F_OR3(i, b, e) F_OR(i, b, e, 1)
#define F_OR4(i, b, e, s) F_OR(i, b, e, s)
#define GET5(a, b, c, d, e, ...) e
#define F_ORC(...) GET5(__VA_ARGS__, F_OR4, F_OR3, F_OR2, F_OR1)
#define FOR(...) F_ORC(__VA_ARGS__)(__VA_ARGS__)
#define EACH(x, a) for (auto& x: a)
template<class T> bool umin(T& a, const T& b) {
	return b<a?a=b, 1:0;
}
template<class T> bool umax(T& a, const T& b) { 
	return a<b?a=b, 1:0;
}
template<class A> void read(vt<A>& v);
template<class A, size_t S> void read(ar<A, S>& a);
template<class T> void read(T& x) {
	cin >> x;
}
void read(double& d) {
	string t;
	read(t);
	d=stod(t);
}
void read(long double& d) {
	string t;
	read(t);
	d=stold(t);
}
template<class H, class... T> void read(H& h, T&... t) {
	read(h);
	read(t...);
}
template<class A> void read(vt<A>& x) {
	EACH(a, x)
		read(a);
}
template<class A, size_t S> void read(array<A, S>& x) {
	EACH(a, x)
		read(a);
}
const int mxN=1e5,di[4]={1,0,-1,0},dj[4]={0,-1,0,1};
void solve(){		
	int n,m,k ;
	read(n,m,k) ;
	vt<string>a(n) ;read(a) ;
	vt<vt<int>>dp(n+2,vt<int>(m+2)),dp2(n+2,vt<int>(m+2));
	int c=0 ;
	FOR(i,1,n+1)
		FOR(j,1,m+1){
			c+=a[i-1][j-1]=='0' ;
			dp[i][j]=dp[i-1][j]+dp[i][j-1]-dp[i-1][j-1]+(a[i-1][j-1]=='1') ;
			dp2[i][j]=dp2[i-1][j]+dp2[i][j-1]-dp2[i-1][j-1]+(a[i-1][j-1]=='0') ;
		}
	int ans = -1 ;
	FOR(i,1,n+1)
		FOR(j,1,m+1){
			int lb=1,rb=min(n-i+1,m-j+1) ;
			while(lb<rb){
				int mb=(lb+rb)/2 ;
				int lx=i,ly=j ;
				int rx=min(i+mb-1,n),ry=min(j+mb-1,m) ;
				int A=dp[rx][ry]-dp[rx][ly-1]-dp[lx-1][ry]+dp[lx-1][ly-1] ;
				int Z=dp2[rx][ry]-dp2[rx][ly-1]-dp2[lx-1][ry]+dp2[lx-1][ly-1] ;
				Z=c-Z ;
				bool ok=1 ;
				if(A>Z)
					ok=0 ;
				if(A>k)
					ok=0 ;
				if(ok)
					lb=mb+1 ;
				else 
					rb=mb ;
			}
			umax(ans,lb-1) ;
		}
	cout<<ans<<'\n';

}
signed main() {
  ios_base::sync_with_stdio(false);
  cin.tie(NULL);
  //cout << setprecision(20) << fixed ;
  int T=1;
	read(T);
	FOR(_,T){
		// pff("Case #", _+1, ": ");
		solve();
	}
	return 0;
}


what do u mean by off-by-one?

thanks for the editorial

Like most of the time, the result I get by binary search differs from the original answer by 1. And if I change the initial bounds or change the while(lb<rb) to while(lb<=rb) then it works, can you tell a concrete approach one should follow while implementing this.

have you solved all the binary search problems from codeforces edu?

2 Likes

No I haven’t, I’ll make it a point to check that

1 Like

Try updating the mid as answer every time in the necessary condition. Then just cout<<ans at the end. Works every time.

See, most of the binary search problems that i do need to find the leftmost element which is \geq something, or rightmost element \leq something.
Take this problem for example, we need to find if a square of size k is possible, and need the maximum value of k.
I do this-

int l=1,r=n,mid,ans=0;    //you can take ans to be the default if answer isn't possible for any value(depends on problem)
while(l<=r)
{
    mid=(l+r)>>1;
    if(possible(mid))
    {
        ans=mid;
        l=mid+1;
    }
    else r=mid-1;
}
System.out.print(ans);

If you do this, you will need an extra ans variable, but the problem you are facing shouldn’t exist.

Check line 36 of my code for this problem.

4 Likes
CODE
// #include<bits/stdc++.h>
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <vector>
#include <set>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <queue>
#include <string>
#include <cstring>
#include <random>
#include <bitset>
using namespace std;
#define vt vector 
#define ar array 
#define sz(a) (int)a.size() 
#define ll long long 
// debugger credits: https://codeforces.com/blog/entry/68809 
//#pragma GCC optimize "trapv"
#define F_OR(i, a, b, s) for (int i = (a); ((s) > 0 ? i < (b) : i > (b)); i += (s))
#define F_OR1(e) F_OR(i, 0, e, 1)
#define F_OR2(i, e) F_OR(i, 0, e, 1)
#define F_OR3(i, b, e) F_OR(i, b, e, 1)
#define F_OR4(i, b, e, s) F_OR(i, b, e, s)
#define GET5(a, b, c, d, e, ...) e
#define F_ORC(...) GET5(__VA_ARGS__, F_OR4, F_OR3, F_OR2, F_OR1)
#define FOR(...) F_ORC(__VA_ARGS__)(__VA_ARGS__)
#define EACH(x, a) for (auto& x: a)
template<class T> bool umin(T& a, const T& b) {
	return b<a?a=b, 1:0;
}
template<class T> bool umax(T& a, const T& b) { 
	return a<b?a=b, 1:0;
}
template<class A> void read(vt<A>& v);
template<class A, size_t S> void read(ar<A, S>& a);
template<class T> void read(T& x) {
	cin >> x;
}
void read(double& d) {
	string t;
	read(t);
	d=stod(t);
}
void read(long double& d) {
	string t;
	read(t);
	d=stold(t);
}
template<class H, class... T> void read(H& h, T&... t) {
	read(h);
	read(t...);
}
template<class A> void read(vt<A>& x) {
	EACH(a, x)
		read(a);
}
template<class A, size_t S> void read(array<A, S>& x) {
	EACH(a, x)
		read(a);
}
const int mxN=1e5,di[4]={1,0,-1,0},dj[4]={0,-1,0,1};
void solve(){		
	int n,m,k ;
	read(n,m,k) ;
	vt<string>a(n) ;read(a) ;
	vt<vt<int>>dp(n+2,vt<int>(m+2)),dp2(n+2,vt<int>(m+2));
	int c=0 ;
	FOR(i,1,n+1)
		FOR(j,1,m+1){
			c+=a[i-1][j-1]=='0' ;
			dp[i][j]=dp[i-1][j]+dp[i][j-1]-dp[i-1][j-1]+(a[i-1][j-1]=='1') ;
			dp2[i][j]=dp2[i-1][j]+dp2[i][j-1]-dp2[i-1][j-1]+(a[i-1][j-1]=='0') ;
		}
	int ans = -1 ;
	FOR(i,1,n+1)
		FOR(j,1,m+1){
			int lb=1,rb=min(n-i+1,m-j+1) ;
			while(lb<=rb){
				int mb=(lb+rb)/2 ;
				int lx=i,ly=j ;
				int rx=min(i+mb-1,n),ry=min(j+mb-1,m) ;
				int A=dp[rx][ry]-dp[rx][ly-1]-dp[lx-1][ry]+dp[lx-1][ly-1] ;
				int Z=dp2[rx][ry]-dp2[rx][ly-1]-dp2[lx-1][ry]+dp2[lx-1][ly-1] ;
				Z=c-Z ;
				bool ok=1 ;
				if(A>Z)
					ok=0 ;
				if(A>k)
					ok=0 ;
				if(ok)
					lb=mb+1 ;
				else 
					rb=mb-1 ;
			}
			umax(ans,lb-1) ;
		}
	cout<<ans<<'\n';

}
signed main() {
  ios_base::sync_with_stdio(false);
  cin.tie(NULL);
  //cout << setprecision(20) << fixed ;
  int T=1;
	read(T);
	FOR(_,T){
		// pff("Case #", _+1, ": ");
		solve();
	}
	return 0;
}


:man_facepalming: Now this one got accepted, I only changed the < to <= and rb=mb-1. I wish someone would’ve told this to me earlier. Thanks a ton!!

3 Likes

such things happen pal, I remember running a binary search in the opposite direction(reversing conditions of l=mid+1 and r=mid-1) on multiple occasions

2 Likes

So is this the generic binary search template you use for all binary search problems or you manipulate the <= or < ,r=m-1 or r=m etc… according to the problem ?

1 Like

i don’t use any templates, i write one according to the problem

2 Likes

Can someone explain the last problem solution. I found no editorial for it.

  1. take all points from all possible line segments of A and B
  2. sort all of them
  3. select the unique points
  4. assign them appropriate values fr eg set of pts=[1,6,8,10], corresp values=[1,2,3,4]
  5. do the same thing u do when u have to add +1 to a segment, ie mark starting indices as +1, ending as -1, and perform cumulative sum
  6. any point on this newly formed array now indicates number of lines moving from pt(i) to pt(i+1), so the value it can contribute will be [pt(i+1)-pt(i)]*(number of lines)
  7. perform prefix sum
  8. find value for all ranges of B in O(1) using the prefix sum
    The total sum is your answer. Time complexity is NlogN.

The code.

1 Like
1 Like

Oh, so sorry that I was not able to find the editorial.

// I AM TRYING TO DEBUGGING MY CODE FROM LAST 2 HRS BUT NOT ABLE THAT WHERE I AM WRONG .PLEASE SOMEONE HELP ME! THNX


import java.util.*;
import java.lang.*;


 public class Main
{
static int[][] arr ;
static int[][] dp ;
static int n,m,k,c0,c1 ;
 

public static boolean check(int a)
{
    
    if(a*a > c0)return false  ;
    
    
     for(int i = 1 ; i <= n-a+1 ;i++)
     {
         for(int j = 1 ; j <= m-a+1 ;j++)
         {
           int x = i+a-1 ;int y = j+a-1 ;  
           
             int val1 = dp[x][y] ;
             int val2  = dp[i-1][j-1] ;
             int val3  = dp[i-1][y] ;
             int val4  = dp[x][j-1] ;
             
             
             int val = val1 -val3-val4+val2  ;
             
             val =  (a*a) -val ;
             
             if(val <= k)return true  ;
             
         }
}
    

return false  ;
    
}


public static void solve()
{

Scanner scn = new Scanner(System.in) ; 

int testcase = 1;
 testcase = scn.nextInt() ;
for(int testcases =1  ; testcases <= testcase ;testcases++)
{
   

n= scn.nextInt() ; m= scn.nextInt() ; k= scn.nextInt() ;

arr = new int[n+10][m+10] ;dp = new int[n+10][m+10] ;


for(int i = 1 ; i <= n ;i++)
{
    String s = scn.next() ;
    s= "%" + s ;
    
    for(int j = 1 ; j <= m ;j++)
    {
        if(s.charAt(j) == '1')arr[i][j] = 1;
        else arr[i][j] = 0 ;
        
        if(arr[i][j] == 0)c0++ ;
        else c1++ ;
        
    }
    
    
    
}


if(arr[1][1] == 0)dp[1][1] =1 ;
else dp[1][1] =0;


for(int j = 2 ; j <= m;j++)
{
    dp[1][j] = dp[1][j-1] ;
    
    if(arr[1][j] == 0)dp[1][j]++ ;
}



for(int i = 2 ; i <= n;i++)
{
    dp[i][1] = dp[i-1][1] ;
    
    if(arr[i][1] == 0)dp[i][1]++ ;
}





for(int i = 2 ; i <= n ; i++)
{
    for(int j = 2; j <= m ;j++)
    {
        
        dp[i][j] = dp[i-1][j] + dp[i][j-1] -dp[i-1][j-1] ;
        
         if(arr[i][j] == 0)dp[i][j]++ ;
        
    }
    
}



int s = 0 ;int e  = Math.min(n,m) ;

int ans = 0 ;

while(s <= e)
{
    int m = (s+e)/2 ;
    
    
    if(check(m))
    {
        ans = m ;
        s=m+1 ;
    }
    
    else{
        
        e=m-1  ;
    }
    
    
    
    
}


System.out.println(ans) ;

} 

    
} 


public static void main (String[] args) throws java.lang.Exception
{
  

solve() ;
      
}


}

IT CAN LOOK A BIT LARGE AT FIRST GLANCE BUT IT IS SAME ,CHECK Fn CHECKING THE CONDITION AND REST JUST PREFIX CALCULATION .

For calculating mid. The best practice is to use mid = r-l / 2 + l
This will prevent overflow in case of l+r

2 Likes