PRIMEQUERY - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: souradeep1999
Tester: raysh07
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Prefix sums

PROBLEM:

You’re given an array A of length N. Answer Q queries on it.
Each query gives you L, R, K.
Consider the subarray A[L\ldots R], and change at most K elements of it to any non-negative integer you like.
Find the maximum possible number of good pairs, where (i, j) is a good pair if L \leq i \lt j \leq R, and A_i + A_j and A_i\cdot A_j are both primes.

EXPLANATION:

First, we analyze good pairs.
Note that for the product A_i\cdot A_j to be a prime, one of A_i or A_j should itself be a prime and the other should be 1 — otherwise the product will have two factors greater than 1 and hence composite.
On the other hand, the only prime p for which p+1 is also prime is p = 2.

So, the pair (i, j) can only be good if A_i = 1 and A_j = 2, or vice versa.
In particular, if there are x_1 occurrences of 1 and x_2 occurrences of 2, the number of good pairs is x_1\cdot x_2.

Now, let’s look at answering a query.
Consider the range [L, R], of which at most K elements can be changed.
Clearly changing an element to something other than 1 or 2 is useless, so the question now is: which elements should be changed, and what should we change them to?

Let x_1 and x_2 be the frequency of 1's and 2's in this range. Their initial values can be found in constant time using prefix sums.
Let x_3 = R-L+1 - x_1 - x_2 be the count of all other elements.
Our aim is to maximize x_1 \cdot x_2.
Without loss of generality, let x_1 \lt x_2.

It turns out that the following greedy strategy is optimal:

  • As long as x_3 \gt 0,
    • If x_1 \lt x_2, add 1 to x_1 and subtract 1 from x_3.
      That is, convert one “other” number to 1.
    • If x_1 = x_2, increase x_2 by 1 and decrease x_3 by 1.
      It doesn’t really matter whether we create a 1 or a 2 here, so we choose the 2 just to keep x_1 \lt x_2.
  • If x_3 = 0,
    • If x_1 \lt x_2 - 1, convert one occurrence of 2 to 1.
    • Otherwise, do nothing.

Of course, implementing this directly will take \mathcal{O}(K) time per query, which is too slow.
However, it can easily be sped up by splitting the process into phases:

  • When x_3 \gt 0, first x_1 is increased till it reaches x_2, after which you may observe that we just alternately add 1 to x_1 and x_2.
    This can easily be simulated in \mathcal{O}(1) time since initial counts are known.
  • When x_3 = 0, we repeatedly bring x_1 and x_2 closer together one step at a time, which reduces the difference between them by 2.
    It’s easy to see that this will be done \min(K, \frac{x_2 - x_1}{2}) times, so again this can be simulated in constant time.
Proof of correctness

We never need to operate on the same element more than once, so the order of operations doesn’t matter much.

So, we can always rearrange any sequence of operations we perform so that conversions from 1 to 2 (or vice versa) are performed only after conversions from other elements to 1 or 2.
Now, the remainder of our greedy choice follows from a simple exchange arguments:

  • You can show with some simple algebra that if a 1 is converted to a 2 yet there still exists some other element (not 1 or 2), it’d be strictly better to convert that other element to 2 instead.
    This is what gives us the first ‘phase’ of the algorithm, where we distribute x_3 to x_1 and x_2.
  • We’re now left with two situations: we either never reach x_3 = 0, or we do (and need to start converting between 1's and 2's).
    • In the case where we do reach x_3 = 0, note that x_1 + x_2 will remain constant (and equal to the length of the subarray).
      To maximize their product, it’s hence optimal for them to be as close to each other as possible; which our greedy choice achieves.
    • When we don’t reach x_3 = 0, again it can be shown that our greedy choice is optimal: we’ll have a total of x_1 + x_2 + \min(x_3, K) ones and twos in total; and again it’s best for them to be as close in count as possible which the greedy achieves.

TIME COMPLEXITY:

\mathcal{O}(N + Q) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp> 
#include <ext/pb_ds/tree_policy.hpp> 
using namespace std;
using namespace __gnu_pbds;
#define int long long int
#define ordered_set tree<int, nuint_type,less<int>, rb_tree_tag,tree_order_statistics_node_update> 
mt19937 rng(std::chrono::duration_cast<std::chrono::nanoseconds>(chrono::high_resolution_clock::now().time_since_epoch()).count());
#define mp make_pair
#define pb push_back
#define F first
#define S second
const int N=1000005;
#define M 1000000007
#define BINF 1e16
#define init(arr,val) memset(arr,val,sizeof(arr))
#define MAXN 10000005
#define deb(xx) cout << #xx << " " << xx << "\n";
const int LG = 22;


void solve() {

    int n;
    cin >> n;
    vector<int> a(n), one(n, 0), two(n, 0);
    for(int i = 0; i < n; i++) {
        cin >> a[i];
        if(a[i] == 1) {
            one[i] = 1;
        }
        if(a[i] == 2) {
            two[i] = 1;
        }
        if(i > 0) {
            one[i] = one[i - 1] + one[i];
            two[i] = two[i - 1] + two[i];
        }
    }

    int q;
    cin >> q;
    while(q--) {
        int l, r, k;
        cin >> l >> r >> k;
        l = l - 1;
        r = r - 1;

        int x = one[r];
        if(l > 0) {
            x = x - one[l - 1];
        }
        int y = two[r];
        if(l > 0) {
            y = y - two[l - 1];
        }

        if(x > y) {
            swap(x, y);
        }

        int can = r - l + 1 - x - y;
        int len = min(min(can, y - x), k);
        can = can - len;
        k = k - len;
        x = x + len;

        can = min(k, can);
        k = k - can;

        if(can > 0) {
            x = x + can / 2;
            y = y + can / 2;
            if((can % 2) == 1) {
                x = x + 1;
            }
        }

        if(x > y) {
            swap(x, y);
        }

        if((y - x) > 1 and k > 0) {
            int diff = min((y - x) / 2, k);
            x = x + diff;
            y = y - diff;
        }

        int c = x * y;
        cout << c << endl;
    }

}


#undef int 
int main() {
#define int long long int
ios_base::sync_with_stdio(false); 
cin.tie(0); 
cout.tie(0);
#ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("optput.txt", "w", stdout);
#endif

    
    int T;
    cin >> T;

    for(int tc = 1; tc <= T; tc++){
        // cout << "Case #" << tc << ": ";
        solve();
    }

return 0;  
 
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
#define f first
#define s second

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

void Solve() 
{
    // a + b, ab are prime 
    // if ab odd => a odd and b odd => (a + b) even => (a + b) = 2, => a = 1 = b not ok
    // ab = even => ab = 2 => (a, b) = (1, 2)
    // answer is f[1] * f[2]
    
    int n; cin >> n;
    
    vector <int> a(n);
    for (auto &x : a) cin >> x;
    
    vector <int> p1(n + 1, 0), p2(n + 1, 0);
    for (int i = 1; i <= n; i++){
        p1[i] = p1[i - 1] + (a[i - 1] == 1);
        p2[i] = p2[i - 1] + (a[i - 1] == 2);
    }
    
    int q; cin >> q;
    
    while (q--){
        int l, r, k; cin >> l >> r >> k;
        
        int x = p1[r] - p1[l - 1];
        int y = p2[r] - p2[l - 1];
        
        if (x > y) swap(x, y);
        
        int reach = min({y, r - l + 1 - y, x + k});
        k -= reach - x;
        x = reach;
        
     //   cout << x << " " << y << "\n";
        
        int ans;
        
        if (x == y){
            // increase both
            k = min(k, r - l + 1 - x - y);
            x += (k + 1) / 2;
            y += k / 2;
            
            ans = x * y;
        } else if (k > 0){
            // change y to x 
            int nx = x + 2 * k;
            if (nx > y){
                int len = r - l + 1;
                ans = (len / 2) * ((len + 1) / 2);
            } else {
                x += k;
                y -= k;
                ans = x * y;
            }
        } else {
            ans = x * y;
        }
        
        
        cout << ans << "\n";
    }
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    p1, p2 = [0], [0]
    for x in a:
        p1.append(p1[-1])
        p2.append(p2[-1])
        if x == 1: p1[-1] += 1
        if x == 2: p2[-1] += 1
    
    q = int(input())
    for i in range(q):
        l, r, k = map(int, input().split())
        ones, twos = p1[r] - p1[l-1], p2[r] - p2[l-1]
        other = r-l+1 - ones - twos
        
        if ones > twos: ones, twos = twos, ones
        change = min(k, other, twos - ones)
        ones += change
        k -= change
        other -= change
        
        if ones == twos:
            k = min(k, other)
            ones += k//2
            twos += (k+1)//2
        else:
            d = twos - ones
            k = min(k, d//2)
            ones += k
            twos -= k
        print(ones *  twos)
1 Like

why does this not work :frowning:

import sys
#sys.stdin = open('in.txt', 'r')
#sys.stdout = open('out.txt', 'w')
read = sys.stdin.readline
write = sys.stdout.write
for _ in range(int(read())):
    n = int(read())
    a = list(map(int, read().split()))
    preone = [0]
    pretwo = [0]
    for i in range(n):
        preone.append(preone[-1]+int((a[i]==1)))
        pretwo.append(pretwo[-1]+int((a[i]==2)))
    q = int(read())
    for i in range(q):
        l, r, k = map(int, read().split())
        l -= 1
        r -= 1
        ones = preone[r+1]-preone[l]
        twos = pretwo[r+1]-pretwo[l]
        other = r-l+1-ones-twos
        k = min(other, k)
        if abs(ones-twos) >= k:
            if ones > twos:
                twos += k
            else:
                ones += k
        else:
            k -= abs(ones-twos)
            ones, twos = max(ones, twos), max(ones, twos)
            ones += k//2
            twos += k-k//2
        write(str(ones*twos)+'\n')

Try on this test case :
6
1 1 1 1 2 0
1
1 6 2

I think this test case covers all possible cases. Correct answer should be 9.

3 Likes

oh ok i see why it doesnt work now

thank you so much for the help <3

someone please help, can’t find for what test cases my code fails

#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define py cout << "YES" << endl;
#define pn cout << "NO" << endl;
#define nl cout << endl;
#define all(p) p.begin(), p.end()
#define pyr                    \
    {                          \
        cout << "YES" << endl; \
        return;                \
    }
#define pnr                   \
    {                         \
        cout << "NO" << endl; \
        return;               \
    }

void breathe()
{
    ll n;
    cin >> n;
    vector<ll> arr(n);
    for (auto &&i : arr)
    {
        cin >> i;
    }

    vector<ll> ones(n, 0);
    vector<ll> twos(n, 0);
    if (arr[0] == 1)
    {
        ones[0] = 1;
    }
    if (arr[0] == 2)
    {
        twos[0] = 1;
    }

    for (ll i = 1; i < n; i++)
    {
        ones[i] = ones[i - 1];
        twos[i] = twos[i - 1];
        if (arr[i] == 1)
        {
            ones[i] += 1;
        }
        if (arr[i] == 2)
        {
            twos[i] += 1;
        }
    }

    // for (auto &&i : ones)
    // {
    //     cout << i << ' ';
    // }
    // nl;

    ll q;
    cin >> q;

    for (ll i = 0; i < q; i++)
    {
        ll l, r, k;
        cin >> l >> r >> k;
        l -= 1;
        r -= 1;

        ll a = ones[r] - ones[l];
        ll b = twos[r] - twos[l];

        if (arr[l] == 1)
        {
            a += 1;
        }
        else if (arr[l] == 2)
        {
            b += 1;
        }

        if (a > b)
        {
            swap(a, b);
        }

        ll left = r - l + 1 - a - b;
        ll cover = b - a;

        if (left >= cover)
        {
            if (k <= cover)
            {
                a += k;
                k = 0;
            }
            else
            {
                left -= cover;
                k -= cover;
                a += cover;
                cover = 0;
            }

            // cout << a << " " << b << ' ' << left << endl;
            a += (min(k, left)) / 2;
            b += (min(k, left)) / 2;
            if ((min(k, left)) % 2 != 0)
            {
                a += 1;
            }
        }
        else
        {
            if (k >= left)
            {
                a += left;
                k -= left;
                left = 0;
                cover -= left;

                a += min(k, (cover) / 2);
                b -= min(k, (cover) / 2);
            }
            else
            {
                a += k;
                k = 0;
                left -= k;
            }
        }

        cout << a * b << endl;
    }

    return;
}

int main()
{
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    ll tt = 1;
    cin >> tt;
    while (tt--)
    {
        breathe();
    }
    return 0;
}

You assigned 0 to left before subtracting left from cover. You should do it in the following order

cover -= left;
left = 0;

Also, the assignment left = 0 is completely redundant in your implementation, so you could eliminate it anyways.

yessss, got it. thank you soo much. <3

Thanks so much!