PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Setter: Shahadat Hossain Shahin
Tester: Teja Vardhan Reddy
Editorialist: Taranpreet Singh
DIFFICULTY:
Medium
PREREQUISITES:
Segment Tree with Lazy Propagation, basic math.
PROBLEM:
Given an array A of length N and an integer M, find the value of \sum_{i = 1}^{N-M+1} \sum_{j = 1}^{N-M+1} F(i, j) where F(p, q) = \sum_{i = 0}^{M-1} A_{p+i} * A_{q+i}
QUICK EXPLANATION
- By a couple of transformations, we can reduce the given expression to \sum_{i = 0}^{M-1} A[i, i+N-M] ^2 where A[i, j] is the sum of elements from i-th element to j-th element both inclusive.
- We can maintain a segment tree with M leaves, each leaf storing values A[k, k+N-M]. Each non-leaf node shall store the sum of values as well as the sum of square of values.
- We need to handle range update, updating all leaves which have position pos in its interval while keeping the sum of squares.
- The answer for each query shall be the sum of squares of the root node.
EXPLANATION
The given sum seems a pain, so let’s try to simplify it.
We have S = \displaystyle\sum_{i = 0}^{N-M} \sum_{j = 0}^{N-M} \sum_{k = 0}^{M-1} A_{i+k}*A_{j+k}
Reordering summations and taking non-dependent terms out.
S = \displaystyle\sum_{k = 0}^{M-1} \sum_{i = 0}^{N-M} A_{i+k}* \sum_{j = 0}^{N-M} A_{j+k}
Writing the sum of subarray [l, r] as A[l, r], we get
S = \displaystyle\sum_{k = 0}^{M-1} \sum_{i = 0}^{N-M} A_{i+k}* A[k, k+N-M]
S = \displaystyle\sum_{k = 0}^{M-1} A[k, k+N-M] *\sum_{i = 0}^{N-M} A_{i+k}
S = \displaystyle\sum_{k = 0}^{M-1} A[k, k+N-M] * A[k, k+N-M]
S = \displaystyle\sum_{k = 0}^{M-1} A[k, k+N-M]^2
We get a nice expression via transformations. Hurray! This allows us to solve subtask 1 by computing the above expression using sliding window technique.
But update are still a pain. Each update affects many terms of the above summation, we need to handle it better.
Writing B_k = A[k, k+N-M], we need to find \sum_{k = 0}^{M-1} {B_k}^2 while updating B_k for each update.
In updating position pos, all B_i are affected which contain A_{pos}. By basic observation, we can find that A_p is included in all B_i such that max(0,p-(N-M) \leq i \leq min(p, M-1).
Now, we have reduced the problem to,
Given an array, handle the following operations
- Update range [L, R] by delta.
- Find the sum of squares of values in the array.
Anyone who have heard of Segment tree shall know what’s about to happen
For the segment tree, only the push function is a bit different due to squares of terms. Rest details can be found easily.
Let us store the sum of values as well as sum of squares of values at each node. Say the sum of values at node i is S_i and the sum of squares of values at node i is SQ_i.
Suppose each node in the range [L, R] needs to be increased by D. Suppose old values of S_i and SQ_i are given by sumX and sumX2 respectively.
The new S_i is given as sumX+(R-L+1)*D
The new SQ_i is given as \displaystyle SQ_i = \sum_{p = L}^{R} (B_p+D)^2 = \sum_{p = L}^{R} {B_p}^2 + \sum_{p = L}^{R} D^2 + 2*D*\sum_{p = L}^{R} B_p
\displaystyle SQ_i = sumX2 + D^2*(R-L+1) + 2*D*sumX
Hence, we can push updates in O(1) time, thus solving the problem.
The final sum of squares after each query shall be stored as SQ_{root}
TIME COMPLEXITY
The time complexity is O(N+M*log(M)) per test case.
SOLUTIONS:
Setter's Solution
#include <bits/stdc++.h>
using namespace std;
const int MAX = 500005;
const int MOD = 998244353;
inline int add(int a, int b) { return a + b >= MOD ? a + b - MOD : a + b; }
inline int sub(int a, int b) { return a - b < 0 ? a - b + MOD : a - b; }
inline int mul(int a, int b) { return (a * 1LL * b) % MOD; }
// ara[i] = sum of all A[id] such that A[id] can be in the
// i'th position of a length M sequence
int ara[MAX];
// sum -> sum of the array values of respective range
// sqsum -> sum of the squared array values of respective range
// The segment tree is built over ara[]
struct node {
int sum, sqsum;
} tree[4 * MAX];
int lazy[4 * MAX];
node Merge(node a, node b) {
node ret;
ret.sum = add(a.sum, b.sum);
ret.sqsum = add(a.sqsum, b.sqsum);
return ret;
}
void lazyUpdate(int n, int st, int ed) {
if(lazy[n] != 0){
tree[n].sqsum = add(tree[n].sqsum, mul(lazy[n] + lazy[n] % MOD, tree[n].sum));
tree[n].sqsum = add(tree[n].sqsum, mul(ed - st + 1, mul(lazy[n], lazy[n])));
tree[n].sum = add(tree[n].sum, mul(ed - st + 1, lazy[n]));
if(st != ed){
lazy[n + n] = add(lazy[n + n], lazy[n]);
lazy[n + n + 1] = add(lazy[n + n + 1], lazy[n]);
}
lazy[n] = 0;
}
}
void build(int n, int st, int ed) {
lazy[n] = 0;
if(st == ed){
tree[n].sum = ara[st];
tree[n].sqsum = mul(ara[st], ara[st]);
return;
}
int mid = (st + ed) / 2;
build(n + n, st, mid);
build(n + n + 1, mid + 1, ed);
tree[n] = Merge(tree[n + n], tree[n + n + 1]);
}
// adds v to the range [i, j] or ara
void update(int n, int st, int ed, int i, int j, int v) {
if(i > j) assert(false);
if(i > j) return;
lazyUpdate(n, st, ed);
if(st > j or ed < i) return;
if(st >= i and ed <= j) {
lazy[n] = add(lazy[n], v);
lazyUpdate(n, st, ed);
return;
}
int mid = (st + ed) / 2;
update(n + n, st, mid, i, j, v);
update(n + n + 1, mid + 1, ed, i, j, v);
tree[n] = Merge(tree[n + n], tree[n + n + 1]);
}
int inp[MAX]; // input array A
int cum[MAX]; // cumulative sum array of the input
int L[MAX], R[MAX]; // index i is can occur in position S[L[i]] to S[R[i]] when doing dot product
int main() {
ios_base::sync_with_stdio(false);
cin.tie(0); cout.tie(0);
// freopen("4.in", "r", stdin);
// freopen("4.out", "w", stdout);
int sum_n = 0;
int sum_q = 0;
int T;
cin >> T;
for(int t=1;t<=T;t++) {
int n, m, q, id, v, w;
cin >> n >> m >> q;
assert(n >= 1 and n <= 5e5);
assert(m >= 1 and m <= 5e5);
assert(q >= 1 and q <= 5e5);
sum_n += n;
sum_q += n;
for(int i=1;i<=n;i++) {
cin >> inp[i];
assert(inp[i] >= 1 and inp[i] <= 5e5);
cum[i] = add(cum[i - 1], inp[i]);
}
int l = 1, r = n - m + 1;
for(int i=1;i<=m;i++) {
ara[i] = sub(cum[r], cum[l - 1]);
l++, r++;
}
build(1, 1, m);
for(int i=1;i<=n;i++) L[i] = 1, R[i] = m;
for(int i=1;i<=m;i++) R[i] = i;
for(int i=m,j=n;i>=1;i--,j--) L[j] = i;
for(int i=1;i<=q;i++) {
cin >> id >> v;
assert(id >= 1 and id <= n);
assert(v >= 1 and v <= 5e5);
w = sub(v, inp[id]);
inp[id] = v;
update(1, 1, m, L[id], R[id], w);
cout << tree[1].sqsum << '\n';
}
}
assert(sum_n <= 1e6);
assert(sum_q <= 1e6);
return 0;
}
Tester's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill - cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x) cout<<fixed<<val; // prints x digits after decimal in val
using namespace std;
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (998244353)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout)
#define primeDEN 727999983
#define int ll
int seg[2123456],lazy[2123456];
int pre[512345];
int build(int node,int s,int e){
lazy[node]=0;
if(s==e){
seg[node]=pre[s];
return 0;
}
int mid=(s+e)/2;
build(2*node,s,mid);
build(2*node+1,mid+1,e);
seg[node]=seg[2*node]+seg[2*node+1];
if(seg[node]>=mod)
seg[node]-=mod;
return 0;
}
int addm(int &a,int b){
a+=b;
a%=mod;
return 0;
}
int update(int node,int s,int e,int l,int r,int val){
if(lazy[node]){
seg[node]+=(e-s+1)*lazy[node];
seg[node]%=mod;
if(s!=e){
addm(lazy[2*node],lazy[node]);
addm(lazy[2*node+1],lazy[node]);
}
lazy[node]=0;
}
if(r<s || e<l){
return 0;
}
if(l<=s && e<=r){
lazy[node]=val;
seg[node]+=(e-s+1)*lazy[node];
seg[node]%=mod;
if(s!=e){
addm(lazy[2*node],lazy[node]);
addm(lazy[2*node+1],lazy[node]);
}
lazy[node]=0;
return 0;
}
int mid=(s+e)/2;
update(2*node,s,mid,l,r,val);
update(2*node+1,mid+1,e,l,r,val);
seg[node]=seg[2*node];
addm(seg[node],seg[2*node+1]);
return 0;
}
int query(int node,int s,int e,int l,int r){
if(lazy[node]){
seg[node]+=(e-s+1)*lazy[node];
seg[node]%=mod;
if(s!=e){
addm(lazy[2*node],lazy[node]);
addm(lazy[2*node+1],lazy[node]);
}
lazy[node]=0;
}
if(r<s || e<l){
return 0;
}
if(l<=s && e<=r){
return seg[node];
}
int mid=(s+e)/2;
int val1=query(2*node,s,mid,l,r);
int val2=query(2*node+1,mid+1,e,l,r);
addm(val1,val2);
return val1;
}
int a[512345];
signed main(){
std::ios::sync_with_stdio(false); cin.tie(NULL);
int t;
cin>>t;
while(t--){
int n,m,q;
cin>>n>>m>>q;
int i;
rep(i,n){
cin>>a[i];
}
pre[0]=0;
rep(i,n-m+1){
pre[0]+=a[i];
}
f(i,1,m){
pre[i]=pre[i-1]-a[i-1]+a[i+n-m];
pre[i]%=mod;
if(pre[i]<0)
pre[i]+=mod;
}
build(1,0,m);
int ans=0;
int st,en,val;
rep(i,n){
st=0;
en=m-1;
st=max(st,i-(n-m));
en=min(en,i);
val=query(1,0,m,st,en);
ans+=a[i]*val;
ans%=mod;
//cout<<st<<" "<<en<<" "<<val<<endl;
}
//cout<<ans<<endl;
int pos,gg;
rep(i,q){
cin>>pos>>gg;
pos--;
st=0;
en=m-1;
st=max(st,pos-(n-m));
en=min(en,pos);
val=query(1,0,m,st,en);
ans-=2*a[pos]*val;
ans+=(en-st+1)*(a[pos]*a[pos]);
update(1,0,m,st,en,gg-a[pos]);
a[pos]=gg;
val=query(1,0,m,st,en);
ans+=2*a[pos]*val;
ans-=(en-st+1)*a[pos]*a[pos];
ans%=mod;
if(ans<0)
ans+=mod;
cout<<ans<<"\n";
}
}
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class DOTTIME{
//SOLUTION BEGIN
long MOD = 998244353;
void pre() throws Exception{}
void solve(int TC) throws Exception{
int n = ni(), m = ni(), q = ni();
long[] a = new long[n];
for(int i = 0; i< n; i++)a[i] = nl();
long[] window = new long[m];
long windowSum = 0;
for(int i = 0; i< n-m; i++){
hold(a[i] >= 0);
}
for(int i = 0; i< n-m; i++)windowSum = add(windowSum, a[i]);
for(int i = 0; i< m; i++){
windowSum = add(windowSum, a[i+n-m]);
window[i] = windowSum;
hold(window[i] >= 0);
windowSum = add(windowSum, MOD-a[i]);
}
SegTree t = new SegTree(window);
while(q-->0){
int p = ni()-1;
long delta = nl()-a[p];
a[p] = add(a[p], delta);
int left = Math.max(0, p-(n-m)), right = Math.min(p, m-1);
t.u(left, right, delta);
pn(t.sum());
}
}
long add(long x, long y){
if(x < 0)x += MOD;
if(y < 0)y += MOD;
return x+y>=MOD?(x+y-MOD):(x+y);
}
long mul(long x, long y){
if(x < 0)x += MOD;
if(y < 0)y += MOD;
return (x*y)%MOD;
}
class SegTree{
int m = 1;
long[] t, t2, lazy;
public SegTree(int n){
while(m<n)m<<=1;
t = new long[m<<1];
t2 = new long[m<<1];
lazy = new long[m<<1];
}
public SegTree(long[] a){
while(m< a.length)m<<=1;
t = new long[m<<1];
t2 = new long[m<<1];
lazy = new long[m<<1];
for(int i = 0; i< a.length; i++){
t[i+m] = a[i];
t2[i+m] = mul(a[i], a[i]);
}
for(int i = m-1; i > 0; i--){
t[i] = add(t[i<<1], t[i<<1|1]);
t2[i] = add(t2[i<<1], t2[i<<1|1]);
}
}
void push(int i, int ll, int rr){
if(lazy[i] != 0){
long sumX = t[i];
long sumX2 = t2[i];
long delta = lazy[i], sz = rr-ll+1;
t[i] = add(sumX, mul(delta, sz));
t2[i] = add(sumX2, add(mul(sz, mul(delta, delta)), mul(2, mul(delta, sumX))));
if(i < m){
lazy[i<<1] = add(lazy[i<<1], lazy[i]);
lazy[i<<1|1] = add(lazy[i<<1|1], lazy[i]);
}
lazy[i] = 0;
}
}
void u(int l, int r, long x){u(l, r, x, 0, m-1, 1);}
void u(int l, int r, long x, int ll, int rr, int i){
push(i, ll, rr);
if(l == ll && r == rr){
lazy[i] = add(lazy[i], x);
push(i, ll, rr);
return;
}
int mid = (ll+rr)/2;
if(r <= mid){
u(l, r, x, ll, mid, i<<1);
push(i<<1|1, mid+1, rr);
}
else if(l > mid){
push(i<<1, ll, mid);
u(l, r, x, mid+1, rr, i<<1|1);
}
else{
u(l, mid, x, ll, mid, i<<1);
u(mid+1, r, x, mid+1, rr, i<<1|1);
}
t[i] = add(t[i<<1], t[i<<1|1]);
t2[i] = add(t2[i<<1], t2[i<<1|1]);
}
long sum(){return t2[1];}
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
DecimalFormat df = new DecimalFormat("0.00000000000");
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new DOTTIME().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Feel free to share your approach. Suggestions are welcomed as always.