AVGAPP - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iceknight1093
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Simple

PREREQUISITES:

Greedy

PROBLEM:

You have an array A. K times, you must do the following:

  • Choose two different elements of A, and append their average (rounded up) to A.

Find the minimum possible sum of A in the end.

EXPLANATION:

Define \text{avg}(x, y) to be \frac{x+y}{2} if x+y is even, and \frac{x+y+1}{2} otherwise.
Our operation is to choose A_i and A_j, and append \text{avg}(A_i, A_j) to A.

Since the objective is to minimize the sum, it makes sense to append elements that are as small as possible.

Let m_1 and m_2 be the two smallest elements of A.
The absolute smallest element we can append on the first move is \text{avg}(m_1, m_2); so if K = 1 doing this would be optimal.

What about K \gt 1?
The natural greedy approach now is to just repeat this process: that is, at each step, compute m_1 and m_2 to be the two smallest elements of A, and append \text{avg}(m_1, m_2) to A.
It’s not hard to see this results in the global optimum too - a simple proof is to note that in every operation it’s ideal to involve the minimum value; and once this is fixed it’s always optimal to choose the second minimum as well.


However, we run into an issue of speed now: computing the values of m_1 and m_2 requires \mathcal{O}(N) time, and doing this for K operations will result in a complexity of \mathcal{O}(NK) which is way too much (in fact, since the array’s length keeps increasing, the true complexity is a bit worse at \mathcal{O}((N+K)\cdot K).)

To optimize this, we make a couple of observations.
First, note that for any x, y we have \min(x, y) \leq \text{avg}(x, y) \leq \max(x, y), after all the average of two integers lies between them (even after rounding up this remains true).

So, if m_1 is the smallest element and m_2 is the second smallest element, we have
m_1 \leq \text{avg}(m_1, m_2) \leq m_2.
This tells us that after appending the average, the minimum element remains m_1, while the new second smallest element is \text{avg}(m_1, m_2).

This immediately improves our complexity to \mathcal{O}(N + K), since after computing m_1 and m_2 for the first time, m_1 remains constant while the change in m_2 can be computed in constant time.
However, since K \leq 10^9 this is still not fast enough.

To further improve this, note that after each operation, the distance between m_1 and m_2 shrinks by about a factor of 2, since \text{avg}(m_1, m_2) is exactly in the middle of m_1 and m_2.
That is, if d = m_2 - m_1, then d gets divided by about 2 after each operation.
So, after about \log_2 d operations, d will become no more than 1, and then never change again.

That is, after very few operations, m_1 and m_2 will both never change again (and this will happen exactly when m_2 \leq m_1 + 1).
Once this stage is reached, we don’t need to simulate further: \text{avg}(m_1, m_2) will be a constant.
So, if there are X operations left to do, we can just add X\cdot\text{avg}(m_1, m_2) to the answer.

Analyzing the complexity, we see that:

  • Computing m_1 and m_2 is done once, and takes \mathcal{O}(N) time (or \mathcal{O}(N\log N) if you sort).
  • Then, we directly simulate some operations till m_2 \leq m_1 + 1 happens.
    This will not take much more than about \log_2 10^9 moves - specifically, if the initial value of (m_2 - m_1) is d, then it will require no more than 1 + \log_2 d moves; and since d \leq 10^9 this value is bounded by 32 or so.
    So, we do \min(32, K) operations here.
  • The remaining part is constant time since we add a single term of the form X\cdot \text{avg}(m_1, m_2) to the answer for appropriate X (i.e. K minus the number of operations done so far).

So, the overall complexity is \mathcal{O}(N + \min(K, \log{10^9})) which is easily fast enough.

TIME COMPLEXITY:

\mathcal{O}(N + \min(K, \log{10^9})) per testcase.

CODE:

Editorialist's code (PyPy3)
for _ in range(int(input())):
    n, k = map(int, input().split())
    a = list(map(int, input().split()))

    a.sort()
    x, y = a[0], a[1]
    ans = sum(a)
    while k > 0 and x+1 < y:
        m = (x + y + 1) // 2
        ans += m
        y = m
        k -= 1
    ans += k * ((x + y + 1) // 2)
    print(ans)
1 Like