GRIDSQRS - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Practice

Setter: Nishant Shah
Tester: Abhinav Sharma and Lavish Gupta
Editorialist: Taranpreet Singh

DIFFICULTY

Easy

PREREQUISITES

Precomputation.

PROBLEM

You are given a binary square matrix A of size N \times N. Let the value at cell (i, j) be denoted by A(i, j).

Your task is to count the number of square frames present in the grid. A square frame is defined to be a square submatrix of A whose border elements are all ‘1’.

Formally,

  • A square submatrix of A of size k with top-left corner (i, j) is defined to be the set of cells \{(i+x, j+y) \mid 0 \leq x, y \lt k\}. Note that this requires i+k-1 \leq N and j+k-1 \leq N.
  • A square frame of size k with top-left corner (i, j) is defined to be a square submatrix of size k such that A(i+x, j+y) = 1 whenever x = 0 or y = 0 or x = k-1 or y = k-1. There is no constraint on the values of elements strictly inside the frame.

Refer to the sample explanation for more details.

QUICK EXPLANATION

  • There are N^3 candidates for square frames, and we need to find a way to check if a candidate is indeed a frame faster than O(N)
  • For each position, we can precompute the number of '1’s starting from that position in each direction.

EXPLANATION

In this problem, let us focus on the number of possible frames if the whole grid was filled with '1’s only. It would be the number of squares having the top left and bottom right corners inside the grid.

Each square can be represented by triplet (r, c, s), denoting a square with the top-left cell at (r, c) and having side length s. This triplet must also satisfy max(r, c)+s-1 \leq N. Ignoring this constraint, each of the r, c and s can take at most N values, providing an upper bound of N^3 possible frames. If we could check each candidate one by one and determine quickly if the frame of square (r, c, s) is filled with '1’s, we can solve this problem in O(N^3).

Checking if the frame of the square is filled with 1s

Now we have a square represented by triplet (r, c, s). We need to check if the border of this square is filled with 1s or not.

For each cell, let’s compute D_{i, j} as the number of cells starting from cell (i, j) moving downwards containing the value ‘1’ before first ‘0’, or border of the grid. Assuming we have D_{i+1, j} computed, we can compute D_{i, j} = 0 if A_{i, j} = 0, otherwise D_{i, j} = 1 + D_{i+1, j}.

Similarly, we can define U_{i, j} for upward direction, L_{i, j} for left direction and R_{i, j} for right direction.

Now, assuming we have this computed for all positions, we need to check if frame of square (r, c, s) is filled with 1s or not. We can check top border by checking if R_{r, c} \geq s, left border by D_{r, c} \geq s, bottom border by checking if L_{r+s-1, c+s-1} \geq s and right border by checking if U_{r+s-1, c+s-1} \geq s. If all four conditions are satisfied, the frame of square (r, c, s) is filled with 1s.

Hence, by computing D_{i, j}, U_{i, j}, R_{i, j} and L_{i, j} beforehand in O(N^2), we can solve the problem in O(N^3) which is enough to get AC on this problem.

Just a fact, in order to compute D_{i, j}, D_{i+1, j} needs to be computed first, which can be ensured by reversing the order of loops. Similarly for L_{i, j}.

Bonus

Solve this problem by computing only two matrices beforehand, not four.

Bonus

Solve this problem in O(N^2*log(N)). The authors originally intended to disallow O(N^3) solutions, as they had the model solution with complexity O(N^2*log(N)) with a constant factor, but O(N^3) solution with some optimizations was able to beat the model solution, so they decided to allow O(N^3) solution.

Hint

The solution processes each diagonal in O(N*log(N)) using Fenwick/Segment tree and there are 2*N-1 diagonals.

TIME COMPLEXITY

The time complexity is O(N^3) per test case.

SOLUTIONS

Setter's Solution
#include <bits/stdc++.h>
using namespace std;
   
#define pb push_back
#define S second
#define F first
#define f(i,n) for(int i=0;i<n;i++)
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define vi vector<int>
#define pii pair<int,int>
#define all(x) x.begin(),x.end()

const int MOD = 1e9+7;

int mod_pow(int a,int b,int M = MOD)
{
    if(a == 0) return 0;
    b %= (M - 1);  //M must be prime here
    
    int res = 1;
    
    while(b > 0)
    {
        if(b&1) res=(res*a)%M;
        a=(a*a)%M;
        b>>=1;
    }
    
    return res;
}

const int N = 2000 + 10;
string s[N];
int U[N][N],L[N][N],R[N][N],D[N][N];

void solve()
{
   int n;
   cin >> n;
    
   f(i,n) cin >> s[i];
    
   f(i,n) f(j,n)
      U[i][j] = D[i][j] = L[i][j] = R[i][j] = (s[i][j] == '1');
    
   f(i,n) f(j,n)
   {
       if(s[i][j] == '0') continue;
       if(i > 0) U[i][j] = U[i-1][j] + 1;
       if(j > 0) L[i][j] = L[i][j-1] + 1;
   }
    
   for(int i=n-1;i>=0;i--)
       for(int j=n-1;j>=0;j--)
   {
       if(s[i][j] == '0') continue;
       if(i != n-1) D[i][j] = D[i+1][j] + 1;
       if(j != n-1) R[i][j] = R[i][j+1] + 1;
   }
    
   int res = 0;
    
   for(int i=0;i<n+n;i++)
   {
       vector<pii> pts;
       
       for(int j=0;j<n;j++)
       {
           //{j,i-j}
           if(i - j >= 0 && i - j < n)
           {
               pts.pb({j,i-j});
           }
       }
       
       for(auto x : pts)
           for(auto y : pts)
             if(x.F <= y.F)
       {
           int r1 = min(L[x.F][x.S],D[x.F][x.S]);
           int r2 = min(U[y.F][y.S],R[y.F][y.S]);
           int sz = y.F - x.F + 1;
           
           if(r1 >= sz && r2 >= sz) res++;
       }
   }
    
   cout << res << '\n';
}

signed main()
{
    fast;
    
    int t = 1;
    
    cin >> t;
    
    while(t--)
        
    solve();
}
Tester's Solution 1
#include <bits/stdc++.h>
using namespace std;
 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/
 
const int MAX_T = 100000;
const int MAX_N = 1e6+5;
const int MAX_SUM_LEN = 1e5;
 
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define ff first
#define ss second
#define mp make_pair
#define ll long long
 
int sum_len = 0;
int max_n = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;

struct fentree{
    // 0 based indexing
    vector<int> v;
    int _n;
    fentree(int n){
        v.assign(n+5,0);
        _n = n+5;
    }

    void upd(int pos, int val){
        while(pos<_n){
            v[pos]+=val;
            pos|=(pos+1);
        }
    }

    int qr(int pos){
        int ret = 0;
        while(pos>=0){
            ret += v[pos];
            pos&=(pos+1);
            pos--;
        }
        return ret;
    }

    // int bitSearch(int sum){
    //     int ret = -1;
    //     rev(i, 21){
    //         if(ret+(1<<i)>=_n) continue;
    //         if(v[ret+(1<<i)]>=sum) continue;
    //         else{
    //             ret += (1<<i);
    //             sum -= v[ret];
    //         }
    //     }

    //     return ret+1;
    // }
};

    
void solve()
{   
    int n;
    n = readIntLn(1, 1000);

    max_n += n*n;
    sum_len += n;
    string s[n];

    for(int i=0; i<n; i++) s[i] = readStringLn(n,n);

    vector<vector<vector<int> > > z(n, vector<vector<int> >(n, vector<int>(4)));
    
    for(int i=0; i<n; i++){
        z[i][n-1][0] = (s[i][n-1]=='0'?0:1);
        for(int j=n-2; j>=0; j--){
            z[i][j][0] = (s[i][j]=='1'?z[i][j+1][0]+1:0);
        }
    }


    for(int i=0; i<n; i++){
        z[i][0][2] = (s[i][0]=='0'?0:1);
        for(int j=1; j<n; j++){
            z[i][j][2] = (s[i][j]=='1'?z[i][j-1][2]+1:0);
        }
    }

    for(int j=0; j<n; j++){
        z[0][j][3] = (s[0][j]=='0'?0:1);
        for(int i=1; i<n; i++){
            z[i][j][3] = (s[i][j]=='1'?z[i-1][j][3]+1:0);
        }
    }   

    for(int j=0; j<n; j++){
        z[n-1][j][1] = (s[n-1][j]=='0'?0:1);
        for(int i=n-2; i>=0; i--){
            z[i][j][1] = (s[i][j]=='1'?z[i+1][j][1]+1:0);
        }
    }   


    long long ans = 0;
    for(int j=0; j<n; j++){
        struct fentree ft(n);
        vector<vector<int> > dlt(n+2);

        int l=0, r=j;
        while(r<n){
            for(auto h:dlt[r]){
                ft.upd(h, -1);
            }

            if(s[l][r]=='0'){
                l++;
                r++;
                continue;
            }

            ft.upd(r, 1);
            dlt[r+min(z[l][r][0], z[l][r][1])].push_back(r);

            int len = min(z[l][r][2], z[l][r][3]);
            ans += ft.qr(r)-(r-len>=0?ft.qr(r-len):0);

            l++;
            r++;
        }

    }

    for(int j=1; j<n; j++){
        struct fentree ft(n);
        vector<vector<int> > dlt(n+2);

        int l =j, r=0;

        while(l<n){
            for(auto h:dlt[l]){
                ft.upd(h, -1);
            }

            if(s[l][r]=='0'){
                l++;
                r++;
                continue;
            }

            ft.upd(l, 1);
            dlt[l+min(z[l][r][0], z[l][r][1])].push_back(l);

            int len = min(z[l][r][2], z[l][r][3]);
            ans += ft.qr(l)-(l-len>=0?ft.qr(l-len):0);

            l++;
            r++;
        }
    }

    cout<<ans<<'\n';

}
 
signed main()
{

    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r" , stdin);
    freopen("output.txt", "w" , stdout);
    #endif
    fast;
    
    int t = 1;
    
    t = readIntLn(1,MAX_T);
    
    for(int i=1;i<=t;i++)
    {    
       solve();
    }
    
    assert(getchar() == -1);
 
    cerr<<"SUCCESS\n";
    cerr<<"Tests : " << t << '\n';
    cerr<<"Sum of lengths : " << sum_len << '\n';
    // cerr<<"Maximum length : " << max_n << '\n';
    // cerr<<"Total operations : " << total_ops << '\n';
    //cerr<<"Answered yes : " << yess << '\n';
    //cerr<<"Answered no : " << nos << '\n';
}
Tester's Solution 2
#include <bits/stdc++.h>
using namespace std;
 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/
 
const int MAX_T = 100000;
const int MAX_N = 1e6+5;
const int MAX_SUM_LEN = 1e5;
 
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define ff first
#define ss second
#define mp make_pair
#define ll long long
 
int sum_len = 0;
int max_n = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;


    
void solve()
{   
    int n;
    // n = readIntLn(1, 1000);
    cin>>n;

    max_n += n*n;
    sum_len += n;
    string s[n];

    for(int i=0; i<n; i++) cin>>s[i];

    // vector<vector<vector<int> > > z(n, vector<vector<int> >(n, vector<int>(4,0)));
    int z[n][n][4];

    for(int i=0; i<n; i++){
        z[i][n-1][0] = (s[i][n-1]=='0'?0:1);
        for(int j=n-2; j>=0; j--){
            z[i][j][0] = (s[i][j]=='1'?z[i][j+1][0]+1:0);
        }
    }


    for(int i=0; i<n; i++){
        z[i][0][2] = (s[i][0]=='0'?0:1);
        for(int j=1; j<n; j++){
            z[i][j][2] = (s[i][j]=='1'?z[i][j-1][2]+1:0);
        }
    }

    for(int j=0; j<n; j++){
        z[0][j][3] = (s[0][j]=='0'?0:1);
        for(int i=1; i<n; i++){
            z[i][j][3] = (s[i][j]=='1'?z[i-1][j][3]+1:0);
        }
    }   

    for(int j=0; j<n; j++){
        z[n-1][j][1] = (s[n-1][j]=='0'?0:1);
        for(int i=n-2; i>=0; i--){
            z[i][j][1] = (s[i][j]=='1'?z[i+1][j][1]+1:0);
        }
    }   

    int maxi[n][n][2] ;
    for(int i = 0 ; i < n ; i++)
    {
        for(int j = 0 ; j < n ; j++)
        {
            maxi[i][j][0] = min(z[i][j][0] , z[i][j][1]);
            maxi[i][j][1] = min(z[i][j][2] , z[i][j][3]) ;
        }
    }
    int ans = 0 ;
    for(int i = 0 ; i < n ; i++)
    {
        int x = i , y = 0 ;
        for(; x < n ; x++ , y++)
        {
            for(int xd = x , yd = y, xlim = max(-1 , x-maxi[x][y][1]); xd > xlim; xd-- , yd--)
            {
                if(maxi[xd][yd][0] > (x-xd))
                    ans++ ;
            }
        }
    }
    for(int i = 1 ; i < n ; i++)
    {
        int y = i , x = 0 ;
        for(; y < n ; x++ , y++)
        {
            for(int xd = x , yd = y, ylim = max(-1 , y-maxi[x][y][1]); yd > ylim ; xd-- , yd--)
            {
                if(maxi[xd][yd][0] > (x-xd))
                    ans++ ;
            }
        }
    }
    cout << ans << '\n' ;
    return ;

}


signed main()
{

    #ifndef ONLINE_JUDGE
    freopen("inputf.txt", "r" , stdin);
    freopen("outputf.txt", "w" , stdout);
    #endif
    fast;
    
    int t = 1;
    
    // t = readIntLn(1,MAX_T);
    cin>>t;
    
    for(int i=1;i<=t;i++)
    {    
       solve();
    }
    
    assert(getchar() == -1);
 
    cerr<<"SUCCESS\n";
    cerr<<"Tests : " << t << '\n';
    cerr<<"Sum of lengths : " << sum_len << '\n';
    // cerr<<"Maximum length : " << max_n << '\n';
    // cerr<<"Total operations : " << total_ops << '\n';
    //cerr<<"Answered yes : " << yess << '\n';
    //cerr<<"Answered no : " << nos << '\n';
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class GRIDSQRS{
    //SOLUTION BEGIN
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        int N = ni();
        boolean[][] g = new boolean[2+N][2+N];
        for(int i = 1; i<= N; i++){
            String s = n();
            for(int j = 1; j<= N; j++)g[i][j] = s.charAt(j-1) == '1';
        }
        
        int[][][] sum = new int[4][2+N][2+N];
        
        for(int i = 1; i<= N; i++){
            for(int j = 1; j <= N; j++){
                sum[0][i][j] = (g[i][j]?(1+sum[0][i-1][j]):0);
                sum[1][i][j] = (g[i][j]?(1+sum[1][i][j-1]):0);
            }
        }
        for(int i = N; i >= 1; i--){
            for(int j = N; j >= 1; j--){
                sum[2][i][j] = (g[i][j]?(1+sum[2][i+1][j]):0);
                sum[3][i][j] = (g[i][j]?(1+sum[3][i][j+1]):0);
            }
        }
        
        int[][] f1 = new int[2+N][2+N], f2 = new int[2+N][2+N];
        for(int i = 1; i<= N; i++){
            for(int j = 1; j<= N; j++){
                f1[i][j] = Math.min(sum[2][i][j], sum[3][i][j]);
                f2[i][j] = Math.min(sum[0][i][j], sum[1][i][j]);
            }
        }
        int ans = 0;
        for(int i = 1; i<= N; i++)
            for(int j = 1; j <= N; j++)
                for(int d = 0; Math.max(i, j)+d <= N && d < f1[i][j]; d++)
                    if(f1[i][j] > d && f2[i+d][j+d] > d)
                        ans++;
        pn(ans);
    }
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = true;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
        new GRIDSQRS().run();
    }
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
                }
            }
            return st.nextToken();
        }

        String nextLine() throws Exception{
            String str = "";
            try{   
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            }  
            return str;
        }
    }
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

https://www.codechef.com/viewsolution/54799603

N^2 * log(N) solution using segtree.

https://www.codechef.com/viewsolution/54859574

Can anyone please help me out with my solution?
I am getting TLE for the last 3 sub-tasks.

As of my knowledge, my code runs for a maximum of O(N^3)

while I was going through the successful submissions, I found the following solution written in PYPY3.
https://www.codechef.com/viewsolution/54742663

my solution(Python 3.6) - Solution: 54859574 | CodeChef
This solution also matches the same logic I used. Yet my solution is giving TLE for the last 3 sub-tasks.

Please help me rectify the time complexity.

1 Like

https://www.codechef.com/viewsolution/54661113

this is also a N^3 solution, you can refer this also

my n^3 submission failed harder than n^4 solution.
n^3 sol : Solution: 54855720 | CodeChef
n^4 sol : Solution: 54650704 | CodeChef
HELP!!!

my n^2 log N solution : using bisect module in python
https://www.codechef.com/viewsolution/54657457

help me to find time complexity of my code.
i don’t know what happend @sayan_244 @ayushkumar_10
Plz…Solution: 55001409 | CodeChef

Can anyone tell me what’s 's wrong with this solution

I tried implementing this in python PYTH3.6 but always getting TLE. I saw a solution that passed all test cases in PYPY3 which is similar to mine. I guess its not possible to pass this with PYTH3.6
Would really appreciate if someone could point out any mistake I’m making.
https://www.codechef.com/viewsolution/55066657

You could easily counter O(N^3) to pass by giving only one testcase with N <= 2000 and around 3-4 seconds time limit ( depending on the runtime of the O(N^2log) solution )

Can anyone explain this part??

It’s just a different style of implementation. pts now contains pairs (i, j). So treat x and y as positions in the grid, so pair (x, y) is a pair of positions.

Thanks for replying.

Can you tell me why are you iterating n + n times.
Also x & y are positions, which positions??

Please help!!

O(N^2 * log(N)) Approach :

Continuing from the hints given in the editorial, let us calculate answer for one diagonal in O(N * log(N)) time.

Lets precompute for each cell (i, j):
1. The maximum side length of a square for which it can act as a top-left cell, next[i][j]
2. The maximum side length of a square for which it can act as a bottom-right cell, prev[i][j].

Every square frame can be divided into two parts:
Part 1: Top Edge + Left Edge
Part 2: Bottom Edge + Right Edge

Now let us solve for one diagonal, Suppose we represent the diagonal in form of an array: dia[5]

Let us iterate on this array and maintain during each iteration, how many indices on left can act as Part 1 of a square, so that current cell can act as bottom-right cell. Suppose when at index i = 3, our array is dia[5] = {1, 0, 1, 1, 0}. It signifies, indices [0, 2, 3] can act as top-left corner when bottom-right corner is 3. This can be maintained by using next[][] array and segment-tree updates.

Secondly, we’ll have to make sure we only take those elements as top-left corner, for which current element can act as bottom-right corner. So we do a range query on segment tree, on range [i - prev[][] + 1, i].

Code : https://www.codechef.com/viewsolution/55797435