PRIMEPREFIX - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: sushil2006
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Easy

PREREQUISITES:

Dynamic programming

PROBLEM:

You’re given an array A consisting of ones and twos.
Define f(A) to be the number of prefix sums of A that are primes.
Let M = \max f(R) across all rearrangements R of A.

Find the minimum number of adjacent swaps in A needed to attain a rearrangement with value M.

EXPLANATION:

First, let’s find the value M, i.e. the maximum possible number of prime prefix sums.

Let S = A_1 + A_2 + \ldots + A_N denote the sum of the elements of A.
Clearly, any prefix sum must be \le S, so the absolute best we can do is to reach all primes that are \le S.

It turns out that this is (almost) always possible.
For example, one construction that attains this maximum is as follows:

  • If every element of A equals 1, then N = S and we trivially reach the values 1, 2, 3, \ldots, N so the claim holds.
  • If A contains both a 1 and a 2, we have the following construction:
    [2, 1, 2, 2, \ldots, 2, 2, 1, 1, \ldots, 1]
    • That is, place a single 2, then a single 1, then all remaining 2’s, and finally all remaining 1’s.
    • This obtains 2 and 3 as the initial prefix sums, then it goes through all odd numbers starting from 3 onwards.
      Once the twos are used up, we place the remaining ones, going through all remaining numbers without skipping anything.
    • Because every prime other than 2 is odd, and we don’t skip any odd number \le S (except 1), every prime \le S will appear as a prefix sum.

The only exceptional case is when S contains only twos; in which case there’s only one rearrangement possible and the only prime it can reach is 2.

So, the single edge case aside, we never have to skip a prime number \le S = \text{sum}(A).


Now, we shift our focus to computing the minimum swaps needed.

To do that, the key observation is that because we have only ones and twos, we can focus on just one ‘type’ of value.

For example, let’s look at only the positions of the 1’s.
Note that if these are fixed, then the positions of the 2’s are automatically fixed too.

Further, note that if the 1’s in A are initially at indices x_1, x_2, \ldots, x_k and finally at positions y_1, y_2, \ldots, y_k, then we need

\sum_{i=1}^k |x_i - y_i|

swaps to achieve this configuration.


This method of representing the swap cost only in terms of the positions of the ones is quite helpful, and allows us to write a solution using dynamic programming.

Specifically, let’s define dp(i, j) to be the minimum cost such that:

  • We’ve placed i ones and j twos in the prefix of length (i+j).
    • Note that the sum of this prefix now equals i+2j.
  • While doing so, we did not skip any prime number that’s \le i+2j.

Transitions are as follows.
The last element of the prefix can be either a one or a two.

First, consider the case that the last element in the prefix is a 1.
Then, we need to place i-1 ones and j twos in the prefix before this.
Further, this guarantees that the i-th one will end up at index (i+j), which as we saw above incurs a cost of |x_i - (i+j)| (given that its initial position is x_i).
Thus, the minimum cost of doing this is

dp(i-1, j) + |x_i - (i+j)|

Note that since we’re placing a 1, no value gets skipped, so we don’t need to worry about skipping a prime.

Next, consider the case that the last element in the prefix is a 2.
Note that this is only valid when (i+2j-1) is not a prime, since we’re skipping that value as a prefix sum.
In this case, the minimum cost of the previous prefix if dp(i, j-1), but that’s it - since we’ve rewritten the cost to depend purely on the positions of the 1’s, placing a 2 here doesn’t affect the cost in any way (yet).

Putting the cases together,

  • If i+2j-1 is not a prime, we have
    dp(i, j) = \min(dp(i, j-1), dp(i-1, j) + |x_i - (i+j)|)
  • If i+2j-1 is a prime, we have dp(i, j) = dp(i-1, j) + |x_i - (i+j)|.

Needless to say, take care of the i = 0 and j = 0 cases appropriately (where it’s impossible to place any ones/twos respectively.)

The final answer is, of course, dp(c_1, c_2), where c_1, c_2 denote the counts of 1 and 2 in A respectively.

There are \mathcal{O}(N^2) states and \mathcal{O}(1) transitions from each one, so this is fast enough for the constraints.

TIME COMPLEXITY:

\mathcal{O}(N^2) per testcase.

CODE:

[details = Editorialist’s code (C++)]

// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

const int N = 5005;
int dp[N][N];

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    const int M = 10010;
    vector prime(M, 0);
    for (int i = 2; i < M; ++i) {
        prime[i] = 1;
        for (int j = 2; j < i; ++j) {
            if (i%j == 0) {
                prime[i] = 0;
                break;
            }
        }
    }

    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector a(n, 0);
        for (int &x : a) cin >> x;

        vector<int> ones;
        for (int i = 0; i < n; ++i) {
            if (a[i] == 1) ones.push_back(i);
        }
        
        if (size(ones) == 0) {
            cout << 0 << '\n';
            continue;
        }
        
        int p = size(ones);
        int q = n - p;
        
        for (int i = 0; i <= p; ++i) {
            for (int j = 0; j <= q; ++j) {
                dp[i][j] = 1e9;

                if (i == 0 and j == 0) dp[i][j] = 0;
                if (i > 0) {
                    int from = ones[i-1];
                    int to = i+j-1;
                    dp[i][j] = min(dp[i][j], dp[i-1][j] + abs(from - to));
                }
                if (j > 0 and !prime[i + 2*j - 1]) {
                    dp[i][j] = min(dp[i][j], dp[i][j-1]);
                }
            }
        }
        cout << dp[p][q] << '\n';
    }
}

[/details]

1 Like

This is a typo. It should be ‘2’. Please fix it asap.

Fixed, thank you.

Here’s a linear solution

#include <bits/stdc++.h>
using namespace std;

/*
IDEA (Prime Prefix + Minimum Swaps)

We want to reorder an array of 1s and 2s to:
1) Maximize the number of prefix sums that are prime
2) Among those, minimize the number of adjacent swaps

Key observations:
- Prefix sum after i elements = 2*i - (# of 1s used so far)
- After prefix sum 3, all primes are odd ⇒ gaps between primes are even
- Each gap corresponds to a segment of only 2s
- Therefore, remaining 1s must be used in PAIRS to fit into gaps

Strategy:
1) Reduce total sum to the largest prime ≤ sum by trimming a suffix
   → may incur some swap cost ("extra")

2) Fix the prefix:
   - Either make first 3 elements = 1,1,1 (sum = 3)
   - Or make first 2 elements = 2,1 (sum = 3)
   → try both, take minimum cost

3) Remaining 1s are processed in consecutive pairs:
   - For each pair, compute minimum swap cost to fit into a valid prime gap
   - Use prefix sums + prime gaps to evaluate best placement

4) Total cost = prefix cost + suffix trimming cost + pairing cost

Time complexity: ~O(N + #pairs * small constant)
*/

const int INF = 1e9;

vector<int> primes, gap;

void sieve(int n) {
    vector<int> spf(n + 1);
    for (int i = 2; i <= n; i++) {
        if (!spf[i]) spf[i] = i, primes.push_back(i);
        for (int p : primes) {
            if (p > spf[i] || 1LL * p * i > n) break;
            spf[p * i] = p;
        }
    }
}

int run(vector<int>& a, int n, deque<int>& pos) {
    vector<int> ps(n);
    ps[0] = a[0];
    for (int i = 1; i < n; i++) ps[i] = ps[i - 1] + a[i];

    int ans = 0;
    for (int i = 0; i < (int)pos.size(); i += 2) {
        int j = i + 1;
        int pi = lower_bound(primes.begin(), primes.end(), ps[pos[i]]) - primes.begin();
        int pj = lower_bound(primes.begin(), primes.end(), ps[pos[j]]) - primes.begin() - 1;

        if (pi > pj) continue;

        if (pi == pj) {
            int x = pos[i];
            while (ps[x] <= primes[pi]) x++;
            ans += min(x - pos[i], pos[j] - x);
        } else {
            int l = pos[i], r = pos[j];
            while (ps[l] <= primes[pi]) l++;
            while (ps[r] >= primes[pj]) r--;
            r++;

            int best = min(r - pos[i], pos[j] - l);
            for (int k = pi + 1; k <= pj; k++) {
                best = min(best, pos[j] - pos[i] - gap[k] / 2);
            }
            ans += best;
        }
    }
    return ans;
}

void solve() {
    int n;
    cin >> n;

    vector<int> a(n), cnt(3);
    deque<int> pos;
    for (int i = 0; i < n; i++) {
        cin >> a[i];
        cnt[a[i]]++;
        if (a[i] == 1) pos.push_back(i);
    }

    if (cnt[1] == 0 || cnt[2] == 0) {
        cout << 0 << '\n';
        return;
    }

    int sum = accumulate(a.begin(), a.end(), 0);
    int mx = 0;
    for (int p : primes) {
        if (p <= sum) mx = p;
        else break;
    }

    int over = sum - mx;
    if (mx == 3) {
        if (over == 0) cout << (a[0] == 2 ? 0 : 1) << '\n';
        else if (a[0] == 2) cout << 0 << '\n';
        else if (a[1] == 2) cout << 1 << '\n';
        else cout << 2 << '\n';
        return;
    }

    int extra = 0;
    if (over > 0) {
        for (int i = n - 1, cur = 0; i >= 0; i--) {
            cur += a[i];
            if (cur == over) { n = i; break; }
            if (cur > over) {
                int idx = lower_bound(pos.begin(), pos.end(), i) - pos.begin();
                if (idx == (int)pos.size()) {
                    extra = i - pos.back();
                    swap(a[i], a[pos.back()]);
                    pos.pop_back();
                    n = i;
                } else {
                    n = pos[idx] + 1;
                }
                break;
            }
        }
    }

    while (!pos.empty() && pos.back() >= n) pos.pop_back();
    if (pos.empty() || (int)pos.size() == n) {
        cout << extra << '\n';
        return;
    }

    vector<int> a0 = a;
    deque<int> pos0 = pos;
    int ans = INF;

    // Case 1: prefix = 1,1,1
    if ((int)pos.size() >= 3) {
        int cur = 0;
        swap(a[0], a[pos[0]]); cur += pos[0];
        swap(a[1], a[pos[1]]); cur += pos[1] - 1;
        swap(a[2], a[pos[2]]); cur += pos[2] - 2;
        pos.pop_front(); pos.pop_front(); pos.pop_front();
        ans = min(ans, cur + extra + run(a, n, pos));
    }

    a = a0;
    pos = pos0;

    // Case 2: prefix = 2,1
    {
        int cur = 0;
        if (a[0] == 2) {
            if (a[1] == 1) {
                pos.pop_front();
            } else {
                cur += pos.front() - 1;
                swap(a[1], a[pos.front()]);
                pos.pop_front();
            }
        } else {
            int i = 0;
            while (a[i] != 2) i++;
            swap(a[0], a[i]);
            cur += i;
            pos.pop_front();
            for (int& p : pos) {
                if (p < i) p++;
                else break;
            }
        }
        ans = min(ans, cur + extra + run(a, n, pos));
    }

    cout << ans << '\n';
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    sieve(20000);
    gap.push_back(0);
    for (int i = 1; i < (int)primes.size(); i++) {
        gap.push_back(primes[i] - primes[i - 1]);
    }

    int T;
    cin >> T;
    while (T--) solve();
    return 0;
}