### PROBLEM LINK:

**Author:** Lalit Kundu

**Tester:** Hiroto Sekido

**Editorialist:** Kevin Atienza

### DIFFICULTY:

SIMPLE

### PREREQUISITES:

Dynamic programming, preprocessing, binary search, cumulative sums

### PROBLEM:

You are given a string S of length N consisting only of $0$s and $1$s. You are also given an integer K.

You have to answer Q queries. In the $i$th query, two integers L and R are given. Then you should print the number of substrings of S[L, R] which contain at most K $0$s and at most K $1$s where S[L, R] denotes the substring from $L$th to $R$th characters of the string S.

In other words, you have to count number of pairs (i, j) of integers such that L \le i \le j \le R such that no character in substring S[i, j] occurs more than K times.

### QUICK EXPLANATION:

Let ext{far}* be the last index j such that S[i, j] has at most K $0$s and at most K $1$s. Then for a given query (L, R), the number of *valid* strings starting at index i is \min(R, ext{far}*)-i+1 (for L \le i \le R). Therefore, the answer for the query (L, R) is the following:

Note that ext{far}* never decreases, so we can use binary search to find the maximum index k such that ext{far}* \le R. The sum then becomes:

Each of these is solvable in closed form, except possibly for range sums of ext{far}*. But for that, we can simply compute *cumulative sums* of ext{far}*.

The algorithm runs in O(N + Q \log N), but can be sped up to O(N + Q) by getting rid of binary search. We’ll describe this below.

### EXPLANATION:

We will begin by describing a slow solution, and then work on improving it incrementally. We will call a string **valid** if there are at most K $0$s and at most K $1$s.

# Slow solution

First, we can simply test each substring and count all those that are valid. This can be accomplished, for example, with the following:

```
def answer_query(L, R):
answer = 0
for i in L...R
for j in i...R
# try the substring S[i, j]
count0 = count1 = 0
for k in i...j
if S[k] == '1'
count1 += 1
else
count0 += 1
if count0 <= K and count1 <= K
answer += 1
return answer
```

How fast does this go? This runs proportionally to the sum of the length of all substrings of S[L, R], and one can easily compute it to be approximately \frac{(R-L+1)^3}{6}. Since R-L+1 is at most N, in the worst case the running time is approximately \frac{N^3}{6} steps. This won’t pass any of the subtasks!

# Breaking early

The above brute-force solution can be optimized to pass subtask 2, by the observation that valid strings are at most 2K in length. Therefore, we can simply check all substrings that are at most 2K in length, and there are O(NK) of them:

```
def answer_query(L, R):
answer = 0
for i in L...R
# process only strings up to min(2K-1,R)
for j in i...min(i+2*K-1, R)
# try the substring S[i, j]
count0 = count1 = 0
for k in i...j
if S[k] == '1'
count1 += 1
else
count0 += 1
if count0 > K or count1 > K:
break
if count0 <= K and count1 <= K
answer += 1
return answer
```

Notice that aside from reducing the limit of j to \min(2K+1,R), we have also added a `break`

statement, because when either `count0`

or `count1`

exceeds K, then there is no more point in proceeding to inspect the rest of the letters, since we already know that the string wouldn’t be valid.

Thus the algorithm now runs in O(NK\min(N,K)) time per query and O(QNK\min(N,K)) time overall, and passes subtask 2!

# Dynamic programming

In fact, we can extend the argument we used about the `break`

statement! Notice that the substring of a valid string is also valid, and a superstring of any invalid string is also invalid! Therefore, we can break out of the j loop once we encounter an invalid string! However, this alone won’t improve the code in the worst case, because it might be the case that we don’t encounter an invalid string, or we encounter such a string late enough.

However, we can exploit even more properties of the substrings we are inspecting! Notice that to check whether a substring S[i, j] is valid or not, we only need to count the $0$s and $1$s in it. However, this can be computed easily once we know these counts for the substring S[i, j-1]! Specifically, as we process the substrings S[i, j] for a fixed i and increasing j, we only need to increment `count0`

or `count1`

depending on S*:

```
def answer_query(L, R):
answer = 0
for i in L...R
count0 = count1 = 0
for j in i...min(i+2*K-1, R)
# try the substring S[i, j]
if S[j] == '1'
count1 += 1
else
count0 += 1
if count0 <= K and count1 <= K
answer += 1
else
break
return answer
```

Now, this runs in O(N\min(N,K)) time per query, and O(QN\min(N,K)) time overall!

However, using this might still not pass subtask 1, because there are up to 10^5 queries. Thankfully, in subtask 1, N is at most 100, so there are only at most N(N-1)/2 substrings S[L, R]. Thus, we can simply precompute the answer for all those substrings, and then answer the Q queries with a simple lookup. This approach should run in O(N^3\min(N,K)+Q) and should be able to pass subtask 1!

# Walking algorithm

There is still a way to optimize the above! Remember what we said above, that the substring of a valid string is also valid, and a superstring of any invalid string is also invalid. Therefore, if S[i, j] is valid, then S[i+1, j] is also valid!. Therefore, when we process the next i, we don’t have to start j from i any more, because we already know many strings are valid from the previous i. Specifically, if j' is the last j such that S[i, j] is valid, then S[i+1, j'] is also valid, so we can simply start iterating j from j'+1 onwards. What’s more, we can compute `count0`

and `count1`

of S[i+1, j'+1] from S[i, j'+1] by simply decrementing one of them depending on S*! This is illustrated in the following code:

```
def answer_query(L, R):
answer = 0
j = L
count0 = count1 = 0
if S[L] == '1'
count1 += 1
else
count0 += 1
for i in L...R
while j <= R and count0 <= K and count1 <= K:
j += 1
if j > R
break
if S[j] == '1'
count1 += 1
else
count0 += 1
# at this point, we know S[i, j-1] is valid but S[i, j] is invalid
answer += j - i
# decrement
if S* == '1'
count1 -= 1
else
count0 -= 1
return answer
```

Now, how fast does this go? Notice that there are still nested loops. However, every time the inner loop iterates, j increases by 1. Therefore, the inner loop runs in at most R-L steps, or O(N). Therefore, the whole algorithm runs in O(N) time per query, and O(QN) time overall! This should be able to pass subtask 1, 2 and 3.

# Preprocessing and binary search

The previous algorithm is too slow for subtask 4, because it still takes O(N) time per query. In fact, this is probably the fastest we can do without some sort of *preprocessing*, because at the very least we have to *read* the string S[L, R] to compute the answer, and this already takes O(N) time. Thus, we will try to speed up the algorithm by preprocessing.

First, note that the crucial part of the previous solution is finding, for each i, the first j such that S[i, j] is invalid, or j > R. However, by ignoring first the "or j > R" part, we see that the first such j for each i only depends on the string S! For a given i, let’s denote that j by ext{far}* (if you read the quick explanation above, note that this ext{far} is different from the ext{far} there. Specifically, this one is larger by exactly 1.). Thus, we can try to precompute ext{far} at the beginning:

```
far[1...N]
def precompute():
count0 = count1 = 0
j = 1
if S[1] == '1'
count1 += 1
else
count0 += 1
for i in 1...N
while j <= N and count0 <= K and count1 <= K:
j += 1
if j > N
break
if S[j] == '1'
count1 += 1
else
count0 += 1
# at this point, we know S[i, j-1] is valid but S[i, j] is invalid
far* = j
# decrement
if S* == '1'
count1 -= 1
else
count0 -= 1
```

and use that for our queries:

```
def answer_query(L, R):
answer = 0
for i in L...R
j = min(far*, R+1) # R+1 is the first j such that j > R
answer += j - i
return answer
```

The precomputation works similarly to the previous code and runs in O(N), but the queries still takes O(N) time each. But now, the `answer_query`

code is much simpler, and in fact can be expressed by the following mathematical expression

We can now compute such a sum using some simple manipulations:

However, we still got this nasty \min term which we need to take care of. Thankfully, we can use the fact that ext{far}* **never decreases**, to know that ext{far}* will be \le R at the beginning, and then as i increases it will eventually exceed R, and once it does, it stays greater than R. Thus, it makes sense to find the last k such that ext{far}[k] \le R, so the above expression becomes

Now, we are almost down to O(1) computation, aside from two things: finding the k and a range sum for ext{far}*. But these are simple. First, k can be computed with binary search, as the index such that ext{far}[k] \le R < ext{far}[k+1], because ext{far}* is monotonically nondecreasing. Also, to compute range sums for ext{far}*, one can simply use *cumulative sums* or *prefix sums*: Let ext{sumfar}* be the sum of the $ ext{far}$s until the $i$th index. Then the sum ext{far}* + ext{far}[i+1] + \cdots + ext{far}[j] is simply ext{sumfar}[j] - ext{sumfar}[i-1]!

These are illustrated in the following code:

```
far[1...N]
sumfar[0...N]
def precompute():
# precompute far
count0 = count1 = 0
j = 1
if S[1] == '1'
count1 += 1
else
count0 += 1
for i in 1...N
while j <= N and count0 <= K and count1 <= K:
j += 1
if j > N
break
if S[j] == '1'
count1 += 1
else
count0 += 1
far* = j
# decrement
if S* == '1'
count1 -= 1
else
count0 -= 1
# precompute sumfar
sumfar[0] = 0
for i in 1...N
sumfar* = sumfar[i-1] + far*
def answer_query(L, R):
# binary search to find k such that far[k] <= R < far[k+1]
# we maintain the invariant far[k1] <= R < far[k2]
k1 = L-1
k2 = R+1
while k2 - k1 > 1:
km = (k1 + k2) / 2 # here, "/" floor division
if far[km] <= R
k1 = km
else
k2 = km
k = k1 # k is now equal to k1 because k2 - k1 = 1 and far[k1] <= R < far[k2]
answer = sumfar[k] - sumfar[L-1] + (R-k)*(R+1) - (R*(R+1)/2 - L*(L-1)/2)
return answer
```

Using this, one can see that precomputation still runs in O(N) time, and queries now run in O(\log N) time each, due to the binary search. Thus, the overall algorithm runs in O(N + Q \log N) time, which comfortably passes all the subtasks!

Be careful with overflows! Use the right data type for this.

# Bonus: ext{far}* and ext{raf}*

The above solution already works, but we will introduce a final optimization here. Specifically, we will try to improve our algorithm to compute the k in each query.

Note that k is the largest index such that ext{far}[k] \le R or k < L. The key idea is that, by ignoring first the "or k < L" part, we see that k is only dependent on R! Thus it would be nice if we are able to precompute the $k$s for all possible $R$s, and in fact it is easy to do so.

Let’s define a similar array ext{raf}, where ext{raf}[R] is the smallest index i such that ext{far}* > R. Using this array, one can compute k simply as \max(L, ext{raf}[R])-1 (we’ll leave this to the reader to see why). Now, how do we compute all the $ ext{raf}s? The idea is that ext{raf} is essentially the *reverse* of ext{far}$, only that the direction is to the left rather than to the right. Thus, a similar O(N) time walking algorithm can be used to compute it. We will leave this as an exercise for the reader, however, because we will show here a different way to compute it in O(N) time, by exploiting the relationships between ext{far} and ext{raf}. See the following pseudocode for details:

```
far[1...N]
raf[1...N]
sumfar[0...N]
def precompute():
# precompute far
.....
# precompute sumfar
....
# precompute raf
# initialize
for i in 1...N
raf* = -1
# we know that far*-1 < far*, so if j = far*-1,
# then raf[j] must be <= i, (because raf[j] is the least such i)
# we process each i in decreasing order to guarantee that we assign the least such i
for i in N...1 by -1
raf[far*-1] = i
# for all the raf*'s that we didn't encounter, set its value to raf[i+1]
# because raf* <= raf[i+1]
for i in N...1 by -1
if raf* == -1
raf* = raf[i+1]
def answer_query(L, R):
k = max(L, raf[R])-1
answer = sumfar[k] - sumfar[L-1] + (R-k)*(R+1) - (R*(R+1)/2 - L*(L-1)/2)
return answer
```

Finally, one can now see that it runs in O(1) time per query, and O(N + Q) time overall!

### Time Complexity:

O(N + Q)