TRPTSTIC - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Authors: d_k_7386
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

2095

PREREQUISITES:

Binary search, 2D prefix sums

PROBLEM:

K students and one mentor want to stay in a hotel that has N\times M rooms, arranged in N rows and M columns.
The room at the intersection of the i-th row and j-th column can accommodate A_{i, j} people; and the mentor can stay in the same room as a student.

The distance between two rooms (x_1, y_1) and (x_2, y_2) is \max(|x_1-x_2|, |y_1-y_2|).
Find the minimum possible possible distance between the mentor’s room and the farthest student’s room.

EXPLANATION:

First, if the sum of all A_{i, j} is \leq K, then it’s impossible to accommodate K+1 people in the hotel, so the answer is -1.
Otherwise, a valid answer always exists.

Suppose we fix which room the mentor is staying in, say (x, y).
Note that this is only possible when A_{x, y} \neq 0.
Suppose we also fix the maximum allowed distance D between the mentor and a student.

Notice that with these two constraints, the set of cells where students are allowed to stay forms a rectangular subgrid of A, specifically,

  • We want all cells (i, j) such that \max(|x-i|, |y-j|) \leq D. This means |x-i|\leq D and |y-j|\leq D.
  • From the definition of absolute value, this means
    • -D \leq x-i \leq D
    • -D\leq y-j \leq D.
  • Rearrange this to
    • x-D \leq i \leq x+D
    • y-D\leq j \leq y+D.

This gives us a range of i and j, forming the rectangle [x-D, x+D]\times [y-D, y+D].
In particular, if K+1 people can be fit into this rectangle, then it’s possible for the maximum distance to be at most D.

Checking whether K+1 people fit into this rectangle is simple to do in \mathcal{O}(1) after some precomputation.
Notice that we only want the sum of all values in the specified rectangle. This is doable with 2D prefix sums: a tutorial can be found here.


We are now able to quickly check, for a fixed (x, y) and D, whether a maximum distance of \leq D is possible.
However, there are N\times M possible cells (x, y) and upto \max(N, M) values of D for each of them, so going through them all would still be too slow.

However, notice that if we’re able to achieve a maximum distance of \leq D, then of course we can achieve a maximum distance of \leq D+1.
So, we only need to find the smallest D such that there exists some cell (x, y) which satisfies the condition.

This is exactly what binary search does!

That gives us the following solution:

  • Binary search on the value of D, from 0 to \max(N, M).
  • For a fixed value of D, go through all cells (x, y) such that A_{x, y}\neq 0, and check whether any of them allow for a maximum distance of \leq D, using 2D prefix sums as discussed above.

For a fixed value of D, this takes \mathcal{O}(NM) time.
Since we’re applying binary search, we check only \mathcal{O}(\log\max(N, M)) values of D, for a solution that’s \mathcal{O}(NM\log\max(N, M)).

TIME COMPLEXITY

\mathcal{O}(NM\log\max(N, M)) per test case.

CODE:

Setter's code (C++)
#define ll long long int
#include<bits/stdc++.h>
#define loop(i,a,b) for(ll i=a;i<b;++i)
#define rloop(i,a,b) for(ll i=a;i>=b;i--)
#define in(a,n) for(ll i=0;i<n;++i) cin>>a[i];
#define pb push_back
#define mk make_pair
#define all(v) v.begin(),v.end()
#define dis(v) for(auto i:v)cout<<i<<" ";cout<<endl;
#define display(arr,n) for(int i=0; i<n; i++)cout<<arr[i]<<" ";cout<<endl;
#define fast ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);srand(time(NULL));
#define l(a) a.length()
#define s(a) (ll)a.size()
#define fr first
#define sc second
#define mod 1000000007
#define endl '\n'
#define yes cout<<"Yes"<<endl;
#define no cout<<"No"<<endl;
using namespace std;
#define debug(x) cerr << #x<<" "; _print(x); cerr << endl;
void _print(ll t) {cerr << t;}
void _print(int t) {cerr << t;}
void _print(string t) {cerr << t;}
void _print(char t) {cerr << t;}
void _print(double t) {cerr << t;}
template <class T, class V> void _print(pair <T, V> p);
template <class T> void _print(vector <T> v);
template <class T> void _print(set <T> v);
template <class T, class V> void _print(map <T, V> v);
template <class T> void _print(multiset <T> v);
template <class T, class V> void _print(pair <T, V> p) {cerr << "{"; _print(p.fr); cerr << ","; _print(p.sc); cerr << "}";}
template <class T> void _print(vector <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template <class T> void _print(set <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template <class T> void _print(multiset <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template <class T, class V> void _print(map <T, V> v) {cerr << "[ "; for (auto i : v) {_print(i); cerr << " ";} cerr << "]";}

ll add(ll x,ll y)  {ll ans = x+y; return (ans>=mod ? ans - mod : ans);}
ll sub(ll x,ll y)  {ll ans = x-y; return (ans<0 ? ans + mod : ans);}
ll mul(ll x,ll y)  {ll ans = x*y; return (ans>=mod ? ans % mod : ans);}


void solve(){
    ll n,m,k;   cin>>n>>m>>k;
    vector<vector<ll>> v(n,vector<ll>(m));
    loop(i,0,n) loop(j,0,m) cin>>v[i][j];
    vector<vector<ll>> vec = v;
    loop(i,1,n) v[i][0]+=v[i-1][0];
    loop(j,1,m) v[0][j]+=v[0][j-1];
    loop(i,1,n) loop(j,1,m) v[i][j]+=v[i-1][j]+v[i][j-1]-v[i-1][j-1];
    ll ans = INT_MAX;
    loop(i,0,n){
        loop(j,0,m){
            if(!vec[i][j])  continue;
            ll l = 0,r = max(n,m);
            while(l<=r){
                ll mid = (l+r)/2;
                ll x = min(n-1,i+mid),y = min(m-1,j+mid);
                ll sum = v[x][y];
                if(i-mid>0) sum-=v[i-mid-1][y];
                if(j-mid>0) sum-=v[x][j-mid-1];
                if(i-mid>0 && j-mid > 0)    sum+=v[i-mid-1][j-mid-1];
                if(sum>=k+1)    r = mid-1,ans = min(ans,mid);
                else l = mid+1;
            }
        }
    }
    if(ans == INT_MAX)  ans = -1;
    cout<<ans<<endl;
}


int main()
{
    fast
    int t; cin>>t;
    
    while(t--) solve();
    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e5);
    in.readEoln();
    int snm = 0;
    while (tt--) {
        int n = in.readInt(1, 1e6);
        in.readSpace();
        int m = in.readInt(1, 1e6);
        in.readSpace();
        int k = in.readInt(1, 1e9);
        in.readEoln();
        snm += n * m;
        vector<vector<int>> a(n);
        for (int i = 0; i < n; i++) {
            a[i] = in.readInts(m, 0, 1e5);
            in.readEoln();
        }
        {
            long long s = 0;
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < m; j++) {
                    s += a[i][j];
                }
            }
            if (s < k + 1) {
                cout << -1 << '\n';
                continue;
            }
        }
        vector<vector<long long>> b(n + 1, vector<long long>(m + 1));
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                b[i + 1][j + 1] = b[i + 1][j] + b[i][j + 1] - b[i][j] + a[i][j];
            }
        }
        int low = -1, high = n + m;
        while (high - low > 1) {
            int mid = (high + low) >> 1;
            int ok = 0;
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < m; j++) {
                    if (a[i][j]) {
                        int i0 = min(i + mid + 1, n);
                        int j0 = min(j + mid + 1, m);
                        int i1 = max(i - mid, 0);
                        int j1 = max(j - mid, 0);
                        if (b[i0][j0] - b[i0][j1] - b[i1][j0] + b[i1][j1] > k) {
                            ok = 1;
                        }
                    }
                }
            }
            if (ok) {
                high = mid;
            } else {
                low = mid;
            }
        }
        cout << high << '\n';
    }
    assert(snm <= 1e6);
    in.readEof();
    return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
	n, m, k = map(int, input().split())
	grid = [ [0 for i in range(m+1) ] ]
	for i in range(1, n+1): grid.append([0] + list(map(int, input().split())))
	
	pref = [ [0 for i in range(m+1)] for j in range(n+1)]
	for i in range(1, n+1):
		for j in range(1, m+1):
			pref[i][j] = pref[i-1][j] + pref[i][j-1] - pref[i-1][j-1] + grid[i][j]
	if pref[n][m] <= k:
		print(-1)
		continue
	
	def getsum(l1, r1, l2, r2):
		return pref[l2][r2] - pref[l1-1][r2] - pref[l2][r1-1] + pref[l1-1][r1-1]
	
	lo, hi = -1, n+m+1
	while lo+1 < hi:
		mid = (lo + hi)//2
		mxsum = 0
		for i in range(1, n+1):
			for j in range(1, m+1):
				if grid[i][j] == 0: continue
				l1 = max(1, i-mid)
				l2 = min(n, i+mid)
				r1 = max(1, j-mid)
				r2 = min(m, j+mid)
				mxsum = max(mxsum, getsum(l1, r1, l2, r2))
		if mxsum <= k: lo = mid
		else: hi = mid
	print(hi)
1 Like
def func(b,dp,k,arr):
    n=len(arr)
    m=len(arr[0])
    for i in range(n):
        for j in range(m):
            if arr[i][j]==0:
                continue 
            i1=min(n-1,i+b)
            j1=min(m-1,j+b)
            s=dp[i1+1][j1+1]-dp[max(0,i1+1-(2*b+1))][j1+1]-dp[i1+1][max(0,j1+1-(2*b+1))]+dp[max(0,i1+1-(2*b+1))][max(0,j1+1-(2*b+1))]
            if s>=k:
                return 1 
    return 0
    




for _ in range(int(input())):
    n,m,k=map(int,input().split())
    arr=[]
    for i in range(n):
        arr.append(list(map(int,input().split())))
    dp=[[0]*(m+1) for i in range(n+1)]
    for i in range(1,n+1):
        for j in range(1,m+1):
            dp[i][j]=arr[i-1][j-1]+dp[i-1][j]+dp[i][j-1]-dp[i-1][j-1]
    
    if dp[n][m]<(k+1):
        print(-1)
        continue
    l=0
    h=max(n,m)-1
    while(l<h):
        if h==l+1:
            if func(l,dp,k+1,arr):
                break 
            else:
                l=h 
                break
        mid=(l+h)//2 
        if func(mid,dp,k+1,arr):
            h=mid
        else:
            l=mid+1 
    print(l)

Need help with this code, I am getting TLE for some cases!!

Submit using pypy instead of python, the same code runs much faster. Here’s your code in pypy.

It gets WA now though.

1 Like

Got it!! Thanks!

Can someone tell me whats wrong in this code ? I have been trying all day

#include <bits/stdc++.h>
#define ll long long
using namespace std;

void solve()
{
    int n, m; ll k; cin>>n>>m>>k;
    ll idk = 0;
    vector<vector<int>>A(n, vector<int>(m));
    for(int i = 0; i<n; i++){
        for(int j = 0; j<m; j++){
            cin>>A[i][j]; //input
            idk+=A[i][j];
        }
    }

    if(idk<=k){ //edge case
        cout<<-1<<endl; return; 
    }

    vector<vector<ll>>pref(n, vector<ll>(m));
    for(int i = 0; i<n; i++){ //2d prefix sum
        for(int j = 0; j<m; j++){
            if(i == 0 and j == 0){
                pref[i][j] = A[i][j];
            }else if(i == 0){
                pref[i][j] = A[i][j] + pref[0][j-1];
            }else if(j == 0){
                pref[i][j] = A[i][j] + pref[i-1][0];
            }else{
                pref[i][j] =  A[i][j] + pref[i-1][j] + pref[i][j-1] - pref[i-1][j-1];
            }
        }
    }
    
    //     for(int i = 0; i<n; i++){
    //     for(int j = 0; j<m; j++){
    //         cout<<pref[i][j]<<" "; //input
    //     }cout<<endl;
    // }

    int ans = 1e9;

    for(int i = 0; i<n; i++){ //traverse 
        for(int j = 0; j<m; j++){
            if(A[i][j] == 0){   continue;   }
            ll l = 0, r = max(n, m);
            while(l<=r){ //do binary search for minMaxDist
                int mid = (l+r)/2;
                int x = min(i+mid, n-1); int y = min(j+mid, m-1);
                //rectangle from i-mid to i+mid
                //j-mid to j+mid
                int sum = pref[x][y];
                if(i - mid >=1){
                    sum-=pref[i-mid-1][y];
                }
                if(j-mid >=1){
                    sum-=pref[x][j-mid-1];
                }
                if(i-mid>=1 and j-mid>=1){
                    sum += pref[i-mid-1][j-mid-1];
                }
                if(sum >= k+1){
                    ans = min(ans, mid);
                    r = mid-1;
                }else{
                    l = mid+1;
                }
            }
        }
    }

    cout<<ans<<endl;


}


int main()
{
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    int t;
    cin >> t;
    while (t--)
    {
        solve();
    }
    return 0;
}

#include <bits/stdc++.h>
using namespace std;
void solve(){
int n,m,k;
cin>>n>>m>>k;
vector<vector> ar(n,vector (m));
int sum = 0;
for(int i = 0;i<n;i++){
for(int j = 0;j<m;j++){
cin>>ar[i][j];
sum += ar[i][j];
}
}
if(sum < k+1){
cout<<-1<<endl;
return;
}
vector<vector> preFix(n+1,vector (m+1));
for(int i = 1;i<n+1;i++){
for(int j = 1;j<m+1;j++){
preFix[i][j] = preFix[i-1][j] + preFix[i][j-1] - preFix[i-1][j-1] + ar[i-1][j-1];
}
}
int low = 0,high = max(n,m);
while(low <= high){
int mid = low + (high - low)/2;
int maxRooms = 0;
for(int i = 0;i<n;i++){
for(int j = 0;j<m;j++){

        if(ar[i][j] == 0) continue;
        
        int i_low = max(i - mid, 0);
        int j_low = max(j - mid, 0);
        int i_high = min(i + mid + 1,n);
        int j_high = min(j + mid + 1,m);
        
       int totalRooms = preFix[i_high][j_high] - preFix[i_low][j_high] - preFix[i_high][j_low] + preFix[i_low][j_low];
       maxRooms = max(maxRooms,totalRooms);
     }
    }
    if(maxRooms >= k+1){
        high = mid-1;
    }
    else{
        low = mid+1;
    }
}
cout<<low<<endl;

}
// 0 1 2 3 4 5 6
// 0 0 0 0 0 0 0 0
// 1 0 0 0 0 0 0 0
// 2 0 0 0 0 0 0 0
// 3 0 0 0 0 0 0 0
// 4 0 0 0 0 0 0 0
//0 1 2 3 4
//1 0 0 0 0
//2 1 1 5 5
//3 1 3 7 10

int main() {
// your code goes here
int t;
cin>>t;
while(t–){
solve();
}
return 0;
}
why my code fails

int sum = pref[x][y]; should be ll I think.

@dineshbabu3017 probably you have the same error, use long long because the prefix sums won’t fit in the range of int.

Also, next time please consider just pasting a link to your code instead of directly dumping the code itself here, it’s much harder to read.

2 Likes

Yes that was it. Its working now. I feel so dumb lol. Thanks a lot!!