MATPAIN80 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: rook_lift
Testers: wuhudsm, satyam_343
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Math

PROBLEM:

There’s an N\times M matrix with the value (i-1)M + j written in cell (i, j).

Q updates are performed on this matrix, each update either multiplies all values in a row or a column by a certain number.
FInd the sum of the matrix after all operations are performed.

EXPLANATION:

This problem can be solved by writing out appropriate equations to simplify things.

First, let’s calculate the sum of the matrix before any operations are performed.
This is just the sum of the first N\times M integers, which is easy to compute (watch out for overflow though!).

Now, let’s look at how updates affect this.
Suppose we multiply row r by x.
Then,

  • The original sum of this row was (r-1)M^2 + \frac{M(M+1)}{2}.
  • The new sum of this row is simply this value multiplied by x, since multiplying everything by x also multiplies their sum by x.
  • So, we can update the sum in \mathcal{O}(1).

A similar calculation can be done for column updates by finding a formula for column sums.

However, this doesn’t quite take care of overlaps correctly.
For example, consider cell (i, j) (with initial value (i-1)M + j) such that row i was multiplied by value x_i and column j was multiplied by value y_j.
In our above scheme, we’ve independently increased the value of this cell for each update. That is, the row operation increased its value x_i times, and the column operation increased its value y_j times, for a total increase of (x_i + y_j) times.

However, updates combine multiplicatively, so its real final value should be (x_i \cdot y_j) times its initial value.

If we know i and j (and hence x_i and y_j), this is easy to fix. Unfortunately, there can be \mathcal{O}(K^2) such affected cells (one for each pair of (row update, column update)), and we surely can’t go through them all one by one.

To optimize this process, we can once again use a bit of math.

Suppose the column updates were to columns c_1, c_2, \ldots, c_w, with respective multipliers y_1, y_2, \ldots, y_w.
Let’s look at a row update, say to row r with value x. We can process all the column updates to this row quickly.

How?

The set of affected cells are (r, c_1), (r, c_2), \ldots, (r, c_w).

Based on our initial scheme, their current values are
((r-1)M + c_1)\cdot(x + y_1)
((r-1)M + c_2)\cdot(x + y_2)
((r-1)M + c_3)\cdot(x + y_3)
\vdots
((r-1)M + c_w)\cdot(x + y_w)

Expanding the parentheses, it can be seen that the sum of these values is
(r-1)Mwx + x\cdot sum(c_i) + (r-1)M\cdot sum(y_i) + sum(c_i\cdot y_i)

Note that sum(c_i), sum(y_i), sum(c_i\cdot y_i) are constants independent of the row we choose, and so their values can be precomputed.
Knowing these values, along with r and x, this sum can be computed in \mathcal{O}(1).

Now, let’s look at what the values of these cells should be:
((r-1)M + c_1)\cdot(x \cdot y_1)
((r-1)M + c_2)\cdot(x \cdot y_2)
((r-1)M + c_3)\cdot(x \cdot y_3)
\vdots
((r-1)M + c_w)\cdot(x \cdot y_w)

Once again, this sum can be expanded, to obtain
((r-1)Mx\cdot sum(y_i)) + x\cdot sum(c_i \cdot y_i)

Just as above, this can be computed in \mathcal{O}(1) time.

Since both the wrong and correct contributions can be computed in \mathcal{O}(1) time, the entire row can be updated in \mathcal{O}(1) time and we’re done.

Do this for each row update to solve the problem in \mathcal{O}(K).

TIME COMPLEXITY

\mathcal{O}(K) per test case.

CODE:

Setter's code (C++)
#include             <bits/stdc++.h>
#include             <ext/pb_ds/assoc_container.hpp>
#include             <ext/pb_ds/tree_policy.hpp>
#define PRE(x,p)     cout<<setprecision(x)<<p; 
#define pb           push_back
#define mp           make_pair
#define f            first
#define s            second
#define pi           3.14159265358979
#define mod          (ll)(1e9 + 7)
#define endl         "\n"
#define high         1e18
#define low          -1e18
#define ll           long long int
#define ld           long double
#define mem(x,val)   memset(x,0,sizeof(x));
#define rep(i,l,r)   for(ll i=l;i<=r;i++)
#define p(a)         for(auto i:a) cout<<i<<' '; cout<<endl;
#define vll          vector<ll> 
#define vb           vector<bool>
#define vpll         vector<pair<ll,ll>>
#define vi           vector<int>
#define vpi          vector<pair<int, int>>
#define vvll         vector<vector<ll>>
#define vvi          vector<vector<int>>
#define vvvll        vector<vector<vector<ll>>>
#define pll          pair<ll,ll>
#define vs           vector<string>
#define vvpll        vector<vector<pair<ll, ll>>>
#define vvpi         vector<vector<pair<int, int>>>
#define vpii         vector<pair<int, int>>
#define sz(a)        (ll)a.size()
#define po(x)        (ll)(1ll<<x)
#define all(x)       begin(x), end(x)
#define speed        ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
#define yes          {cout<<"YES"<<endl;return;}
#define no           {cout<<"NO"<<endl; return;}
#define ok           cout<<"ok"<<endl;
#define ordered_set  tree<int, null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>


using namespace std;
using namespace __gnu_pbds;

void showa(ll a[],ll n)  { for(ll i=1;i<=n;i++) cout<<a[i]<<' ';  cout<<endl;  }
ll ison(ll w ,ll i)    {return  w&(1ll<<i);}
void amax(ll &a, ll b){ a=max(a,b); }
void amin(ll &a, ll b){ a=min(a,b);}
void modadd(ll &a , ll b) {a=((a%mod)+(b%mod))%mod;}
void modsub(ll &a , ll b) {a=((a%mod)-(b%mod)+mod)%mod;}
void modmul(ll &a , ll b) {a=((a%mod)*(b%mod))%mod;}

#ifndef ONLINE_JUDGE
#define debug(x) cerr << #x <<" "; _print(x); cerr << endl;
#else
#define debug(x)
#endif

void _print(ll t)     {cerr << t<<' ';}
void _print(int t)    {cerr << t<<' ';}
void _print(string t) {cerr << t<<' ';}
void _print(char t)   {cerr << t<<' ';}
void _print(ld t)     {cerr << t<<' ';}
void _print(double t) {cerr << t<<' ';}
template<class T,class V> void _print(pair <T, V> p);
template<class T>void _print(vector <T> v);
template<class T>void _print(vector <T> v);
template<class T>void _print(set <T> v);
template<class T,class V> void _print(map <T, V> v);
template<class T>void _print(multiset <T> v);
template<class T,class V> void _print(pair <T, V> p) {cerr << "{"; _print(p.f); cerr << ","; _print(p.s); cerr << "}";}
template<class T>void _print(vector <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template<class T>void _print(set <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template<class T>void _print(multiset <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template<class T,class V> void _print(map <T, V> v) {cerr << "[ "; for (auto i : v) {_print(i); cerr << " ";} cerr << "]";}

//const ll l=30;   //log2(n)
//const ll N=200005;

ll n,m,sum_k;

ll get(ll x)
{
   x%=mod;
   ll ans=(x*(x+1))%mod;
   modmul(ans,500000004);
   return ans;
}
ll getrow(ll x)  //get value of xth row
{
  ll ans= get(x*m);
  modsub( ans, get ((x-1)*m));
  return ans;
}
ll getcol(ll x)  //get value of xth row
{
  ll ans=n;
  ll here=(2*x)%mod;
  modadd(here, (n-1)*m);
  modmul(ans,here);
  modmul(ans,500000004);
  return ans;
}
void check(ll x , ll l , ll r){
  assert(x>=l && x<=r);
}

ll rec(vvll vec)
{
   ll ans=0;
   for(vll v:vec)
   {
     ll here=1;
     for(ll x:v) modmul(here,x); 
     modadd(ans,here);
   }
   return ans; 
}
void solve()
{
  ll k;
  assert(cin>>n>>m>>k);
  sum_k+=k;
  check(n,1,1e9);
  check(m,1,1e9);
  check(k,1,2e5);
  check(sum_k,1,2e5);
  map<ll,ll>row,col;
  rep(i,1,k)
  {
    ll type,x,c;
    assert(cin>>type>>x>>c);
    check(type,0,1);
    check(c,0,1e9);
    if(type==0)
    {
      check(x,1,n);
      assert(row.count(x)==0);
      row[x]=c;
    }
    if(type==1)
    {
      check(x,1,m);
      assert(col.count(x)==0);
      col[x]=c;
    }
  }
  ll yy=0;
  ll cc_yy=0;

  for(auto it:col) 
  {
    modadd(yy,it.s);
    modadd(cc_yy, it.f*it.s);
  }

  ll ans=get(n*m);
  for(auto it:row) modadd(ans,getrow(it.f) * (it.s-1+mod)%mod);
  for(auto it:col) modadd(ans,getcol(it.f) * (it.s-1+mod)%mod);
  
  ll tot_c=col.size()%mod;
  ll col_sum=0;
  for(auto it:col) modadd(col_sum,it.f);
  for(auto it:row)
  {
    ll r=it.f;
    ll x=it.s;

    ll here= ((r-1)*m)%mod ;
    modmul(here,x);
    modmul(here,yy);

    ll toadd=x;
    modmul(toadd, cc_yy);
    modadd(here,toadd);
    modadd(ans,here);
    
    ll tosub1=rec({{m,r-1,x,tot_c},{x,col_sum}});
    ll tosub2=rec({{m,r-1,yy},{cc_yy}});
    modsub(tosub2, m*((r-1)*tot_c)%mod);
    modsub(tosub2, col_sum);
    
    modsub(ans,tosub1);
    modsub(ans,tosub2);

  }   
  cout<<ans<<endl;

}
 
signed main()
{
   #ifndef ONLINE_JUDGE
    freopen("input_5.in", "r", stdin);
    freopen("output_5.out", "w", stdout);
    #endif 
  speed

  ll t=1;  
  assert(cin>>t);
  check(t,1,100);
  
  for(ll test=1;test<=t;test++)
  {
    solve();
  }
  return 0;
}      
Tester's code (C++)
#include <map>
#include <set>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
typedef double db; 
typedef long long ll;
typedef unsigned long long ull;
const int N=1000010;
const int LOGN=28;
const ll  TMD=1000000007;
const ll  INF=2147483647;
ll  T,n,m,k;
ll  S1,S2,ans;
ll  q[N],x[N],v[N];

//-------------------------------------------------
//涉及变量:fac[],inv[],TMD,N(上界)
//注意:TMD自取 

ll pw(ll x,ll p)
{
	if(!p) return 1;
	ll y=pw(x,p>>1);
	y=(y*y)%TMD;
	if(p&1) y=(y*(x%TMD))%TMD;	//典中典:若底数太大要对底数取模! 
	return y;
}

ll inv(ll x)
{
	return pw(x,TMD-2);
}

//-------------------------------------------------

int main()
{
	scanf("%d",&T);
	while(T--)
	{
		scanf("%lld%lld%lld",&n,&m,&k);
		for(int i=1;i<=k;i++) scanf("%lld%lld%lld",&q[i],&x[i],&v[i]);
		S1=(ll)m*(m+1)/2%TMD;S2=m;
		for(int i=1;i<=k;i++) if(q[i]) S1=(S1+(ll)x[i]*(v[i]-1+TMD))%TMD,S2=(S2+(v[i]-1+TMD))%TMD;
		ans=(S1+S1+(ll)m*(n-1)%TMD*S2%TMD)*n%TMD*inv(2)%TMD;
		for(int i=1;i<=k;i++) if(!q[i]) ans=(ans+(ll)(v[i]-1+TMD)*(S1+(ll)m*(x[i]-1)%TMD*S2%TMD))%TMD;
		printf("%lld\n",ans);
	}
	
	return 0;
}
3 Likes

image

Is there a missing x here?

1 Like

yes

Thanks for noticing, updated.

1 Like

Where can I find the video solution for this problem?

Why are we subtracting 2 times at line 158 and 159. Is it not enough to subtract only once. It is making sense to me that we should subtract 1 but why two times?
Let say element at intersection be ‘a’ then,
a+((rmul -1)* a) + ((cmul -1)a) = a(rmul+cmul-1) but it should be a*(rmul+cmul)
@iceknight1093 ?

#include             <bits/stdc++.h>
#include             <ext/pb_ds/assoc_container.hpp>
#include             <ext/pb_ds/tree_policy.hpp>
#define PRE(x,p)     cout<<setprecision(x)<<p; 
#define pb           push_back
#define mp           make_pair
#define f            first
#define s            second
#define pi           3.14159265358979
#define mod          (ll)(1e9 + 7)
#define endl         "\n"
#define high         1e18
#define low          -1e18
#define ll           long long int
#define ld           long double
#define mem(x,val)   memset(x,0,sizeof(x));
#define rep(i,l,r)   for(ll i=l;i<=r;i++)
#define p(a)         for(auto i:a) cout<<i<<' '; cout<<endl;
#define vll          vector<ll> 
#define vb           vector<bool>
#define vpll         vector<pair<ll,ll>>
#define vi           vector<int>
#define vpi          vector<pair<int, int>>
#define vvll         vector<vector<ll>>
#define vvi          vector<vector<int>>
#define vvvll        vector<vector<vector<ll>>>
#define pll          pair<ll,ll>
#define vs           vector<string>
#define vvpll        vector<vector<pair<ll, ll>>>
#define vvpi         vector<vector<pair<int, int>>>
#define vpii         vector<pair<int, int>>
#define sz(a)        (ll)a.size()
#define po(x)        (ll)(1ll<<x)
#define all(x)       begin(x), end(x)
#define speed        ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
#define yes          {cout<<"YES"<<endl;return;}
#define no           {cout<<"NO"<<endl; return;}
#define ok           cout<<"ok"<<endl;
#define ordered_set  tree<int, null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>


using namespace std;
using namespace __gnu_pbds;

void showa(ll a[],ll n)  { for(ll i=1;i<=n;i++) cout<<a[i]<<' ';  cout<<endl;  }
ll ison(ll w ,ll i)    {return  w&(1ll<<i);}
void amax(ll &a, ll b){ a=max(a,b); }
void amin(ll &a, ll b){ a=min(a,b);}
void modadd(ll &a , ll b) {a=((a%mod)+(b%mod))%mod;}
void modsub(ll &a , ll b) {a=((a%mod)-(b%mod)+mod)%mod;}
void modmul(ll &a , ll b) {a=((a%mod)*(b%mod))%mod;}

#ifndef ONLINE_JUDGE
#define debug(x) cerr << #x <<" "; _print(x); cerr << endl;
#else
#define debug(x)
#endif

void _print(ll t)     {cerr << t<<' ';}
void _print(int t)    {cerr << t<<' ';}
void _print(string t) {cerr << t<<' ';}
void _print(char t)   {cerr << t<<' ';}
void _print(ld t)     {cerr << t<<' ';}
void _print(double t) {cerr << t<<' ';}
template<class T,class V> void _print(pair <T, V> p);
template<class T>void _print(vector <T> v);
template<class T>void _print(vector <T> v);
template<class T>void _print(set <T> v);
template<class T,class V> void _print(map <T, V> v);
template<class T>void _print(multiset <T> v);
template<class T,class V> void _print(pair <T, V> p) {cerr << "{"; _print(p.f); cerr << ","; _print(p.s); cerr << "}";}
template<class T>void _print(vector <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template<class T>void _print(set <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template<class T>void _print(multiset <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template<class T,class V> void _print(map <T, V> v) {cerr << "[ "; for (auto i : v) {_print(i); cerr << " ";} cerr << "]";}

//const ll l=30;   //log2(n)
//const ll N=200005;

ll n,m,sum_k;

ll get(ll x)
{
   x%=mod;
   ll ans=(x*(x+1))%mod;
   modmul(ans,500000004);
   return ans;
}
ll getrow(ll x)  //get value of xth row
{
  ll ans= get(x*m);
  modsub( ans, get ((x-1)*m));
  return ans;
}
ll getcol(ll x)  //get value of xth row
{
  ll ans=n;
  ll here=(2*x)%mod;
  modadd(here, (n-1)*m);
  modmul(ans,here);
  modmul(ans,500000004);
  return ans;
}
void check(ll x , ll l , ll r){
  assert(x>=l && x<=r);
}

ll rec(vvll vec)
{
   ll ans=0;
   for(vll v:vec)
   {
     ll here=1;
     for(ll x:v) modmul(here,x); 
     modadd(ans,here);
   }
   return ans; 
}
void solve()
{
  ll k;
  assert(cin>>n>>m>>k);
  sum_k+=k;
  check(n,1,1e9);
  check(m,1,1e9);
  check(k,1,2e5);
  check(sum_k,1,2e5);
  map<ll,ll>row,col;
  rep(i,1,k)
  {
    ll type,x,c;
    assert(cin>>type>>x>>c);
    check(type,0,1);
    check(c,0,1e9);
    if(type==0)
    {
      check(x,1,n);
      assert(row.count(x)==0);
      row[x]=c;
    }
    if(type==1)
    {
      check(x,1,m);
      assert(col.count(x)==0);
      col[x]=c;
    }
  }
  ll yy=0;
  ll cc_yy=0;

  for(auto it:col) 
  {
    modadd(yy,it.s);
    modadd(cc_yy, it.f*it.s);
  }

  ll ans=get(n*m);
  for(auto it:row) modadd(ans,getrow(it.f) * (it.s-1+mod)%mod);
  for(auto it:col) modadd(ans,getcol(it.f) * (it.s-1+mod)%mod);
  
  ll tot_c=col.size()%mod;
  ll col_sum=0;
  for(auto it:col) modadd(col_sum,it.f);
  for(auto it:row)
  {
    ll r=it.f;
    ll x=it.s;

    ll here= ((r-1)*m)%mod ;
    modmul(here,x);
    modmul(here,yy);

    ll toadd=x;
    modmul(toadd, cc_yy);
    modadd(here,toadd);
    modadd(ans,here);
    
    ll tosub1=rec({{m,r-1,x,tot_c},{x,col_sum}});
    ll tosub2=rec({{m,r-1,yy},{cc_yy}});
    modsub(tosub2, m*((r-1)*tot_c)%mod);
    modsub(tosub2, col_sum);
    
    modsub(ans,tosub1);
    modsub(ans,tosub2);

  }   
  cout<<ans<<endl;

}
 
signed main()
{
   #ifndef ONLINE_JUDGE
    freopen("input_5.in", "r", stdin);
    freopen("output_5.out", "w", stdout);
    #endif 
  speed

  ll t=1;  
  assert(cin>>t);
  check(t,1,100);
  
  for(ll test=1;test<=t;test++)
  {
    solve();
  }
  return 0;
}