DIDE - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iceknight1093
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

None

PROBLEM:

Alice rolls N standard 6-sided dice. The i-th of them lands on A_i.
At most K times, Alice can flip a die to its opposite (where (1, 6), (2, 5), (3, 4) are opposites).

What’s the maximum sum she can achieve?

EXPLANATION:

First, note that if A_i is one of \{1, 2, 3\}, flipping it will always only increase Alice’s score; and if A_i is one of \{4, 5, 6\}, flipping it will only decrease her score.
So, Alice will only flip some of the dice that are originally 1, 2, or 3.

Next, notice that Alice gets the maximum “profit” by converting a 1 to a 6, the second highest profit by converting 2 to 5, and minimal profit from converting 3 to 4.
So, Alice’s optimal strategy is quite simple:

  • Flip as many 1's to 6 as possible.
  • When no 1's remain, flip as many 2's to 5 as possible.
  • When no more 1's and 2's remain, flip as many 3's to 4 as possible.

This can easily be simulated with a loop.
Let c_x denote the number of occurrences of x in A.
Then, while K \gt 0,

  • If c_1 \gt 0, decrement it by 1 and increment c_6 by 1.
  • Else, if c_2 \gt 0, decrement it by 1 and increment c_5 by 1.
  • Else, if c_3 \gt 0, decrement it by 1 and increment c_4 by 1.
  • Else, stop.

Finally, the answer is

c_1 + 2c_2 + 3c_3 + 4c_4 + 5c_5 + 6c_6

An alternate implementation is to sort A and then flip elements from left to right as long as K \gt 0 (of course, don’t flip anything \geq 4).

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Editorialist's code (Python)
for _ in range(int(input())):
    n, k = map(int, input().split())
    a = list(map(int, input().split()))
    ct = [0]*7
    for i in range(n):
        ct[a[i]] += 1
    while k > 0:
        if ct[1] > 0:
            ct[1] -= 1
            ct[6] += 1
        elif ct[2] > 0:
            ct[2] -= 1
            ct[5] += 1
        elif ct[3] > 0:
            ct[3] -= 1
            ct[4] += 1
        else:
            break
        k -= 1
    ans = 0
    for i in range(1, 7):
        ans += i*ct[i]
    print(ans)
Alternate implementation (Python)
for _ in range(int(input())):
    n, k = map(int, input().split())
    a = list(map(int, input().split()))
    a = sorted(a)
    
    ans = 0
    for i in range(n):
        if k == 0 or a[i] >= 4:
            ans += a[i]
            continue
        a[i] = 7 - a[i]
        k -= 1
        ans += a[i]
    print(ans)