SPMISS - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: iceknight1093
Editorialist: iceknight1093

DIFFICULTY:

Cakewalk

PREREQUISITES:

None

PROBLEM:

A game has N missions. Each mission is either a normal mission or a special mission.
The i-th mission gives you A_i coins for completing it.

All normal missions are accessible from the beginning.
However, to access the special missions, you must pay a one-time cost of C coins (which can only be done if you have at least C coins at all).

Find the maximum number of coins you can earn.

EXPLANATION:

We have A_i \ge 1 for all i, which means that in general, it’s ideal to just do any mission that’s available - it will only increase the income.

This means we really only have two options for what we should do.
The first is to simply complete all normal missions, and stop there.
The second is to complete all normal missions and all special missions, while incurring a cost of C to unlock the special missions.

So, let S_1 be the total income from normal missions, and S_2 be the total income from special missions.
Then,

  1. If C \le S_1, then we are always able to unlock the special missions by paying C coins; since we can clear all normal missions first.
    So, the answer is either S_1 or (S_1 - C + S_2), whichever is larger.
    Here, S_1 is the income from clearing only normal missions, and (S_1-C+S_2) is the income from unlocking special missions on top to clear them all too.
  2. If C \gt S_1, then we cannot even unlock the special missions, since we don’t have enough coins from the normal missions.
    So, in this case the answer is simply S_1 itself.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Editorialist's code (PyPy3)
for _ in range(int(input())):
    n, c = map(int, input().split())
    a = list(map(int, input().split()))
    s = input()
    
    s1, s2 = 0, 0
    for i in range(n):
        if s[i] == '0': s1 += a[i]
        else: s2 += a[i]
        
    if c <= s1: print(max(s1, s1-c+s2))
    else: print(s1)