MMA - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: pols_agyi_pols
Tester: kingmessi
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Binary search

PROBLEM:

There are N monsters, the i-th has health H_i.
You start with a strength of X, and can kill the i-th monster only if X\gt H_i.
After killing the i-th monster, your strength gets set to H_i.

At most once, you can multiply your strength by K.
Find the maximum number of monsters you can kill.

EXPLANATION:

First, let’s solve the problem if multiplication operation didn’t exist.

After killing a monster with health H_i, our strength gets set to H_i.
This means the next monster we kill must have a strength that strictly less than H_i.
The monsters we kill must thus have a strictly decreasing sequence of strengths.

The very first monster we kill must have a strength that’s less than X.
So, the maximum number of monsters that can be killed, is simply the number of distinct strengths less than X.


Now, let’s think about what the multiplication operation lets us do.
There are two possibilities: either we multiply right at the start, or we kill a few monsters and then multiply.

The first case is trivial: multiplying right at the start gives us a strength of X\cdot K, after which we can kill one monster for each distinct strength that’s less than X\cdot K.

The second case requires a bit more thought.
Suppose we multiply right after killing a monster with strength H_i, so that we’re now at a strength of K\cdot H_i.
Then,

  • We can certainly kill one monster for each distinct strength \lt K\cdot H_i.
  • We can also kill one monster for each distinct strength that’s \geq H_i and \lt X, on our initial path to reach H_i.
  • There is one caveat however: if there’s exactly one monster with strength Y, where both Y\lt K\cdot H_i and H_i \leq Y \lt X, then it should be counted once and not twice (since we can only kill it once).

So, for a fixed H_i being the last kill before multiplication, we want to know:

  1. The number of distinct elements that are \lt K\cdot H_i.
  2. The number of distinct elements that are \lt X and \geq H_i.
  3. The number of elements that appear exactly once, and are between X_i and \min(X, K\cdot H_i) - 1.

For the hard version, all of these can be found quite easily using binary search: for the first two, store a sorted list of all distinct elements and binary search on that; for the third, store a sorted of all elements that appear exactly once and binary search on that.

So, for a fixed H_i the answer can be found in \mathcal{O}(\log N) time - we only perform a few binary searches.
Repeat this for every possible H_i \lt X, and take the best answer among them.
Don’t forget to also consider the case of taking all the distinct elements \lt X\cdot K.

TIME COMPLEXITY:

\mathcal{O}(N\log N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define ll long long

int main(){
	ios_base::sync_with_stdio(false);
	cin.tie(NULL);
    ll kitne_cases_hain;
    kitne_cases_hain=1;
    cin>>kitne_cases_hain;
    while(kitne_cases_hain--){          
        ll n,x,k;
        cin>>n>>x>>k;
        ll y;
        map <ll,ll> m;
        for(int i=0;i<n;i++){
            cin>>y;
            m[y]++;
        }
        vector <ll> v;
        for(auto it:m){
            v.push_back(it.first);
        }
        vector <ll> u;
        ll f=v.size();
        ll ans=0,cnt,z=0,maxi;
        for(int i=f-1;i>=0;i--){
            if(v[i]>=x){
                cnt=(lower_bound(v.begin(),v.end(),x*k)-v.begin());
            }else{
                z++;
                cnt=i+z;
                maxi=v[i]*k;
                if(maxi>x){
                    cnt+=(lower_bound(v.begin(),v.end(),maxi)-lower_bound(v.begin(),v.end(),x));
                    maxi=x;
                }
                if(m[v[i]]>1){
                    u.push_back(-1*v[i]);
                }
                y=u.size();
                if(y){
                    cnt+=(y-(upper_bound(u.begin(),u.end(),-1*maxi)-u.begin()));    
                }
            }
            ans=max(ans,cnt);
        }
        cout<<ans<<"\n";
    }
	return 0;
}

Tester's code (C++)
//Har Har Mahadev
#include<bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp> // Common file
#include <ext/pb_ds/tree_policy.hpp>
#define ll long long
#define int unsigned int
#define rep(i,a,b) for(int i=a;i<b;i++)
#define rrep(i,a,b) for(int i=a;i>=b;i--)
#define repin rep(i,0,n)
#define precise(i) cout<<fixed<<setprecision(i)
#define vi vector<int>
#define si set<int>
#define mii map<int,int>
#define take(a,n) for(int j=0;j<n;j++) cin>>a[j];
#define give(a,n) for(int j=0;j<n;j++) cout<<a[j]<<' ';
#define vpii vector<pair<int,int>>
#define db double
#define be(x) x.begin(),x.end()
#define pii pair<int,int>
#define pb push_back
#define pob pop_back
#define ff first
#define ss second
#define lb lower_bound
#define ub upper_bound
#define bpc(x) __builtin_popcountll(x) 
#define btz(x) __builtin_ctz(x)
using namespace std;

using namespace __gnu_pbds;

typedef tree<int, null_type, less<int>, rb_tree_tag,tree_order_statistics_node_update> ordered_set;
typedef tree<pair<int, int>, null_type,less<pair<int, int> >, rb_tree_tag,tree_order_statistics_node_update> ordered_multiset;

const long long INF=1e18;
const long long M=1e9+7;
const long long MM=998244353;
  
int power( int N, int M){
    int power = N, sum = 1;
    if(N == 0) sum = 0;
    while(M > 0){if((M & 1) == 1){sum *= power;}
    power = power * power;M = M >> 1;}
    return sum;
}

int smn = 0;
 
void solve()
{
    int n,d,k;
    cin >> n >> d >> k;

    smn += n;
    vi a(n);
    take(a,n);

    set<int> s;
    mii m;
    for(auto x : a)s.insert(x),m[x]++;
    vi v;
    for(auto x : s)v.pb(x);
    n = v.size();
    vi pf(n);
    int ans = 0;
    rep(i,0,n){
        if(m[v[i]] > 1 || v[i] >= d)pf[i] = 1;
        if(i)pf[i] += pf[i-1];
        if(d*k > v[i])ans++;
    }
    int mx = 0;
    int tot = 0;
    rep(i,0,n){
        if(v[i] >= d)break;
        tot++;
        int l = i,r = n;
        while(r > l+1){
            int m = (l+r)/2;
            if(v[m] < k*v[i])l = m;
            else r = m;
        }
        mx = max(mx,pf[l]-(i?pf[i-1]:0));
    }
    ans = max(ans,tot + mx);
    cout << ans << "\n";

}

signed main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    #ifdef NCR
        init();
    #endif
    #ifdef SIEVE
        sieve();
    #endif
    int t;
    cin >> t;
    // t = inp.readInt(1,100'000);
    // inp.readEoln();
    while(t--)
        solve();
    // inp.readEof();
    return 0;
}
Editorialist's code (Python)
import bisect, collections

for _ in range(int(input())):
    n, x, k = map(int, input().split())
    a = list(map(int, input().split()))
    
    freq = collections.Counter(a)
    distinct, twice = [], []
    for y in freq.keys():
        distinct.append(y)
        if freq[y] > 1: twice.append(y)
    distinct = sorted(distinct)
    twice = sorted(twice)
    
    ans = 0
    for y in distinct:
        if y >= x: break

        # take till y -> multiply -> take remaining
        # one copy of everything that's < max(y*k, x)
        # another copy of everything that's < min(y*k, x) and >= y

        mn, mx = min(y*k, x), max(y*k, x)
        cur = bisect.bisect_left(distinct, mx)
        cur += bisect.bisect_left(twice, mn) - bisect.bisect_left(twice, y)
        ans = max(ans, cur)
    
    # also consider multiplying right at the start
    ans = max(ans, bisect.bisect_left(distinct, x*k))
    print(ans)