ADVITIYA10 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: mehul_g2874
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Sets and multisets

PROBLEM:

There are N cities.
These cities will be destroyed in some order. After each change, find the minimum time needed to rebuild the electricity supply of all remaining cities, if your options are as follows:

  • Build a power station in city i, needing A_i days.
  • Connect city i to city j (where j already has electricity), needing (i+j) days.

EXPLANATION:

Suppose we have some subset of “alive” cities. Let’s try to compute its minimum cost.
We’ll call a city electrified if it has an electricity supply.

Let m denote the label of the minimum city we have.
Then, for any other city i such that A_i \leq i+m (meaning A_i - i \leq m), it’s always optimal to just build a power station at city i.

Now, look at the remaining cities, which all satisfy A_i \gt i+m.
Ideally, we’d like to electrify them by connecting them all to city m. However, that requires us to electrify city m in the first place, so let’s figure out how to do that.
We have a few options.

  1. Build a power plant directly at city m, needing A_m days.
  2. Join m to some existing city. For that, once again there are two options:
    • Recall that we already have several power stations built (everything for which A_i-i \leq m).
      One option is to join m to the minimum labelled city among these cities; say m_2.
      If no power station has been build, we say m_2 = \infty
    • Alternately, we must electrify some other city (or cities), and then use it to connect to m.

The first couple of options are straightforward enough, only the last needs some care: which new cities should be electrified?
It turns out that if we do choose this option, it’s always optimal to electrify exactly one new city!

Proof

Suppose x is the city to which m is connected; for a cost of (m+x).
Let’s look at how x was electrified. There are three options:

  • First, a power station was built at x, using A_x days.
    In this case, clearly there’s no need to electrify anything else: it’s best to electrify x, then electrify m using it, then use m for everything else.
  • Second, x could’ve been connected to one of the cities with A_i - i \leq m.
    In this case, the optimal choice is of course to use m_2, with the total cost being (m_2 + x) + (x + m) days.
    However, we could’ve instead used m_2 to electrify m, then used it to electrify x, needing (m_2 + m) + (m+ x) days, which is strictly better.
    So, this case is never optimal.
  • Finally, there might be some other y that was electrified before x, and used to electrify x.
    Once again, we can instead skip x and electrify m using y directly for lower cost.

So, the only possibility for minimum number of days is that some x is chosen, a power station is built there for A_x days, then m is connected to it in (x+m) days.

All this is easy enough to compute in \mathcal{O}(N) time, since we only need to know:

  • The minimum label, m.
  • The sum of all A_i such that A_i - i \leq m.
  • The minimum label m_2 across all i such that A_i - i \leq m.
  • The minimum value of (A_i + i) across all i such that A_i - i \gt m.

Finally, we need to deal with deletions, and maintaining these quantities across deletions.
This is not especially hard - observe that the minimum label m only increases as cities get deleted.
Let’s partition the existing cities into two sets: S_1, which consists of all cities (other than m) for which A_i-i \leq m, and S_2 which consists of everything else (but still not m).
When city x is deleted, the changes are as follows:

  • If x \neq m, then x belongs to either S_1 or S_2; and just gets deleted from this set. Nothing else changes.
  • If x = m, the minimum value is updated.
    This causes some elements to move from S_2 to S_1, because their A_i - i values might be \leq the new minimum.

Note that in the latter case, while \mathcal{O}(N) elements can move in a single step, every element can move at most once (since there’s no movement from S_1 into S_2), so the total number of movements is bounded by N.
All of this is easily simulated with a structure that supports quick insertion/deletion and finding the minimum element, such as std::set in C++.

Of course, you need to keep certain quantities corresponding to S_1 and S_2 (the sum of A_i of elements in S_1 and the minimum label in it, and the minimum of (A_i + i) of elements in S_2) but that’s also easy to maintain when updating, and boils down to some bookkeeping with sets/multisets.

TIME COMPLEXITY:

\mathcal{O}(N\log N) per testcase.

CODE:

Author's code (C++)
#ifndef LOCAL
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,avx2,sse,sse2,sse3,sse4,popcnt,fma")
#endif

#include <bits/stdc++.h>
using namespace std;

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && !isspace(buffer[now])) {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int32_t main() {
    ios_base::sync_with_stdio(0);   cin.tie(0);

    input_checker input;

    int T = input.readInt(1, (int)1e5);  input.readEoln();
    int NN = 0;
    while(T-- > 0) {
        int N = input.readInt(1, (int)2e5);  input.readEoln();
        NN += N;
        vector<int> A = input.readInts(N, 1, (int)1e9);     input.readEoln();
        vector<int> B = input.readInts(N, 1, N);    input.readEoln();

        set<pair<int, int>> above, below;
        set<int> abi, bbi;
        multiset<int> aba, bba;

        reverse(B.begin(), B.end());
        int x = N + 1;
        int64_t suma = 0, sumb = 0, sumai = 0, sumbi = 0;
        vector<bool> vis(N);
        vector<int64_t> sol;
        for(auto &i: B) {
            --i;
            assert(i >= 0 && i < N && !vis[i]);
            vis[i] = 1;

            below.insert({A[i] - i - 1, i});
            bbi.insert(i);
            bba.insert(A[i]);
            sumb += A[i];
            x = min(x, i + 1);
            while(!below.empty() && below.rbegin() -> first > x) {
                int ind = below.rbegin() -> second;
                above.insert({A[ind] - ind - 1, ind});
                abi.insert(ind);
                aba.insert(A[ind]);
                sumai += ind + 1;

                sumb -= A[ind];
                bba.erase(bba.find(A[ind]));
                bbi.erase(ind);
                below.erase({A[ind] - ind - 1, ind});
            }
            int64_t here = sumb + sumai + x * (int64_t)above.size();
            here -= min(2 * x, A[x - 1]);
            int64_t res = here + A[x - 1];
            if(!bbi.empty()) // x + i > T[i]
                res = min(res, here + x + *bbi.begin() + 1);
            if(!aba.empty()) // x + i <= T[i]
                res = min(res, here + *aba.begin());
            sol.push_back(res);
        }
        reverse(sol.begin(), sol.end());

        for(int i = 0 ; i < N ; ++i)
            cout << sol[i] << " \n"[i == N - 1];
    }
    assert(NN <= (int)2e5);
    input.readEof();

    return 0;
}

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

void Solve() 
{
    int n; cin >> n;
    
    vector <int> ord(n + 1), a(n + 1);
    for (int i = 1; i <= n; i++) cin >> a[i];
    for (int i = 1; i <= n; i++) cin >> ord[i];
    
    vector<vector<int>> adj(n + 1);
    int S = 0;
    multiset <int> s1, s2; 
    set <int> alive;
    for (int i = 1; i <= n; i++) alive.insert(i);
    
    for (int i = 2; i <= n; i++){
        if (i + 1 >= a[i]){
            S += a[i];
            s2.insert(i);
        } else {
            S += i;
            s1.insert(a[i]);
            adj[a[i] - i].push_back(i);
        }
    }
    
    for (int i = 1; i <= n; i++){
        int x = ord[i];
        // just calculate answer 
        int m = *alive.begin();
        int sum = S + m * s1.size();
        int ans = sum + a[m];
        if (s1.size()) ans = min(ans, sum + *s1.begin());
        if (s2.size()) ans = min(ans, sum + m + *s2.begin());
        
        cout << ans << " ";
        
        if (x == m){
            alive.erase(x);
            
            if (!alive.size()) continue;
            int nm = *alive.begin();
            
            if (nm + m < a[nm]){
                S -= nm;
                s1.erase(s1.find(a[nm]));
            } else {
                S -= a[nm];
                s2.erase(s2.find(nm));
            }
            
            // remove a lot of people 
            for (int i = m + 1; i <= nm; i++){
                // remove adj[i] guys 
                for (auto x : adj[i]){
                    if (alive.find(x) == alive.end() || x == nm) continue;
                    S -= x;
                    s1.erase(s1.find(a[x]));
                    S += a[x];
                    s2.insert(x);
                }
            }
        } else {
            alive.erase(x);
            if (x + m < a[x]){
                S -= x;
                s1.erase(s1.find(a[x]));
            } else {
                S -= a[x];
                s2.erase(s2.find(x));
            }
        }
    }
    cout << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}