BDAYPARTY - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: sushil2006
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

Segment trees

PROBLEM:

There are K cakes.
For each i = 1, 2, \ldots, N, on the i-th second, cake A_i will be eaten if it isn’t yet.
There are also M intervals [L_i, R_i].
Exactly once, at the start of some second t, you can choose an i and replenish the cakes [L_i, R_i].

Your score is (N - t + 1) \cdot C, where C is the number of cakes that remain uneaten in the end.
Find the maximum possible score.

EXPLANATION:

First, observe that only the last time at which a cake is eaten matters.
This is easy to see: if a cake is replenished, then it will remain if and only the last time it was eaten was before the replenishment; and if it wasn’t replenished then only whether it was ever eaten or not matters - so again storing the last time is enough.

Let l_i denote the last time cake i was eaten.

Next, note that if we have intervals [L_1, R_1] and [L_2, R_2] such that L_1 \leq L_2 \leq R_2 \leq R_1, meaning [L_2, R_2] is completely contained inside [L_1, R_1], it’s never optimal to use [L_2, R_2] - using [L_1, R_1] instead is not worse.
So, let’s discard all such useless intervals.

Let the remaining intervals be [L_1, R_1], [L_2, R_2], \ldots, [L_m, R_m].
Note that if these are sorted in ascending order of L_i, then they’ll also be sorted in ascending order of R_i automatically.
This is a rather useful property.

Let’s fix a time instant t, and try to figure out which interval is the optimal one to choose here.
To do this, we’d like to compute for the i-th interval the value c_i: the number of cakes that will remain if the i-th interval is used at time t.

Looking at interval [L_i, R_i],

  • For all cakes \lt L_i or \gt R_i, their state with respect to this interval is known, and is independent of t.
  • For cakes in [L_i, R_i], whether they contribute to c_i or not depends on their l_j values: only those with l_j \lt t will add 1 to c_i.

Suppose we’re able to (somehow) compute all the c_i values.
Let’s analyze how they change when moving from time t to time t+1.

First, if t is not the last time at which cake A_{t} is eaten, nothing changes at all.
Otherwise, we need to add 1 to c_i for every interval [L_i, R_i] that contains A_t.
This is because operating at or before time t would’ve resulted in A_t being eaten later anyway - but from t+1 onwards, it will be replenished and can’t be eaten again.

Here’s where the intervals being sorted becomes useful: the set of intervals containing A_t will form a contiguous range.
This means what we really want to do is add 1 to some range of the array c, and then compute its maximum - which is easily done quickly using a lazy segment tree!

We now have a fairly straightforward solution: initialize the array c with the states of the cakes outside each range; then iterate over every t from 1 to N, add 1 to the appropriate range when necessary, and then query for the maximum of c.

TIME COMPLEXITY:

\mathcal{O}((K + N + M)\log M) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;

template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;

#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "YES" << endl
#define no cout << "NO" << endl

#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)

template<typename T>
void amin(T &a, T b) {
    a = min(a,b);
}

template<typename T>
void amax(T &a, T b) {
    a = max(a,b);
}

#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif

/*



*/

const int MOD = 1e9 + 7;
const int N = 1e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;

// range add, range max
template<typename T>
struct lazysegtree {
    /*=======================================================*/

    struct data {
        ll a;
    };

    struct lazy {
        ll a;
    };

    data d_neutral = {-inf2};
    lazy l_neutral = {0};

    void merge(data &curr, data &left, data &right) {
        curr.a = max(left.a,right.a);
    }

    void create(int x, int lx, int rx, T v) {
        tr[x].a = v;
    }

    void modify(int x, int lx, int rx, T v) {
        lz[x].a = v;
    }

    void propagate(int x, int lx, int rx) {
        ll v = lz[x].a;
        if(!v) return;

        tr[x].a += v;

        if(rx-lx > 1){
            lz[x<<1].a += v;
            lz[x<<1|1].a += v;
        }

        lz[x] = l_neutral;
    }

    /*=======================================================*/

    int siz = 1;
    vector<data> tr;
    vector<lazy> lz;

    lazysegtree() {

    }

    lazysegtree(int n) {
        while (siz < n) siz *= 2;
        tr.assign(2 * siz, d_neutral);
        lz.assign(2 * siz, l_neutral);
    }

    void build(vector<T> &a, int n, int x, int lx, int rx) {
        if (rx - lx == 1) {
            if (lx < n) {
                create(x, lx, rx, a[lx]);
            }

            return;
        }

        int mid = (lx + rx) >> 1;

        build(a, n, x<<1, lx, mid);
        build(a, n, x<<1|1, mid, rx);

        merge(tr[x], tr[x<<1], tr[x<<1|1]);
    }

    void build(vector<T> &a, int n) {
        build(a, n, 1, 0, siz);
    }

    void rupd(int l, int r, T v, int x, int lx, int rx) {
        propagate(x, lx, rx);

        if (lx >= r or rx <= l) return;
        if (lx >= l and rx <= r) {
            modify(x, lx, rx, v);
            propagate(x, lx, rx);
            return;
        }

        int mid = (lx + rx) >> 1;

        rupd(l, r, v, x<<1, lx, mid);
        rupd(l, r, v, x<<1|1, mid, rx);

        merge(tr[x], tr[x<<1], tr[x<<1|1]);
    }

    void rupd(int l, int r, T v) {
        rupd(l, r + 1, v, 1, 0, siz);
    }

    data query(int l, int r, int x, int lx, int rx) {
        propagate(x, lx, rx);

        if (lx >= r or rx <= l) return d_neutral;
        if (lx >= l and rx <= r) return tr[x];

        int mid = (lx + rx) >> 1;

        data curr;
        data left = query(l, r, x<<1, lx, mid);
        data right = query(l, r, x<<1|1, mid, rx);

        merge(curr, left, right);
        return curr;
    }

    data query(int l, int r) {
        return query(l, r + 1, 1, 0, siz);
    }
};

void solve(int test_case){
    ll n,m,k; cin >> n >> m >> k;
    vector<ll> a(n+5);
    rep1(i,n) cin >> a[i];
    vector<pll> b(m+5);
    rep1(i,m) cin >> b[i].ff >> b[i].ss;
        
    auto cmp = [&](pll p1, pll p2){
        if(p1.ff != p2.ff) return p1.ff < p2.ff;
        return p1.ss > p2.ss;
    };

    sort(b.begin()+1,b.begin()+m+1,cmp);

    ll mxr = -1;
    vector<pll> b2;

    rep1(i,m){
        auto [l,r] = b[i];
        if(r > mxr){
            b2.pb({l,r});
            mxr = r;
        }
    }

    ll sum_len = 0;
    rep(i,sz(b2)){
        auto [l,r] = b2[i];
        sum_len += r-l+1;
    }

    // #of cakes that would anyways be there
    vector<bool> cakes(k+5,1);
    rep1(i,n) cakes[a[i]] = 0;
    ll untouched = 0;
    rep1(i,k) untouched += cakes[i];

    vector<ll> pc(k+5);
    rep1(i,k) pc[i] = pc[i-1]+!cakes[i];

    ll siz = sz(b2);
    vector<ll> ini(siz);
    rep(i,siz) ini[i] = pc[b2[i].ss]-pc[b2[i].ff-1];
    lazysegtree<ll> st(siz+5);
    st.build(ini,siz);

    fill(all(cakes),1);

    ll ans = 0;

    rev(i,n,1){
        ll x = a[i];
        if(cakes[x]){
            // first seg that contains x
            ll first = -1;
            
            {
                ll lo = 0, hi = siz-1;
                while(lo <= hi){
                    ll mid = (lo+hi) >> 1;
                    auto [l,r] = b2[mid];
                    if(l <= x and x <= r){
                        first = mid;
                        hi = mid-1;
                    }
                    else{
                        if(l > x){
                            hi = mid-1;
                        }
                        else if(r < x){
                            lo = mid+1;
                        }
                        else{
                            assert(0);
                        }
                    }
                }
            }

            // last seg that contains x
            ll last = -1;

            {
                ll lo = 0, hi = siz-1;
                while(lo <= hi){
                    ll mid = (lo+hi) >> 1;
                    auto [l,r] = b2[mid];
                    if(l <= x and x <= r){
                        last = mid;
                        lo = mid+1;
                    }
                    else{
                        if(l > x){
                            hi = mid-1;
                        }
                        else if(r < x){
                            lo = mid+1;
                        }
                        else{
                            assert(0);
                        }
                    }
                }
            }

            if(first != -1){
                st.rupd(first,last,-1);
            }

            cakes[x] = 0;
        }

        ll mx = st.query(0,siz-1).a;
        ll val = (mx+untouched)*(n-i+1);
        amax(ans,val);
    }

    cout << ans << endl;
}

int main()
{
    fastio;

    int t = 1;
    cin >> t;

    rep1(i, t) {
        solve(i);
    }

    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

struct segtree{
    struct node{
        int x = 0;
        int lz = 0;
 
        void apply(int l, int r, int y){
            x += y;
            lz += y;
        }
    };
 
    int n;
    vector <node> seg;
 
    node unite(node a, node b){
        node res;
        res.x = max(a.x, b.x);
        return res;
    }
 
    void push(int l, int r, int pos){
        if (l != r){
            int mid = (l + r) / 2;
            seg[pos * 2].apply(l, mid, seg[pos].lz);
            seg[pos * 2 + 1].apply(mid + 1, r, seg[pos].lz);
        }
        
        seg[pos].lz = 0;
    }
 
    void pull(int pos){
        seg[pos] = unite(seg[pos * 2], seg[pos * 2 + 1]);
    }
 
    void build(int l, int r, int pos){
        if (l == r){
            return;
        }
 
        int mid = (l + r) / 2;
        build(l, mid, pos * 2);
        build(mid + 1, r, pos * 2 + 1);
        pull(pos);
    }
 
    template<typename M>
    void build(int l, int r, int pos, vector<M> &v){
        if (l == r){
            seg[pos].apply(l, r, v[l]);
            return;
        }
 
        int mid = (l + r) / 2;
        build(l, mid, pos * 2, v);
        build(mid + 1, r, pos * 2 + 1, v);
        pull(pos);
    }
 
    node query(int l, int r, int pos, int ql, int qr){
        push(l, r, pos);
        if (l >= ql && r <= qr){
            return seg[pos];
        }
        
        int mid = (l + r) / 2;
        node res{};
        if (qr <= mid) res = query(l, mid, pos * 2, ql, qr);
        else if (ql > mid) res = query(mid + 1, r, pos * 2 + 1, ql, qr);
        else res = unite(query(l, mid, pos * 2, ql, qr), query(mid + 1, r, pos * 2 + 1, ql, qr));
        
        pull(pos);
        return res;
    }
 
    template <typename... M>
    void modify(int l, int r, int pos, int ql, int qr, M&... v){
        push(l, r, pos);
        if (l >= ql && r <= qr){
            seg[pos].apply(l, r, v...);
            return;
        }
 
        int mid = (l + r) / 2;
        if (ql <= mid) modify(l, mid, pos * 2, ql, qr, v...);
        if (qr > mid) modify(mid + 1, r, pos * 2 + 1, ql, qr, v...);
 
        pull(pos);
    }
 
    segtree (int _n){
        n = _n;
        seg.resize(4 * n + 1);
        build(1, n, 1);
    }
 
    template <typename M>
    segtree (int _n, vector<M> &v){
        n = _n;
        seg.resize(4 * n + 1);
        if (v.size() == n){
            v.insert(v.begin(), M());
        }
        build(1, n, 1, v);
    }
 
    node query(int l, int r){
        return query(1, n, 1, l, r);
    }
 
    node query(int x){
        return query(1, n, 1, x, x);
    }
 
    template <typename... M>
    void modify(int ql, int qr, M&...v){
        modify(1, n, 1, ql, qr, v...);
    }
};


void Solve() 
{
    int n, m, k; cin >> n >> m >> k;
    
    vector <int> a(n + 1);
    vector <int> l(k + 1, 0);
    
    for (int i = 1; i <= n; i++){
        cin >> a[i];
        l[a[i]] = i;
    }
    
    int fre = 0;
    for (int i = 1; i <= k; i++){
        fre += l[i] == 0;
    }
    
    vector <pair<int, int>> b;
    for (int i = 1; i <= m; i++){
        int l, r; cin >> l >> r;
        b.push_back({l, r});
    }
    
    sort(b.begin(), b.end(), [](pair <int, int> x, pair <int, int> y){
        if (x.first != y.first) return x.first < y.first;
        return x.second > y.second;
    });
    
    vector <pair<int, int>> c;
    int mxr = -1;
    for (auto [l, r] : b){
        if (r > mxr){
            mxr = r;
            c.push_back({l, r});
        }
    }
    b = c;
    
    vector<vector<int>> at(n + 1);
    for (int i = 1; i <= k; i++) if (l[i]){
        at[l[i]].push_back(i);
    }
    
    m = b.size();
    segtree seg(m);
    int ans = 0;
    
    vector <int> ls, rs;
    for (auto [l, r] : b){
        ls.push_back(l);
        rs.push_back(r);
      //  cout << l << " " << r << "\n";
    }
    
  //  cout << fre << "\n";
    for (int i = 1; i <= n; i++){
        // calculate answer first 
        int sv = fre + seg.query(1, m).x;
        ans = max(ans, sv * (n + 1 - i));
        
        for (int j : at[i]){
         //   cout << "HERE " << i << " " << j << "\n";
            // j is now saved for intervals containing j 
            // binary search
            
            // first interval containing 
            auto id1 = lower_bound(rs.begin(), rs.end(), j) - rs.begin();
            // first interval not containing 
            auto id2 = upper_bound(ls.begin(), ls.end(), j) - ls.begin();
            
            id2--;
            if (id1 <= id2){
                id1++;
                id2++;
                int pp = 1;
                seg.modify(id1, id2, pp);
            }
        }
    }
    
    cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}