PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: iceknight1093
Tester: sushil2006
Editorialist: iceknight1093
DIFFICULTY:
Simple
PREREQUISITES:
Greedy
PROBLEM:
You have an array A. K times, you must do the following:
- Choose two different elements of A, and append their average (rounded up) to A.
Find the minimum possible sum of A in the end.
EXPLANATION:
Define \text{avg}(x, y) to be \frac{x+y}{2} if x+y is even, and \frac{x+y+1}{2} otherwise.
Our operation is to choose A_i and A_j, and append \text{avg}(A_i, A_j) to A.
Since the objective is to minimize the sum, it makes sense to append elements that are as small as possible.
Let m_1 and m_2 be the two smallest elements of A.
The absolute smallest element we can append on the first move is \text{avg}(m_1, m_2); so if K = 1 doing this would be optimal.
What about K \gt 1?
The natural greedy approach now is to just repeat this process: that is, at each step, compute m_1 and m_2 to be the two smallest elements of A, and append \text{avg}(m_1, m_2) to A.
It’s not hard to see this results in the global optimum too - a simple proof is to note that in every operation it’s ideal to involve the minimum value; and once this is fixed it’s always optimal to choose the second minimum as well.
However, we run into an issue of speed now: computing the values of m_1 and m_2 requires \mathcal{O}(N) time, and doing this for K operations will result in a complexity of \mathcal{O}(NK) which is way too much (in fact, since the array’s length keeps increasing, the true complexity is a bit worse at \mathcal{O}((N+K)\cdot K).)
To optimize this, we make a couple of observations.
First, note that for any x, y we have \min(x, y) \leq \text{avg}(x, y) \leq \max(x, y), after all the average of two integers lies between them (even after rounding up this remains true).
So, if m_1 is the smallest element and m_2 is the second smallest element, we have
m_1 \leq \text{avg}(m_1, m_2) \leq m_2.
This tells us that after appending the average, the minimum element remains m_1, while the new second smallest element is \text{avg}(m_1, m_2).
This immediately improves our complexity to \mathcal{O}(N + K), since after computing m_1 and m_2 for the first time, m_1 remains constant while the change in m_2 can be computed in constant time.
However, since K \leq 10^9 this is still not fast enough.
To further improve this, note that after each operation, the distance between m_1 and m_2 shrinks by about a factor of 2, since \text{avg}(m_1, m_2) is exactly in the middle of m_1 and m_2.
That is, if d = m_2 - m_1, then d gets divided by about 2 after each operation.
So, after about \log_2 d operations, d will become no more than 1, and then never change again.
That is, after very few operations, m_1 and m_2 will both never change again (and this will happen exactly when m_2 \leq m_1 + 1).
Once this stage is reached, we don’t need to simulate further: \text{avg}(m_1, m_2) will be a constant.
So, if there are X operations left to do, we can just add X\cdot\text{avg}(m_1, m_2) to the answer.
Analyzing the complexity, we see that:
- Computing m_1 and m_2 is done once, and takes \mathcal{O}(N) time (or \mathcal{O}(N\log N) if you sort).
- Then, we directly simulate some operations till m_2 \leq m_1 + 1 happens.
This will not take much more than about \log_2 10^9 moves - specifically, if the initial value of (m_2 - m_1) is d, then it will require no more than 1 + \log_2 d moves; and since d \leq 10^9 this value is bounded by 32 or so.
So, we do \min(32, K) operations here. - The remaining part is constant time since we add a single term of the form X\cdot \text{avg}(m_1, m_2) to the answer for appropriate X (i.e. K minus the number of operations done so far).
So, the overall complexity is \mathcal{O}(N + \min(K, \log{10^9})) which is easily fast enough.
TIME COMPLEXITY:
\mathcal{O}(N + \min(K, \log{10^9})) per testcase.
CODE:
Editorialist's code (PyPy3)
for _ in range(int(input())):
n, k = map(int, input().split())
a = list(map(int, input().split()))
a.sort()
x, y = a[0], a[1]
ans = sum(a)
while k > 0 and x+1 < y:
m = (x + y + 1) // 2
ans += m
y = m
k -= 1
ans += k * ((x + y + 1) // 2)
print(ans)