WATERBUCKETS - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: mshcherba
Tester: raysh_07
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Square-root decomposition, segment trees, binary search

PROBLEM:

You have an array A of length N, all of its elements being \leq M.
Process Q updates and queries on it:

  • Point updates: given x and y, set A_x := y
  • Range queries: given L and R, greedily partition the array [A_L, A_{L+1}, \ldots, A_R] into subarrays, each time picking the largest prefix whose sum doesn’t exceed M.
    Find the number of subarrays in this partition.

EXPLANATION:

Let’s first try to solve a simpler version of this task: one where there are no point updates.
Let \text{jump}[i] denote the index we end up at if we take the longest possible subarray starting from i with sum \leq M.
Finding all the \text{jump}[i] values isn’t too hard, for instance with binary search or a two-pointer algorithm.

Then, to answer a range query (L, R), we must repeat the following process:

  • Let our current index be x. Initially, x = L.
  • Go from x to \text{jump}[x] + 1, and increment the answer by 1.
    This represents the entire operation of taking a bucket and filling it as much as possible, then going to the next starting point.
  • Repeat the above process as long as x \leq R.

This still takes \mathcal{O}(N) time (per query), of course.
However, it’s easily improved to \mathcal{O}(\log N): since we only care about the number of jumps till we exceed N, we can use binary lifting.
That is, for each index i and each integer j, compute the position you reach if you jump 2^j times. Note that we only care about j \leq \log_2 N, so this is \mathcal{O}(N\log N) memory and time (since the jumps of length 2^j can be computed as two jumps of length 2^{j-1}).
Then, multiple jumps can be quickly simulated, allowing for a query to be answered in \mathcal{O}(\log N) time.

Unfortunately, there’s no simple way to quickly update such a binary lifting table when point updates are involved.
Nevertheless, we’ll use the idea of trying to perform multiple jumps at the same time — we just need to apply it to some information that’s a bit easier to keep updated.


Let’s divide the array into contiguous blocks of B elements each (for some constant B that will be decided later).

For each index i, let’s compute two values:

  • X_i — the number of subarrays we obtain if we start the process at i and end it at the rightmost element of the block containing i.
  • Y_i — the sum of the X_i-th subarray in this process.

Suppose we knew X_i and Y_i for each index.
Then, a query (L, R) can be answered as follows:

  • If L and R lie in the same block, directly simulate the process in \mathcal{O}(B) (recall that each block has size B).
  • Otherwise, we start at L, and use X_L steps to jump directly to the endpoint of the block containing it.
  • Now, we enter the next block; but this time, with only M - Y_i space remaining.
    Binary search lets us find how many more elements can be taken, after which we end up at some index i.
  • From i, again we use X_i to go directly to the end of its block, and so on.
  • Finally, when we reach the block that contains R, process the last part directly in \mathcal{O}(B).

This way, we perform at most one binary search for each block between L and R; and otherwise do \mathcal{O}(1) work per block since we have the precomputed X_i and Y_i values.
Apart from that, we do \mathcal{O}(B) work in at most one block.
There are \frac{N}{B} blocks, so each query is answered in \mathcal{O}(B + \frac{N}{B} \log B).

Now, we need to deal with point updates, and keep X_i and Y_i updated while we’re at it.
To do that, we can use a segment tree.
Specifically, if A_k is updated, we do the following:

  • X_i and Y_i values for indices not in the block of k don’t change at all.
    Even within the block containing k, only the values for i \leq k change.
  • For each i = k, k-1, k-2, \ldots in order, till the leftmost index of the block, do the following:
    • Find the largest subarray with sum \leq M starting at index i. This can be done in \mathcal{O}(\log B) with binary search on a segment tree.
      Let this subarray end at index j.
    • If j is the last element of the block, set X_i = 1 and Y_i to be the sum of all elements after i (within its block).
    • Otherwise, set X_i to X_{j+1} + 1 and Y_i to Y_{j+1}.

Each update is thus processed in \mathcal{O}(B\log B) time.

Choosing B = \sqrt{N} gets us a complexity of \mathcal{O}(\sqrt{N}\log N) for both updates and queries, which is fast enough.
It’s possible to choose slightly better B than this, as seen by the computation here, but in practice choosing a constant B that’s reasonable will work just fine.

TIME COMPLEXITY:

\mathcal{O}(N + Q\cdot (B + \frac{N}{B}\log B + B\log B)) per testcase, for a constant B.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;

const int MAX = 100'007;
const int LOG = 17;
const int K = 300;

int n, c;
int a[MAX];
long long s[MAX];
int lift[MAX][LOG];

void update()
{
	int ptr = 0, sum = 0;
	for(int i = 0; i < n; i++)
	{
		s[i] = (i == 0 ? 0 : s[i - 1]) + a[i];
		while(ptr < n && sum + a[ptr] <= c)
			sum += a[ptr++];
		lift[i][0] = ptr;
		sum -= a[i];
	}
	lift[n][0] = n;
	for(int i = n; i >= 0; i--)
		for(int j = 1; j < LOG; j++)
			lift[i][j] = lift[lift[i][j - 1]][j - 1];
}

int doLift(int& l, int r)
{
	int res = 0;
	for(int j = LOG - 1; j >= 0; j--)
	{
		if(lift[l][j] <= r)
		{
			l = lift[l][j];
			res += 1 << j;
		}
	}
	return res;
}

int main()
{
	cin >> n >> c;
	for(int i = 0; i < n; i++)
		cin >> a[i];
	update();
	int q;
	cin >> q;
	set<int> changed{n};
	while(q--)
	{
		int t;
		cin >> t;
		if(t == 1)
		{
			int l, r, res = 0;
			cin >> l >> r;
			l--;
			r--;
			auto it = changed.lower_bound(l);
			while(l <= r)
			{
				if(*it > r)
				{
					res += doLift(l, r) + 1;
					break;
				}
				res += doLift(l, *it - 1);
				long long sum = (*it == 0 ? 0 : s[*it - 1]) - (l == 0 ? 0 : s[l - 1]);
				assert(sum <= c);
				auto itR = it;
				itR++;
				while(true)
				{
					if(sum + a[*it] > c)
					{
						l = *it;
						res++;
						break;
					}
					sum += a[*it];
					long long sIt = s[*it];
					int pos = upper_bound(s + *it + 1, s + *itR, c - sum + sIt) - s;
					if(pos < *itR)
					{
						l = pos;
						it++;
						res++;
						break;
					}
					sum += s[*itR - 1] - sIt;
					it++;
					itR++;
					if(*it == n)
					{
						l = n;
						res++;
						break;
					}
				}
			}
			cout << res << "\n";
		}
		else
		{
			int i;
			cin >> i;
			i--;
			cin >> a[i];
			changed.insert(i);
			if(changed.size() > K)
			{
				update();
				changed = {n};
			}
		}
	}
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
#define f first
#define s second

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

int n, m, q;
const int N = 1e5 + 69;
const int B = 600;
int block[N], a[N], seg[4 * N], p[N];
pair <int, int> dp[N];
int go[N];

void Print(int l, int r, int pos){
    cout << l << " " << r << " " << seg[pos] << "\n";
    if (l == r) return;
    int mid = (l + r) / 2;
    Print(l, mid, pos * 2);
    Print(mid + 1, r, pos * 2 + 1);
}

void upd(int l, int r, int pos, int qp){
    if (l == r){
        seg[pos] = a[l];
        return;
    }

    int mid = (l + r) / 2;
    if (qp <= mid) upd(l, mid, pos * 2, qp);
    else upd(mid + 1, r, pos * 2 + 1, qp);

    seg[pos] = seg[pos * 2] + seg[pos * 2 + 1];
}

int range_sum(int l, int r, int pos, int ql, int qr){
   // cout << "RANGE SUM " << l << " " << r << " " << pos << " " << ql << " " << qr << "\n";
    if (l >= ql && r <= qr){
       // cout << "ADDED " << seg[pos] << "\n";
        return seg[pos];
    } else if (l > qr || r < ql){
       // cout << "RETURNED\n";
        return 0;
    } else {
       // cout << "GOING DEEPER\n";
        int mid = (l + r) / 2;
        return range_sum(l, mid, pos * 2, ql, qr) + range_sum(mid + 1, r, pos * 2 + 1, ql, qr);
    }
}

int find_last(int l, int r, int pos, int v){
    if (l == r && v >= seg[pos]) return n + 1;
    if (l == r) return l;

    int mid = (l + r) / 2;
    if (v < seg[pos * 2]) return find_last(l, mid, pos * 2, v);
    else return find_last(mid + 1, r, pos * 2 + 1, v - seg[pos * 2]);
}

void updbl(int l, int r){
    p[l] = a[l];
    go[l] = l;
    for (int i = l + 1; i <= r; i++){
        p[i] = p[i - 1] + a[i];
        if (p[i] <= m) go[l] = i;
    }

    int x = l + 1;
    for (int i = l + 1; i <= r; i++){
        while (x != r && p[x + 1] - p[i - 1] <= m){
            x++;
        }
        go[i] = x;
    }

    for (int i = r; i >= l; i--){
        if (go[i] == r){
            dp[i] = {0, i};
        } else {
            dp[i].f = dp[go[i] + 1].f + 1;
            dp[i].s = dp[go[i] + 1].s;
        }
    }
}

void Solve() 
{
    // Holes Logic : dp[i] = maxmimum moves you can do while staying inside this block 
    cin >> n >> m;

    for (int i = 1; i <= n; i++){
        cin >> a[i];
        block[i] = i / B;
        upd(1, n, 1, i);
    }
    
    // cout << "PRINTING SEGTREE\n";
    // Print(1, n, 1);

    for (int i = 0; i <= n / B; i++){
        updbl(max(1LL, i * B), min(n, (i + 1) * B - 1));
    }
    // for (int i = 1; i <= n; i++){
    //   //  cout << go[i] << " \n"[i == n];
    //      cout << dp[i].f << " " << dp[i].s << "\n";
    // }

    cin >> q;

    for (int _ = 1; _ <= q; _++){
        int t; cin >> t;

        if (t == 1){
            int l, r; cin >> l >> r;
            int ans = 0;
            while (l / B != r / B){
                ans += dp[l].f;
                l = dp[l].s;
                int pref_sum = range_sum(1, n, 1, 1, l - 1);
                int ok = find_last(1, n, 1, pref_sum + m) - 1;
                ans++;
                if (ok >= r){
                    l = r + 1;
                    break;
                }
                l = ok + 1;
            }
            
            int curr = 0;
            if (l <= r)
            ans++;
            for (int i = l; i <= r; i++){
                curr += a[i];
                if (curr > m) curr = a[i], ans++;
            }

            cout << ans << "\n";
        } else {
            int p, x; cin >> p >> x;
            a[p] = x;
            int i = p / B;
            updbl(max(1LL, i * B), min(n, (i + 1) * B - 1));
            upd(1, n, 1, p);
        }
    }
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
  //  cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
1 Like

an altogether evil complexity that should be avoided at all costs.

1 Like