ASSIGNTASKS - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author:
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

Easy - Medium

PREREQUISITES:

Binary search, sets

PROBLEM:

There are N tasks. The i-th task can be started at time S_i or later, and requires A_i time.
There are M workers. The i-th worker has a skill deficiency of F_i, meaning a task that normally takes x time will need F_i\cdot x time if done by him.

Let P_i denote the worker assigned to do the i-th task.
Across all non-decreasing arrays P, find the minimum possible time at which all tasks can be completed.
A worker can only work on one task at once.

EXPLANATION:

Let’s fix X, the maximum allowed finish time, and try to check whether it’s possible to distribute tasks to workers such that they’re all done by time X.

Since the assignments must be in non-decreasing order of workers, the basic idea should be fairly obvious: give as many tasks to worker 1 as possible, then as many to worker 2 as possible, then to 3, and so on.
If all the tasks are able to be assigned this way, finishing by X is possible; otherwise it isn’t.


Our algorithm for a fixed X is as follows:

  • Find the largest prefix of tasks that worker 1 can do.
  • Then find the largest prefix of the remaining tasks that worker 2 can do.
  • Then do the same for worker 3.
    \vdots

So, we need to be able to decide whether a worker can actually finish a certain segment of tasks by time X.

Suppose we’re considering worker 1, and tasks 1 to r.
The j-th task can be started no earlier than time S_j, and for this worker, will take F_1\cdot A_j time.

To check whether all the tasks can be completed in time, a straightforward greedy algorithm works here: the worker should work on tasks in ascending order of their start time. When finishing a task, immediately start work on the next available one - or just wait for the next starting time if all tasks so far have been completed.

This is easy enough to simulate in O(r\log r) time, but repeatedly doing this isn’t really efficient for our use case - we want to find the maximum valid r, and doing it in even linear time for each worker will lead to an algorithm that \mathcal{O}(N\cdot M) or worse, surely too slow.

Instead, we observe how things change when one more task is added.
That is, suppose we know that the worker can complete the first r-1 tasks, and try to see what happens when the r-th task is added.

The r-th task can be started at time S_r, and needs A_r\cdot F_1 time.
To fit this in, we need to find the first free instant at a time \geq S_r, say t, and then take up A_r\cdot F_1 units of time starting from time t - which can be thought of as placing the interval [t, t + F_1\cdot A_r].

If there are already existing intervals that intersect this one, we’ll then have to “move” them rightward to accommodate this, which seems like it’ll be slow since we might have to move potentially r-1 intervals.

However, a simple way to see this is simply that starting from time t, the next F_1\cdot A_r “free” timeslots will be occupied - after all, we don’t really care about exactly which job is done when, only which slots are free and which aren’t.

This allows us to quickly simulate the process by storing the free timeslots as intervals.
That is, start with the single interval [1, X), which is the entire range.
Then, for each start point S_r and length F_1\cdot A_r,

  1. Find the first free interval that contains a point \geq S_r.
  2. Repeatedly remove free segments to the first of this interval, till F_1\cdot A_r positions have been covered.
  3. Keep the set of free intervals updated appropriately.
  4. If we ever run out of space (i.e. have to go beyond X), it’s not possible for this worker to take on this task, so move to the next worker.

While the second part seems slow since we might have to delete several “free” intervals, it can be seen that we only ever perform \mathcal{O}(M) set operations in total, across all indices.
This is because every time we process a new task, some existing free intervals will be deleted, and at most two new free intervals can be created - one to break the interval containing S_r at S_r itself, and one for the last interval considered, which may not need to be deleted fully.

Apart from that, when resetting the set of intervals as we move to a new worker, we technically create one new interval of the form [1, X), while also deleting all existing free intervals.

Only created intervals can be deleted, and we only create 3M intervals in total - so the overall number of interval insertions/deletions amortizes to \mathcal{O}(M).

If the intervals are stored in an appropriate data structure (such as std::set in C++), insertion, deletion, and finding the next free interval can all be done in \mathcal{O}(\log N) time.


In \mathcal{O}((N+M)\log N), we’re able to decide whether all jobs can be completed by time X or not.

If all jobs can be completed by X, they can certainly be completed by X+1 (with the exact same assignment to workers).
This allows us to binary search on the smallest possible time X at which completion is possible, which is our answer.

The upper bound for the binary search can safely be set to a bit above 10^{11}, since in the absolute worst case, there’s only one worker with a deficiency of 10, and 10^5 tasks that all need 10^5 time to complete (so 10^6 for this worker) and can’t be started before time 10^5, leading to a finish time of about 10^5 + 10^6\cdot 10^5 = 10^{11} + 10^5.

10^{12} is pretty safe as an upper bound, as a result.

TIME COMPLEXITY:

\mathcal{O}((N + M)\log N\log{10^{12}}) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

struct segtree{
    struct node{
        int x = 0;
        int lz = 0;
 
        void apply(int l, int r, int y){
            x += y;
            lz += y;
        }
    };
 
    int n;
    vector <node> seg;
 
    node unite(node a, node b){
        node res;
        res.x = max(a.x, b.x);
        return res;
    }
 
    void push(int l, int r, int pos){
        if (l != r){
            int mid = (l + r) / 2;
            seg[pos * 2].apply(l, mid, seg[pos].lz);
            seg[pos * 2 + 1].apply(mid + 1, r, seg[pos].lz);
        }
        
        seg[pos].lz = 0;
    }
 
    void pull(int pos){
        seg[pos] = unite(seg[pos * 2], seg[pos * 2 + 1]);
    }
 
    void build(int l, int r, int pos){
        if (l == r){
            return;
        }
 
        int mid = (l + r) / 2;
        build(l, mid, pos * 2);
        build(mid + 1, r, pos * 2 + 1);
        pull(pos);
    }
 
    template<typename M>
    void build(int l, int r, int pos, vector<M> &v){
        if (l == r){
            seg[pos].apply(l, r, v[l]);
            return;
        }
 
        int mid = (l + r) / 2;
        build(l, mid, pos * 2, v);
        build(mid + 1, r, pos * 2 + 1, v);
        pull(pos);
    }
 
    node query(int l, int r, int pos, int ql, int qr){
        push(l, r, pos);
        if (l >= ql && r <= qr){
            return seg[pos];
        }
        
        int mid = (l + r) / 2;
        node res{};
        if (qr <= mid) res = query(l, mid, pos * 2, ql, qr);
        else if (ql > mid) res = query(mid + 1, r, pos * 2 + 1, ql, qr);
        else res = unite(query(l, mid, pos * 2, ql, qr), query(mid + 1, r, pos * 2 + 1, ql, qr));
        
        pull(pos);
        return res;
    }
 
    template <typename... M>
    void modify(int l, int r, int pos, int ql, int qr, M&... v){
        push(l, r, pos);
        if (l >= ql && r <= qr){
            seg[pos].apply(l, r, v...);
            return;
        }
 
        int mid = (l + r) / 2;
        if (ql <= mid) modify(l, mid, pos * 2, ql, qr, v...);
        if (qr > mid) modify(mid + 1, r, pos * 2 + 1, ql, qr, v...);
 
        pull(pos);
    }
 
    segtree (int _n){
        n = _n;
        seg.resize(4 * n + 1);
        build(1, n, 1);
    }
 
    template <typename M>
    segtree (int _n, vector<M> &v){
        n = _n;
        seg.resize(4 * n + 1);
        if (v.size() == n){
            v.insert(v.begin(), M());
        }
        build(1, n, 1, v);
    }
 
    node query(int l, int r){
        return query(1, n, 1, l, r);
    }
 
    node query(int x){
        return query(1, n, 1, x, x);
    }
 
    template <typename... M>
    void modify(int ql, int qr, M&...v){
        modify(1, n, 1, ql, qr, v...);
    }
    
    template <typename... M>
    void modify(int ql, M&...v){
        modify(1, n, 1, ql, ql, v...);
    }
};

void Solve() 
{
    // consider only things >= time t, then the ending time is at least t + things starting for >= t
    // lazy add + range max 
    
    int n, m; cin >> n >> m;
    // int n = 1e5;
    // int m = 1e5;
    
    vector <int> t(n), s(n), f(m);
    for (int i = 0; i < n; i++){
        cin >> t[i];
        // t[i] = 1 + RNG() % n;
    }
    for (int i = 0; i < n; i++){
        cin >> s[i];
        // s[i] = 1 + RNG() % n;
    }
    for (int i = 0; i < m; i++){
        cin >> f[i];
        // f[i] = 1 + RNG() % 10;
    }
    
    segtree seg(n);
    for (int i = 1; i <= n; i++){
        seg.modify(i, i, i);
    }
    
    auto check = [&](int x){
        int p = 0;
        
        for (int i = 0; i < m; i++){
            vector <pair<int, int>> add;
            int mx = 0;
            while (p < n){
                int ok = t[p] * f[i];
                seg.modify(1, s[p], ok);
                add.push_back({s[p], ok});
                mx = max(mx, s[p]);
                
                if (seg.query(1, mx).x > x){
                    break;
                } else {
                    p++;
                }
            }
            
            for (auto pi : add){
                int x = pi.first;
                int y = pi.second;
                y = -y;
                seg.modify(1, x, y);
            }
        }
        
        if (p == n){
            return true;
        } else {
            return false;
        }
    };
    
    int lo = 0, hi = 1e12;
    while (hi != lo){
        int mid = (lo + hi) / 2;
        
        if (check(mid)){
            hi = mid;
        } else {
            lo = mid + 1;
        }
    }
    
    cout << lo << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++)
#include<bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...)
#endif

#define int int64_t

#ifdef LOCAL
struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};
#else

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
      string X; cin >> X;
      return X;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res;  cin >> res;
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res;  cin >> res;
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
    }

    void readEoln() {
    }

    void readEof() {
    }
};
#endif


int32_t main() {
  ios_base::sync_with_stdio(0);
  cin.tie(0);

  input_checker inp;
  int T = inp.readInt(1, (int)1e4), NN = 0, MM = 0; inp.readEoln();
  while(T-- > 0) {
    int N = inp.readInt(1, (int)1e5); inp.readSpace();
    int M = inp.readInt(1, (int)1e5); inp.readEoln();
    NN += N, MM += M;

    vector<int> A = inp.readInts(N, 1, N);  inp.readEoln();
    vector<int> S = inp.readInts(N, 1, N);  inp.readEoln();
    vector<int> F = inp.readInts(M, 1, 10); inp.readEoln();

    set<pair<int, int>> os;
    auto insert = [&](int l, int size, int limit) -> bool {
      if(os.empty()) {
        if(limit < l + size - 1) {
          return false;
        }
        os.insert({l, l + size - 1});
        return true;
      }

      while(size > 0) {
        auto it = os.lower_bound(make_pair(l + 1, 0));
        if(it == os.begin()) {
          int add = it -> first - l;
          if(add > size) {
            os.insert({l, l + size - 1});
            return true;
          }
          size -= add;
          add += it -> second - it -> first;
          os.erase(os.begin());
          os.insert({l, l + add});
          continue;
        }
        auto pv = prev(it);
        if(it == os.end()) {
          auto st = max(pv -> second + 1, l);
          int en = st + size - 1;
          if(st == pv -> second + 1) {
            st = pv -> first;
            os.erase(pv);
          }
          size = 0;
          os.insert({st, en});
          break;
        }
        int st = max(pv -> second + 1, l);
        int add = min(size, it -> first - st);
        int en = st + add - 1;
        if(en == it -> first - 1) {
          en = it -> second;
          os.erase(it);
        }
        if(st == pv -> second + 1) {
          st = pv -> first;
          os.erase(pv);
        }
        os.insert({st, en});
        size -= add;
      }
      if(os.rbegin() -> second > limit) return false;
      return true;
    };

    auto solve = [&](int mid) -> bool {
      int p = 0;
      os.clear();
      for(int i = 0 ; i < N ; ++i) {
        if(p == M)  break;
        if(!insert(S[i], A[i] * F[p], mid)) {
          --i;
          ++p;
          os.clear();
          continue;
        }
      }
      return p == M;
    };

    // solve(12);
    // break;

    int64_t lo = 0, hi = 1e12;
    while(hi - lo > 1) {
      int64_t mid = (lo + hi) >> 1;
      if(solve(mid)) {
        lo = mid;
      } else {
        hi = mid;
      }
    }
    cout << 1 + hi << '\n';
  }
  assert(NN <= (int)1e5 && MM <= (int)1e5);
  inp.readEof();
  
  return 0;
}
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n, m; cin >> n >> m;

        vector<int> a(n), s(n), f(m);
        for (int &x : a) cin >> x;
        for (int &x : s) cin >> x;
        for (int &x : f) cin >> x;
        reverse(begin(f), end(f));

        ll lo = 0, hi = 1.5e11;
        while (lo < hi) {
            ll mid = (lo + hi) / 2;
            // Finish < mid

            set<array<ll, 2>> active = {{1, mid}}; // Segments are [l, r)
            int p = m - 1;
            for (int i = 0; p >= 0 and i < n; ++i) {
                vector<array<ll, 2>> to_insert;
                auto update = [&] () {
                    ll L = s[i], len = a[i] * f[p];
                    to_insert.clear();
                    
                    auto it = active.lower_bound({L+1, -1});
                    if (it != begin(active) and prev(it)->at(1) > L) --it;
                    
                    while (it != end(active) and len > 0) {
                        auto [l, r] = *it;
                        ll reduce = r - max(l, L);
                        reduce = min(reduce, len);

                        len -= reduce;
                        if (l < L) to_insert.push_back({l, L});
                        if (len == 0 and reduce < r - max(l, L)) to_insert.push_back({max(l, L) + reduce, r});
                        it = active.erase(it);
                    }
                    return len == 0;
                };

                while (p >= 0) {
                    bool done = update();
                    if (done) break;

                    active = {{1, mid}};
                    --p;
                }

                for (auto x : to_insert) active.insert(x);
            }

            if (p == -1) lo = mid + 1;
            else hi = mid;
        }
        cout << lo << '\n';
    }
}