PROBLEM LINK:
Contest Division 1
Contest Division 2
Contest Division 3
Practice
Setter: Nishant Shah
Tester: Abhinav Sharma and Lavish Gupta
Editorialist: Taranpreet Singh
DIFFICULTY
Easy
PREREQUISITES
Precomputation.
PROBLEM
You are given a binary square matrix A of size N \times N. Let the value at cell (i, j) be denoted by A(i, j).
Your task is to count the number of square frames present in the grid. A square frame is defined to be a square submatrix of A whose border elements are all ‘1’.
Formally,
- A square submatrix of A of size k with top-left corner (i, j) is defined to be the set of cells \{(i+x, j+y) \mid 0 \leq x, y \lt k\}. Note that this requires i+k-1 \leq N and j+k-1 \leq N.
- A square frame of size k with top-left corner (i, j) is defined to be a square submatrix of size k such that A(i+x, j+y) =
1
whenever x = 0 or y = 0 or x = k-1 or y = k-1. There is no constraint on the values of elements strictly inside the frame.
Refer to the sample explanation for more details.
QUICK EXPLANATION
- There are N^3 candidates for square frames, and we need to find a way to check if a candidate is indeed a frame faster than O(N)
- For each position, we can precompute the number of '1’s starting from that position in each direction.
EXPLANATION
In this problem, let us focus on the number of possible frames if the whole grid was filled with '1’s only. It would be the number of squares having the top left and bottom right corners inside the grid.
Each square can be represented by triplet (r, c, s), denoting a square with the top-left cell at (r, c) and having side length s. This triplet must also satisfy max(r, c)+s-1 \leq N. Ignoring this constraint, each of the r, c and s can take at most N values, providing an upper bound of N^3 possible frames. If we could check each candidate one by one and determine quickly if the frame of square (r, c, s) is filled with '1’s, we can solve this problem in O(N^3).
Checking if the frame of the square is filled with 1s
Now we have a square represented by triplet (r, c, s). We need to check if the border of this square is filled with 1s or not.
For each cell, let’s compute D_{i, j} as the number of cells starting from cell (i, j) moving downwards containing the value ‘1’ before first ‘0’, or border of the grid. Assuming we have D_{i+1, j} computed, we can compute D_{i, j} = 0 if A_{i, j} = 0, otherwise D_{i, j} = 1 + D_{i+1, j}.
Similarly, we can define U_{i, j} for upward direction, L_{i, j} for left direction and R_{i, j} for right direction.
Now, assuming we have this computed for all positions, we need to check if frame of square (r, c, s) is filled with 1s or not. We can check top border by checking if R_{r, c} \geq s, left border by D_{r, c} \geq s, bottom border by checking if L_{r+s-1, c+s-1} \geq s and right border by checking if U_{r+s-1, c+s-1} \geq s. If all four conditions are satisfied, the frame of square (r, c, s) is filled with 1s.
Hence, by computing D_{i, j}, U_{i, j}, R_{i, j} and L_{i, j} beforehand in O(N^2), we can solve the problem in O(N^3) which is enough to get AC on this problem.
Just a fact, in order to compute D_{i, j}, D_{i+1, j} needs to be computed first, which can be ensured by reversing the order of loops. Similarly for L_{i, j}.
Bonus
Solve this problem by computing only two matrices beforehand, not four.
Bonus
Solve this problem in O(N^2*log(N)). The authors originally intended to disallow O(N^3) solutions, as they had the model solution with complexity O(N^2*log(N)) with a constant factor, but O(N^3) solution with some optimizations was able to beat the model solution, so they decided to allow O(N^3) solution.
Hint
The solution processes each diagonal in O(N*log(N)) using Fenwick/Segment tree and there are 2*N-1 diagonals.
TIME COMPLEXITY
The time complexity is O(N^3) per test case.
SOLUTIONS
Setter's Solution
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define S second
#define F first
#define f(i,n) for(int i=0;i<n;i++)
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define vi vector<int>
#define pii pair<int,int>
#define all(x) x.begin(),x.end()
const int MOD = 1e9+7;
int mod_pow(int a,int b,int M = MOD)
{
if(a == 0) return 0;
b %= (M - 1); //M must be prime here
int res = 1;
while(b > 0)
{
if(b&1) res=(res*a)%M;
a=(a*a)%M;
b>>=1;
}
return res;
}
const int N = 2000 + 10;
string s[N];
int U[N][N],L[N][N],R[N][N],D[N][N];
void solve()
{
int n;
cin >> n;
f(i,n) cin >> s[i];
f(i,n) f(j,n)
U[i][j] = D[i][j] = L[i][j] = R[i][j] = (s[i][j] == '1');
f(i,n) f(j,n)
{
if(s[i][j] == '0') continue;
if(i > 0) U[i][j] = U[i-1][j] + 1;
if(j > 0) L[i][j] = L[i][j-1] + 1;
}
for(int i=n-1;i>=0;i--)
for(int j=n-1;j>=0;j--)
{
if(s[i][j] == '0') continue;
if(i != n-1) D[i][j] = D[i+1][j] + 1;
if(j != n-1) R[i][j] = R[i][j+1] + 1;
}
int res = 0;
for(int i=0;i<n+n;i++)
{
vector<pii> pts;
for(int j=0;j<n;j++)
{
//{j,i-j}
if(i - j >= 0 && i - j < n)
{
pts.pb({j,i-j});
}
}
for(auto x : pts)
for(auto y : pts)
if(x.F <= y.F)
{
int r1 = min(L[x.F][x.S],D[x.F][x.S]);
int r2 = min(U[y.F][y.S],R[y.F][y.S]);
int sz = y.F - x.F + 1;
if(r1 >= sz && r2 >= sz) res++;
}
}
cout << res << '\n';
}
signed main()
{
fast;
int t = 1;
cin >> t;
while(t--)
solve();
}
Tester's Solution 1
#include <bits/stdc++.h>
using namespace std;
/*
------------------------Input Checker----------------------------------
*/
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
if(!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(1 == 0);
}
return x;
} else {
assert(false);
}
}
}
string readString(int l,int r,char endd){
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
return readString(l,r,'\n');
}
string readStringSp(int l,int r){
return readString(l,r,' ');
}
/*
------------------------Main code starts here----------------------------------
*/
const int MAX_T = 100000;
const int MAX_N = 1e6+5;
const int MAX_SUM_LEN = 1e5;
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define ff first
#define ss second
#define mp make_pair
#define ll long long
int sum_len = 0;
int max_n = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;
struct fentree{
// 0 based indexing
vector<int> v;
int _n;
fentree(int n){
v.assign(n+5,0);
_n = n+5;
}
void upd(int pos, int val){
while(pos<_n){
v[pos]+=val;
pos|=(pos+1);
}
}
int qr(int pos){
int ret = 0;
while(pos>=0){
ret += v[pos];
pos&=(pos+1);
pos--;
}
return ret;
}
// int bitSearch(int sum){
// int ret = -1;
// rev(i, 21){
// if(ret+(1<<i)>=_n) continue;
// if(v[ret+(1<<i)]>=sum) continue;
// else{
// ret += (1<<i);
// sum -= v[ret];
// }
// }
// return ret+1;
// }
};
void solve()
{
int n;
n = readIntLn(1, 1000);
max_n += n*n;
sum_len += n;
string s[n];
for(int i=0; i<n; i++) s[i] = readStringLn(n,n);
vector<vector<vector<int> > > z(n, vector<vector<int> >(n, vector<int>(4)));
for(int i=0; i<n; i++){
z[i][n-1][0] = (s[i][n-1]=='0'?0:1);
for(int j=n-2; j>=0; j--){
z[i][j][0] = (s[i][j]=='1'?z[i][j+1][0]+1:0);
}
}
for(int i=0; i<n; i++){
z[i][0][2] = (s[i][0]=='0'?0:1);
for(int j=1; j<n; j++){
z[i][j][2] = (s[i][j]=='1'?z[i][j-1][2]+1:0);
}
}
for(int j=0; j<n; j++){
z[0][j][3] = (s[0][j]=='0'?0:1);
for(int i=1; i<n; i++){
z[i][j][3] = (s[i][j]=='1'?z[i-1][j][3]+1:0);
}
}
for(int j=0; j<n; j++){
z[n-1][j][1] = (s[n-1][j]=='0'?0:1);
for(int i=n-2; i>=0; i--){
z[i][j][1] = (s[i][j]=='1'?z[i+1][j][1]+1:0);
}
}
long long ans = 0;
for(int j=0; j<n; j++){
struct fentree ft(n);
vector<vector<int> > dlt(n+2);
int l=0, r=j;
while(r<n){
for(auto h:dlt[r]){
ft.upd(h, -1);
}
if(s[l][r]=='0'){
l++;
r++;
continue;
}
ft.upd(r, 1);
dlt[r+min(z[l][r][0], z[l][r][1])].push_back(r);
int len = min(z[l][r][2], z[l][r][3]);
ans += ft.qr(r)-(r-len>=0?ft.qr(r-len):0);
l++;
r++;
}
}
for(int j=1; j<n; j++){
struct fentree ft(n);
vector<vector<int> > dlt(n+2);
int l =j, r=0;
while(l<n){
for(auto h:dlt[l]){
ft.upd(h, -1);
}
if(s[l][r]=='0'){
l++;
r++;
continue;
}
ft.upd(l, 1);
dlt[l+min(z[l][r][0], z[l][r][1])].push_back(l);
int len = min(z[l][r][2], z[l][r][3]);
ans += ft.qr(l)-(l-len>=0?ft.qr(l-len):0);
l++;
r++;
}
}
cout<<ans<<'\n';
}
signed main()
{
#ifndef ONLINE_JUDGE
freopen("input.txt", "r" , stdin);
freopen("output.txt", "w" , stdout);
#endif
fast;
int t = 1;
t = readIntLn(1,MAX_T);
for(int i=1;i<=t;i++)
{
solve();
}
assert(getchar() == -1);
cerr<<"SUCCESS\n";
cerr<<"Tests : " << t << '\n';
cerr<<"Sum of lengths : " << sum_len << '\n';
// cerr<<"Maximum length : " << max_n << '\n';
// cerr<<"Total operations : " << total_ops << '\n';
//cerr<<"Answered yes : " << yess << '\n';
//cerr<<"Answered no : " << nos << '\n';
}
Tester's Solution 2
#include <bits/stdc++.h>
using namespace std;
/*
------------------------Input Checker----------------------------------
*/
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
if(!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(1 == 0);
}
return x;
} else {
assert(false);
}
}
}
string readString(int l,int r,char endd){
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
return readString(l,r,'\n');
}
string readStringSp(int l,int r){
return readString(l,r,' ');
}
/*
------------------------Main code starts here----------------------------------
*/
const int MAX_T = 100000;
const int MAX_N = 1e6+5;
const int MAX_SUM_LEN = 1e5;
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define ff first
#define ss second
#define mp make_pair
#define ll long long
int sum_len = 0;
int max_n = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;
void solve()
{
int n;
// n = readIntLn(1, 1000);
cin>>n;
max_n += n*n;
sum_len += n;
string s[n];
for(int i=0; i<n; i++) cin>>s[i];
// vector<vector<vector<int> > > z(n, vector<vector<int> >(n, vector<int>(4,0)));
int z[n][n][4];
for(int i=0; i<n; i++){
z[i][n-1][0] = (s[i][n-1]=='0'?0:1);
for(int j=n-2; j>=0; j--){
z[i][j][0] = (s[i][j]=='1'?z[i][j+1][0]+1:0);
}
}
for(int i=0; i<n; i++){
z[i][0][2] = (s[i][0]=='0'?0:1);
for(int j=1; j<n; j++){
z[i][j][2] = (s[i][j]=='1'?z[i][j-1][2]+1:0);
}
}
for(int j=0; j<n; j++){
z[0][j][3] = (s[0][j]=='0'?0:1);
for(int i=1; i<n; i++){
z[i][j][3] = (s[i][j]=='1'?z[i-1][j][3]+1:0);
}
}
for(int j=0; j<n; j++){
z[n-1][j][1] = (s[n-1][j]=='0'?0:1);
for(int i=n-2; i>=0; i--){
z[i][j][1] = (s[i][j]=='1'?z[i+1][j][1]+1:0);
}
}
int maxi[n][n][2] ;
for(int i = 0 ; i < n ; i++)
{
for(int j = 0 ; j < n ; j++)
{
maxi[i][j][0] = min(z[i][j][0] , z[i][j][1]);
maxi[i][j][1] = min(z[i][j][2] , z[i][j][3]) ;
}
}
int ans = 0 ;
for(int i = 0 ; i < n ; i++)
{
int x = i , y = 0 ;
for(; x < n ; x++ , y++)
{
for(int xd = x , yd = y, xlim = max(-1 , x-maxi[x][y][1]); xd > xlim; xd-- , yd--)
{
if(maxi[xd][yd][0] > (x-xd))
ans++ ;
}
}
}
for(int i = 1 ; i < n ; i++)
{
int y = i , x = 0 ;
for(; y < n ; x++ , y++)
{
for(int xd = x , yd = y, ylim = max(-1 , y-maxi[x][y][1]); yd > ylim ; xd-- , yd--)
{
if(maxi[xd][yd][0] > (x-xd))
ans++ ;
}
}
}
cout << ans << '\n' ;
return ;
}
signed main()
{
#ifndef ONLINE_JUDGE
freopen("inputf.txt", "r" , stdin);
freopen("outputf.txt", "w" , stdout);
#endif
fast;
int t = 1;
// t = readIntLn(1,MAX_T);
cin>>t;
for(int i=1;i<=t;i++)
{
solve();
}
assert(getchar() == -1);
cerr<<"SUCCESS\n";
cerr<<"Tests : " << t << '\n';
cerr<<"Sum of lengths : " << sum_len << '\n';
// cerr<<"Maximum length : " << max_n << '\n';
// cerr<<"Total operations : " << total_ops << '\n';
//cerr<<"Answered yes : " << yess << '\n';
//cerr<<"Answered no : " << nos << '\n';
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class GRIDSQRS{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
int N = ni();
boolean[][] g = new boolean[2+N][2+N];
for(int i = 1; i<= N; i++){
String s = n();
for(int j = 1; j<= N; j++)g[i][j] = s.charAt(j-1) == '1';
}
int[][][] sum = new int[4][2+N][2+N];
for(int i = 1; i<= N; i++){
for(int j = 1; j <= N; j++){
sum[0][i][j] = (g[i][j]?(1+sum[0][i-1][j]):0);
sum[1][i][j] = (g[i][j]?(1+sum[1][i][j-1]):0);
}
}
for(int i = N; i >= 1; i--){
for(int j = N; j >= 1; j--){
sum[2][i][j] = (g[i][j]?(1+sum[2][i+1][j]):0);
sum[3][i][j] = (g[i][j]?(1+sum[3][i][j+1]):0);
}
}
int[][] f1 = new int[2+N][2+N], f2 = new int[2+N][2+N];
for(int i = 1; i<= N; i++){
for(int j = 1; j<= N; j++){
f1[i][j] = Math.min(sum[2][i][j], sum[3][i][j]);
f2[i][j] = Math.min(sum[0][i][j], sum[1][i][j]);
}
}
int ans = 0;
for(int i = 1; i<= N; i++)
for(int j = 1; j <= N; j++)
for(int d = 0; Math.max(i, j)+d <= N && d < f1[i][j]; d++)
if(f1[i][j] > d && f2[i+d][j+d] > d)
ans++;
pn(ans);
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new GRIDSQRS().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Feel free to share your approach. Suggestions are welcomed as always.