RIP - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: khaab_2004
Tester: raysh07
Editorialist: iceknight1093

DIFFICULTY:

3201

PREREQUISITES:

Sorting, sets/priority queues

PROBLEM:

There are N houses, with strengths A_i and B_i people in the i-th one.
The reaper will visit these houses in some order, but there’s an array C that specifies that C_i must be visited before C_{i+1}.

The reaper’s ability equals the strength of the previous visited house, and is 0 initially.
If the ability is at least the current house’s safety, the reaper can reap the souls of all people in the house.

What’s the maximum number of souls that can be reaped?

EXPLANATION:

Let’s first solve a less restrictive version of the problem: suppose the array C didn’t exist (or rather, M = 0).

In this case, note that if we visit houses in descending order of their A_i values, souls from all but the first house can be reaped.
If there are multiple houses with equal maximum A_i values, choose to skip the one among them with minimum B_i.
Further, we always have to skip at least one house with maximum A_i no matter what order we go in, so the above solution is also optimal.

Now, let’s attempt to extend this to our more general version.
Clearly, it’d be best to visit houses in descending order; but now the C array constrains us and might make that impossible.

Note that if A_{C_i} \geq A_{C_{i+1}}, then we really don’t have a problem: C_{i+1} can be visited after C_i without any loss, as long as we keep going in descending order.
However, if A_{C_i} \lt A_{C_{i+1}}, we have to move from a lower safety to a higher one at some point, which requires us to skip a house.

Let’s call some C_i bad, if A_{C_{i-1}} \gt A_{C_{i}}.
In particular, C_1 is also bad.

Rather than figure out which houses we’ll be able to reap from, let’s figure out which houses we’ll skip optimally.

Each bad index C_i needs to be matched to something else that’ll be skipped in its place.
This ‘something else’ should also have a safety that’s \geq A_{C_i}, since it should be placed before C_i in the order.
Further, this ‘something else’ can be either C_i itself, or some index that’s not in C.

This leads us to the following greedy algorithm:

  • Consider only all the bad indices, and those not in C.
  • Sort them in descending order of A_i values.
    To break ties, do the following:
    • Place elements of C after elements that aren’t in C.
    • If there’s still a tie, break ties by increasing B_i.
  • Now, process these elements in this sorted order, using the following algorithm.
    • Maintain a (multi)set S, initially empty.
    • When you meet an index i, insert B_i into S.
    • Next, if you’re at a bad index, extract the minimum element from S; this will be one of the houses we skip.

Finally, notice that we have to skip at least one house with the maximum A_i value, just as in the M = 0 case.
So, if none of them have been skipped by our choices, skip the one among them with lowest B_i, and un-skip the highest B_i we chose (for minimum possible penalty).
S can be maintained as a multiset or a priority queue, for an \mathcal{O}(N\log N) algorithm.

Proof of correctness

As observed earlier, each ‘bad’ index must be matched to some other index with at least the same strength as it.
Let match[i] denote the index matched to i by our greedy solution.
Let opt[i] denote the index matched to i in an optimal solution.

Claim: There exists an optimal solution such that match[i] = opt[i] for all i.
Proof: Suppose this isn’t the case.
Pick the leftmost i (leftmost in terms of the order we sorted in) such that match[i] \neq opt[i].

Note that since match[j] = opt[j] for all j \lt i and match[i] was chosen greedily from the remaining, we will certainly have B[match[i]] \leq B[opt[i]].
Now, suppose B[match[i]] \lt B[opt[i]].

There are two cases:

  • First, suppose that there’s no k\gt i such that opt[k] = match[i].
    That is, match[i] is unused in opt.
    Then, note that we can instead set opt[i] \to match[i] and change nothing else, which gives a strictly lower sum; contradicting optimality.
    This is clearly impossible.
  • Otherwise, suppose opt[k] = match[i] for some k.
    Then, we have k \gt i, so simply swap opt[i] and opt[k].
    This doesn’t change the cost of the solution; hence it remains optimal and we have opt[i] = match[i].

Finally, if we instead had B[match[i]] = B[opt[i]], a similar argument is possible from the other side:

  • If opt[k] \neq match[i] for any k, set opt[i] = match[i] instead, which doesn’t change optimality of opt.
  • If opt[k] = match[i] for some k\gt i, swap opt[i] and opt[k].

TIME COMPLEXITY

\mathcal{O}(N\log N) per testcase.

CODE:

Setter's code (C++)
#ifndef LOCAL
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,avx2,sse,sse2,sse3,sse4,popcnt,fma")
#endif

#include <bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...) "11-111"
#endif

int32_t main() {
    ios_base::sync_with_stdio(0);   cin.tie(0);

    
    auto __solve_testcase = [&](int testcase) -> void {
        int n, m;  cin >> n >> m;
        vector<int> a(n), b(n), c(m), vis(n), compress(1);
        for(auto &i : a)    cin >> i, compress.push_back(i);
        for(auto &i : b)    cin >> i;
        for(auto &i : c)    cin >> i, vis[i - 1] = 1;

        sort(compress.begin(), compress.end());
        compress.resize(unique(compress.begin(), compress.end()) - compress.begin());
        for(auto &i : a)    i = lower_bound(compress.begin(), compress.end(), i) - compress.begin();
        int S = compress.size();
        vector<multiset<int>> upd(S + 1);

        for(int i = 0 ; i < n ; i++) if(!vis[i])
            upd[a[i]].insert(b[i]);

        vector<int> cnt(S + 1);
        int lst = 0;
        long long res = 0;

        for(int &i : c) {
            --i;
            if(a[i] > lst) {
                cnt[a[i]]++;
                upd[a[i]].insert(b[i]);
            } else {
                res += b[i];
            }
            lst = a[i];
        }

        multiset<int> ms;
        int s = 1, rem = 0;
        upd[S - 1].erase(upd[S - 1].begin());

        for(int i = S - 1 ; i >= 0 ; i--) {
            for(auto &u : upd[i])
                ms.insert(u);

            rem += cnt[i];
            while(s < rem) {
                ms.erase(ms.begin());   s++;
            }
        }

        cout << res + accumulate(ms.begin(), ms.end(), 0ll) << '\n';
    };

    // sum (N - L + 1) * L * 2 ^ (N - L + 1) / 2 ^ N
    
    int no_of_tests;   cin >> no_of_tests;
    for(int test_no = 1 ; test_no <= no_of_tests ; test_no++)
        __solve_testcase(test_no);
    

    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
#define f first
#define s second

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

struct input_checker {
    string buffer;
    int pos;
 
    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";
 
    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }
 
    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }
 
    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }
 
    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }
 
    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }
 
    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }
 
    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }
 
    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }
 
    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }
 
    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }
 
    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

input_checker inp;
int sum_n = 0;

void Solve() 
{
    int n = inp.readInt(1, (int)1e5);
    inp.readSpace();
    int m = inp.readInt(0, n);
    inp.readEoln();

    auto a = inp.readInts(n, 1, (int)1e5);
    inp.readEoln();
    auto b = inp.readInts(n, 0, (int)1e4);
    inp.readEoln();
    auto c = inp.readInts(m, 1, n);
    inp.readEoln();
    
    int mx = 0; for (auto x : a) mx = max(mx, x);

    vector <bool> mark(n, false);
    for (auto &x : c){
        x--;
        assert(!mark[x]);
        mark[x] = true;
    }

    int ans = accumulate(b.begin(), b.end(), 0);
    
    if (m == 0){
        int mn = INF;
        for (int i = 0; i < n; i++){
            if (a[i] == mx) mn = min(mn, b[i]);
        }
        
        cout << ans - mn << "\n";
        return;
    }

    vector <pair<int, int>> v;
    for (int i = 0; i < n; i++) if (!mark[i]) v.push_back({a[i], i});

    for (int i = 1; i < m; i++) if (a[c[i]] > a[c[i - 1]]) v.push_back({a[c[i]], c[i]});
    v.push_back({a[c[0]], c[0]});

    sort(v.begin(), v.end(), [&](pair<int, int> x1, pair<int, int> y1){
        int x = x1.second;
        int y = y1.second;
        if (a[x] != a[y]){
            return a[x] > a[y];
        }

        if (mark[x] && !mark[y]) return false;
        if (mark[y] && !mark[x]) return true;

        return b[x] < b[y];
    });

    int ptr = -1;
    for (auto &x : v) if (mark[x.second]){
        ptr = x.second;
        break;
    }
    
    // for (auto &x : v){
    //     cout << x.first << " " << x.second << "\n";
    // }

    if (a[ptr] != mx){
        mark[ptr] = false;
        reverse(v.begin(), v.end());
        for (auto &x : v){
            if (x.first == mx){
                mark[x.second] = true;
                break;
            }
        }
        reverse(v.begin(), v.end());
    }

    priority_queue <int> pq;
    for (auto x : v){
        pq.push(-b[x.second]);
        if (mark[x.second]){
            ans += pq.top();
            pq.pop();
        }
    }

    cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
   // cin >> t;
    t = inp.readInt(1, (int)1e5);
    inp.readEoln();

    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }

    inp.readEof();

    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}