CANDIES3 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: frtransform
Testers: nishant403, satyam_343
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Harmonic series/sieve-like methods, prefix sums/binary search

PROBLEM:

You have an array A, where 1 \leq A_i \leq M for each i denotes the budget of the i-th customer.

There are infinitely many candies; whose price you can choose to be between 1 and M (every candy has the same price).
The i-th customer will buy as many candies as their budget allows.

You also have an array C of length M, where C_i denotes the bonus you receive for each candy of price i bought.
Find the maximum possible bonus if the price is chosen appropriately.

EXPLANATION:

Suppose we fix the price of the candy, P.

Of course, we can always look at each element of A and figure out how much each person will buy, but that’s too slow.
Let’s look at a different perspective.

How many people will buy exactly k candies when the price is P?
It’s not hard to see that this is exactly the number of people whose A_i lies in the range [kP, (k+1)\cdot P-1].

If we were able to quickly count the number of people whose A_i lies in this range (say, in \mathcal{O}(1)), then we could iterate over every possible value of k and add k\cdot\text{count} to the number of candies we sell.

Computing this count quickly is fairly simple.

How?

If we sort the A_i, finding the number of elements in a given range is a simple exercise in binary searching, and can be done in \mathcal{O}(\log N).

In fact, we can utilize the constraints to do even better.
Notice that A_i \leq M.

So, let \text{freq} be an array of length M, where \text{freq}[r] denote the number of people with A_i = r.
Then, what we want is \text{freq}[kP] + \text{freq}[kP+1] + \ldots + \text{freq}[kP+k-1].
This is a range sum on \text{freq}, which can be computed in \mathcal{O}(1) using prefix sums.

Now, notice that for a fixed P, we don’t need to check too many values of k.
In particular, we can stop as soon as M\lt kP, because A_i \leq M anyway.
This means we need to check for each k from 1 to \left\lfloor \frac{M}{P}\right\rfloor.
Each check is done in \mathcal{O}(1).

Doing this for every P from 1 to M brings our overall time complexity to

\mathcal{O}\left (\sum_{P=1}^M \left\lfloor \frac{M}{P}\right\rfloor\right )

which is, rather famously, equal to \mathcal{O}(M\log M).

TIME COMPLEXITY:

\mathcal{O}(M\log M) per testcase.

CODE:

Setter's code (C++)
#include <bits/stdc++.h>
#include "stdio.h"

using namespace std;

#define SZ(s) ((int)s.size())
#define all(x) (x).begin(), (x).end()
#define lla(x) (x).rbegin(), (x).rend()
#define bpc(x) __builtin_popcount(x)
#define bpcll(x) __builtin_popcountll(x)
#define MP make_pair
#define endl '\n'

mt19937 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

typedef long long ll;
const int MOD = 1e9 + 7;
const int N = 1e6 + 3e2;

int sumn = 0, summ = 0;
void solve(){
    int n, m;
    cin >> n >> m;
    
    sumn += n;
    summ += m;

    vector<int> c(m + 1), cnt(m + 1, 0);
    while (n--){
        int x;
        cin >> x;
        assert(1 <= x && x <= m);
        cnt[x]++;
    }

    for (int i = 1; i <= m; i++){
        cin >> c[i];
        assert(1 <= c[i] && c[i] <= 1000000);
    }

    for (int i = 2; i <= m; i++) cnt[i] += cnt[i - 1];

    long long ans = 0;
    for (int p = 1; p <= m; p++){
        long long candies = 0;
        for (int x = 1; x <= m / p; x++){
            int l = x * p, r = min(m, (x + 1) * p - 1);
            candies += (ll)(cnt[r] - cnt[l - 1]) * x;
        }
        ans = max(ans, candies * c[p]);
    }

    cout << ans << endl;
}

int main(){
    clock_t startTime = clock();
    ios_base::sync_with_stdio(false);

#ifdef LOCAL
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    freopen("error.txt", "w", stderr);
#endif

    int test_cases = 1;
    cin >> test_cases;
    
    assert(1 <= test_cases && test_cases <= 10000);
    
    for (int test = 1; test <= test_cases; test++){
        // cout << (solve() ? "YES" : "NO") << endl;
        solve();
    }
    
    assert(sumn <= 100000);
    assert(summ <= 100000);

    cerr << "Time: " << int((double) (clock() - startTime) / CLOCKS_PER_SEC * 1000) << " ms" << endl;

    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/
 
#define int long long 

const int MAX_T = 1e4;
const int MAX_N = 1e5;
const int MAX_SUM_N = 1e5;
const int MAX_C = 1e6;
 
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
 
int sum_n = 0;
int sum_m = 0;
int max_n = 0;
int max_m = 0;
int max_ans = 0;

void solve()
{
    int n,m;
    n = readIntSp(1,MAX_N);
    max_n = max(max_n, n);
    sum_n += n;
    assert(sum_n <= MAX_SUM_N);

    m = readIntLn(1,MAX_N);
    max_m = max(max_m, m);
    sum_m += m;
    assert(sum_m <= MAX_SUM_N);
    
    int a[n];
    for(int i=0;i<n;i++) {
        if(i != n - 1) {
            a[i] = readIntSp(1 , m);
        } else {
            a[i] = readIntLn(1 , m);
        }
    }
    
    int c[m];
    for(int i=0;i<m;i++) {
        if(i != m - 1) {
            c[i] = readIntSp(1 , MAX_C);
        } else {
            c[i] = readIntLn(1 , MAX_C);
        }
    }
    
    
    vector<int> fre(m + 1 , 0);
    for(int i=0;i<n;i++) {
        fre[a[i]]++;
    }
    
    vector<int> fre_pref_sum(m + 1 , 0);
    for(int i=1;i<=m;i++) {
        fre_pref_sum[i] = fre[i] + fre_pref_sum[i - 1];
    }
    
    int ans = 0;
    int ans_ind = -1;
    
    //iterate over P 
    for(int p=1;p<=m;p++) {
        
        // A[i]/p is bounded by M/p 
        int cur_sum = 0;
        
        for(int i=1;i<=(m/p);i++) {
            //how many values in a provide answer i (i.e. (A[j]/p) = i)
            int min_val = (p * i);
            int max_val = min(m , (p * (i + 1)) - 1);
            
            int val_count = fre_pref_sum[max_val] - fre_pref_sum[min_val - 1];
            cur_sum += i * val_count;
        }
        
        int cur_ans = cur_sum * c[p - 1];
        ans = max(ans , cur_ans);
    
        if(ans == cur_ans) {
            ans_ind = p;
        }
    }
    
    cerr << "Optimal p : " << ans_ind << " for given m : " << m << '\n'; 
    
    max_ans = max(max_ans , ans);
    
    cout << ans << '\n';
}
 
signed main()
{
    int t = 1;
    t = readIntLn(1,MAX_T);
    
    for(int i=1;i<=t;i++)
    {     
       solve();
    }
    
    assert(getchar() == -1);
 
    cerr<<"SUCCESS\n";
    cerr<<"Tests : " << t << '\n';
    cerr<<"Sum of lengths A : " << sum_n << '\n';
    cerr<<"Maximum length A : " << max_n << '\n';
    cerr<<"Sum of lengths B : " << sum_m << '\n';
    cerr<<"Maximum length B : " << max_m << '\n';
    cerr<<"Maximum answer : " << max_ans << '\n';
}
Editorialist's code (Python)
for _ in range(int(input())):
	n, m = map(int, input().split())
	a = list(map(int, input().split()))
	c = list(map(int, input().split()))
	pref = [0]*(m+1)
	for x in a: pref[x] += 1
	for i in range(1, m+1): pref[i] += pref[i-1]
	ans = 0
	for x in range(1, m+1):
		val = 0
		for y in range(x, m+1, x):
			R = min(m, y+x-1)
			val += y//x * (pref[R] - pref[y-1])
		ans = max(ans, val*c[x-1])
	print(ans)
1 Like

I managed to get O(N \sqrt{N}) solution to pass (it barely fit under the TL though, 0.96s). Basically just computed, for each A_i, which values of P will change the value of floor of \frac{A_i}{P}. (There are only O(\sqrt{A_i}) such values). Then I iterated over P, making the required changes while going from P to P + 1, and maximising the answer at each step.

My Submission

3 Likes

Yes yes I did the same, but my submission passes comfortably in .38
https://www.codechef.com/viewsolution/86018624

Turns out that removing the #define int long long from your code makes it run in 0.32s: submission.

Maybe this is a reason to consider not using it :slight_smile:

2 Likes

Wow, I’ve never really had TL issues with #define int long long before, I used to think it just doubled the memory usage and that was that. Thanks!

I have a slightly different approach.
I noticed that
P_k\sum_{i=0}^{n}\left \lfloor \frac{A_i}{k}\right \rfloor=P_k\sum_{i=0}^{n}\frac{A_i-A_i\%k}{k}=\frac{P_k}{k}(\sum_{i=0}^{n}A_i-\sum_{i=0}^{n}(A_i\%k))
for each k, (1\leq k \leq m)

You can find \sum_{i=0}^{n}(A_i\%k) in O(m \log^2(m\log n)) for each k, using Prefix Sum and Binary Search.
https://www.codechef.com/viewsolution/86078566

2 Likes

my weird solution passed 6/7 test cases! I just did initial sum divided by all 1 <= P <= M multiplied by Cp. And whenever this value was greater I stored it in a set. This gave me a rough estimate of the value I was looking for. Then I just calculated for these indices. Solution was working in 0.1s but wasnt correct by any means :rofl:

Is that by any chance the same trick as the one used in floor sum from atcoder library?

Upd: hold on, that’s not O(N log N), that’s O(N log^2 N), with the complexity of the two for loops being the same as that of a sieve. Very misleading claim about logarithmic time…

Yup, my bad.

can you please explain your approach more explicitly ??

Can anyone explain what is wrong with the approach used below.
it passed 4/7 test case .
https://www.codechef.com/viewsolution/86752980