HRSCPMTR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2

Setter: Venkata Nikhil Medam
Tester: Rahul Dugar
Editorialist: Taranpreet Singh

DIFFICULTY

Simple

PREREQUISITES

None

PROBLEM

Given a matrix of integers divided into N rows and M integers, and Q queries on matrix, each query replacing an element in the matrix with a specific value, determine after each query whether the matrix is good.

A matrix is considered good if, for each diagonal from top left to the bottom right direction, all elements have the same value.

QUICK EXPLANATION

Maintain a multiset of values for each diagonal. A matrix is good if all multiset contains one distinct value each.

EXPLANATION

Let’s consider each diagonal as a separate array.
For example, for matrix

1 2 3 4
5 6 7 8
9 10 11 12

We get the following arrays

9
5 10
1 6 11
2 7 12
3 8
4

For the matrix to be good, all elements in each array should be equal. So we can simply represent these elements using a multiset, which allows fast insertion, removal, and checking if all values in the multiset are the same or not.

Constant memory solution
The above solution was boring, nothing interesting, but there exists a way elegant solution to solve this problem in time O(N*M+Q) and memory complexity O(1) excluding input array.

Click here for solution

Let’s denote count as the number of adjacent pairs in the above arrays such that the values differ. Compute it for the given matrix.

Now, it is easy to recalculate this count after each update, since at most two pairs are affected. The matrix would be good if and only if the count is 0 after the update. The proof of why this work is left as an exercise.

TIME COMPLEXITY

Time complexity O(N*M*log(min(N,M))+Q*log(N+M))
Memory complexity O(N*M)

SOLUTIONS

Setter's Solution
// Setter: Nikhil_Medam
#include <bits/stdc++.h>
using namespace std;
 
#define IOS ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define endl "\n"
 
int32_t main()
{
	IOS;
	int t;
	cin >> t;
	while(t--)
	{
	    int n, m;
	    cin >> n >> m;
	    vector<vector<int>> a(n, vector<int>(m));
	    int count = 0;
	    for(int i = 0; i < n; i++)
	    {
	        for(int j = 0; j < m; j++)
	        {
	            cin >> a[i][j];
	        }
	    }
	    for(int i = 1; i < n; i++)
	    {
	        for(int j = 1; j < m; j++)
	        {
	            count += (a[i][j] != a[i - 1][j - 1]);
	        }
	    }
	    int q, x, y, v;
	    cin >> q;
	    while(q--)
	    {
	        cin >> x >> y >> v;
 
	        x--, y--;
 
	        if(x - 1 >= 0 and y - 1 >= 0)
	            count -= (a[x][y] != a[x - 1][y - 1]);
 
	        if(x + 1 < n and y + 1 < m)
	            count -= (a[x][y] != a[x + 1][y + 1]);
 
	        a[x][y] = v;
 
	        if(x - 1 >= 0 and y - 1 >= 0)
	            count += (a[x][y] != a[x - 1][y - 1]);
 
	        if(x + 1 < n and y + 1 < m)
	            count += (a[x][y] != a[x + 1][y + 1]);
	    
	        cout << (count == 0 ? "YES" : "NO") << endl;
	    }
	}
	return 0;
}
Tester's Solution
#pragma GCC optimize("Ofast")
#include <bits/stdc++.h>
using namespace std;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/rope>
using namespace __gnu_pbds;
using namespace __gnu_cxx;
#ifndef rd
#define trace(...)
#define endl '\n'
#endif
#define pb push_back
#define fi first
#define se second
#define int long long
typedef long long ll;
typedef long double f80;
#define double long double
#define pii pair<int,int>
#define pll pair<ll,ll>
#define sz(x) ((long long)x.size())
#define fr(a,b,c) for(int a=b; a<=c; a++)
#define rep(a,b,c) for(int a=b; a<c; a++)
#define trav(a,x) for(auto &a:x)
#define all(con) con.begin(),con.end()
const ll infl=0x3f3f3f3f3f3f3f3fLL;
const int infi=0x3f3f3f3f;
const int mod=998244353;
//const int mod=1000000007;
typedef vector<int> vi;
typedef vector<ll> vl;

typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> oset;
auto clk=clock();
mt19937_64 rang(chrono::high_resolution_clock::now().time_since_epoch().count());
int rng(int lim) {
	uniform_int_distribution<int> uid(0,lim-1);
	return uid(rang);
}

int powm(int a, int b) {
	int res=1;
	while(b) {
		if(b&1)
			res=(res*a)%mod;
		a=(a*a)%mod;
		b>>=1;
	}
	return res;
}

long long readInt(long long l, long long r, char endd) {
	long long x=0;
	int cnt=0;
	int fi=-1;
	bool is_neg=false;
	while(true) {
		char g=getchar();
		if(g=='-') {
			assert(fi==-1);
			is_neg=true;
			continue;
		}
		if('0'<=g&&g<='9') {
			x*=10;
			x+=g-'0';
			if(cnt==0) {
				fi=g-'0';
			}
			cnt++;
			assert(fi!=0 || cnt==1);
			assert(fi!=0 || is_neg==false);

			assert(!(cnt>19 || ( cnt==19 && fi>1) ));
		} else if(g==endd) {
			if(is_neg) {
				x=-x;
			}
			assert(l<=x&&x<=r);
			return x;
		} else {
			assert(false);
		}
	}
}
string readString(int l, int r, char endd) {
	string ret="";
	int cnt=0;
	while(true) {
		char g=getchar();
		assert(g!=-1);HRSCPMTR
		if(g==endd) {
			break;
		}
		cnt++;
		ret+=g;
	}
	assert(l<=cnt&&cnt<=r);
	return ret;
}
long long readIntSp(long long l, long long r) {
	return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
	return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
	return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
	return readString(l,r,' ');
}

int a[505][505];
int sum_nm=0,sum_q=0;
multiset<int> holo[1005];
void solve() {
	int n=readIntSp(1,500),m=readIntLn(1,500);
	fr(i,1,n) {
		rep(j,1,m)
			a[i][j]=readIntSp(-1000'000'000,1000'000'000);
		a[i][m]=readIntLn(-1000'000'000,1000'000'000);
	}
	int q=readIntLn(1,200'000);
	sum_nm+=n*m;
	sum_q+=q;
//	assert(sum_nm<=500000&&sum_q<=200000);
	fr(i,0,1000)
		holo[i].clear();
	fr(i,1,n)
		fr(j,1,m)
			holo[i-j+500].insert(a[i][j]);
	int vald=0;
	fr(i,500-m+1,500+n-1)
		if((*holo[i].begin())==(*holo[i].rbegin()))
			vald++;
	trace(vald,n+m-1);
	while(q--) {
		int x=readIntSp(1,n),y=readIntSp(1,m),v=readIntLn(-1000'000'000,1000'000'000);
		if((*holo[x-y+500].begin())==(*holo[x-y+500].rbegin()))
			vald--;
		holo[x-y+500].erase(holo[x-y+500].lower_bound(a[x][y]));
		a[x][y]=v;
		holo[x-y+500].insert(a[x][y]);
		if((*holo[x-y+500].begin())==(*holo[x-y+500].rbegin()))
			vald++;
		if(vald==n+m-1) {
			cout<<"YES"<<endl;
		} else
			cout<<"NO"<<endl;
	}
}


signed main() {
	ios_base::sync_with_stdio(0),cin.tie(0);
	srand(chrono::high_resolution_clock::now().time_since_epoch().count());
	cout<<fixed<<setprecision(7);
	int t=readIntLn(1,100);
	fr(i,1,t)
		solve();
	assert(getchar()==EOF);
#ifdef rd
	cerr<<endl<<endl<<endl<<"Time Elapsed: "<<((double)(clock()-clk))/CLOCKS_PER_SEC<<endl;
#endif
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class HRSCPMTR{
	//SOLUTION BEGIN
	void pre() throws Exception{}
	void solve(int TC) throws Exception{
	    int N = ni(), M = ni();
	    int[][] mat = new int[N][M];
	    for(int i = 0; i< N; i++)for(int j = 0; j< M; j++)mat[i][j] = ni();
	    int count = 0;
	    for(int i = 1; i< N; i++)for(int j = 1; j< M; j++)if(mat[i][j] != mat[i-1][j-1])count++;
	    int Q = ni();
	    for(int q = 0; q< Q; q++){
	        int r = ni()-1, c = ni()-1, v = ni();
	        if(Math.min(r, c) > 0 && mat[r][c] != mat[r-1][c-1])count--;
	        if(r+1 < N && c+1 < M && mat[r][c] != mat[r+1][c+1])count--;
	        mat[r][c] = v;
	        if(Math.min(r, c) > 0 && mat[r][c] != mat[r-1][c-1])count++;
	        if(r+1 < N && c+1 < M && mat[r][c] != mat[r+1][c+1])count++;

	        pn(count==0?"YES":"NO");
	    }
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new HRSCPMTR().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

VIDEO EDITORIAL (English):

VIDEO EDITORIAL (Hindi):

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

4 Likes

I think the time constraint was too much , my O(N * M+Q * min(N,M)) even got passed in 1.2 seconds almost. BTW Thanks for a very well written tutorial

3 Likes

A relatively neat/readable solution for anyone interested.

AC_CODE
#include<bits/stdc++.h>
using namespace std ;
void solve(){ 
  int n,m ;cin >> n >> m ;set<int>s ;
  vector<vector<int>>a(n,vector<int>(m)) ;
  vector<unordered_map<int,int>>mp(n+m) ; 
  for(int i=0;i<n;i++)
    for(int j=0;j<m;j++)
      cin >> a[i][j],mp[n-i+j][a[i][j]]++ ; 
  for(int i=0;i<n+m;i++)
    if(mp[i].size()>1)
      s.insert(i) ;
  int q ;cin >> q ;
  while(q--){
    int x,y,v ;
    cin >> x >> y >> v ;
    --x;--y ;
    mp[n-x+y][a[x][y]]-- ;
    if(mp[n-x+y][a[x][y]]==0)
      mp[n-x+y].erase(a[x][y]) ;
    mp[n-x+y][v]++ ;
    if(mp[n-x+y].size()==1)
      s.erase(n-x+y) ;
    if(mp[n-x+y].size()>1)
      s.insert(n-x+y) ;
    a[x][y]=v ;
    cout << (s.size()?"No":"Yes")  << '\n' ;
  }
}
signed main(){
  int T;
  cin >> T ;
  while(T--)
    solve() ;
}
8 Likes

@taran_1407 @zappelectro can u provide the implementation for this approach?
How do we map every element of the matrix to the corresponding set and position ?

Can someone explain , how in the given example test case(provided in the question) 2nd query is NO but 3rd is YES.

2 Likes

See editorialist solution, it uses the constant memory approach.

1 Like

thnx :slight_smile:

actually i tried to solve it using segment tree , lol now i found that its a simple one.

2 Likes

Setter’s solution is the most elegant solution. It completely symbolizes that we should not grab a fking chainsaw, if the same work could be done by butter cutting knife…

8 Likes

Can you please tell me where i am getting wa
https://www.codechef.com/viewsolution/40851386

Anyone please tell me where i am getting wa

https://www.codechef.com/viewsolution/40851386

HI,
I Am not getting why I am getting SIGSEGV error with my code to please if possible provide any test case ( ANY SUBTASK),
https://www.codechef.com/viewsolution/40864098
Thanks , :innocent: :innocent:

Can Someone tell why I am getting WA …

https://www.codechef.com/viewsolution/40875080

THANKS, I got where it is getting wrong I AM attaching the correct solution for it by removing the error.
https://www.codechef.com/viewsolution/40879407
in line 122 of CodeChef: Practical coding for everyone change a[m-1+x-y] to a[x-y] .

Why does this code give TLE on test 3?
submission - CodeChef: Practical coding for everyone
LOGIC - Keeping a set for the diagonals and a set for the bad diagonals, if at any query point bad size is 0, answer is yes else no.
Shouldn’t this code pass the time constraints as the main overhead is building my diagonal in the query which would be done in -
O(min(n,m)*log(min(n,m))*Q)
The find operation and insert in the bad set would again take O(log(min(n,m))*Q
Shouldn’t this pass?

What i am getting with ur code if all queries have x=500 & y=500 ur solution will be qnn that is sufficient for 2 second but maybe system not taking it , but for sure question not wanted to go with n*n in each query.

why my code fails for All test case?. Please help me!!. :cry:
https://www.codechef.com/viewsolution/40920729

But I am not right, it is n.log(n) in each query.

#include<bits/stdc++.h>
using namespace std;
int main()
{
int t;
cin>>t;
while(t–)
{
int n,m,i,j;
cin>>n>>m;
int a[n+1][m+1];
for(i=1;i<=n;i++)
{
for(j=1;j<=m;j++)
{
cin>>a[i][j];
}
}
int q,x,y,v,flag,d;
cin>>q;
while(q–)
{
cin>>x>>y>>v;
a[x][y]=v;
j=1;
d=j;
while(j<m && d<m)
{
j=d;i=1;
flag=0;
while(i<=n && j<=m && i+1<=n && j+1<=m)
{
if(a[i][j]!=a[i+1][j+1])
{
//cout<<a[i][j]<<" "<<a[i+1][j+1]<<endl;
flag=1;
break;
}
i++,j++;
}
if(flag==1)
{
break;
}
d=d+1;
}
if(flag==1)
{
cout<<“No”<<endl;
}
else
{
cout<<“Yes”<<endl;
}
a[x][y]=v;
}

}
return 0;

}
guys can any one tell me what is wrong in this for clearing subtask of horascope matrix.

Yes, Anyone can explain this. 1st, 2nd example are fine, but unable to understand the 3rd example.

1 Like