COSS - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: pols_agyi_pols
Tester: watoac2001
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Data structures

PROBLEM:

You have an array W.
For a cost of S coins, you can swap any two elements of W.

After the swaps, you must take as many elements of W from the left as possible, such that their total sum doesn’t exceed M.
Each element taken costs C coins.

Find the minimum possible number of coins that you can spend.

EXPLANATION:

If the sum of A doesn’t exceed M, it’s always possible to take every item.
In this case, it’s optimal to just do no swaps, and the answer is N\cdot C.

Otherwise, it’s not possible to take every item.
Suppose we decide to take exactly i items, for 0 \leq i \lt N.
For this to be possible, we need to perform swaps to ensure that that W_1 + W_2 + \ldots + W_{i+1} \gt M holds (so we can’t choose the i+1-th item); and also W_1 + W_2 + \ldots + W_i \leq M.

Notice that the second condition can in fact be relaxed: it doesn’t really matter if W_1 + W_2 + \ldots + W_i \leq M or not, since if it was \gt M we wouldn’t be able to take i items at all (and such a case will be considered when we’re looking at lower i anyway).

So, all we need to do now is find the minimum number of swaps needed to make W_1 + W_2 + \ldots + W_{i+1} \gt M.
Let this be \text{opt}_i.
If we know the array \text{opt}, the final answer is just the minimum of i\cdot C + \text{opt}_i\cdot S across all i.


First, let’s look at a slow method of computing \text{opt}_i.
Clearly, the only useful swaps are to swap an element at indices 1\ldots i+1 with another element outside this range, since once i is fixed we only care about the sum of elements and not their order.
This means we never need more than i+1 swaps, at which point we’d have replaced everything in this range anyway.

So, we can simply try every count of swaps between 0 and i+1.
For a fixed count k, it’s obviously best to pick the k smallest elements among [W_1, \ldots, W_{i+1}], and swap them with the k largest elements of [W_{i+2}, \ldots, W_N].

This trivially gives a solution in \tilde{\mathcal{O}}(N^3) by trying every pair of (i, k) and finding the smallest/largest k elements in linearithmic time.
It’s not hard to improve this to something like \mathcal{O}(N^2 \log N), since for a fixed i, when moving from k to k+1 you add only one more element to the sum (which is easily found in constant time if the corresponding prefix and suffix are kept sorted).
This is still too slow, but it’s a reasonable start.


As we noted in the previous section, the only useful swaps are ones that swap elements of [W_1, \ldots, W_{i+1}] with ones outside of that range.
So, finding the minimum number of swaps is equivalent to finding the maximum number of elements that aren’t swapped.

Let’s redefine \text{opt}_i to be the maximum number of elements of [W_1, \ldots, W_{i+1}] that don’t need to be swapped, while the sum can still be made to exceed M.
This redefinition gives arise to a rather nice property: the array \text{opt} is non-decreasing!
That is, \text{opt}_i \leq \text{opt}_{i+1} for every i.
This should be fairly easy to see: if you can not swap x elements among the first i elements and still have a sum that’s \gt M, you can also not swap x elements among the first i+1 elements and achieve the same thing (indeed, you can simply perform exactly the same set of swaps).

So, suppose we’ve computed \text{opt}_i.
Then, when computing \text{opt}_{i+1}, we don’t need to try every possibility: we can instead start with \text{opt}_{i+1} := \text{opt}_i, and then keep increasing it as necessary!
Since \text{opt}_N \leq N, the total number of increases across all indices is bounded by N.

The only remaining part now is actually checking for those increases quickly.
That is, given i and k, we need to be able to quickly find the sum of the largest k elements among the prefix of length i+1, and the smallest k elements of the suffix starting from i+2.
This is a data structure problem, and can be tackled in several ways - though we again exploit the fact that the process is incremental, i.e, k changes by 1 at a time.

Solution 1 (sets)

Store a few sets of elements:

  • S_1, which is the set of largest \text{opt}_i elements of the prefix till i+1.
  • S_2, which is everything else in this prefix.
  • S_3, which is the set of smallest \text{opt}_i elements of the suffix from i+2.
  • S_4, which is everything else in the suffix.

Also store the sums of S_1 and S_3, since those are what we care about.

When moving from i to i+1,

  • W_{i+2} is currently present in either S_3 or S_4.
    Remove it from the corresponding set, and make sure to rebalance to ensure that S_3 still contains \text{opt}_i elements.
  • W_{i+2} must now enter either S_1 or S_2.
    Again, insert it into the appropriate set, and rebalance to ensure that S_1 has \text{opt}_i elements.
    The easiest way to do this is to insert it into S_1, then move the largest element of S_1 to S_2.

When trying to increment \text{opt}_{i+1} now, you only need to move one element each from S_2 and S_4 to S_1 and S_3.
Since we’re dealing with sets, finding the smallest/largest element, insertion, and deletion are all done in \mathcal{O}(\log N) time.

As noted earlier, we have at most N increments to process, each of which costs a constant number of set operations.
We also perform an additional constant number of set operations when moving i to i+1.
So the overall complexity is \mathcal{O}(N\log N).

Note that elements can be duplicates, so you’ll need a multiset rather than a set (or work with pairs of (value, index)).
This can also be done using priority queues with lazy deletion (since they don’t really support arbitrary deletion).

The author’s code implements this.

Solution 2 (segment trees)

The idea here is basically the same as the set idea above, just that we use segment trees to simulate a multiset instead.
Build two segment trees: one for the prefix and one for the suffix.
The segment trees should be built on value, i.e, \text{seg}[x] should denote the number of occurrences of x.
Moving an element from the suffix to the prefix segment tree corresponds to one point subtraction and one point addition, both of which can be handled.
Finding the sum of the k largest/smallest elements in a segment tree can be done in \mathcal{O}(\log^2 N) with binary search, or optimized to \mathcal{O}(\log N) using a segtree walk.

Note that if you don’t coordinate compress, the complexity is \mathcal{O}(\log\max A) per operation instead, which is marginally worse though still fast enough to get AC.

The editorialist’s code implements this.

TIME COMPLEXITY:

\mathcal{O}(N\log N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define ll long long

int main(){
	ios_base::sync_with_stdio(false);
  cin.tie(NULL);
  // freopen("input.txt","r",stdin);freopen("output.txt","w",stdout);
  ll kitne_cases_hain;
  kitne_cases_hain=1;
  cin>>kitne_cases_hain;
  while(kitne_cases_hain--){
    ll n,H,S,T;
    cin>>n>>H>>S>>T;
    ll a[n];
    ll sum=0;
    ll ans=H*n;
    pair<ll,ll> p;
    ll x,y;
    multiset <ll> pinc;
    multiset <ll> prem;
    set <pair<ll,ll>> sinc;
    priority_queue <pair<ll,ll>> srem;
    ll total=0;
    for(int i=0;i<n;i++){
      cin>>a[i];
      srem.push({a[i],i+1});
    }
    for(int i=1;i<=n;i++){
      sum+=a[i-1];
      if(sum>T){
        ans=min(ans,(i-1)*H);
        break;
      }
      if(sinc.find({a[i-1],i})!=sinc.end()){
        total-=a[i-1];
        total+=(*pinc.rbegin());
        prem.insert(*pinc.rbegin());
        prem.insert(a[i-1]);
        pinc.erase(pinc.find(*pinc.rbegin()));
        sinc.erase(sinc.find({a[i-1],i}));
      }else{
        if(pinc.size() && (*pinc.rbegin())>a[i-1]){
          total+=(*pinc.rbegin());
          total-=a[i-1];
          prem.insert((*pinc.rbegin()));
          pinc.insert(a[i-1]);
          pinc.erase(pinc.find(*pinc.rbegin()));
        }else{
          prem.insert(a[i-1]);
        }
      }
      while(sum+total>T){
        p=*sinc.begin();x=p.first;
        y=*pinc.rbegin();
        total-=(x-y);
        srem.push(p);
        prem.insert(y);
        sinc.erase(sinc.begin());
        pinc.erase(pinc.find(y));
      }
      while((sum+total)<=T && srem.size() && prem.size()){
        p=srem.top();
        if(p.second<=i){
          srem.pop();
          continue;
        }
        y=*prem.begin();
        x=p.first;
        if(x>y){
          total+=(x-y);
          srem.pop();
          prem.erase(prem.begin());
          sinc.insert(p);
          pinc.insert(y);
        }else{
          break;
        }
      }
      x=sinc.size();
      if((sum+total)>T){
        ans=min(ans,(i-1)*H+x*S);
      }
    }
    cout<<ans<<"\n";
  }
	return 0;
}  

Tester's code (C++)
#include<bits/stdc++.h>

using namespace std;

#define all(x) begin(x), end(x)
#define sz(x) static_cast<int>((x).size())
#define int long long
typedef long long ll;


signed main() {

        ios::sync_with_stdio(0);
        cin.tie(0);

        int t;
        cin >> t;

        while (t--) {

                int n, h, s, m;
                cin >> n >> h >> s >> m;
                int a[n];
                for (auto &x : a) cin >> x;
                int sm = accumulate(a, a + n, 0ll);
                if (sm <= m) {
                        cout << n * h << "\n";
                        continue;
                }

                multiset<int> s1, s2, s3, s4;
                for (auto u : a) s1.insert(u);
                int cur = sm;
                int swp = 0;
                int ans = (n - 1) * h;
                
                for (int i = n - 1; i > 0; i--) {
                        int x = a[i];
                        if (s2.find(x) != s2.end()) {
                                s2.erase(s2.find(x));
                                swp--;
                                int y = *s4.begin();
                                cur -= y;
                                s4.erase(s4.begin());
                                s3.insert(y);
                                s3.insert(x);
                        }
                        else {
                                s1.erase(s1.find(x));
                                if (sz(s4) && *s4.begin() < x) {
                                        int y = *s4.begin();
                                        s4.erase(s4.begin());
                                        s3.insert(y);
                                        s4.insert(x);
                                        cur -= y;
                                }
                                else {
                                        s3.insert(x);
                                        cur -= x;
                                }
                        }
                        while (cur <= m && sz(s1) && sz(s3)) {
                                int x1 = *s1.begin();
                                int x2 = *s3.rbegin();
                                s1.erase(s1.find(x1));
                                s3.erase(s3.find(x2));
                                cur += x2 - x1;
                                swp++;
                                s2.insert(x1);
                                s4.insert(x2);
                        }
                        if (cur <= m) break;
                        ans = min(ans, (i - 1) * h + swp * s);
                }

                cout << ans << "\n";

        }
        
        
        
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

struct Node {
    using T = pair<ll, int>;
    T unit = {0, 0};
    T f(T a, T b) {
        return T{
            a.first + b.first,
            a.second + b.second
        };
    }
 
    Node *l = 0, *r = 0;
    int lo, hi;
    T val = unit;
    Node(int _lo,int _hi):lo(_lo),hi(_hi){}
    T query(int L, int R) {
        if (R <= lo || hi <= L) return unit;
        if (L <= lo && hi <= R) return val;
        push();
        return f(l->query(L, R), r->query(L, R));
    }
    void upd(int pos, int type) {
        if (pos >= hi or pos < lo) return;
        if (lo+1 == hi) {
            val.first += type*lo;
            val.second += type;
            return;
        }
        push();
        l->upd(pos, type), r->upd(pos, type);
        val = f(l->val, r->val);
    }
    int getkth(int k) {
        if (lo+1 == hi) return lo;
        push();
        if (r->val.second >= k) return r->getkth(k);
        else return l->getkth(k - r->val.second);
    }
    ll getksum(int k) {
        if (lo+1 == hi) return 1ll*lo*k;
        push();
        if (r->val.second >= k) return r->getksum(k);
        else return r->val.first + l->getksum(k - r->val.second);
    }
    void push() {
        if (!l) {
            int mid = lo + (hi - lo)/2;
            l = new Node(lo, mid); r = new Node(mid, hi);
        }
    }
};

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    const int lim = 1e9 + 10;

    int t; cin >> t;
    while (t--) {
        int n, h, s, m; cin >> n >> h >> s >> m;
        vector<int> a(n);
        for (int &x : a) cin >> x;

        Node *unused = new Node(0, lim);
        Node *used = new Node(0, lim);
        for (int i = 0; i < n; ++i) unused -> upd(a[i], 1);

        auto b = a;
        sort(rbegin(b), rend(b));
        ll ans = 1ll*n*h, mx = 0, opt = 0, pmin = lim, optsum = 0;
        for (int i = 0; i < n; ++i) {
            used -> upd(a[i], 1);
            mx += b[i];
            if (mx <= m) continue;
            
            ll cur = 1ll*i*h;
            // Now, keep as many elements at indices 0...i as possible
            // Clearly best to keep the larger ones among them
            
            if (a[i] > pmin) {
                optsum += a[i] - pmin;
                used -> upd(pmin, 1);
                used -> upd(a[i], -1);
                unused -> upd(a[i], -1);
            }

            while (opt <= i) {
                // Try opt+1
                int what = used -> getkth(1);
                unused -> upd(what, -1);
                ll outsum = unused -> getksum(i+1-(opt+1));

                if (optsum + outsum + what > m) {
                    used -> upd(what, -1);
                    optsum += what;
                    ++opt;
                }
                else {
                    unused -> upd(what, 1);
                    break;
                }
            }

            ans = min(ans, cur + (i+1-opt)*s);
        }

        cout << ans << '\n';
    }
}