BANGER - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: satyam_343
Testers: apoorv_me, tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Square-root decomposition

PROBLEM:

For an array B, define \text{score}(B) to be the maximum possible value of \text{dist}(C) across all non-decreasing integer arrays C that satisfy 1 \leq C_i \leq B_i for every i.
Here, \text{dist}(C) denotes the number of distinct elements in C.

You’re given an array A.
Process Q point updates to it - after each one, compute the sum of scores across all subarrays of A.

EXPLANATION:

Let’s forget about updates initially, and focus on just computing the sum of scores across all subarrays.
This is best built one step at a time.

Step 1: A single array

An obvious first step is to compute \text{score}(B) for a fixed array B of length N.
This turns out to be not that hard: the natural greedy strategy of setting C_1 = 1 and then C_i = \min(C_{i-1} + 1, B_i) for i\gt 1 works (after adjusting C to ensure it’s sorted, if we’re forced to place B_i), and it isn’t too hard to see why.

However, this greedy idea isn’t all that useful by itself, so let’s look at the process a little closer.
Going from left to right,

  • If i \leq B_i, we just place i.
  • When we first hit an index with i \gt B_i, we place B_i here (and to maintain sortedness, set anything that was previously \gt B_i, to B_i).
    Let this first ‘bad’ index be x_1.
  • Note that for indices i \gt x_1, we no longer want to check whether i \leq B_i.
    Instead, we want to check whether B_{x_1} + (i - x_1) \leq B_i.
    Rewriting this, the check is i - B_i \leq x_1 - B_{x_1}.
    As long as this check is satisfied, we can repeatedly place the values B_{x_1} + 1, B_{x_1} + 2, \ldots
  • Let x_2 be the first index where this check is violated. Once again, observe that after being forced to place B_{x_2} here, further checks will be whether i - B_i \leq x_2 - B_{x_2}, and so on.

This tells us that the important quantity here is really the value (i - B_i).
In fact, if we set B_0 = 0 and pretend our array starts with 0 always, the initial i \leq B_i check also transforms into i - B_i \leq 0 - B_0.
Define d_i := B_i - i.

Observe that each time the inequality is violated, we move to a strictly larger value of d_i.
So, let K = \max_{i=0}^N(d_i), and m be the index at which this maximum occurs.
If there are multiple occurrences of K, choose the leftmost one as m.

Then, we know that we’re forced to set C_m = B_m, and all values from 1 to B_m - 1 appear before it.
Further, by virtue of it being the maximum d_i, there are no further restrictions to its right, meaning we place the values B_m+1, B_m+2, \ldots, B_m + (N - m).

In other words, we simply have \text{score}(B) = N + (B_m - m) = N - K, since K = B_m - m by definition.
This is a rather useful criterion to have!

Step 2: All suffixes

Next, let’s figure out how to compute the scores of all suffixes of B.
Once again, let’s define d_i := B_i - i, with K being the maximum of all their values and m being the leftmost occurrence of this maximum.

Note that one way of creating the array C is as follows:

  • Set C_m := B_m.
  • For i \gt m, set C_i := C_m + (i - m).
  • For i \lt m, set C_i := \max(1, C_m - (m - i))

That is, have a long prefix of 1's, followed by 2, 3, 4, \ldots in such a way that C_m = B_m.
Given how the index m was computed, it’s not hard to prove that this construction is always possible.

In particular, with this construction, the last occurrence of 1 is at index m - B_m + 1 = K+1.
So,

  • For all indices i \geq K + 1, the suffix B[i, N] has a score of N-i+1, since it can be made to consist of distinct elements.
  • For all indices i \leq K, the suffix B[i, N] has a score equal to that of B itself, that being N - K.
    • Our construction already gives such suffixes this score as a lower bound. To see why it’s also optimal, it’s easily shown that \text{score}(B[i, N]) \geq \text{score}(B[i+1, N]) for any index i — meaning a longer suffix can never have a smaller score than a shorter one (quick proof: take any optimal solution for B[i, N] and insert 1 at its beginning.)

So, the sum of scores across all suffixes is just

1+2+3+\ldots + (N-K) + K\cdot (N-K)

Step 3: All subarrays

Now that we know how to find the sum of scores across all suffixes of an array, it’s easy to extend the idea to summing the score across all subarrays: simply sum up the answers for the suffixes of each prefix of the array.

The nice thing about this is that the array d remains exactly the same across the entire process: when processing all subarrays ending at i, we only care about the prefix maximum of d till index i.

With this, we’re now ready for…

Step 4: Handling updates

With our array A in hand, let’s define d_i := A_i - i, as before.
Suppose we set A_x = y.
The only change to the array d is at index x, since we need to set d_x := y - i.

However, our answer computation requires us to care about prefix maximums of d, and those change for (possibly) every index \geq x.
There also isn’t really a nice structure on the prefix maximums, since if d_x reduces there’s potentially a lot of recomputation needed, beginning from index x+1.

Instead, we use square-root decomposition.

Let’s partition A into blocks of size B (for some constant B that will be chosen later).
For each block, maintain prefix maximums of the d_i values.
Note that these prefix maximums are within the block only, and not global prefix maximums.
Also maintain the sum of answers of subarrays ending at each index when considering this in-block prefix maximum d_i (while the prefix maximum is in-block, we consider all subarrays to compute the score, even those that start outside the block - the reason for this should become clear soon.)

This allows us to process both updates and queries at a reasonable complexity.

Updates

Each block maintains only the prefix maximums of d_i values within it.
When updating position x, simply recompute these values for the entire block containing x, which takes \mathcal{O}(B) time.

Queries

Process the blocks in order from left to right.

Suppose we’re processing the i-th block.
Let M_i denote the prefix maximum of d till just before the i-th block’s start.
Then, since we know the prefix maximums within the i-th block, and those are sorted:

  • Find the last prefix maximum within the block that’s \lt M_i.
    All of these indices will then have a prefix maximum of M_i globally.
    Compute the sum of answers for them all, which requires a bit of simple algebra.
  • For all remaining indices, their in-block prefix maximums will in fact be their global prefix maximums, so we can just sum up their already-computed scores!
  • Then, update M_i with the prefix maximum of this block and continue on.

There are \mathcal{O}(\frac{N}{B}) blocks, and in each one we perform one binary search in \mathcal{O}(\log B) to find the partition point.

So, each update + answer recomputation can be done in \mathcal{O}(B + \frac{N}{B}\log B) time.
Choosing B = \sqrt{N\log N} gives a complexity of \mathcal{O}(\sqrt{N\log N}) per update/query, which is fast enough to get AC.

TIME COMPLEXITY:

\mathcal{O}(N\sqrt{N\log N}) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h> 
using namespace std;
#define ll long long
#define nline "\n"
#define all(x) x.begin(),x.end()
const ll MAX=200200;
const ll till=20;
const ll MOD=998244353;
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count()); 
ll sum[2][MAX];
ll getv(ll p,ll l,ll r){
	return sum[p][r]-sum[p][l-1];
}
ll get_sum(ll nax,ll l,ll r){
	if(l>r){
		return 0ll;
	}
	ll low=l-nax+1,high=r-nax+1;
	return getv(0,low,high)*(nax-1)+getv(1,low,high);
}
ll do_recomp=0;
const ll MAX_BLOCK=1005;
struct magic{
	ll len,lft,nax;
	ll a[MAX_BLOCK],pref[MAX_BLOCK];
	magic(ll len,ll lft){
		this->len=len;
		this->lft=lft;
		nax=0;
		for(ll i=1;i<=len+1;i++){
			a[i]=pref[i]=0;
		}
	}
	void recompute(){
		nax=0;
		for(ll i=1;i<=len;i++){
			ll pos=lft+i;
			nax=max(nax,pos-a[i]+1);
			pref[i]=pref[i-1]+get_sum(nax,pos,pos);
		}
	}
	void upd(ll ind,ll v){
		ind-=lft;
		a[ind]=v;
		if(do_recomp){
		    recompute();
		}
	}
	ll sum(ll l,ll r){
		l-=lft,r-=lft;
		if(l>r){
			return 0ll;
		}
		return pref[r]-pref[l-1];
	}
	ll sum_all(){
		return pref[len];
	}
};
struct get_index{
	ll len,lft,external_min;
	ll pos[MAX_BLOCK],suffix_min[MAX_BLOCK];
	get_index(ll len,ll lft){
		this->len=len;
		this->lft=lft;
		external_min=MAX;
		for(ll i=1;i<=len+1;i++){
			pos[i]=suffix_min[i]=MAX;
		}
	}
	void recompute(){
		for(ll i=len;i>=1;i--){
			suffix_min[i]=min(suffix_min[i+1],pos[i]);
		}
	}
	void upd(ll v,ll ind){
		v-=lft;
		pos[v]=ind;
		if(do_recomp){
		    recompute();
		}
	}
	ll min_val(ll v){
		v-=lft;
		if(v>len){
			return MAX;
		}
		return min(external_min,suffix_min[v]);
	}
};
void solve(){
	ll n,q; cin>>n>>q;
	vector<ll> a(n+5),block_number(n+5),b(n+5);
	vector<set<ll>> track(n+5);
	ll block_size=min(1000ll,n);
	do_recomp=0;
	for(ll i=1;i<=n;i++){
		cin>>a[i];
		b[i]=max(1ll,i-a[i]+1);
		track[i].insert(n+1);
		track[b[i]].insert(i);
		block_number[i]=(i+block_size-1)/block_size-1;
	}
	vector<magic> to_get_sum; 
	vector<get_index> to_get_min; 
	auto init=[&](){
		ll l=1,r=block_size;
		while(l<=n){
			r=min(r,n);
			magic for_sum(r-l+1,l-1);
			get_index for_index(r-l+1,l-1);
			for(ll i=l;i<=r;i++){
				for_sum.upd(i,a[i]);
				for_index.upd(i,*track[i].begin());
			}
			for_sum.recompute();
			for_index.recompute();
			to_get_sum.push_back(for_sum);
			to_get_min.push_back(for_index);
			l=r+1,r+=block_size;
		}
		ll nin=n+1;
		for(ll i=block_number[n];i>=0;i--){
			to_get_min[i].external_min=n+1;
			nin=min(nin,to_get_min[i].suffix_min[1]);
		}
	};
	auto find_ans=[&](){
		ll l=1,r=block_size,nax=1,ans=0;
		while(l<=n){
			r=min(r,n); ll till=n;
			if(nax!=n){
				till=to_get_min[block_number[nax+1]].min_val(nax+1);
			}
			ll diff=ans;
			if(till>r){
				ans+=get_sum(nax,l,r);
			}
			else{
				ans+=get_sum(nax,l,till-1)+to_get_sum[block_number[l]].sum(till,r);
			}
			nax=max(nax,to_get_sum[block_number[l]].nax);
			l=r+1,r+=block_size;
			
		}
		return ans;
	};
	init();
	do_recomp=1;
	auto update=[&](ll ind,ll v){
		track[b[ind]].erase(ind);
		to_get_min[block_number[b[ind]]].upd(b[ind],*track[b[ind]].begin());
		a[ind]=v; b[ind]=max(1ll,ind-a[ind]+1);
		track[b[ind]].insert(ind); 
		to_get_min[block_number[b[ind]]].upd(b[ind],*track[b[ind]].begin());
		to_get_sum[block_number[ind]].upd(ind,a[ind]);
		ll nin=n+1;
		for(ll i=block_number[n];i>=0;i--){
			to_get_min[i].external_min=nin;
			nin=min(nin,to_get_min[i].suffix_min[1]);
		}
	};
	while(q--){
		ll ind,v; cin>>ind>>v;
		update(ind,v);
		cout<<find_ans()<<nline;
	}
}
int main()                                                                                 
{         
  ios_base::sync_with_stdio(false);                         
  cin.tie(NULL);                                  
  ll test_cases=1;                 
  cin>>test_cases;
  for(ll i=1;i<MAX;i++){
  	 sum[0][i]=sum[0][i-1]+i;
  }
  for(ll i=1;i<MAX;i++){
  	 sum[1][i]=sum[1][i-1]+(i*(i+1))/2;
  }
  while(test_cases--){
  	 solve();
  }
  cout<<fixed<<setprecision(10);
  cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
} 
Tester's code (apoorv_me, C++)
#include<bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...)
#endif

int32_t main() {
  ios_base::sync_with_stdio(0);
  cin.tie(0);

  constexpr int SQRT = 1000;
  auto __solve_testcase = [&](int test) {
    int N, NQ;  cin >> N >> NQ;
    vector<int> A(N);
    for(auto &a: A)
      cin >> a;

    for(int i = 0 ; i < N ; ++i)
      A[i] = i + 1 - A[i];

    int M = (N + SQRT - 1) / SQRT;
    vector<int> pf(N);
    vector<int> L(M), R(M), pos(N);
    vector<int64_t> ans(N);
    for(int i = 0 ; i < N ; i += SQRT) {
      L[i / SQRT] = i;
      R[i / SQRT] = min(i + SQRT, N);
    }
    for(int i = 0 ; i < M ; ++i) for(int j = L[i] ; j < R[i] ; ++j)
      pos[j] = i;

    auto aps = [&](int64_t a) { // sum of 1 + 2 + 3 - - -
      return a * (a + 1) / 2;
    };
    auto sqs = [&](int64_t b) { // sum of 1^2 + 2^2 - - - 
      return b * (b + 1) * (2 * b + 1) / 6;
    };

    auto get = [&](int64_t x, int64_t s) {
      if(x < 0) x = 0;
      return aps(s - x) + x * (s - x);
    };

    auto gett = [&](int64_t x, int64_t s1, int64_t s2) {
      if(x < 0) x = 0;
      // 1 / 2 s ^ 2 - (2x - 1) * s + x * (x - 1);
      int64_t res = sqs(s2) - sqs(s1 - 1);
      res -= (2 * x - 1) * (aps(s2) - aps(s1 - 1));
      res += x * (x - 1) * (s2 - s1 + 1);
      res /= 2;

      // x * s - x * x
      res += x * (aps(s2) - aps(s1 - 1));
      res -= x * x * (s2 - s1 + 1);
      return res;
    };

    auto doit = [&](int blok) {
      ans[blok] = 0;
      for(int j = L[blok] ; j < R[blok] ; ++j) {
        pf[j] = A[j];
        if(j > L[blok])
          pf[j] = max(pf[j], pf[j - 1]);
        ans[j] = get(pf[j], j + 1);
        if(j > L[blok])
          ans[j] += ans[j - 1];
      }
    };

    for(int i = M - 1 ; i >= 0 ; --i)
      doit(i);

    for(int q = 0 ; q < NQ ; ++q) {
      int p, val; cin >> p >> val;
      --p; 
      A[p] = p - val + 1;
      doit(pos[p]);
      int64_t res = ans[R[0] - 1], bst = pf[R[0] - 1];
      for(int i = 1 ; i < M ; ++i) {
        if(pf[L[i]] >= bst) {
          res += ans[R[i] - 1];
          bst = pf[R[i] - 1];
          continue;
        }
        if(pf[R[i] - 1] < bst) {
          res += gett(bst, L[i] + 1, R[i]);
          continue;
        }
        int p = lower_bound(pf.begin() + L[i], pf.begin() + R[i], bst) - pf.begin();
        res += ans[R[i] - 1] - ans[p - 1];
        res += gett(bst, L[i] + 1, p);
        bst = pf[R[i] - 1];
      }
      cout << res << '\n';
    }

  };
  
  int NumTest = 1;
  cin >> NumTest;
  for(int testno = 1; testno <= NumTest ; ++testno) {
    __solve_testcase(testno);
  }
  
  return 0;
}
Tester's code (tabr, C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

void solve(int n, int q, vector<int>& a, vector<int>& p, vector<int>& x) {
    for (int i = 0; i < n; i++) {
        a[i] = max(i - a[i] + 1, 0);
    }

    int b = (int) floor(sqrt(n + 1));
    vector f((n - 1) / b + 1, vector<pair<int, int>>());
    vector g(n, 0LL);
    auto update = [&](int t) {
        int low = b * t;
        int high = min(n, b * (t + 1));
        int mx = -1;
        f[t].clear();
        for (int i = low; i < high; i++) {
            if (mx < a[i]) {
                mx = a[i];
                f[t].emplace_back(a[i], i);
            }
            g[i] = (i + mx + 2) * 1LL * (i - mx + 1);
        }
        f[t].emplace_back(n, high);
        for (int i = high - 2; i >= low; i--) {
            g[i] += g[i + 1];
        }
    };

    for (int t = 0; t <= (n - 1) / b; t++) {
        update(t);
    }

    auto get_sum = [&](long long t, long long mx) {
        long long res = 0;
        res += t * t * t;
        res += 3 * t * t;
        res += 2 * t;
        res -= 3 * mx * t;
        res -= 3 * mx * mx * t;
        return res / 3;
    };

    for (int i = 0; i < q; i++) {
        p[i]--;
        a[p[i]] = max(p[i] - x[i] + 1, 0);
        update(p[i] / b);
        long long res = 0;
        int mx = -1;
        for (int t = 0; t <= (n - 1) / b; t++) {
            auto it = lower_bound(f[t].begin(), f[t].end(), make_pair(mx + 1, -1));
            if (it->first < n) {
                res += g[it->second];
            }
            res += get_sum(it->second, mx) - get_sum(b * t, mx);
            mx = max(mx, f[t].rbegin()[1].first);
        }
        cout << res / 2 << "\n";
    }
}

////////////////////////////////////////

#define IGNORE_CR

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
#ifdef IGNORE_CR
            if (c == '\r') {
                continue;
            }
#endif
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            assert(!isspace(buffer[pos]));
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    input_checker in;
    int tt = in.readInt(1, 2e5);
    in.readEoln();
    int sn = 0;
    int sq = 0;
    while (tt--) {
        int n = in.readInt(1, 2e5);
        in.readSpace();
        sn += n;
        int q = in.readInt(1, 2e5);
        in.readEoln();
        sq += q;
        auto a = in.readInts(n, 1, n);
        in.readEoln();
        vector<int> p(q), x(q);
        for (int i = 0; i < q; i++) {
            p[i] = in.readInt(1, n);
            in.readSpace();
            x[i] = in.readInt(1, n);
            in.readEoln();
        }
        solve(n, q, a, p, x);
    }
    cerr << sn << endl;
    cerr << sq << endl;
    assert(sn <= 2e5);
    assert(sq <= 2e5);
    in.readEof();
    return 0;
}