STACKSTONE - EDITORIAL

PROBLEM LINK:

Contest
Practice

Setter: Raj Tiwari
Tester: Nishit Patel

DIFFICULTY

Hard

PROBLEM

Jayesh recently went on a trip to Japan and fell in love with the ancient Japanese art of Stone stacking/rock balancing. Insipired by it, he now wants to try it out himself. He bought n rocks each having a width Wi and cost Ci. While returning he met an ancient priest who told him the ancient trick of balaning the stones - ‘an arrangement with k stones will be balanced only if more than half of the stones in that arrangement are of the maximum width.’
Example:-
if the arrangement consists of 5 stones, minimum of 3 stones should be of the maximum width of this arrangement.
Similarly, if it has 4 stones, minimum of 3 stones should have the maximum width of this arrangement.
Similarly, if it has 2 stones, both the stones should have the maximum width of this arrangement.

Jayesh realizes that he will have to throw all the other stones which are not a part of the arrangement.
Jayesh feels sad when his money gets wasted. Help Jayesh to make a stable arrangement with the stones such that the cost of the thrown stones is minimum.

EXPLANATION

For a given arrangement, it can be balanced only if more than half of the stones in that arrangement are of the maximum width. Hence, we can sort the stones based on their increasing width. We can then iterate over each width and try to make an arrangement such that the current width is the one with the maximum width of the arrangement.
If there are k stones with a width equal to the current width, then we can select at most k-1 stones of smaller width, while the other remaining stones with smaller width and also the ones with width greater than the current width will be wasted. Inorder to minimize the final cost,we select the k-1 stones such that their cost is maximum. Since the cost of the stones has a range of 0 to 500, we can initialize an array of this length with values set to 0 in the beginning. Then we go on incrementing the value by 1 each time we encounter a stone with the corresponding cost. Thus for any width with k stones, we can get the k-1 stones with smaller widths having maximum cost by tranversing this array.
For each width, we compare the current cost with the global minimum and update the global minimum cost in case the current cost is smaller, thus getting the minimum cost in the end.

TIME COMPLEXITY

The time complexity is O(n*500).

SOLUTION

Setter's Solution
//Author : Raj Tiwari

#include <bits/stdc++.h>
using namespace std;

typedef vector<int> vi;
typedef vector<long long> vl;
typedef pair<int, int> pii;
#define endl "\n"
#define debug(val) printf("check%d\n", val)
#define all(v) v.begin(), v.end()
#define pb push_back
#define mp make_pair
#define FF first
#define SS second
#define ll long long
#define ull unsigned long long
#define FOR(i, j, k, in) for (int i = j; i < k; i += in)
#define forr(k) for (int i = 0; i < k; i += 1)
#define forrr(l) for (int j = 0; j < l; j += 1)
#define For(j, k) for (int i = j; i < k; i += 1)
#define MOD 1000000007
#define clr(val) memset(val, 0, sizeof(val))
#define what_is(x) cerr << #x << " is " << x << endl;
#define OJ                             \
    freopen("test_5.txt", "r", stdin); \
    freopen("temp.txt", "w", stdout);
#define FIO                           \
    ios_base::sync_with_stdio(false); \
    cin.tie(NULL);                    \
    cout.tie(NULL);

int main()
{
    FIO;
    <!-- OJ; -->
    ll t;
    cin >> t;
    while (t--)
    {
        ll n;
        cin >> n;
        vector<pair<ll, ll>> p(n);
        ll a, b;
        for (ll i = 0; i < n; i++)
        {
            cin >> p[i].first;
        }

        for (ll i = 0; i < n; i++)
        {
            cin >> p[i].second;
        }
        sort(all(p));
        ll ans = 1e9;

        vector<ll> suff(n);
        ll curr = 0;

        for (ll i = n - 1; i >= 0; i--)
        {
            suff[i] = curr;
            curr += p[i].second;
        }
        vector<ll> cost(2001, 0);

        ll dup = 1;
        ll till_now = 0;

        for (ll i = 0; i < n; i++)
        {
            if (p[i].first != p[i + 1].first || i == n - 1)
            {
                ll total = till_now + suff[i];
                ll count = 1;
                ll can_add = dup - 1;

                for (int k = cost.size() - 1; k > 0 && can_add > 0; k--)
                {
                    if (cost[k] && can_add > 0)
                    {
                        ll c = min(can_add, cost[k]);
                        total -= (k * c);
                        can_add -= c;
                    }
                }

                for (ll j = i; j >= 0 && j > (i - dup); j--)
                {
                    till_now += p[j].second;
                    cost[p[j].second]++;
                }
                ll curr_min = total;
                ans = min(ans, curr_min);
                dup = 1;
            }
            else
            {
                dup++;
            }
        }
        cout << ans << endl;
    }

    return 0;
}
Setter's Solution
//Author : Nishit Patel

#include <bits/stdc++.h>
#include <chrono>
using namespace std;
typedef long double ld;
typedef double db;
typedef long long ll;
#define pb push_back
#define FAST cin.sync_with_stdio(0); cin.tie(0);
#define rep(i, n)      for(int i = 0; i < (n); ++i)
#define arep(i, a, n)  for(int i = a; i <= (n); ++i)
#define drep(i, a, n)  for(int i = a; i >= (n); --i)
#define trav(a, x)     for(auto& a : x)
#define all(x) x.begin(), x.end()
typedef pair<int, int> pii;
typedef vector<ll> vl;
typedef map<ll, ll> ml;
typedef unordered_map<ll, ll> hash;
#define deb_1(A)                     {cerr << "[" << #A << "] : " << A << endl;}
#define deb_2(A,B)                   {cerr << "[" << #A << "] : " << A << endl<< "[" << #B << "] : " << B << endl;}
#define deb_3(A,B,C)                 {cerr << "[" << #A << "] : " << A << endl<< "[" << #B << "] : " << B << endl\
                                    << "[" << #C << "] : " << C << endl;}
#define deb_4(A,B,C,D)               {cerr << "[" << #A << "] : " << A << endl<< "[" << #B << "] : " << B << endl\
                                    << "[" << #C << "] : " << C << endl<< "[" << #D << "] : " << D << endl;}
#define deb_X(x,A,B,C,D,FUNC, ...)  FUNC  
#define deb(...)                    deb_X(,##__VA_ARGS__,\
                                        deb_4(__VA_ARGS__),\
                                        deb_3(__VA_ARGS__),\
                                        deb_2(__VA_ARGS__),\
                                        deb_1(__VA_ARGS__),\
                                        )
#define M 100000
using namespace std::chrono;


int main() {
    FAST
    int t;
    cin>>t;
    while(t--){
        ll n, h=INT_MIN, ans=0, total_sum=0;
        cin>>n;
        vl c(n), w(n);
        ll arr[501] = {0}, mp[100005] = {0}, mp2[100005] = {0};
        vector<pii> v;
        rep (i, n)
        {
            cin>>w[i];
            if (mp2[w[i]])
                mp2[w[i]]++;
            else
                mp2[w[i]] = 1;
        }
        rep (i, n)
        {
            cin>>c[i];
            if (mp[w[i]])
                mp[w[i]]+=c[i];
            else
                mp[w[i]] = c[i];
            total_sum += c[i];
        }
        rep (i, n)
        {  
            v.push_back({w[i], c[i]});
        }
        sort(v.begin(), v.end());
        for (int i=0; i < v.size()-1; i++)
        {
            arr[v[i].second]++;
            ll sum=0;
            if (v[i].first != v[i+1].first)
            {
                ll ctr = mp2[v[i+1].first]-1;
                for (int j = 500; j > 0; j--)
                {
                    if (ctr <= 0)
                    {
                        break;
                    }
                    else if (arr[j] > 0 && arr[j] <= ctr) {
                        sum += j*arr[j];
                        ctr -= arr[j];
                    }
                    else if (arr[j] > 0 && arr[j] >= ctr)
                    {
                        sum += j*ctr;
                        ctr = 0;
                    }
                }
                h = max(h, mp[v[i+1].first]+ sum);
                ctr = 0;
            }
        }
        ans = total_sum - h; 
        cout<<ans<<endl;

    }
    
}