GAMEPROF - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3

Author: Nguyen Ha Duy
Testers: Shubham Anand Jain, Aryan
Editorialist: Nishank Suresh

DIFFICULTY:

Easy-medium

PREREQUISITES:

Segment Tree

PROBLEM:

You are playing a game, where entering at time l and exiting at time r costs k\cdot (r-l) coins. You can enter and exit the game at most once.
There are n items in the game, the i th of them is valued at v_i coins, and can only be collected if l\leq x_i\leq y_i\leq r.
What is the maximum profit you can make?

QUICK EXPLANATION:

  • It is optimal to enter only at some x_i, and exit at some y_j.
  • Consider all possible y_j in increasing order, and for each of them compute the maximum profit. The final answer is the maximum over all of these.
  • Calculating the maximum profit can be done with any data structure supporting range max query and range sum updates, for example a segment tree.

EXPLANATION:

First, note that it is optimal to exit only at some y_i - staying any longer would incur a cost of k coins for each unit of time, which is not optimal.
The same applies to entering - it is optimal to do so at some x_i.

Subtask 1
n \leq 20, so an \mathcal{O}(n\cdot 2^n) bruteforce over all possible subsets of items to choose will work fast enough.
Given a fixed non-empty set of items (x_{i_1}, y_{i_1}, v_{i_1}), \dotsc, (x_{i_r}, y_{i_r}, v_{i_r}), let L = min(x_{i_1}, x_{i_2}, \dotsc, x_{i_r}) and R = max(y_{i_1}, y_{i_2}, \dotsc, y_{i_r}). The maximum profit obtained while getting all of them is v_{i_1} + v_{i_2} + \dotsc + v_{i_r} - k\cdot (R-L).
Compute this value over all possible subsets and print the maximum of this and 0 (where we enter and exit at the same point and pick up nothing)

Subtask 2
Suppose, for the moment, that all the y_i are distinct.
Let us sort the items by y.
Now, for each 1\leq i\leq n,
Suppose we want to exit at y_i. What’s the maximum profit possible?
Given that the items are sorted by increasing y, exiting at y_i means that it’s impossible to pick up any of items i+1, i+2, \cdots, n, so we only need to worry about items 1, 2, \cdots, i.
Now suppose we enter at time x. What’s the maximum profit?

  • The cost of entering and exiting is k\cdot(y_i-x)
  • For each 1\leq j\leq i, we can add v_j to the profit if x_j \geq x (because we know that y_j \leq y_i).

But, as mentioned before, it’s enough to consider x to be one of the x_j where j\leq i, because entering at any other time cannot be optimal.
This gives an \mathcal{O}(n^3) solution: for each 1\leq i\leq n, iterate over 1\leq j\leq i and find out in \mathcal{O}(n) the maximum profit if we enter at x_j and leave at y_i. However, this is not fast enough to pass this subtask yet.

One can observe that this solution can be improved to O(n^2) by noting that, when we fix y_i, it’s enough to traverse the x 's we have, in descending order. Doing that allows us to keep a running sum of the v_j which are possible to get, without needing a third loop to do so.
This can be implemented by, for example, keeping a sorted (multi)set of pairs (x_i, v_i). Iterating over this in descending order is easy, and insertion is \mathcal{O}(logN), leading to an overall O(n^2) solution.

How to deal with the fact that the y_i might not be distinct? Simply ignore it!
Even if they’re not all distinct, the above process will give the correct answer, simply because the maximum profit possible when exiting at any y will be calculated correctly when the last item with y_i = y is being processed.

Subtask 3
Quadratic is too slow here. Let’s look at what we’re actually computing.
The profit for entering at x and leaving at y_i is given by -k\cdot (y_i-x) + \displaystyle\sum_{\substack{1\leq j\leq i \\ x_j \geq x}} v_j
Rewrite this as -k*y_i + k*x + \displaystyle\sum_{\substack{1\leq j\leq i \\ x_j \geq x}} v_j.
k*y_i is a constant depending only on y_i, so we can ignore it for now.
What about the rest? It depends purely on x, so if we were able to update this value for each x after processing an item, maybe we could do something to get the maximum of them all.
And that’s exactly what we will do!

Suppose we had a really large array A, indexed by the integers -10^9 to 10^9.
Initially, let A[i] = k*i.
Whenever we process an item (x, y, v), do a range update and add v to the range [-10^9, x]
Now when we fix y_i and starting point x, what’s the cost of entering at x? It is exactly A[x] - k*y_i
So, our problem reduces to finding the maximum A[x] where x\leq y_i (equivalently, the max on the range [-10^9, y_i]), and being able to add a value to a range.
A segment tree with lazy propagation supports both these operations in logarithmic time, so that is our data structure of choice.
(Don’t know how to perform these operations using a segment tree? Codeforces EDU is for you)

We are left with only one issue now: the initial array. It is way too big for us to explicitly create it, even if all subsequent operations on it are fast.
How to overcome this?

Editorialist's Approach

The editorialist used a dynamic segment tree/implicit segment tree (click for more details), i.e, a segment tree where the only nodes created are the ones used in some query or update.
Since the only points we care about querying are the x_i, create an implicit segment tree with the values of all points set to -\infty. Then, for all x_i, set the value at x_i to be k\cdot x_i.
After this, update and query the tree as mentioned above, and the problem is done.

Setter's Approach

The setter chose a different approach, and instead used coordinate compression. Map the smallest x_i to 1, the second smallest to 2, and so on. This maps all relevant points to [1, 10^5], and now we can build a normal segment tree on this array and update/query it as necessary. Finding the range to query/update requires binary search on the compressed coordinates.

Both methods use the fact that we only care about \mathcal{O}(N) points.

TIME COMPLEXITY

\mathcal{O}(N\log{M}) or \mathcal{O}(N\log{N}) depending on implementation, where M = 2*10^9

SOLUTIONS:

Setter's Solution
#include<bits/stdc++.h>
using namespace std;
 
const long long INF = 2e18;
 
const int maxN = 1e5 + 69;
 
struct items {
    int x, y, v;
    bool operator < (const items& oth) const {
        return y < oth.y;
    }
};
 
struct SegmentTree {
    int n;
    vector<long long> st;
    vector<long long> lazy;
 
    SegmentTree(int n) : n(n), st(n * 4, -INF), lazy(n * 4) {}
    SegmentTree() {}
 
    void change(long long v, int id) {
        st[id] += v;
        lazy[id] += v;
    }
 
    void push(int id) {
        change(lazy[id], id * 2);
        change(lazy[id], id * 2 + 1);
 
        lazy[id] = 0;
    }
 
    void update(int L, int R, long long v, int id, int l, int r) {
        if (R < l || r < L) return;
 
        if (L <= l && r <= R) {
            change(v, id);
            return;
        }
 
        push(id);
 
        int mid = (l + r) / 2;
        update(L, R, v, id * 2, l, mid);
        update(L, R, v, id * 2 + 1, mid + 1, r);
 
        st[id] = max(st[id * 2], st[id * 2 + 1]);
    }
 
    void update(int L, int R, long long v) {
        update(L, R, v, 1, 0, n);
    }
};
 
int n, k;
int m;
items a[maxN];
vector<int> vals;
SegmentTree st;
 
int main() {
    #define filename "BAI6"
              //     freopen("0.in.txt","r",stdin);
//  std::ofstream outfile("output5.009", std::ios_base::binary | std::ios_base::out );
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
 
    cin >> n >> k;
    for (int i = 1; i <= n; i++) {
        cin >> a[i].x >> a[i].y >> a[i].v;
 
        vals.push_back(a[i].x);
        vals.push_back(a[i].y);
    }
 
    sort(vals.begin(), vals.end());
    vals.resize(unique(vals.begin(), vals.end()) - vals.begin());
    m = vals.size();
 
    sort(a + 1, a + 1 + n);
 
    st = SegmentTree(m);
 
    long long ans = 0;
    int cur = 0;
 
    for (int r = 0; r < m; r++) {
        st.update(r, r, INF + 1ll * k * vals[r]);
 
        while (cur + 1 <= n && a[cur + 1].y <= vals[r]) {
            cur++;
 
            if (a[cur].v > 0) {
                int id = lower_bound(vals.begin(), vals.end(), a[cur].x) - vals.begin();
                st.update(0, id, a[cur].v);
            }
        }
 
        ans = max(ans, st.st[1] - 1ll * k * vals[r]);
    }
 
    cout << ans<<"\n";
}
Tester's Solution
//By TheOneYouWant
#pragma GCC optimize ("-O2")
#include <bits/stdc++.h>
using namespace std;
#define fastio ios_base::sync_with_stdio(0);cin.tie(0)
#define all(x) x.begin(),x.end()
#define forstl(i,v) for(auto &i: v)
#define forn(i,e) for(int i=0;i<e;++i)
#define ln '\n'
typedef long long ll;
typedef pair<int,int> p32;
typedef vector<long long int> v64; 
 
struct Segtree{
 
    v64 t, lazy;
 
    Segtree(int n) {
        t.assign(4*n, -1e18);
        lazy.assign(4*n, 0);
    }
 
    void build(long long int a[], int v, int tl, int tr){
        if(tl == tr){
            t[v] = a[tl];
        }
        else{
            int tm = (tl + tr)/2;
            build(a, v*2, tl, tm);
            build(a, v*2+1, tm+1, tr);
            t[v] = max(t[v*2], t[v*2+1]);
        }
    }
 
    void push(int v){
        t[v*2] += lazy[v];
        lazy[v*2] += lazy[v];
        t[v*2+1] += lazy[v];
        lazy[v*2+1] += lazy[v];
        lazy[v] = 0;
    }
 
    void update(int v, int tl, int tr, int l, int r, long long int addend){
        if(l > r){
            return;
        }
        if(l <= tl && tr <= r){
            t[v] += addend;
            lazy[v] += addend;
        } else {
            push(v);
            int tm = (tl + tr)/2;
            update(v*2, tl, tm, l, min(r, tm), addend);
            update(v*2+1, tm+1, tr, max(l, tm+1), r, addend);
            t[v] = max(t[v*2], t[v*2+1]);
        }
    }
 
    long long int query(int v, int tl, int tr, int l, int r) {
        if(l>r) return -1e18;
        if(l <= tl && tr <= r){
            return t[v];
        }
        push(v);
        int tm = (tl + tr)/2;
        return max(query(v*2, tl, tm, l, min(r, tm)), query(v*2+1, tm+1, tr, max(l, tm+1), r));
    }
};
 
 
signed main(){
    fastio;
 
    long long int n, k;
    cin>>n>>k;
 
    vector<tuple<int,int>> events;
    set<int> val;
 
    tuple<int,int,int> things[n];
 
    forn(i,n){
        int x, y, v;
        cin>>x>>y>>v;
        events.push_back({y, i});   
        things[i] = {x, y, v};
        val.insert(x);
        val.insert(y);
    }
    sort(all(events));
 
    map<int,int> to;
    int cnt = 0;
    forstl(r, val){
        to[r] = cnt;
        cnt++;
    }
 
    Segtree s(cnt);
    long long int a[cnt];
 
    forstl(r, val){
        a[to[r]] = k * r;
    }
 
    s.build(a, 1, 0, cnt-1);
    long long int ans = 0;
 
    forstl(r, events){
        long long int pt, ind;
        tie(pt, ind) = r;
        long long int x, y, v;
        tie(x, y, v) = things[ind];
        long long int get = s.query(1, 0, cnt-1, 0, to[x]) + v;
        get -= k * y;
        ans = max(ans, get);
        s.update(1, 0, cnt-1, 0, to[x], v);
    }
    cout<<ans<<ln;
 
    return 0;
}
Editorialist's Solution
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,mmx,avx,avx2")
using namespace std;
using ll = long long;
 
struct Node {
    typedef ll T;
    static constexpr T unit = -2e18;
    T f(T a, T b) { return max(a, b); }
 
    Node *l = 0, *r = 0;
    ll lo, hi, mset = unit, madd = 0;
    T val = unit;
    Node(int _lo,int _hi):lo(_lo),hi(_hi){}
    T query(int L, int R) {
        if (R <= lo || hi <= L) return unit;
        if (L <= lo && hi <= R) return val;
        push();
        return f(l->query(L, R), r->query(L, R));
    }
    void set(int L, int R, T x) {
        if (R <= lo || hi <= L) return;
        if (L <= lo && hi <= R) mset = val = x, madd = 0;
        else {
            push(), l->set(L, R, x), r->set(L, R, x);
            val = f(l->val, r->val);
        }
    }
    void add(int L, int R, T x) {
        if (R <= lo || hi <= L) return;
        if (L <= lo && hi <= R) {
            if (mset != unit) mset += x;
            else madd += x;
            val += x;
        }
        else {
            push(), l->add(L, R, x), r->add(L, R, x);
            val = f(l->val, r->val);
        }
    }
    void push() {
        if (!l) {
            int mid = lo + (hi - lo)/2;
            l = new Node(lo, mid); r = new Node(mid, hi);
        }
        if (mset != unit)
            l->set(lo,hi,mset), r->set(lo,hi,mset), mset = unit;
        else if (madd)
            l->add(lo,hi,madd), r->add(lo,hi,madd), madd = 0;
    }
};
 
int main()
{
    ios::sync_with_stdio(0); cin.tie(0);
    mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());
 
    int n, k; cin >> n >> k;
    Node *tr = new Node(-1000000007, 1000000007);
    vector<array<int, 3>> a(n);
    for (auto &[x, y, v] : a) {
        cin >> x >> y >> v;
        tr -> set(x, x+1, 1LL*k*x);
        tr -> set(y, y+1, 1LL*k*y);
    }
    sort(begin(a), end(a), [](auto x, auto y){return x[1] < y[1];});
    ll ans = 0;
    for (auto [x, y, v] : a) {
        // Best answer ending at y, with everything considered so far
        tr -> add(-1000000007, x+1, v);
        ans = max(ans, tr->query(-1000000007, y+1)-1LL*k*y);
    }
    cout << ans << '\n';
}
1 Like

https://www.codechef.com/viewsolution/44228893
I was trying to get subtask 2. I implemented almost the same approach as in editorial sorted the intervals then fixing entry or exit and then found the maxing profit for intervals before the fixed position but I am getting wrong answer, can anybody help out where i am getting wrong.

@iceknight1093
I was trying to implement the approach explained in the editorial. I couldn’t figure out what’s wrong in my code. I have stress tested my solution a lot but still couldn’t find a counter test case.
Can u please help me with this.

Code : CodeChef: Practical coding for everyone

Hi Guys Can anyone help in finding issue with my code? Tried n^2 approach almost similar to editorial and expected it to pass upto subtask two but it didn’t.

Submission link for reference.

https://www.codechef.com/viewsolution/44225250

@basant07 @p1god
Here’s a testcase where both of your solutions fail (the answer is 191: enter at 1 and exit at 10).

Click
3 1
1 5 100
6 10 100
3 20 1
2 Likes

Thanks Man Understood that why we should chose r as starting point and iterate over l. However I also tried using r and got issue. Can you provide with the counter case for that as well? Would be really helpful to identify any mistakes I have been making
https://www.codechef.com/viewsolution/44247585

Found the issue. I was expending l to be sorted but that may or may not be true.

I tried segment tree approach but it’s giving sigsegv. Any idea where I’m accessing wrong memory location?

Not entirely sure what you’re doing, but there are several issues with that code. Here’s at least three of them:

  1. It RTE’s on inputs with large indices
Input
3 2
1 1000000000 3
1 2 5
4 7 1
  1. All your variables are int, the answer doesn’t necessarily fit in a 32-bit integer.
  2. You don’t seem to be handling negative indices properly.
Input
1 2
-2 -1 10