ABCC2 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iceknight1093
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Prefix sums, stacks (or) divide & conquer

PROBLEM:

For a string S, let f(S) denote the minimum number of times you can choose a subsequence of S that’s "abc", and delete either the a or the b till it’s impossible to perform any further operations.

Given S, compute the sum of f(S[L\ldots R]) across all substrings of S.

EXPLANATION:

First, we need to figure out how exactly f(S) is computed for a single string S.

Note that since each operation chooses a subsequence that’s "abc" and deletes either the a or the b, the number of operations we perform will equal the number of characters deleted from S.
So, in order to minimize the number of operations, we can instead try to maximize the length of the remaining string.
Further, the final string shouldn’t have "abc" as a subsequence, so let’s try to classify such strings.

Strings without abc

Clearly, only the rightmost occurrence of c in the string matters - if there’s an abc subsequence, there will certainly be one involving the last c.
Then, before this final c, "ab" shouldn’t appear as a subsequence - which means that the existing a’s and b’s should form a sequence that looks like bbb...bbbaaa...aaa, i.e, several occurrences of b followed by several occurrences of a.

To recap, a string S doesn’t contain "abc" as a subsequence if and only if it’s of the following form:

  • Let k be the index of the last occurrence of c in S.
    If c doesn’t exist in S, we choose k to be -1.
  • Before index k, every a should appear after every b, and there can be several c’s.
  • After index k, there will be no c’s, but the a’s and b’s can appear in any order.

Now that we know what S should look like finally, let’s see how many characters we can keep while still bringing it to this form.
Let k denote the index of the last c in S (if S doesn’t contain a c, we say k = 0 instead).
Then,

  • Everything after index k can be kept.
  • Every c in S can be kept.
  • Before S, we can keep some subsequence of a’s and b’s that’s of the form bbb...baaa...a
    Since we want to maximize the number of characters we keep, it’s best to choose the longest such subsequence.

This allows us to compute f(S) in \mathcal{O}(|S|) time: finding the last occurrence of c in S can be done in a single pass, and finding the longest subsequence of b’s followed by a’s before it can also be done in linear time.

How?

Let’s define some variables:

  • k is the index of the last c in S.
  • \text{ans} denotes the maximum length subsequence that’s of the form bbb...baaa...a before k.
  • \text{mx} denotes the maximum length subsequence that’s of the form bbb...baaa...a overall (i.e, can include indices after k as well).
  • \text{ct}_x denotes the number of occurrences of character x seen so far.

Initially, these are all 0.
Then, for each i from 1 to |S|:

  • If S_i is a, increase \text{mx} by 1, since we can always append a to the end of a subsequence of the form bbb...baaa...a.
    Also increase \text{ct}_a by 1.
  • If S_i is b, increase \text{ct}_b by 1, and then set \text{mx} = \max(\text{mx}, \text{ct}_b).
    This is because we can either keep the current maximal subsequence, or choose the subsequence consisting of all b’s which might be longer (for example, consider S = \texttt{babb}).
  • If S_i is c, set \text{ans} = \text{mx}, k = i, and increase \text{ct}_c by 1.
    This is because i is now the last c we’ve seen, and \text{mx} denotes the longest subsequence of the desired form before it, as we’ve maintained it so.

In the end, the number of characters we keep is \text{ans} + \text{ct}_c + (|S|-1-k).


This gives us a solution in \mathcal{O}(N^3) by directly applying this to every substring of S.
However, this is still too slow, and needs further optimization.

Subtask 1

For the first subtask, the constraints allow for a solution in \mathcal{O}(N^2).

Let’s look back at the algorithm to compute f(S) for a single string.
Notice that we computed it from left to right, but while doing so, maintained the current answer at all times.
More specifically, our algorithm simply recomputes the answer in constant time by appending a single character to S.

So, if you already know the answer for the substring S[L\ldots R], the answer for the substring S[L\ldots R+1] can be computed from it in constant time, instead of having to redo the whole process from the start!
This immediately brings the complexity down to \mathcal{O}(N^2), solving the first subtask.


Subtask 2

There are a few different ways to optimize the above solution.
One of them is using divide and conquer, which I’ll detail below - when dealing with tasks requiring some quantity across all subarrays like this one, it’s often possible to use divide and conquer to get a solution without too much thought (though perhaps an extra \log factor, and a potentially not-so-nice implementation)
This task can be also be solved in \mathcal{O}(N) using a stack, which can be seen in the tester’s code (and is fairly simple to implement as well).

Let’s define \text{ans}(L, R) to be the sum of answers of all substrings whose indices lie in [L, R].
(Here, the “answer” of a substring is the maximum number of characters in it that we keep.)

Let M = \frac{L+R}{2}.
\text{ans}(L, M) and \text{ans}(M+1, R) can be computed recursively.
This only leaves substrings S[i\ldots j] such that L \leq i \leq M and M \lt j \leq R.
Note that such a substring is the concatenation of S[i\ldots M] and S[M+1\ldots j], so let’s analyze what the answer of such a concatenation will be.

  • If the left part contains a c and the right part doesn’t, the answer for the concatenation is the answer of the left part, plus the length of the right part.
  • If both parts don’t contain a c, the answer is just the sum of their lengths.
    Notice that this can be merged with the first case since it’s equivalent (if the left part doesn’t contain a c, its answer is just its length).
  • If the right part does contain a c, we need some more care.

Let’s leave the last case aside for now, we’ll return to it.
Suppose we fix an index j such that there’s no c in S[M+1\ldots j].
Let a_i denote the answer for the substring S[i\ldots M].
Then, across all i \in [L, M], the sum of answers of all S[i\ldots j] for this fixed j is exactly the sum of all a_i, plus (j-M) multiplied by (M-L+1).

All the a_i values can be computed in \mathcal{O}(M-L) time, since each time we just append a single character to the front, allowing for constant time recomputation.
So, once all the a_i (and hence, their sum) are known, a single index j on the right side such that S[M+1\ldots j] doesn’t contain a c can be processed in constant time.

That leaves us with the only remaining case: when the right part does contain a c.
Here, we have two possibilities:

  • We can take the longest possible subsequence of the form bb...baa...a from the left part, and then append to it every a in the right part that appears before the last c; or
  • We can take every b from the left part, and then append to it the longest possible subsequence of the form bb...baa...a from the right part (before the last c, of course).

Either one could be optimal, let’s see when it’s which.
Suppose the longest ab-avoiding subsequence on the left has length x_1, and on the right has length x_2.
Let the number of b’s on the left be y_1, and the number of a’s on the right be y_2.
Then, we’re looking at \max(x_1+y_2, x_2+y_1), along with of course the number of c’s and all the characters after the last c.

Notice that x_1 + y_2 \geq x_2 + y_1 \iff x_1 - y_1 \geq x_2 - y_2.
Since we’ve fixed j, x_2 - y_2 is a constant, and x_1 - y_1 depends only on the index i on the left!

So, compute the values of x_1 and y_1 for every index L \leq i \leq M. This can be done in linear time.
Sort all the indices by x_1 - y_1.
Now, when processing index j, you can binary search on the sorted list to find the range of indices where x_1-y_1 \geq x_2-y_2, which will be some suffix.
If you know the sum of x_1 on this suffix, and the sum of y_1 on the complementary prefix, you can compute the sum of answers across all i.

So, recursive steps aside, the entire range can be solved for in \mathcal{O}((R-L)\log(R-L)) time, since we do a single sort and a single binary search at each step.
The \log can be optimized out using the fact that the values are “small”, but it isn’t necessary to do so.

Our time complexity is T(N) = 2T(\frac{N}{2}) + \mathcal{O}(N\log N), which is bounded above by \mathcal{O}(N\log ^2 N) and so is fast enough.

TIME COMPLEXITY:

\mathcal{O}(N) \sim \mathcal{O}(N\log^2 N) per testcase, depending on choice of algorithm.

CODE:

Subtask 1 (Python)
for _ in range(int(input())):
    n = int(input())
    s = input()
    ans = 0
    for i in range(n):
        cs, mx, bct = 0, 0, 0
        curans = 0
        lastc = i-1
        for j in range(i, n):
            if s[j] == 'c':
                curans = mx
                lastc = j
                cs += 1
            elif s[j] == 'b':
                bct += 1
                mx = max(mx, bct)
            else:
                mx += 1
            keep = curans + cs + j - lastc
            ans += j-i+1 - keep
    print(ans)
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;

// #define IGNORE_CR

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
#ifdef IGNORE_CR
            if (c == '\r') {
                continue;
            }
#endif
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            assert(!isspace(buffer[pos]));
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e5);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 2e5);
        in.readEoln();
        sn += n;
        string s = in.readString(n, n, "abc");
        in.readEoln();
        vector<int> a(n + 1), b(n + 1);
        for (int i = 0; i < n; i++) {
            a[i + 1] = a[i] + (s[i] == 'a');
        }
        for (int i = n - 1; i >= 0; i--) {
            b[i] = b[i + 1] + (s[i] == 'b');
        }
        vector<int> ab(n + 1);
        for (int i = 0; i < n + 1; i++) {
            ab[i] = a[i] + b[i];
        }
        long long ans = 0;
        int last = -1;
        long long pref = 0;
        stack<pair<int, int>> st;
        st.emplace(-1, -1);
        long long sum = 0;
        vector<long long> f(n + 1);
        for (int i = 0; i < n; i++) {
            if (s[i] == 'c') {
                last = i;
            }
            while (st.top().second >= ab[i]) {
                int k = st.top().second;
                sum -= k * 1LL * st.top().first;
                st.pop();
                sum += k * 1LL * st.top().first;
            }
            sum += (i - st.top().first) * 1LL * ab[i];
            st.emplace(i, ab[i]);
            f[i] += sum;
            pref += a[i];
            f[i] -= pref;
            f[i] -= b[i + 1] * 1LL * (i + 1);
            if (last != -1) {
                ans += f[last];
            }
        }
        cout << ans << '\n';
    }
    in.readEof();
    assert(sn <= 2e5);
    return 0;
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    
    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        string s; cin >> s;

        auto solve = [&] (const auto &self, int L, int R) -> ll {
            // solve(L, R) = sum of maximum number of elements that *remain* across all subarrays contained in [L, R)
            if (L+1 == R) return 0ll;
            int M = (L+R)/2;
            ll ret = self(self, L, M) + self(self, M, R);
            
            bool c_seen = false;
            int act = 0, bct = 0, mxlen = 0, curmx = 0, cura = 0;
            ll leftsum = 0;
            vector<array<int, 2>> vals;
            for (int i = M-1; i >= L; --i) {
                c_seen |= s[i] == 'c';
                bct += s[i] == 'b';
                if (c_seen) {
                    act += s[i] == 'a';
                    mxlen = max(mxlen, act);
                    mxlen += s[i] == 'b';
                }
                leftsum += mxlen;
                cura += s[i] == 'a';
                curmx = max(curmx, cura);
                curmx += s[i] == 'b';
                vals.push_back({curmx, bct});
            }
            sort(begin(vals), end(vals), [] (auto a, auto b) {
                return a[0] - a[1] < b[0] - b[1];
            });

            int sz = vals.size();
            vector<ll> pref(sz), suf(sz);
            for (int i = 0; i < sz; ++i) {
                pref[i] = vals[i][1];
                suf[i] = vals[i][0];
                if (i) pref[i] += pref[i-1];
            }
            for (int i = sz-2; i >= 0; --i) {
                suf[i] += suf[i+1];
            }

            c_seen = false;
            act = bct = mxlen = cura = curmx = 0;
            for (int i = M; i < R; ++i) {
                c_seen |= s[i] == 'c';
                bct += s[i] == 'b';
                act += s[i] == 'a';
                mxlen = max(mxlen, bct);
                mxlen += s[i] == 'a';
                if (!c_seen) {
                    ret += leftsum;
                    continue;
                }
                if (s[i] == 'c') curmx = mxlen, cura = act;

                // Solve for (curmx, cura) = (x2, y2)
                auto pos = upper_bound(begin(vals), end(vals), array{curmx, cura}, [] (auto a, auto b) {
                    return a[0] - a[1] < b[0] - b[1];
                }) - begin(vals);

                if (pos != sz) ret += suf[pos] + 1ll*(sz-pos)*cura;
                if (pos) ret += 1ll*pos*curmx + pref[pos-1];
            }
            return ret;
        };
        ll ans = -solve(solve, 0, n);
        for (int i = 0; i < n; ++i) ans += 1ll*(i+1)*(n-i);
        for (int i = 0; i < n; ++i) if (s[i] == 'c')
            ans -= 1ll*(i+1)*(n-i);
        int lastc = -1;
        for (int i = 0; i < n; ++i) {
            if (s[i] == 'c') lastc = i;
            ans -= 1ll*(i-lastc)*(lastc+1);
            ans -= 1ll*(i-lastc)*(i-lastc+1)/2;
        }
        cout << ans << '\n';
    }
}

/**
 * (x1, y1) on left
 * (x2, y2) on right
 * max(x1+y2, ,x2+y1)
 * 
 * fix (x2, y2)
 * across all (x1, y1), 
 * 
 * x1+y2 > x2+y1 -> x1-y1 > x2-y2
 * so, sum(x1) and count(x1) for some suffix of (x1-y1)
 * 
 * x1+y2 <= x2+y1 -> x1-y1 <= x2-y2
 * sum(y1) and count(y1) for some prefix of (x1-y1)
 * 
 */