MAXSUMOPS - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

None

PROBLEM:

You have an array A. Perform K moves of the following kind:

  • Choose x\in A, and replace it with \frac{x}{2} if it’s even or x+1 if it’s odd.

Find the maximum possible sum of the array after applying all K operations.

EXPLANATION:

First, note that we can always increase odd elements of A by 1, and it’s optimal to do so while we can.
So, spend as many moves as possible doing this.

Now we only need to deal with arrays whose elements are all even.

The main observation to be made here is that almost every element will be left untouched.
Indeed, in almost every case, it’s optimal to use all our operations on a single element.

Proof

Let A be sorted, so A_1 \leq A_2 \leq A_3 \leq \ldots

If we use all our operations on A_1, we “lose” at most A_1 - 1 from the overall sum.

On the other hand, suppose we use \geq 1 operation on two different indices.
Then, the minimum “loss” is at least \frac{A_1}{2} + \frac{A_2}{2} - 2
Notice that when A_1 \lt A_2, we have

\frac{A_1}{2} + \frac{A_2}{2} - 2 \gt 2\cdot \frac{A_1}{2} - 2 = A_1 - 2

The loss is strictly greater than A_1-2, meaning it’s at least A_1-1 (which we could’ve obtained by using all our operations on A_1 anyway).

So, when A_1 \lt A_2, there exists an optimal solution where we use all our operations on a single element.
When A_1 = A_2, we might want to use operations on both of them (but again, we’ll never operate on two elements that were initially different).

Finding what happens to a single element when all the operations are applied to it isn’t hard: note that A_i at least halves after two operations, so after \mathcal{O}(\log A_i) operations it’ll reach 1.
After this, it’ll simply keep cycling between 1 and 2, with its final value only depending on the parity of the number of operations remaining.
So, all K operations can be directly simulated on a single element in \mathcal{O}(\min(K, \log A_i)) time.
Doing this for every element is fast enough.


As outlined in the proof above, the only edge case is when A_1 = A_2, in which case we might want to split operations between them.
This edge case can’t be ignored: for example, A = [6, 6] with K = 4 has its optimal result be [4, 4] rather than [1, 6].

It can further be seen that if K is “large enough” (say, \geq 10) it’s still not optimal to operate on both elements, since more than halving both is strictly worse than just reducing one of them to 1.

So, when A_1 = A_2 and K is small enough, you can further simulate all possible choices of distributing K operations between them, and consider that for the answer too.
In fact, it turns out that literally the only case when it’s optimal to do this, is when A_1 = A_2 = 6 and K = 4.

TIME COMPLEXITY:

\mathcal{O}(N\cdot\min (K, \log\max A)) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
#define f first
#define s second

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

int f(int x, int k){
    if (k == 0) return x;
    if (x == 1){
        if (k % 2 == 0) return 1;
        else return 2;
    }

    if (x % 2 == 0) x /= 2;
    else x++;

    return f(x, k - 1);
}

void Solve() 
{
    int n, k; cin >> n >> k;
    vector <int> a(n);
    int sum = 0, odd = 0;
    for (auto &x : a) cin >> x, sum += x, odd += x % 2;

    if (k <= odd){
        cout << sum + k << "\n";
        return;
    }

    for (auto &x : a) if (x & 1) x++;

    k -= odd;
    sum += odd; 

    int loss = INF;
    for (auto x : a){
        loss = min(loss, x - f(x, k));
    }
    
    sort(a.begin(), a.end());

    if (k == 4 && a[0] == 6 && a[1] == 6){
        loss = 4;
    }

    sum -= loss;

    cout << sum << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++)
#ifndef LOCAL
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,avx2,sse,sse2,sse3,sse4,popcnt,fma")
#endif
#include <bits/stdc++.h>
using namespace std;

int32_t main() {
    ios_base::sync_with_stdio(false);   cin.tie(nullptr);

    int T;  cin >> T;
    while(T-- > 0) {
        int N;  cin >> N;
        int64_t K;  cin >> K;
        vector<int64_t> A(N);
        int64_t sum = 0, count = 0;
        for(auto &a: A) {
            cin >> a;
            if((a & 1) && K > 0) {
                ++a, --K;
            }
            sum += a;
            if(a == 6)
                ++count;
        }
        int64_t res = sum - A[0] + 1;
        if(count > 1 && K == 4)
            res = max(res, sum - 4);

        int ops = 0;
        for(auto &a: A) {
            auto here = sum - a;
            ops = 0;
            while(ops < K && a > 1) {
                if(a & 1)   ++a;
                else    a >>= 1;
                ++ops;
            }
            res = max(res, here + a + (K - ops & 1));
        }
        cout << res << '\n';
    }

    return 0;
}
Editorialist's code (Python)
import sys
input = sys.stdin.readline
for _ in range(int(input())):
    n, k = map(int, input().split())
    a = sorted(list(map(int, input().split())))
    for i in range(n):
        if a[i]%2 == 1 and k > 0:
            a[i] += 1
            k -= 1
    tot = sum(a)
    ans = 0
    if a[0] == a[1]:
        cur, rem, x = sum(a[2:]), k, a[0]
        while rem > 1 and x > 1:
            rem -= 2
            if x%2 == 0: x //= 2
            else: x += 1
        rem %= 4
        if rem == 0: ans = cur + 2*x
        elif rem == 1: ans = (cur + x + x//2 if x%2 == 0 else cur + 2*x + 1)
        elif rem == 2: ans = cur + 4
        else: ans = cur + 3
    for i in range(n):
        rem = k
        cur = tot - a[i]
        while rem > 0 and a[i] > 1:
            if a[i]%2 == 0: a[i] //= 2
            else: a[i] += 1
            rem -= 1
        if rem%2 == 1: a[i] += 1
        ans = max(ans, cur + a[i])
    print(ans)
1 Like