MNUSE - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iceknight1093
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Simple

PREREQUISITES:

Basic Math

PROBLEM:

Given N, K, S, M, find the minimum number of times M must be used in a choices of K values in [1, N] that sum to S.

EXPLANATION:

It turns out that there are several different ways to approach this problem.
Here is one of them.

First, note that M+M = 2M = (M-1) + (M+1).
This means that if we ever have two copies of M, we can replace them with two elements that are not equal to M while keeping the sum the same.
The only time this is not possible is when M = 1 or M = N, in which case the values M-1 and M+1 respectively are not valid choices.

So, let’s set aside M = 1 and M = N for now, and work with just 1 \lt M \lt N.


When 1 \lt M \lt N, we know that we’ll never need two or more copies.
So, we only need to check if using 0 copies is possible - if it’s not, the answer is 1.

To check if using 0 copies is possible, we want to know whether we can choose K values from among [1, M-1] and [M+1, N] such that their sum equals S.

One way to perform this check is as follows.
Suppose we pick L elements from [1, M-1] and R = (K-L) elements from [M+1, N].
Then,

  • The L elements from the first interval can be chosen to have any sum in the range
    [L, L\cdot (M-1)].
    This should be easy to see: L and L\cdot (M-1) are the lowest and highest possible sums, and any value between them can be reached by starting with L ones and then repeatedly incrementing a value by 1 (though not exceeding M-1) till we reach a state of all (M-1).
  • Similarly, the R elements from the second interval can be chosen to have any sum in the range
    [R\cdot (M+1), R\cdot N].
  • Exactly the same logic applies to merging the sums formed by these two intervals as well!
    That is, the lowest possible sum we can obtain is L + R\cdot (M+1), the highest is L\cdot (M-1) + R\cdot N, and everything in between these two values is attainable as well.

So, if lo_L and hi_L denote the minimum and maximum possible sums given that we pick exactly L elements from [1, M-1], it’s possible to form the sum S if and only if lo_L \le S \le hi_L.

We can now try every choice of L from 0 to K, and if S satisfies the requisite condition for any one of them the answer is 0; otherwise the answer is 1.
This gives us a solution in \mathcal{O}(K) time.


Now, we return to the two cases we missed M = 1 and M = N.

Let’s look at M = 1 first.
Suppose we have X occurrences of M.
Then, the remaining K-X elements must all lie in [2, N].
As we’ve seen above, this means the sum of these elements must lie in [2\cdot (K-X), N\cdot (K-X)].
After X occurrences of 1, the remaining sum we want is simply S-X.
So, this X is valid if and only if

2\cdot (K-X) \le S - X \le N\cdot (K-X)

This now allows us to check every possible value of X, giving a solution in \mathcal{O}(K) time.

M = N can be solved similarly, just that the interval for other elements is [1, N-1] instead of [2, N].


It is in fact possible to solve the problem in \mathcal{O}(1) time using a bit of math and/or casework, but it wasn’t enforced in this problem - since the sum of K across tests is bounded, the above solution is good enough (and much cleaner than having to do any casework.)

TIME COMPLEXITY:

\mathcal{O}(K) per testcase.

CODE:

Editorialist's code (PyPy3)
for _ in range(int(input())):
    n, k, s, m = map(int, input().split())

    if m == 1:
        for i in range(k+1):
            # [2, n] k-i times
            lo = 2*(k-i)
            hi = n*(k-i)
            if lo <= s - i <= hi:
                print(i)
                break
    elif m == n:
        for i in range(k+1):
            # [1, n-1] k-i times
            lo = (k-i)
            hi = (n-1)*(k-i)
            if lo <= s - i*n <= hi:
                print(i)
                break
    else:
        # ans is either 0 or 1
        ans = 1

        # check 0
        for L in range(k+1):
            R = k-L

            lo = L + (m+1)*R
            hi = L*(m-1) + n*R

            if lo <= s <= hi:
                ans = 0
        print(ans)

2 Likes

C was very tough then D, iI don’t know how people get it, and could you please share o(1) solution

2 Likes