MAXSEG - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3

Authors: tle_99
Tester: Danny Mittal
Editorialist: Nishank Suresh

DIFFICULTY:

Medium

PREREQUISITES:

Binary lifting, binary search, prefix sums

PROBLEM:

You are given two arrays A and B of length N. You have to process Q queries of the following form:

Each query is a pair of integers (L, R) with L \leq R. Consider a partition of [L, R] into subintervals [l_1, r_1], [l_2, r_2], \ldots, [l_k, r_k] such that B_{l_i} + B_{l_i + 1} + \ldots + B_{r_i} = 1 for each 1 \leq i \leq k.

The goodness of such a partition is the number of subintervals [l_i, r_i] such that A_{l_i} + A_{l_i + 1} + \ldots + A_{r_i} \geq 0. For each query, find the maximum possible goodness across all partitions satisfying the given condition.

QUICK EXPLANATION:

  • For each index i, compute jump_i — the optimal right endpoint of an interval starting at i.
  • This is done greedily, by binary searching on prefix sums of all those indices j such that B_1 + B_2 + \ldots B_j = B_1 + B_2 + \ldots + B_{i-1} + 1.
  • Given a query [L, R], the number of intervals it is divided into is always B_L + \ldots + B_R.
  • Use binary lifting on the jump array to compute the answer for the first B_L + \ldots + B_R - 1 intervals in \mathcal{O}(\log N).
  • The last interval’s contribution to the answer can then be computed in \mathcal{O}(1) using prefix sums.

EXPLANATION:

SUBTASK 1

The first subtask has N \leq 1000 and only one query. For convenience, let’s assume L = 1 and R = N, because the other elements will not affect the answer anyway.

Because N \leq 1000, this subtask can be solved by a rather simple \mathcal{O}(N^2) dynamic programming approach.

Let dp_i denote the maximum goodness of the subarray A[1 \ldots i]. If there is no way to partition [1 \ldots i] into subarrays satisfying the condition on B, we set dp_i = -1.
Our answer is dp_N (or 0, if dp_N = -1), with the base case being dp_0 = 0.
Suppose we fix the last subarray of the partition, i.e, fix the last subarray to be [j, i] where j \leq i.
If B_j + B_{j+1} + \ldots + B_i \neq 1, this cannot be chosen as a subarray in any partition. Similarly, if dp_{j-1} = -1, this cannot be chosen as the last subarray (since doing so would leave no way to partition the remaining portion).
If both dp_{j-1} \geq 0 and B_j + B_{j+1} + \ldots + B_i = 1, we have

dp_i = \begin{cases} \max(dp_i, dp_{j-1} + 1), & \text{if } sum(j, i) \geq 0 \\ \max(dp_i, dp_{j-1}), & \text{otherwise} \end{cases}

Iterating over all j once we fix i gives us an \mathcal{O}(N^2) solution, assuming we are able to calculate B_j + B_{j+1} + \ldots, B_i and sum(j, i) = A_j + A_{j+1} + \ldots + A_i for a given i and j in \mathcal{O}(1).

This can be done in several ways - either by iterating j downwards from i and maintaining the current sum of the A and B arrays, or simply by using prefix sums.

Code for this subtask.

SUBTASK 2

Once again we have Q = 1, so let’s assume L = 1 and R = N. N is too large for \mathcal{O}(N^2) to work here, so we need to do better.

Below are two different solutions to this subtask — one which speeds up subtask 1 using data structures but doesn’t generalize to subtask 3, and one which does generalize to subtask 3.

Solution 1 (Speeding up Subtask 1)

Let’s look at the algorithm for subtask 1 again: we fix a right endpoint i, and then iterate across all j < i to compute dp_i.

Let’s ignore the constraint on B for now.
Let pref_i = A_1 + A_2 + \ldots + A_i.
Note that we don’t care about the exact value of A_j + A_{j+1} + \ldots + A_i here — we only care whether it is non-negative or not.
So, if we write out the sum in terms of prefix sums as A_j + \ldots + A_i = pref_i - pref_{j-1}, all indices j such that pref_{j-1} \leq pref_i are equivalent, and all indices such that dp_{j-1} > pref_i are equivalent.

Thus, if we were able to compute the following two values:

  • ans_1 = \max dp_{j-1} across all 1 \leq j \leq i such that pref_{j-1} \leq pref_i
  • ans_2 = \max dp_{j-1} across all 1 \leq j \leq i such that pref_{j-1} > pref_i

we would simply have dp_i = \max(ans_1 + 1, ans_2).
Notice that we have reduced the problem of computing dp_i to two range queries!

Note that our ranges are on prefix sums, so we need to build a range query structure on those. This can be accomplished by either using an implicit segment tree or coordinate compression.

Now we take care of the constraint on B.
We want B_j + B_{j+1} + \ldots + B_i = 1, which means that the range should contain exactly one index k with B_k = 1.
Note that this essentially gives us a range of acceptable indices (in the original array) for j. All we need to do then is make sure that our segment tree only contains answers for indices corresponding to these indices, and we are done.

The complexity of this solution is \mathcal{O}(N\log N) — each index requires two segment tree queries, and is added to/removed from the segment tree at most once.
Please see the attached (commented) code for implementation details.

Code for this solution.

Solution 2 (A different approach)

Let’s forget about dynamic programming entirely and try something else.

For convenience, let pref_i = A_1 + A_2 + \ldots + A_i.
A partition of [1, N] which satisfies the 4 given conditions will be called a good partition.

We have the following observation, which is fairly easy to see:

  • The number of intervals in any good partition is a constant — in particular, it equals the number of $1$s in B.

In other words, if there are k occurrences of 1 in B, any good partition will have exactly k subintervals — each of which covers one of the $1$s.

This leads us to think of the following greedy solution:

  • Let the positions of the $1$s in B be x_1 < x_2 \ldots < x_k.
  • Consider the first interval [l_1, r_1] in a good partition.
  • If k = 0, no partition is possible and we output 0.
  • If k = 1, we are forced to have l_1 = 1 and r_1 = N, so just check if the array has a non-negative sum or not.
  • Otherwise, if k \geq 2, we must have l_1 = 1 and x_1 \leq r_1 < x_2.
  • Choose r_1 such that pref_{r_1} \geq 0 but pref_{r_1} is as low as possible.
  • If no such index r_1 exists (i.e, pref_{r_1} < 0 for every x_1 \leq r_1 < x_2), choose r_1 to be the index in that range with the lowest prefix sum.
  • Repeat this process on the suffix [r_1, N].

The idea behind this greedy solution is simple: we (greedily) attempt to maximize the answer by making the first subinterval non-negative. If we are able to do so, we try to make its sum as small as possible while still remaining non-negative so that, intuitively, later intervals have a larger sum to work with.
If there is no way to make it non-negative, we still make it as small as possible with the same logic.
It turns out that this greedy is correct!

Proof of correctness

Consider an optimal good partition [1, r_1], [l_2, r_2], \ldots, [l_k, N].
Suppose [1, r_1] is not chosen as per our greedy above. Then,

  • If there is no way to make the first interval non-negative, moving r_1 to the index with smallest prefix sum does not change the contribution of the first interval, while the sum of the second interval either remains the same or increases. Thus, the goodness of the partition doesn’t decrease.
  • Otherwise, there is some index x_1 \leq j < x_2 such that pref_j \geq 0.
    • If pref_{r_1} < 0, changing the first two intervals to [1, j] and [j+1, r_2] does not decrease the answer — the first interval now adds 1 to the answer, while the second adds either 0 or 1 depending on how much its sum changes by; either way the overall goodness doesn’t decrease.
    • If pref_{r_1} \geq 0, similar to the first case moving r_1 to the index with smallest non-negative sum keeps the contribution of the first interval the same, while increasing the sum of the second interval.

So, the [1, r_1] can be replaced by the interval found by our greedy scheme without affecting optimality of the answer. Repeating this with the remaining intervals shows that the answer constructed by the greedy is optimal.

Implementing this solution is fairly easy (especially compared to the one above), and its complexity is \mathcal{O}(N).

Code for this solution

SUBTASK 3

We now need to deal with both a large number of queries and a large array size. The first solution to subtask 2 doesn’t generalize well to multiple queries, but the second one does, so please have a look at it if you haven’t yet.

Notice that in the greedy solution, we essentially found optimal ‘jumps’ for some indices.
That is, we started at index L, then found the optimal right endpoint r_1. This process was repeated with r_1 + 1 as left endpoint to obtain r_2, and so on.
Let this optimal index found be denoted jump_i for index i.

Suppose we are able to find jump_i for every 1 \leq i \leq N with some preprocessing. That information can be used to answer queries just as described above:
follows:
Start at L, go to jump_L, go to jump_{jump_{L} + 1}, and so on, i.e, follow the jumps as long as they lie within R. For each of these jumps, we also know whether the subarray corresponding to it has non-negative sum or not, which allows us to compute the answer as we go.

However, this is still too slow — we might require \mathcal{O}(N) jumps for a given range.
This is where binary lifting (or binary jumping) will be used to optimize the jumps in order to answer each query in \mathcal{O}(\log N).

Thus, we have to do two things: compute jump_i for each i, and figure out how to use binary lifting to optimize actually jumping.

Computing jumps

Recall how we computed jumps in subtask 2: for an index i, we would like to find an index j such that j \geq i and B_i + B_{i+1} + \ldots + B_j = 1, while

  • pref_j - pref_{i-1} \geq 0 and pref_j is as small as possible satisfying this condition, or
  • if no such j exists, simply choose j such that pref_j is as small as possible.

Define prefB_i = B_1 + B_2 + \ldots + B_i. Note that the condition B_i + B_{i+1} + \ldots + B_j = 1 tells us that prefB_j = prefB_{i-1} + 1.
So, we can create buckets of indices corresponding to each value of prefB_i. The k-th bucket (say V_k) is a list of pairs (pref_i, i) corresponding to all those indices i such that prefB_i = k.

Now, for a fixed index i, computing jump_i is easy:

  • We know that we only need to look at indices among V_{prefB_{i-1} + 1}.
  • Among all these indices, we would like to find
    • Either the pair (pref_j, j) such that pref_j \geq pref_{i-1} and pref_j is as small as possible, or
    • (pref_j, j) where pref_j is as small as possible
      and then set jump_i = (j, s), where s = 1 if pref_j - pref_{i-1} \geq 0 and s = 0 otherwise.

This is easy to do with binary search if each bucket is kept sorted.

Binary lifting

Read about binary lifting here if the concept is new to you.

Suppose we know jump_i for every i. We will use binary lifting to speed up the jumping process.
Create a table up of size N \times (1 + \log N) , where up_{i, j} holds the result of jumping 2^j times when starting from i.
This table is populated as follows:

  • up_{i, 0} = jump_i
  • For j \geq 1, let up_{i, j-1} = (x_1, y_1) and up_{up_{i, j-1} + 1, j-1} = (x_2, y_2). Then, up_{i, j} = (x_2, y_1 + y_2).

Once we have this information, a query [L, R] can be answered as follows:

  • If prefB_R = prefB_{L-1}, no partition is possible so output 0.
  • Otherwise, we know any partition must have exactly k = prefB_R - prefB_{L-1} intervals.
  • Of these, the first k-1 will be ‘optimal’ intervals, i.e, chosen by our greedy process. These intervals can be found using our binary lifting table by iterating over the power of 2 and jumping 2^j for each j which is a set bit in k-1.
  • The final interval is not necessarily optimal because it must end at R, but there is only one such interval so directly compute its sum and check if it is non-negative or not.

Thus, each query has been answered in \mathcal{O}(\log N), and we are done.

IMPLEMENTATION

Note that you need to be somewhat careful while implementing, to make sure that there are no unexpected out-of-bounds errors.
A simple way to accomplish this is to redefine jump_i to mean jumping from index i+1 instead.
Then,

  • up_{i, 0} = jump_i
  • For j \geq 1, let up_{i, j-1} = (x_1, y_1) and up_{up_{i, j-1}, j-1} = (x_2, y_2). Then, up_{i, j} = (x_2, y_1 + y_2).

This avoids needing to increment the index by 1, and by defining jump_{N+1} = (N+1, 0) we get a rather clean implementation.

TIME COMPLEXITY:

\mathcal{O}((N+Q)\log N) per test case.

SOLUTIONS:

Setter's Solution
#ifdef DEBUG
#define _GLIBCXX_DEBUG
#endif
//#pragma GCC optimize("O3")
#include <bits/stdc++.h>
using namespace std;
typedef long double ld;
typedef long long ll;
const int maxN = 1e6 + 10;
ll prefA[maxN];
int prefB[maxN];
int n;
vector<pair<ll,int>> by[maxN];
const int LOG = 20;
int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
//    freopen("input.txt", "r", stdin);
    int t;  cin>>t;
    while(t--)
    {
        pair<int,int> up[maxN][LOG];
        cin >> n;
        int q;
        cin >> q;
        for (int i = 1; i <= n; i++) {
            ll a;
            cin >> a;
            prefA[i] = prefA[i - 1] + a;
        }
        for (int i = 1; i <= n; i++) {
            int b;
            cin >> b;
            assert(0 <= b && b <= 1);
            prefB[i] = prefB[i - 1] + b;
        }
        for (int i = 0; i <= n; i++) {
            by[prefB[i]].emplace_back(prefA[i], i);
        }
        int lim = prefB[n];
        for (int i = 0; i <= n; i++) {
            sort(by[i].begin(), by[i].end());
        }
        for (int i = 0; i <= n; i++) {
            if (prefB[i] < lim) {
                int pos = lower_bound(by[prefB[i] + 1].begin(), by[prefB[i] + 1].end(), make_pair(prefA[i], -1)) - by[prefB[i] + 1].begin();
                if (pos != by[prefB[i] + 1].size()) {
                    up[i][0] = make_pair(by[prefB[i] + 1][pos].second, 1);
                }
                else {
                    up[i][0] = make_pair(by[prefB[i] + 1][0].second, 0);
                }
            }
        }
        for (int z = 0; z + 1 < LOG; z++) {
            for (int i = 0; i <= n; i++) {
                if (prefB[i] + (1 << (z + 1)) <= lim) {
                    up[i][z + 1] = make_pair(up[up[i][z].first][z].first, up[i][z].second + up[up[i][z].first][z].second);
                }
            }
        }
        while (q--) {
            int l, r;
            cin >> l >> r;
            l--;
            if (prefB[l] == prefB[r]) {
                cout << 0 << '\n';
                continue;
            }
            int ans = 0;
            int steps = prefB[r] - prefB[l] - 1;
            int where = l;
            for (int t = LOG - 1; t >= 0; t--) {
                if (steps & (1 << t)) {
                    ans += up[where][t].second;
                    where = up[where][t].first;
                }
            }
            assert(prefB[where] + 1 == prefB[r]);
            ans += (prefA[where] <= prefA[r]);
            cout << ans << '\n';
        }
        
        for (int i = 0; i <= n; i++) {
            by[i].clear();
        }
    }
    return 0;
}
Tester's Solution (Kotlin)
import java.io.BufferedInputStream
import java.util.*

const val DIM_LIM = 500000
const val BILLION = 1000000000L

fun main(omkar: Array<String>) {
    val jin = FastScanner(false)
    val out = StringBuilder()
    var nSum = 0
    var qSum = 0
    repeat(jin.nextInt(1000)) {
        val n = jin.nextInt(DIM_LIM, false)
        nSum += n
        if (nSum > DIM_LIM) {
            throw InvalidInputException("constraint on sum n violated")
        }
        val q = jin.nextInt(DIM_LIM)
        qSum += q
        if (qSum > DIM_LIM) {
            throw InvalidInputException("constraint on sum q violated")
        }
        val ay = longArrayOf(0L) + LongArray(n) { jin.nextLong(-BILLION, BILLION, it == n - 1) }
        val by = intArrayOf(0) + IntArray(n) { jin.nextInt(0, 1, it == n - 1) }
        val sums = LongArray(n + 1)
        for (j in 1..n) {
            sums[j] = sums[j - 1] + ay[j]
        }
        val treeMap = TreeMap<Long, Int>()
        var j = 0
        while (j <= n && by[j] != 1) {
            treeMap[sums[j]] = j
            j++
        }
        val left = Array(n + 2) { IntArray(19) { -1 } }
        val amt = IntArray(n + 2)
        while (j <= n) {
            var k = j
            do {
                var x = treeMap.floorEntry(sums[k])?.value
                if (x == null) {
                    x = treeMap.lastEntry()!!.value
                }
                left[k][0] = x!!
                for (d in 1..18) {
                    left[k][d] = left[left[k][d - 1]][d - 1]
                    if (left[k][d] == -1) {
                        break
                    }
                }
                amt[k] = amt[x] + (if (sums[k] - sums[x] >= 0L) 1 else 0)
                k++
            } while (k <= n && by[k] == 0)
            treeMap.clear()
            while (j < k) {
                treeMap[sums[j]] = j
                j++
            }
        }
        val ones = TreeSet<Int>()
        ones.addAll((1..n).filter { by[it] == 1 })
        repeat(q) {
            val from = jin.nextInt(n, false)
            val to = jin.nextInt(from, n)
            val leftmostOne = ones.ceiling(from)
            if (leftmostOne == null || leftmostOne > to) {
                out.appendln(0)
            } else {
                var j = to
                for (d in 18 downTo 0) {
                    if (left[j][d] >= leftmostOne) {
                        j = left[j][d]
                    }
                }
                val res = (amt[to] - amt[j]) + (if (sums[j] - sums[from - 1] >= 0L) 1 else 0)
                out.appendln(res)
            }
        }
    }
    print(out)
    jin.endOfInput()
}

class InvalidInputException(message: String): Exception(message)

class FastScanner {
    private val BS = 1 shl 16
    private val NC = 0.toChar()
    private val buf = ByteArray(BS)
    private var bId = 0
    private var size = 0
    private var c = NC
    private var `in`: BufferedInputStream? = null
    private val validation: Boolean

    constructor(validation: Boolean) {
        this.validation = validation
        `in` = BufferedInputStream(System.`in`, BS)
    }

    constructor() : this(true)

    private val char: Char
        private get() {
            while (bId == size) {
                size = try {
                    `in`!!.read(buf)
                } catch (e: Exception) {
                    return NC
                }
                if (size == -1) return NC
                bId = 0
            }
            return buf[bId++].toChar()
        }

    fun validationFail(message: String) {
        if (validation) {
            throw InvalidInputException(message)
        }
    }

    fun endOfInput() {
        if (char != NC) {
            validationFail("excessive input")
        }
        if (validation) {
            System.err.println("input validated")
        }
    }

    fun nextInt(from: Int, to: Int, endsLine: Boolean = true) = nextLong(from.toLong(), to.toLong(), endsLine).toInt()

    fun nextInt(to: Int, endsLine: Boolean = true) = nextInt(1, to, endsLine)

    fun nextLong(endsLine: Boolean): Long {
        var neg = false
        c = char
        if (c !in '0'..'9' && c != '-' && c != ' ' && c != '\n') {
            validationFail("found character other than digit, negative sign, space, and newline, character code = ${c.toInt()}")
        }
        if (c == '-') {
            neg = true
            c = char
        }
        var res = 0L
        while (c in '0'..'9') {
            res = (res shl 3) + (res shl 1) + (c - '0').toLong()
            c = char
        }
        if (endsLine) {
            if (c != '\n') {
                validationFail("found character other than newline, character code = ${c.toInt()}")
            }
        } else {
            if (c != ' ') {
                validationFail("found character other than space, character code = ${c.toInt()}")
            }
        }
        return if (neg) -res else res
    }

    fun nextLong(from: Long, to: Long, endsLine: Boolean = true): Long {
        val res = nextLong(endsLine)
        if (res !in from..to) {
            validationFail("$res not in range $from..$to")
        }
        return res
    }

    fun nextLong(to: Long, endsLine: Boolean = true) = nextLong(1L, to, endsLine)
}
Editorialist's Solution
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(0); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n, q; cin >> n >> q;
        vector<int> a(n), b(n), bct(n+1);
        vector<ll> pref(n+1);
        for (int &x : a)
            cin >> x;
        for (int &x : b)
            cin >> x;

        vector<vector<array<ll, 2>>> pos(n+2);
        vector<array<array<int, 2>, 19>> jump(n+2);
        pos[0] = {{0, 0}};
        for (int i = 0; i < n; ++i) {
            pref[i+1] = a[i] + pref[i];
            bct[i+1] = b[i] + bct[i];
            pos[bct[i+1]].push_back({pref[i+1], i+1});
        }
        for (auto &vec : pos)
            sort(begin(vec), end(vec));

        for (int i = 0; i <= n; ++i) {
            auto ptr = lower_bound(begin(pos[bct[i]+1]), end(pos[bct[i]+1]), array<ll, 2>{pref[i], -1});
            if (ptr == end(pos[bct[i]+1])) {
                if (empty(pos[bct[i]+1])) jump[i][0] = array<int, 2>{n+1, 0};
                else jump[i][0] = array<int, 2>{pos[bct[i]+1][0][1], 0};
            }
            else jump[i][0] = array<int, 2>{(*ptr)[1], 1};
        }
        jump[n+1][0] = {n+1, 0};

        for (int level = 1; level < 19; ++level) {
            for (int i = 0; i <= n+1; ++i) {
                auto [x1, y1] = jump[i][level-1];
                auto [x2, y2] = jump[x1][level-1];
                jump[i][level] = {x2, y1 + y2};
            }
        }

        while (q--) {
            int L, R; cin >> L >> R;
            int segments = bct[R] - bct[L-1];
            if (segments == 0) {
                cout << 0 << '\n';
                continue;
            }

            --segments;
            int ans = 0, u = L-1;
            for (int i = 18; i >= 0; --i) {
                if (segments&(1<<i)) {
                    ans += jump[u][i][1];
                    u = jump[u][i][0];
                }
            }

            if (pref[R] - pref[u] >= 0) ++ans;
            cout << ans << '\n';
        }
    }
}