P5BARH - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: pols_agyi_pols
Tester: kingmessi
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

Dynamic Programming, (optionally) divide and conquer

PROBLEM:

You have a binary array A of length N.
In one move, you can increase any element of A by 1.
Define \text{cost}(A) to be the minimum number of moves needed to ensure that A_i \ne A_{i-1} for all indices i.
Compute the sum of costs across all subarrays of A.

EXPLANATION:

As seen in the solution to the easy version (where we only had to compute \text{cost}(A)), the main observation is that in an optimal solution, all resulting elements will remain \le 2; after which we were able to write a fairly simple DP.
We now need to extend this to summing across all subarrays.

There are a couple of different ways of doing this, here’s one of them.

When summing up some quantity across all subarrays, an often-helpful technique is divide and conquer.
That ism, let’s define a function f(L, R) that returns the sum of costs of all subarrays A[l, r] such that L \le l \le r \le R, i.e. all subarrays with both endpoints contained in [L, R].
To compute this, we do the following:

  • Let M = \frac{L+R}{2} denote the midpoint of the range.
  • Recursively call f(L, M) and f(M+1, R) to solve for subarrays that lie entirely in the left/right halves.
  • That leaves only subarrays that ‘cross’ the middle; which have to be solved for separately.

If we’re able to do this, the answer is of course just f(1, N).
The hard part here is clearly the last one, where we figure out the answer for all subarrays crossing the middle; if we’re able to do this in \mathcal{O}(T(R-L+1)) time then we obtain an overall time complexity of \mathcal{O}(T(R-L+1)\log (R-L+1)) with this solution, which is easily fast enough if we can manage T(N) = \mathcal{O}(N) or T(N) = \mathcal{O}(N\log N) or similar.

So, figuring out how to solve for subarrays crossing the middle quickly is the only thing we need to do now.


Observe that every subarray crossing the middle can be obtained by pasting together a subarray ending at M and a subarray starting at M+1.

In an ideal solution, we’d be able to obtain a solution for [l, r] by taking solutions for [l, M] and [M+1, r] together; which almost works but fails in exactly one spot: we have no idea if A_M and A_{M+1} are equal are not; and if they are, extra operations might be needed.

Luckily, it’s possible for us to get around this limitation by just storing that as extra information!
That is, let’s define dp_{l, x, y} to be the minimum number of operations needed on the elements A_l, A_{l+1}, \ldots, A_M such that:

  1. No two adjacent elements are equal after the operation; only considering the range [l, M].
  2. A_l = x
  3. A_M = y.

Note that this is pretty much the same DP we used to solve the easy version, just with an additional parameter.
Transitions are again very similar - in fact, different values of y don’t even interact with each other so it’s more like three separate DP tables.
Either way, i’s easy to compute this DP for all L \le l \le R in linear time.

A similar DP can be computed for elements on the right; which we can call dp_{r, x, y} to be the answer for the segment [M+1, r] such that A_r = x and A_{M+1} = y.

Let’s also define, for each L \le i \le R, the value \text{opt}_i to be the minimum of dp_{i, x, y} across all choices of x and y.
Essentially, this is the answer for just the segment from i to the middle if we don’t care about anything else.


With these DP tables in hand, let’s try to compute the answer.
Say we fix a left endpoint l, and try to find the sum of costs across all M+1 \le r \le R.
There are then two possibilities.

Case 1: dp_{l, x, y} is minimized for (at least) two different values of y.
Then for any choice of r, we can always obtain a cost of \text{opt}_l + \text{opt}_r.
This is because we can choose whichever value of A_{M+1} achieves \text{opt}_r; and then take whichever value of A_M achieves \text{opt}_l that’s not equal to A_{M+1} (which will exist since we have at least two such choices).
So, in this case we just need to sum up \text{opt}_l + \text{opt}_r across all r, which is easy to do in constant time if we know the sum of \text{opt}_r and the count of r (which can both be precomputed and stored).

Case 2: dp_{l, x, y} is minimized for only one value of y.
Let y_0 be this unique choice that achieves \text{opt}_l.
Observe that the only ‘bad’ choices of r are those for which dp_{r, x, y} is minimized for only y_0 as well: for any other r, it’s possible to achieve \text{opt}_l + \text{opt}_r by choosing A_{M+1} to be different from y_0.

As for the ‘bad’ choices of r: for each of them, the answer surely cannot be equal to \text{opt}_l + \text{opt}_r, and must be larger.
However, the answer also won’t be too far off from this value: in particular, it will never exceed \text{opt}_l + \text{opt}_r + 2.

Proof

Let’s make \text{opt}_l + \text{opt}_r moves first, to obtain a situation where A_M = A_{M+1} but all other adjacent pairs differ.

Further, recall that all elements are currently \le 2.
Now,

  • If A_M = 1 or A_M = 2, simply increase A_M twice.
    This will make it \gt 2, at which point it’s guaranteed to not equal either A_{M-1} or A_{M+1}.
  • If A_M = 0, then:
    • If A_{M-1} = 1, make A_M equal to 2, utilizing two increases.
    • If A_{M-1} = 2, make A_M equal to 1, utilizing one increase.

In any case, we’re able to obtain an array with adjacent pairs differing using no more than \text{opt}_l + \text{opt}_r + 2 operations, as claimed.
(Note that the above constructions may not be optimal: they only prove that the answer is no more than \text{opt}_l + \text{opt}_r + 2.)

So, summing up across all choices of r can be done as follows: initialize with just the sum of \text{opt}_l + \text{opt}_r across all r (just as in the first case); then if c_1 is the number of ‘bad’ r for which the answer is \text{opt}_l + \text{opt}_r+1 and c_2 is the count of the rest, add c_1 + 2c_2 to the answer.

The question now is: how to compute c_1?
To do that, observe that we want to make exactly one extra move compared to the optimal, which must thus lie entirely on the left side or entirely on the right side.
So,

  • If there exists y\ne y_0 and some x such that dp_{l, x, y} = \text{opt}_l + 1, then we are always able to perform one extra operation on the left and end up with A_M \ne y_0.
    This means that for ‘every’ bad r, the answer is \text{opt}_l + \text{opt}_r + 1.
    So, in this case, we just have c_1 equal to the count of ‘bad’ r, and c_2 equal to 0.
  • If no such y exists, then we’re at the mercy of the right side.
    In this case, c_1 equals the number of ‘bad’ r such that one extra operation will allow for A_{M+1} \ne y_0; while c_2 will equal everything else.
    This count can be precomputed and stored as well.

Thus, with appropriate precomputation (that takes linear time), it’s possible to take a single l and process all M+1 \le r \le R in constant time.

Putting this into the divide-and-conquer, we obtain a solution in \mathcal{O}(N\log N) (though with a somewhat high constant factor, fairly close to an extra log - but fast enough nonetheless.)


There is also an alternate solution, with the rough idea being as follows:
Try to use a greedy algorithm to compute the cost of a fixed array. For this, break the array into blocks of equal elements, and then try to solve for each block from left to right.

  • Consider the leftmost block. At least half of its elements (rounded down) must be incremented.
  • If the block has even length, it’s always possible to increment exactly half the elements, while ensuring no conflicts with the next block.
    The remaining blocks can then be solved recursively.
  • If the block has odd length and the element is 0, again only half the elements (rounded down) need to be incremented (those at indices 2, 4, 6, \ldots) and there will be no conflict.
    Again, the remaining blocks can be solved recursively.
  • If the block has odd length but the element is 1, incrementing the elements at indices 2, 4, 6, \ldots might lead to a conflict with the next block; in case the next block requires its first element to be incremented.
  • Further analyzing when this is forced, it can be observed that whether one extra move is needed or not depends entirely on the nearest block (after the first) that has odd length - specifically, whether it contains a 0 or a 1.
    (If it’s a 0, no extra increment is needed; if it’s a 1, one increment is needed.)
    After this odd block, the remaining ones can be solved recursively again.

The above observations allow for a dynamic programming solution.
Define dp_R to be the sum of answers of all subarrays ending at R; so the answer is dp_1 + \ldots + dp_N.
To compute dp_R quickly, consider the block of elements ending at R.
Depending on the length parity of the block, and whether A_R equals 0 or 1, the value of dp_R can be quickly based on some dp_i for an appropriate i \lt R, as well as a bit of algebra.
For details, see the tester’s code below.
This solution runs in \mathcal{O}(N) time.

TIME COMPLEXITY:

\mathcal{O}(N \log N) or \mathcal{O}(N) per testcase.

CODE:

Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector a(n, 0);
        for (int &x : a) cin >> x;

        ll ans = 0;
        vector dp(n, array<array<int, 3>, 3>());
        vector opt(n, 0);
        auto solve = [&] (const auto &self, int L, int R) -> void {
            if (L == R) return;
            int mid = (L+R)/2;
            self(self, L, mid);
            self(self, mid+1, R);

            for (int i = L; i <= R; ++i) for (int x = 0; x < 3; ++x) for (int y = 0; y < 3; ++y) {
                dp[i][x][y] = 5*n + 10;
                opt[i] = 5*n + 10;
            }

            for (int y = a[mid]; y < 3; ++y) {
                dp[mid][y][y] = y - a[mid];
                for (int i = mid-1; i >= L; --i) {
                    for (int x = a[i]; x < 3; ++x) {
                        for (int x2 = 0; x2 < 3; ++x2) if (x != x2) {
                            dp[i][x][y] = min(dp[i][x][y], dp[i+1][x2][y] + x - a[i]);
                        }
                    }
                }
            }
            for (int y = a[mid+1]; y < 3; ++y) {
                dp[mid+1][y][y] = y - a[mid+1];
                for (int i = mid+2; i <= R; ++i) {
                    for (int x = a[i]; x < 3; ++x) {
                        for (int x2 = 0; x2 < 3; ++x2) if (x != x2) {
                            dp[i][x][y] = min(dp[i][x][y], dp[i-1][x2][y] + x - a[i]);
                        }
                    }
                }
            }

            ll sm_right = 0;
            array<int, 3> goodct{}, allct{};
            for (int i = R; i >= L; --i) {
                array<int, 3> who{};
                for (int x = 0; x < 3; ++x) for (int y = 0; y < 3; ++y) {
                    if (dp[i][x][y] < opt[i]) {
                        opt[i] = dp[i][x][y];
                        who = {0, 0, 0};
                        who[y] = 1;
                    }
                    if (dp[i][x][y] == opt[i]) who[y] = 1;
                }

                if (i > mid) sm_right += opt[i];
                else ans += sm_right + 1ll*opt[i]*(R-mid);
                
                if (who[0] + who[1] + who[2] == 1) {
                    int y0 = 0;
                    while (who[y0] == 0) ++y0;
                    
                    bool plusone = false;
                    for (int x = 0; x < 3; ++x) for (int y = 0; y < 3; ++y) {
                        if (y == y0) continue;
                        plusone |= dp[i][x][y] == opt[i] + 1;
                    }
                    
                    if (i > mid) {
                        allct[y0] += 1;
                        if (plusone) goodct[y0] += 1;
                    }
                    else {
                        if (!plusone) ans += 2*allct[y0] - goodct[y0];
                        else ans += allct[y0];
                    }
                }
            }
        };
        solve(solve, 0, n-1);
        cout << ans << '\n';
    }
}
Tester's code (C++)
// Har Har Mahadev
#include<bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp> // Common file
#include <ext/pb_ds/tree_policy.hpp>
#define ll long long
#define int long long
#define rep(i,a,b) for(int i=a;i<b;i++)
#define rrep(i,a,b) for(int i=a;i>=b;i--)
#define repin rep(i,0,n)
using namespace std;

void solve()
{
    int n; cin >> n;
    vector<int> a(n);
    for(int i = 0;i < n;i++){
        cin >> a[i];
    }
    
    vector<int> fr(n,1);

    auto pf = a;
    rep(i,1,n){
        pf[i] += pf[i-1];
    }

    
    vector<int> nx(n,n); //next position with odd freq
    vector<int> nd(n,n); //next position with diff ele
    for(int i = n-2;i >= 0;i--){
        if(a[i] == a[i+1])fr[i] = fr[i+1]+1;
        if(a[i] != a[i+1]){
            if(fr[i+1]&1)nx[i] = i+1;
            else nx[i] = nx[i+1];
            nd[i] = i+1;
        }
        else nx[i] = nx[i+1],nd[i] = nd[i+1];
    }
    
    vector<int> dp1(n+1);
    for(int i = n-2;i >= 0;i--){
        if(a[i] == a[i+1]){
            dp1[i] = 1 + dp1[i+2] + (n-(i+2));
        }
        else dp1[i] = dp1[i+1];
    }
    
    vector<int> dp(n+1); 
    
    for(int i = n-2;i >= 0;i--){
        if(nd[i] == n || nd[nd[i]] == n)continue;
        if(a[i] == 0){
            dp[i] = dp[i+1];
            continue;
        }
        if(nd[i] == nx[i]){
            dp[i] = dp[nd[i]];
            continue;
        }
        if(!((nd[i]-i)&1)){
            dp[i] = dp[nd[i]];
            continue;
        }



        if(nx[i] == n){
            int res = nd[nd[i]];
            int sm = pf.back()-pf[res-1];
            dp[i] = (sm+1)/2;
        }
        else if(a[nx[i]] == 1){
            int res = nd[nd[i]];
            int res1 = nd[nx[i]];
            int sm = pf[res1-1]-pf[res-1];
            dp[i] = (sm+1)/2 + dp[res1] + (n-res1);
        }
        else{
            int res = nd[nd[i]];
            int sm = pf[nx[i]] - pf[res-1];
            dp[i] = dp[nx[i]] + (sm+1)/2;
        }   
    }

    
    long long ans = 0;
    for(auto &x : dp){
        ans += x;
    }
    for(auto &x : dp1){
        ans += x;
    }
    
    cout << ans << "\n";
    
}

signed main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    #ifdef NCR
        init();
    #endif
    #ifdef SIEVE
        sieve();
    #endif
    int t; cin >> t;
    while(t--)
        solve();
    return 0;
}