PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: rook_lift
Testers: wuhudsm, satyam_343
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Math
PROBLEM:
There’s an N\times M matrix with the value (i-1)M + j written in cell (i, j).
Q updates are performed on this matrix, each update either multiplies all values in a row or a column by a certain number.
FInd the sum of the matrix after all operations are performed.
EXPLANATION:
This problem can be solved by writing out appropriate equations to simplify things.
First, let’s calculate the sum of the matrix before any operations are performed.
This is just the sum of the first N\times M integers, which is easy to compute (watch out for overflow though!).
Now, let’s look at how updates affect this.
Suppose we multiply row r by x.
Then,
- The original sum of this row was (r-1)M^2 + \frac{M(M+1)}{2}.
- The new sum of this row is simply this value multiplied by x, since multiplying everything by x also multiplies their sum by x.
- So, we can update the sum in \mathcal{O}(1).
A similar calculation can be done for column updates by finding a formula for column sums.
However, this doesn’t quite take care of overlaps correctly.
For example, consider cell (i, j) (with initial value (i-1)M + j) such that row i was multiplied by value x_i and column j was multiplied by value y_j.
In our above scheme, we’ve independently increased the value of this cell for each update. That is, the row operation increased its value x_i times, and the column operation increased its value y_j times, for a total increase of (x_i + y_j) times.
However, updates combine multiplicatively, so its real final value should be (x_i \cdot y_j) times its initial value.
If we know i and j (and hence x_i and y_j), this is easy to fix. Unfortunately, there can be \mathcal{O}(K^2) such affected cells (one for each pair of (row update, column update)), and we surely can’t go through them all one by one.
To optimize this process, we can once again use a bit of math.
Suppose the column updates were to columns c_1, c_2, \ldots, c_w, with respective multipliers y_1, y_2, \ldots, y_w.
Let’s look at a row update, say to row r with value x. We can process all the column updates to this row quickly.
How?
The set of affected cells are (r, c_1), (r, c_2), \ldots, (r, c_w).
Based on our initial scheme, their current values are
((r-1)M + c_1)\cdot(x + y_1)
((r-1)M + c_2)\cdot(x + y_2)
((r-1)M + c_3)\cdot(x + y_3)
\vdots
((r-1)M + c_w)\cdot(x + y_w)
Expanding the parentheses, it can be seen that the sum of these values is
(r-1)Mwx + x\cdot sum(c_i) + (r-1)M\cdot sum(y_i) + sum(c_i\cdot y_i)
Note that sum(c_i), sum(y_i), sum(c_i\cdot y_i) are constants independent of the row we choose, and so their values can be precomputed.
Knowing these values, along with r and x, this sum can be computed in \mathcal{O}(1).
Now, let’s look at what the values of these cells should be:
((r-1)M + c_1)\cdot(x \cdot y_1)
((r-1)M + c_2)\cdot(x \cdot y_2)
((r-1)M + c_3)\cdot(x \cdot y_3)
\vdots
((r-1)M + c_w)\cdot(x \cdot y_w)
Once again, this sum can be expanded, to obtain
((r-1)Mx\cdot sum(y_i)) + x\cdot sum(c_i \cdot y_i)
Just as above, this can be computed in \mathcal{O}(1) time.
Since both the wrong and correct contributions can be computed in \mathcal{O}(1) time, the entire row can be updated in \mathcal{O}(1) time and we’re done.
Do this for each row update to solve the problem in \mathcal{O}(K).
TIME COMPLEXITY
\mathcal{O}(K) per test case.
CODE:
Setter's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#define PRE(x,p) cout<<setprecision(x)<<p;
#define pb push_back
#define mp make_pair
#define f first
#define s second
#define pi 3.14159265358979
#define mod (ll)(1e9 + 7)
#define endl "\n"
#define high 1e18
#define low -1e18
#define ll long long int
#define ld long double
#define mem(x,val) memset(x,0,sizeof(x));
#define rep(i,l,r) for(ll i=l;i<=r;i++)
#define p(a) for(auto i:a) cout<<i<<' '; cout<<endl;
#define vll vector<ll>
#define vb vector<bool>
#define vpll vector<pair<ll,ll>>
#define vi vector<int>
#define vpi vector<pair<int, int>>
#define vvll vector<vector<ll>>
#define vvi vector<vector<int>>
#define vvvll vector<vector<vector<ll>>>
#define pll pair<ll,ll>
#define vs vector<string>
#define vvpll vector<vector<pair<ll, ll>>>
#define vvpi vector<vector<pair<int, int>>>
#define vpii vector<pair<int, int>>
#define sz(a) (ll)a.size()
#define po(x) (ll)(1ll<<x)
#define all(x) begin(x), end(x)
#define speed ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
#define yes {cout<<"YES"<<endl;return;}
#define no {cout<<"NO"<<endl; return;}
#define ok cout<<"ok"<<endl;
#define ordered_set tree<int, null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>
using namespace std;
using namespace __gnu_pbds;
void showa(ll a[],ll n) { for(ll i=1;i<=n;i++) cout<<a[i]<<' '; cout<<endl; }
ll ison(ll w ,ll i) {return w&(1ll<<i);}
void amax(ll &a, ll b){ a=max(a,b); }
void amin(ll &a, ll b){ a=min(a,b);}
void modadd(ll &a , ll b) {a=((a%mod)+(b%mod))%mod;}
void modsub(ll &a , ll b) {a=((a%mod)-(b%mod)+mod)%mod;}
void modmul(ll &a , ll b) {a=((a%mod)*(b%mod))%mod;}
#ifndef ONLINE_JUDGE
#define debug(x) cerr << #x <<" "; _print(x); cerr << endl;
#else
#define debug(x)
#endif
void _print(ll t) {cerr << t<<' ';}
void _print(int t) {cerr << t<<' ';}
void _print(string t) {cerr << t<<' ';}
void _print(char t) {cerr << t<<' ';}
void _print(ld t) {cerr << t<<' ';}
void _print(double t) {cerr << t<<' ';}
template<class T,class V> void _print(pair <T, V> p);
template<class T>void _print(vector <T> v);
template<class T>void _print(vector <T> v);
template<class T>void _print(set <T> v);
template<class T,class V> void _print(map <T, V> v);
template<class T>void _print(multiset <T> v);
template<class T,class V> void _print(pair <T, V> p) {cerr << "{"; _print(p.f); cerr << ","; _print(p.s); cerr << "}";}
template<class T>void _print(vector <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template<class T>void _print(set <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template<class T>void _print(multiset <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template<class T,class V> void _print(map <T, V> v) {cerr << "[ "; for (auto i : v) {_print(i); cerr << " ";} cerr << "]";}
//const ll l=30; //log2(n)
//const ll N=200005;
ll n,m,sum_k;
ll get(ll x)
{
x%=mod;
ll ans=(x*(x+1))%mod;
modmul(ans,500000004);
return ans;
}
ll getrow(ll x) //get value of xth row
{
ll ans= get(x*m);
modsub( ans, get ((x-1)*m));
return ans;
}
ll getcol(ll x) //get value of xth row
{
ll ans=n;
ll here=(2*x)%mod;
modadd(here, (n-1)*m);
modmul(ans,here);
modmul(ans,500000004);
return ans;
}
void check(ll x , ll l , ll r){
assert(x>=l && x<=r);
}
ll rec(vvll vec)
{
ll ans=0;
for(vll v:vec)
{
ll here=1;
for(ll x:v) modmul(here,x);
modadd(ans,here);
}
return ans;
}
void solve()
{
ll k;
assert(cin>>n>>m>>k);
sum_k+=k;
check(n,1,1e9);
check(m,1,1e9);
check(k,1,2e5);
check(sum_k,1,2e5);
map<ll,ll>row,col;
rep(i,1,k)
{
ll type,x,c;
assert(cin>>type>>x>>c);
check(type,0,1);
check(c,0,1e9);
if(type==0)
{
check(x,1,n);
assert(row.count(x)==0);
row[x]=c;
}
if(type==1)
{
check(x,1,m);
assert(col.count(x)==0);
col[x]=c;
}
}
ll yy=0;
ll cc_yy=0;
for(auto it:col)
{
modadd(yy,it.s);
modadd(cc_yy, it.f*it.s);
}
ll ans=get(n*m);
for(auto it:row) modadd(ans,getrow(it.f) * (it.s-1+mod)%mod);
for(auto it:col) modadd(ans,getcol(it.f) * (it.s-1+mod)%mod);
ll tot_c=col.size()%mod;
ll col_sum=0;
for(auto it:col) modadd(col_sum,it.f);
for(auto it:row)
{
ll r=it.f;
ll x=it.s;
ll here= ((r-1)*m)%mod ;
modmul(here,x);
modmul(here,yy);
ll toadd=x;
modmul(toadd, cc_yy);
modadd(here,toadd);
modadd(ans,here);
ll tosub1=rec({{m,r-1,x,tot_c},{x,col_sum}});
ll tosub2=rec({{m,r-1,yy},{cc_yy}});
modsub(tosub2, m*((r-1)*tot_c)%mod);
modsub(tosub2, col_sum);
modsub(ans,tosub1);
modsub(ans,tosub2);
}
cout<<ans<<endl;
}
signed main()
{
#ifndef ONLINE_JUDGE
freopen("input_5.in", "r", stdin);
freopen("output_5.out", "w", stdout);
#endif
speed
ll t=1;
assert(cin>>t);
check(t,1,100);
for(ll test=1;test<=t;test++)
{
solve();
}
return 0;
}
Tester's code (C++)
#include <map>
#include <set>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
typedef double db;
typedef long long ll;
typedef unsigned long long ull;
const int N=1000010;
const int LOGN=28;
const ll TMD=1000000007;
const ll INF=2147483647;
ll T,n,m,k;
ll S1,S2,ans;
ll q[N],x[N],v[N];
//-------------------------------------------------
//涉及变量:fac[],inv[],TMD,N(上界)
//注意:TMD自取
ll pw(ll x,ll p)
{
if(!p) return 1;
ll y=pw(x,p>>1);
y=(y*y)%TMD;
if(p&1) y=(y*(x%TMD))%TMD; //典中典:若底数太大要对底数取模!
return y;
}
ll inv(ll x)
{
return pw(x,TMD-2);
}
//-------------------------------------------------
int main()
{
scanf("%d",&T);
while(T--)
{
scanf("%lld%lld%lld",&n,&m,&k);
for(int i=1;i<=k;i++) scanf("%lld%lld%lld",&q[i],&x[i],&v[i]);
S1=(ll)m*(m+1)/2%TMD;S2=m;
for(int i=1;i<=k;i++) if(q[i]) S1=(S1+(ll)x[i]*(v[i]-1+TMD))%TMD,S2=(S2+(v[i]-1+TMD))%TMD;
ans=(S1+S1+(ll)m*(n-1)%TMD*S2%TMD)*n%TMD*inv(2)%TMD;
for(int i=1;i<=k;i++) if(!q[i]) ans=(ans+(ll)(v[i]-1+TMD)*(S1+(ll)m*(x[i]-1)%TMD*S2%TMD))%TMD;
printf("%lld\n",ans);
}
return 0;
}