TWOAVG - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: frtransform
Tester & Editorialist: iceknight1093

DIFFICULTY:

2353

PREREQUISITES:

Math, (maybe) binary search

PROBLEM:

You’re given two arrays A and B, each containing integers between 1 and K.
In one move, you can insert an element between 1 and K into either A or B.
FInd the minimum number of moves such that the average of A is strictly larger than the average of B.

EXPLANATION:

First, if K = 1 then A and B can both contain only 1's, and hence will have an average of 1 no matter what.
So, making A have a strictly larger average is impossible, and the answer is -1.

If K \gt 1, it’s always possible to do so.

As a preliminary observation, note that it’s never optimal to insert anything less than K into A, and anything more than 1 into B.

So, suppose we insert x elements into A and y elements into B.
If this makes the average of A larger than the average of B, we’ll have

\frac{A_1 + A_2 + \ldots + A_N + x\cdot K}{N+x} \gt \frac{B_1 + B_2 + \ldots + B_M + y}{M+y}

Our aim is to minimize x+y, while ensuring that the above inequality holds.

Suppose we fix the value of x.
Then, the left side of that inequality becomes a constant, say C_x.
We’d then like to find the smallest possible value of y such that C_x \gt \frac{B_1 + \ldots + B_M + y}{M+y}

This can be found by direct math by manipulating the above inequality, or by using binary search (because the average of A is now fixed, and increasing y means adding more 1's to B, which can only decrease its average).
Either way, for a fixed x the problem is solved in \mathcal{O}(1) or \mathcal{O}(\log{(\text{something})}).

In fact, we don’t really have to check for too many values of x.
Note that if we choose x = M+1 and y = N+1, we have:

  • The average of A is \displaystyle \frac{A_1 + \ldots + A_N + K\cdot (M+1)}{N+M+1}

  • The average of B is \displaystyle \frac{B_1 + \ldots + B_M + N+1}{N+M+1}

  • Denominators are equal, so it’s enough to compare their numerators.

  • We have:

    • A_i \geq 1 for each i, so A_1 + \ldots + A_N \geq N
    • B_i \leq K for each i, so B_1 + \ldots + B_M \leq K\cdot M
    • K \gt 1
  • Putting all three together, we have A_1 + \ldots + A_N + K\cdot (M+1) \gt B_1 + \ldots + B_M + N+1, which is exactly what we wanted.

So, x = M+1 and y = N+1 gives us a solution already.
In particular, this means that N+M+2 is an upper bound for x+y.

This means it suffices to check for each x from 0 to N+M+2 what the best y is, and take the minimum across them all.
Note that this also means the binary search for y can be done in \mathcal{O}(\log(N+M)).

Since each x is processed quickly, this solution is fast enough.

TIME COMPLEXITY

\mathcal{O}(N+M) or \mathcal{O}((N+M)\log(N+M)) per test case.

CODE:

Author's code (C++)
#include <bits/stdc++.h>

using namespace std;

void test_case(){
    int n, m, k;
    cin >> n >> m >> k;

    vector<int> a(n), b(m);
    for (int i = 0; i < n; i++) cin >> a[i];
    for (int i = 0; i < m; i++) cin >> b[i];

    if (k == 1){
        cout << -1 << endl;
        return;
    }

    long long sumA = accumulate(a.begin(), a.end(), 0LL);
    long long sumB = accumulate(b.begin(), b.end(), 0LL);

    int ans = n + m + 2;
    int Y = n + m + 2;
    for (int X = 0; X <= n + m + 2; X++){
        while (Y >= 0 && (sumA + (long long) X * k) * (m + Y) > (sumB + Y) * (n + X)) Y--;
        Y++;
        ans = min(ans, X + Y);
    }

    cout << ans << endl;
}

int main(){
    ios_base::sync_with_stdio(false);

#ifdef LOCAL
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
#endif

    int T;
    cin >> T;

    while (T--){
        test_case();
    }

    return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
	n, m, k = map(int, input().split())
	a = list(map(int, input().split()))
	b = list(map(int, input().split()))
	if k == 1:
		print(-1)
		continue
	
	sa, sb = sum(a), sum(b)
	ans = n+m+2
	for x in range(n+m+3):
		num = (sa + k*x)*m - (n+x)*sb
		den = (n+x) - (sa + k*x)
		if den == 0: continue
		y = 1 + (num // den)
		if num > 0: y = 0
		ans = min(ans, x + y)
	print(ans)
2 Likes