ONETOTHREE - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iceknight1093
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

Simple

PREREQUISITES:

None

PROBLEM:

You’re given an array A with elements between 1 and 3.
In one move, you can choose an index i and set A_i := 4 - A_i, but only if A_{i-1} + A_{i+1} = 4 as well.

Find the minimum possible sum of A after performing this operation several times.

EXPLANATION:

The given operation converts 1 to 3 and vice versa, and doesn’t change 2's.
So, if A_i = 2 initially, it will always remain that way; and otherwise since we want to minimize the sum, it seems best to convert 3's to 1's.

Since the 2's will remain as they are, they essentially split up the array into smaller blocks between adjacent twos (where each block contains only ones and threes); and each such block is independent, so we can just solve for a single block and repeat for every other block.

So, suppose we have a block of elements that are all 1 or 3, and there’s a 2 at both its left end and right end.

Since 2+1 \neq 4 and 2+3 \neq 4, the twos will never contribute to any moves as long as the block has size \geq 2.
For now, let’s deal with only blocks with size \geq 2.

Since we have ones and threes, the only way to obtain a sum of 4 is 1 + 3.
In particular, if there are no ones in the block, it’s not possible to perform any moves at all; so we can just ignore the block.
This means we only need to care about blocks that contain both threes and ones.

Now, in general our moves look like [1, x, 3] \to [1, 4-x, 3] or [3, x, 1] \to [3, 4-x, 1].
In particular, to turn a 3 into a 1, there must be both another 3 and a 1 adjacent to it.

Now, let’s look at a contiguous segment of 3's, something like [\ldots, x, 3, 3, 3, 3, y, \ldots] where x and y are 1 or 2.
We know that either x or y must be 1 for sure (since we’re in the case where the block contains both threes and ones).

Suppose x = 1. Then, by repeatedly performing the operation [1, 3, 3] \to [1, 1, 3], this segment of threes can be brought down to a single three.
Similarly, if y = 1 we can repeat [3, 3, 1] \to [3, 1, 1] instead to achieve the same thing.

However, it’s not possible to eliminate the final three; since as we noted above, whenever a 3 is deleted, there must’ve been another one adjacent to it initially.

Further, it’s not possible to “merge” two segments of threes into a single segment; since two different segments will be separated by at least one 1, and the operation to merge them together must look like [3, 1, 3] \to [3, 3, 3] (which isn’t allowed, 3+3\neq 4).

So, the best we can do is leave exactly one three in each contiguous segment of threes - everything else can be turned into a 1.
This is optimal because, as noted above, we have no way to change the number of segments of threes.

Finally, let’s look at blocks of length 1.
In this case, the block must be either [2, 1, 2] or [2, 3, 2]; and the latter can be transformed to [2, 1, 2] so it’s optimal to do so.
Note that a length-1 block at the start or end of the array doesn’t have this option, i.e, if A = [3, 2, \ldots] you can’t transform it to [1, 2, \ldots].


While the above solution might seem somewhat caseworky, there’s in fact a very simple implementation.
Observe that all our moves only transformed 3's to 1's; and when doing so we pretty much move continuously (as in, each block of threes gets reduced to length 1 either moving right-to-left, or left-to-right).

So, the following greedy algorithm always gives the correct answer:

  • First, for each i from 2 to N-1:
    • If A_{i-1} + A_{i+1} = 4, set A_i := \min(A_i, 4 - A_i).
  • Then, for each i from N-1 to 2 in decreasing order:
    • If A_{i-1} + A_{i+1} = 4, set A_i := \min(A_i, 4 - A_i).

That is, quite simply, greedily replace elements front-to-back and then back-to-front, then print the sum of the final array!

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;

void solve(istringstream cin) {
    int n;
    cin >> n;
    vector<int> a(n);
    for (int i = 0; i < n; i++) {
        cin >> a[i];
    }
    for (int i = 1; i < n - 1; i++) {
        if (a[i - 1] == 2 && a[i] == 3 && a[i + 1] == 2) {
            a[i] = 1;
        }
    }
    int ans = accumulate(a.begin(), a.end(), 0);
    for (int i = 0, j = 0; i < n; i = j) {
        if (a[i] <= 2) {
            j = i + 1;
            continue;
        }
        while (j < n && a[j] == 3) {
            j++;
        }
        if ((i - 1 >= 0 && a[i - 1] == 1) || (j < n && a[j] == 1)) {
            ans -= 2 * (j - i - 1);
        }
    }
    cout << ans << '\n';
}

////////////////////////////////////////

// #define IGNORE_CR

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
#ifdef IGNORE_CR
            if (c == '\r') {
                continue;
            }
#endif
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            assert(!isspace(buffer[pos]));
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e5);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(2, 3e5);
        sn += n;
        in.readEoln();
        auto a = in.readInts(n, 1, 3);
        in.readEoln();
        ostringstream sout;
        sout << n << '\n';
        for (int i = 0; i < n; i++) {
            sout << a[i] << " \n"[i == n - 1];
        }
        solve(istringstream(sout.str()));
    }
    cerr << sn << endl;
    assert(sn <= 3e5);
    in.readEof();
    return 0;
}
Editorialist's code (PyPy3)
for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    for i in range(1, n-1):
        if a[i-1] + a[i+1] == 4:
            a[i] = min(a[i], 4 - a[i])
    for i in reversed(range(1, n-1)):
        if a[i-1] + a[i+1] == 4:
            a[i] = min(a[i], 4 - a[i])
    print(sum(a))
1 Like

Thankyou