DOTTIME - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2

Setter: Shahadat Hossain Shahin
Tester: Teja Vardhan Reddy
Editorialist: Taranpreet Singh

DIFFICULTY:

Medium

PREREQUISITES:

Segment Tree with Lazy Propagation, basic math.

PROBLEM:

Given an array A of length N and an integer M, find the value of \sum_{i = 1}^{N-M+1} \sum_{j = 1}^{N-M+1} F(i, j) where F(p, q) = \sum_{i = 0}^{M-1} A_{p+i} * A_{q+i}

QUICK EXPLANATION

  • By a couple of transformations, we can reduce the given expression to \sum_{i = 0}^{M-1} A[i, i+N-M] ^2 where A[i, j] is the sum of elements from i-th element to j-th element both inclusive.
  • We can maintain a segment tree with M leaves, each leaf storing values A[k, k+N-M]. Each non-leaf node shall store the sum of values as well as the sum of square of values.
  • We need to handle range update, updating all leaves which have position pos in its interval while keeping the sum of squares.
  • The answer for each query shall be the sum of squares of the root node.

EXPLANATION

The given sum seems a pain, so let’s try to simplify it.

We have S = \displaystyle\sum_{i = 0}^{N-M} \sum_{j = 0}^{N-M} \sum_{k = 0}^{M-1} A_{i+k}*A_{j+k}
Reordering summations and taking non-dependent terms out.
S = \displaystyle\sum_{k = 0}^{M-1} \sum_{i = 0}^{N-M} A_{i+k}* \sum_{j = 0}^{N-M} A_{j+k}

Writing the sum of subarray [l, r] as A[l, r], we get
S = \displaystyle\sum_{k = 0}^{M-1} \sum_{i = 0}^{N-M} A_{i+k}* A[k, k+N-M]

S = \displaystyle\sum_{k = 0}^{M-1} A[k, k+N-M] *\sum_{i = 0}^{N-M} A_{i+k}
S = \displaystyle\sum_{k = 0}^{M-1} A[k, k+N-M] * A[k, k+N-M]

S = \displaystyle\sum_{k = 0}^{M-1} A[k, k+N-M]^2

We get a nice expression via transformations. Hurray! This allows us to solve subtask 1 by computing the above expression using sliding window technique.

But update are still a pain. Each update affects many terms of the above summation, we need to handle it better.

Writing B_k = A[k, k+N-M], we need to find \sum_{k = 0}^{M-1} {B_k}^2 while updating B_k for each update.

In updating position pos, all B_i are affected which contain A_{pos}. By basic observation, we can find that A_p is included in all B_i such that max(0,p-(N-M) \leq i \leq min(p, M-1).

Now, we have reduced the problem to,
Given an array, handle the following operations

  • Update range [L, R] by delta.
  • Find the sum of squares of values in the array.

Anyone who have heard of Segment tree shall know what’s about to happen :stuck_out_tongue:

For the segment tree, only the push function is a bit different due to squares of terms. Rest details can be found easily.

Let us store the sum of values as well as sum of squares of values at each node. Say the sum of values at node i is S_i and the sum of squares of values at node i is SQ_i.

Suppose each node in the range [L, R] needs to be increased by D. Suppose old values of S_i and SQ_i are given by sumX and sumX2 respectively.
The new S_i is given as sumX+(R-L+1)*D
The new SQ_i is given as \displaystyle SQ_i = \sum_{p = L}^{R} (B_p+D)^2 = \sum_{p = L}^{R} {B_p}^2 + \sum_{p = L}^{R} D^2 + 2*D*\sum_{p = L}^{R} B_p

\displaystyle SQ_i = sumX2 + D^2*(R-L+1) + 2*D*sumX

Hence, we can push updates in O(1) time, thus solving the problem.

The final sum of squares after each query shall be stored as SQ_{root}

TIME COMPLEXITY

The time complexity is O(N+M*log(M)) per test case.

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>

using namespace std;

const int MAX = 500005;
const int MOD = 998244353;

inline int add(int a, int b) { return a + b >= MOD ? a + b - MOD : a + b; }
inline int sub(int a, int b) { return a - b < 0 ? a - b + MOD : a - b; }
inline int mul(int a, int b) { return (a * 1LL * b) % MOD; }

// ara[i] = sum of all A[id] such that A[id] can be in the 
// i'th position of a length M sequence
int ara[MAX];


// sum -> sum of the array values of respective range
// sqsum -> sum of the squared array values of respective range
// The segment tree is built over ara[]
struct node {
	int sum, sqsum;
} tree[4 * MAX];

int lazy[4 * MAX];

node Merge(node a, node b) {
	node ret;
	ret.sum = add(a.sum, b.sum);
	ret.sqsum = add(a.sqsum, b.sqsum);
	return ret;
}

void lazyUpdate(int n, int st, int ed) {
	if(lazy[n] != 0){
	    tree[n].sqsum = add(tree[n].sqsum, mul(lazy[n] + lazy[n] % MOD, tree[n].sum));
	    tree[n].sqsum = add(tree[n].sqsum, mul(ed - st + 1, mul(lazy[n], lazy[n])));
	    tree[n].sum = add(tree[n].sum, mul(ed - st + 1, lazy[n]));
	    if(st != ed){
	        lazy[n + n] = add(lazy[n + n], lazy[n]);
	        lazy[n + n + 1] = add(lazy[n + n + 1], lazy[n]);
	    }
	    lazy[n] = 0;
	}
}

void build(int n, int st, int ed) {
	lazy[n] = 0;
	if(st == ed){
	    tree[n].sum = ara[st];
	    tree[n].sqsum = mul(ara[st], ara[st]);
	    return;
	}
	int mid = (st + ed) / 2;
	build(n + n, st, mid);
	build(n + n + 1, mid + 1, ed);
	tree[n] = Merge(tree[n + n], tree[n + n + 1]);
}

// adds v to the range [i, j] or ara
void update(int n, int st, int ed, int i, int j, int v) {
	if(i > j) assert(false);
	if(i > j) return;
	lazyUpdate(n, st, ed);
	if(st > j or ed < i) return;
	if(st >= i and ed <= j) {
	    lazy[n] = add(lazy[n], v);
	    lazyUpdate(n, st, ed);
	    return;
	}
	int mid = (st + ed) / 2;
	update(n + n, st, mid, i, j, v);
	update(n + n + 1, mid + 1, ed, i, j, v);
	tree[n] = Merge(tree[n + n], tree[n + n + 1]);
}

int inp[MAX]; // input array A
int cum[MAX]; // cumulative sum array of the input

int L[MAX], R[MAX]; // index i is can occur in position S[L[i]] to S[R[i]] when doing dot product

int main() {
	ios_base::sync_with_stdio(false);
	cin.tie(0); cout.tie(0);

	// freopen("4.in", "r", stdin); 
	// freopen("4.out", "w", stdout);

	int sum_n = 0;
	int sum_q = 0;

	int T;
	cin >> T;
	for(int t=1;t<=T;t++) {
	    int n, m, q, id, v, w;
	    cin >> n >> m >> q;

	    assert(n >= 1 and n <= 5e5);
	    assert(m >= 1 and m <= 5e5);
	    assert(q >= 1 and q <= 5e5);


	    sum_n += n;
	    sum_q += n;
	    for(int i=1;i<=n;i++) {
	        cin >> inp[i];
	        assert(inp[i] >= 1 and inp[i] <= 5e5);
	        cum[i] = add(cum[i - 1], inp[i]);
	    }
	    int l = 1, r = n - m + 1;
	    for(int i=1;i<=m;i++) {
	        ara[i] = sub(cum[r], cum[l - 1]);
	        l++, r++;
	    }

	    build(1, 1, m);
	    
	    for(int i=1;i<=n;i++) L[i] = 1, R[i] = m;
	    for(int i=1;i<=m;i++) R[i] = i;
	    for(int i=m,j=n;i>=1;i--,j--) L[j] = i;

	    for(int i=1;i<=q;i++) {
	        cin >> id >> v;
	        assert(id >= 1 and id <= n);
	        assert(v >= 1 and v <= 5e5);
	        w = sub(v, inp[id]);
	        inp[id] = v;
	        update(1, 1, m, L[id], R[id], w);
	        cout << tree[1].sqsum << '\n';
	    }
	}
	assert(sum_n <= 1e6);
	assert(sum_q <= 1e6);
	return 0;
}
Tester's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val
 
using namespace std; 
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (998244353)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout) 
#define primeDEN 727999983
#define int ll

int seg[2123456],lazy[2123456];
int pre[512345]; 
int build(int node,int s,int e){
	lazy[node]=0;
	if(s==e){
		seg[node]=pre[s];
		return 0;
	}
	int mid=(s+e)/2;
	build(2*node,s,mid);
	build(2*node+1,mid+1,e);
	seg[node]=seg[2*node]+seg[2*node+1];
	if(seg[node]>=mod)
		seg[node]-=mod;
	return 0;
} 
int addm(int &a,int b){
	a+=b;
	a%=mod;
	return 0;
}
int update(int node,int s,int e,int l,int r,int val){
	if(lazy[node]){
		seg[node]+=(e-s+1)*lazy[node];
		seg[node]%=mod;
		if(s!=e){
			addm(lazy[2*node],lazy[node]);
			addm(lazy[2*node+1],lazy[node]);
		}
		lazy[node]=0;
	}
	if(r<s || e<l){
		return 0;
	}
	if(l<=s && e<=r){
		lazy[node]=val;
		seg[node]+=(e-s+1)*lazy[node];
		seg[node]%=mod;
		if(s!=e){
			addm(lazy[2*node],lazy[node]);
			addm(lazy[2*node+1],lazy[node]);
		}
		lazy[node]=0;
		return 0;
	}
	int mid=(s+e)/2;
	update(2*node,s,mid,l,r,val);
	update(2*node+1,mid+1,e,l,r,val);
	seg[node]=seg[2*node];
	addm(seg[node],seg[2*node+1]);
	return 0;
}

int query(int node,int s,int e,int l,int r){
	if(lazy[node]){
		seg[node]+=(e-s+1)*lazy[node];
		seg[node]%=mod;
		if(s!=e){
			addm(lazy[2*node],lazy[node]);
			addm(lazy[2*node+1],lazy[node]);
		}
		lazy[node]=0;
	}
	if(r<s || e<l){
		return 0;
	}
	if(l<=s && e<=r){
		return seg[node];
	}
	int mid=(s+e)/2;
	int val1=query(2*node,s,mid,l,r);
	int val2=query(2*node+1,mid+1,e,l,r);
	addm(val1,val2);
	return val1;
}
int a[512345];
signed main(){
	std::ios::sync_with_stdio(false); cin.tie(NULL);
	int t;
	cin>>t;
	while(t--){
		int n,m,q;
		cin>>n>>m>>q;
		int i;
		rep(i,n){
			cin>>a[i];
		}
		pre[0]=0;
		rep(i,n-m+1){
			pre[0]+=a[i];
		}
		f(i,1,m){
			pre[i]=pre[i-1]-a[i-1]+a[i+n-m];
			pre[i]%=mod;
			if(pre[i]<0)
				pre[i]+=mod;
		}
		build(1,0,m);
		int ans=0;
		int st,en,val;
		rep(i,n){
			st=0;
			en=m-1;
			st=max(st,i-(n-m));
			en=min(en,i);
			val=query(1,0,m,st,en);
			ans+=a[i]*val;
			ans%=mod;
			//cout<<st<<" "<<en<<" "<<val<<endl;
		}
		//cout<<ans<<endl;
		int pos,gg;
		rep(i,q){
			cin>>pos>>gg;
			pos--;
			st=0;
			en=m-1;
			st=max(st,pos-(n-m));
			en=min(en,pos);
			val=query(1,0,m,st,en);
			ans-=2*a[pos]*val;
			ans+=(en-st+1)*(a[pos]*a[pos]);
			update(1,0,m,st,en,gg-a[pos]);
			a[pos]=gg;
			val=query(1,0,m,st,en);
			ans+=2*a[pos]*val;
			ans-=(en-st+1)*a[pos]*a[pos];
			ans%=mod;
			if(ans<0)
				ans+=mod;
			cout<<ans<<"\n";
		}
	}
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class DOTTIME{
	//SOLUTION BEGIN
	long MOD = 998244353;
	void pre() throws Exception{}
	void solve(int TC) throws Exception{
	    int n = ni(), m = ni(), q = ni();
	    long[] a = new long[n];
	    for(int i = 0; i< n; i++)a[i] = nl();
	    long[] window = new long[m];
	    long windowSum = 0;
	    for(int i = 0; i< n-m; i++){
	        hold(a[i] >= 0);
	    }
	    for(int i = 0; i< n-m; i++)windowSum = add(windowSum, a[i]);
	    for(int i = 0; i< m; i++){
	        windowSum = add(windowSum, a[i+n-m]);
	        window[i] = windowSum;
	        hold(window[i] >= 0);
	        windowSum = add(windowSum, MOD-a[i]);
	    }
	    SegTree t = new SegTree(window);
	    while(q-->0){
	        int p = ni()-1;
	        long delta = nl()-a[p];
	        a[p] = add(a[p], delta);
	        int left = Math.max(0, p-(n-m)), right = Math.min(p, m-1);
	        t.u(left, right, delta);
	        pn(t.sum());
	    }
	}
	long add(long x, long y){
	    if(x < 0)x += MOD;
	    if(y < 0)y += MOD;
	    return x+y>=MOD?(x+y-MOD):(x+y);
	}
	long mul(long x, long y){
	    if(x < 0)x += MOD;
	    if(y < 0)y += MOD;
	    return (x*y)%MOD;
	}
	class SegTree{
	    int m = 1;
	    long[] t, t2, lazy;
	    public SegTree(int n){
	        while(m<n)m<<=1;
	        t = new long[m<<1];
	        t2 = new long[m<<1];
	        lazy = new long[m<<1];
	    }
	    public SegTree(long[] a){
	        while(m< a.length)m<<=1;
	        t = new long[m<<1];
	        t2 = new long[m<<1];
	        lazy = new long[m<<1];
	        for(int i = 0; i< a.length; i++){
	            t[i+m] = a[i];
	            t2[i+m] = mul(a[i], a[i]);
	        }
	        for(int i = m-1; i > 0; i--){
	            t[i] = add(t[i<<1], t[i<<1|1]);
	            t2[i] = add(t2[i<<1], t2[i<<1|1]);
	        }
	    }
	    void push(int i, int ll, int rr){
	        if(lazy[i] != 0){
	            long sumX = t[i];
	            long sumX2 = t2[i];
	            long delta = lazy[i], sz = rr-ll+1;
	            t[i] = add(sumX, mul(delta, sz));
	            t2[i] = add(sumX2, add(mul(sz, mul(delta, delta)), mul(2, mul(delta, sumX))));
	            if(i < m){
	                lazy[i<<1] = add(lazy[i<<1], lazy[i]);
	                lazy[i<<1|1] = add(lazy[i<<1|1], lazy[i]);
	            }
	            lazy[i] = 0;
	        }
	    }
	    void u(int l, int r, long x){u(l, r, x, 0, m-1, 1);}
	    void u(int l, int r, long x, int ll, int rr, int i){
	        push(i, ll, rr);
	        if(l == ll && r == rr){
	            lazy[i] = add(lazy[i], x);
	            push(i, ll, rr);
	            return;
	        }
	        int mid = (ll+rr)/2;
	        if(r <= mid){
	            u(l, r, x, ll, mid, i<<1);
	            push(i<<1|1, mid+1, rr);
	        }
	        else if(l > mid){
	            push(i<<1, ll, mid);
	            u(l, r, x, mid+1, rr, i<<1|1);
	        }
	        else{
	            u(l, mid, x, ll, mid, i<<1);
	            u(mid+1, r, x, mid+1, rr, i<<1|1);
	        }
	        t[i] = add(t[i<<1], t[i<<1|1]);
	        t2[i] = add(t2[i<<1], t2[i<<1|1]);
	    }
	    long sum(){return t2[1];}
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	DecimalFormat df = new DecimalFormat("0.00000000000");
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new DOTTIME().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

7 Likes

If you look at the change in answer after each operation, it can be calculated using a simple sum based segment tree without maintaining sum of squares. https://www.codechef.com/viewsolution/32285096

12 Likes

@taran_1407 @teja349 @s_h_shahin @admin

Can you please check this answer. This is completely wrong but still it got 100 points. It should complete only sub task 1. It should under no circumstances complete subtask 2 for which it is returning 0. Pls check it, this is some sort of bug.

Even in the contest ratings, he got 100 points.

https://www.codechef.com/viewsolution/32241239

this is some serious bug in codechef(verdict of the link mentioned above)

he is doing if (q != 1)return 0;
it should get only 50 points

@admin

On his profile, he has submitted three answers with WA, 100 and 50 respectively. Whereas the in the all submissions for this question, it shows only WA and 100.

Moreover since he got only task 1, hence he took only 0.25 seconds, hence he is ranked one in the correct submissions.

And this is clearly a bug since there is no chart below the code showing the time and verdict (WA, AC, SIGSEV, SIGABRT, etc) for each testcase, as you see on other submission.

Hi,

Thanks for reporting. We found this issue a day after the contest, and this affected (in a fortunate way for them) a few other users in that contest as well. This was related to the migration to the new checkers that we were experimenting with at that time, and was fixed soon after. But we decided not to change the verdicts as changing the verdict after the contest would be unfair to them because they did not have the opportunity to fix it during the contest. As it gave an unfair advantage only to a few users, we’ve decided not to change the ranklist and keep it as it is. This does affect all the other users, in the sense that their rank should have actually been slightly lower. Apologies for this issue.

https://www.codechef.com/viewsolution/33165291

can somebody help me pls