PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: sushil2006
Tester: sushil2006
Editorialist: iceknight1093
DIFFICULTY:
Medium
PREREQUISITES:
Segment trees
PROBLEM:
There are K cakes.
For each i = 1, 2, \ldots, N, on the i-th second, cake A_i will be eaten if it isn’t yet.
There are also M intervals [L_i, R_i].
Exactly once, at the start of some second t, you can choose an i and replenish the cakes [L_i, R_i].
Your score is (N - t + 1) \cdot C, where C is the number of cakes that remain uneaten in the end.
Find the maximum possible score.
EXPLANATION:
First, observe that only the last time at which a cake is eaten matters.
This is easy to see: if a cake is replenished, then it will remain if and only the last time it was eaten was before the replenishment; and if it wasn’t replenished then only whether it was ever eaten or not matters - so again storing the last time is enough.
Let l_i denote the last time cake i was eaten.
Next, note that if we have intervals [L_1, R_1] and [L_2, R_2] such that L_1 \leq L_2 \leq R_2 \leq R_1, meaning [L_2, R_2] is completely contained inside [L_1, R_1], it’s never optimal to use [L_2, R_2] - using [L_1, R_1] instead is not worse.
So, let’s discard all such useless intervals.
Let the remaining intervals be [L_1, R_1], [L_2, R_2], \ldots, [L_m, R_m].
Note that if these are sorted in ascending order of L_i, then they’ll also be sorted in ascending order of R_i automatically.
This is a rather useful property.
Let’s fix a time instant t, and try to figure out which interval is the optimal one to choose here.
To do this, we’d like to compute for the i-th interval the value c_i: the number of cakes that will remain if the i-th interval is used at time t.
Looking at interval [L_i, R_i],
- For all cakes \lt L_i or \gt R_i, their state with respect to this interval is known, and is independent of t.
- For cakes in [L_i, R_i], whether they contribute to c_i or not depends on their l_j values: only those with l_j \lt t will add 1 to c_i.
Suppose we’re able to (somehow) compute all the c_i values.
Let’s analyze how they change when moving from time t to time t+1.
First, if t is not the last time at which cake A_{t} is eaten, nothing changes at all.
Otherwise, we need to add 1 to c_i for every interval [L_i, R_i] that contains A_t.
This is because operating at or before time t would’ve resulted in A_t being eaten later anyway - but from t+1 onwards, it will be replenished and can’t be eaten again.
Here’s where the intervals being sorted becomes useful: the set of intervals containing A_t will form a contiguous range.
This means what we really want to do is add 1 to some range of the array c, and then compute its maximum - which is easily done quickly using a lazy segment tree!
We now have a fairly straightforward solution: initialize the array c with the states of the cakes outside each range; then iterate over every t from 1 to N, add 1 to the appropriate range when necessary, and then query for the maximum of c.
TIME COMPLEXITY:
\mathcal{O}((K + N + M)\log M) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "YES" << endl
#define no cout << "NO" << endl
#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)
template<typename T>
void amin(T &a, T b) {
a = min(a,b);
}
template<typename T>
void amax(T &a, T b) {
a = max(a,b);
}
#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif
/*
*/
const int MOD = 1e9 + 7;
const int N = 1e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;
// range add, range max
template<typename T>
struct lazysegtree {
/*=======================================================*/
struct data {
ll a;
};
struct lazy {
ll a;
};
data d_neutral = {-inf2};
lazy l_neutral = {0};
void merge(data &curr, data &left, data &right) {
curr.a = max(left.a,right.a);
}
void create(int x, int lx, int rx, T v) {
tr[x].a = v;
}
void modify(int x, int lx, int rx, T v) {
lz[x].a = v;
}
void propagate(int x, int lx, int rx) {
ll v = lz[x].a;
if(!v) return;
tr[x].a += v;
if(rx-lx > 1){
lz[x<<1].a += v;
lz[x<<1|1].a += v;
}
lz[x] = l_neutral;
}
/*=======================================================*/
int siz = 1;
vector<data> tr;
vector<lazy> lz;
lazysegtree() {
}
lazysegtree(int n) {
while (siz < n) siz *= 2;
tr.assign(2 * siz, d_neutral);
lz.assign(2 * siz, l_neutral);
}
void build(vector<T> &a, int n, int x, int lx, int rx) {
if (rx - lx == 1) {
if (lx < n) {
create(x, lx, rx, a[lx]);
}
return;
}
int mid = (lx + rx) >> 1;
build(a, n, x<<1, lx, mid);
build(a, n, x<<1|1, mid, rx);
merge(tr[x], tr[x<<1], tr[x<<1|1]);
}
void build(vector<T> &a, int n) {
build(a, n, 1, 0, siz);
}
void rupd(int l, int r, T v, int x, int lx, int rx) {
propagate(x, lx, rx);
if (lx >= r or rx <= l) return;
if (lx >= l and rx <= r) {
modify(x, lx, rx, v);
propagate(x, lx, rx);
return;
}
int mid = (lx + rx) >> 1;
rupd(l, r, v, x<<1, lx, mid);
rupd(l, r, v, x<<1|1, mid, rx);
merge(tr[x], tr[x<<1], tr[x<<1|1]);
}
void rupd(int l, int r, T v) {
rupd(l, r + 1, v, 1, 0, siz);
}
data query(int l, int r, int x, int lx, int rx) {
propagate(x, lx, rx);
if (lx >= r or rx <= l) return d_neutral;
if (lx >= l and rx <= r) return tr[x];
int mid = (lx + rx) >> 1;
data curr;
data left = query(l, r, x<<1, lx, mid);
data right = query(l, r, x<<1|1, mid, rx);
merge(curr, left, right);
return curr;
}
data query(int l, int r) {
return query(l, r + 1, 1, 0, siz);
}
};
void solve(int test_case){
ll n,m,k; cin >> n >> m >> k;
vector<ll> a(n+5);
rep1(i,n) cin >> a[i];
vector<pll> b(m+5);
rep1(i,m) cin >> b[i].ff >> b[i].ss;
auto cmp = [&](pll p1, pll p2){
if(p1.ff != p2.ff) return p1.ff < p2.ff;
return p1.ss > p2.ss;
};
sort(b.begin()+1,b.begin()+m+1,cmp);
ll mxr = -1;
vector<pll> b2;
rep1(i,m){
auto [l,r] = b[i];
if(r > mxr){
b2.pb({l,r});
mxr = r;
}
}
ll sum_len = 0;
rep(i,sz(b2)){
auto [l,r] = b2[i];
sum_len += r-l+1;
}
// #of cakes that would anyways be there
vector<bool> cakes(k+5,1);
rep1(i,n) cakes[a[i]] = 0;
ll untouched = 0;
rep1(i,k) untouched += cakes[i];
vector<ll> pc(k+5);
rep1(i,k) pc[i] = pc[i-1]+!cakes[i];
ll siz = sz(b2);
vector<ll> ini(siz);
rep(i,siz) ini[i] = pc[b2[i].ss]-pc[b2[i].ff-1];
lazysegtree<ll> st(siz+5);
st.build(ini,siz);
fill(all(cakes),1);
ll ans = 0;
rev(i,n,1){
ll x = a[i];
if(cakes[x]){
// first seg that contains x
ll first = -1;
{
ll lo = 0, hi = siz-1;
while(lo <= hi){
ll mid = (lo+hi) >> 1;
auto [l,r] = b2[mid];
if(l <= x and x <= r){
first = mid;
hi = mid-1;
}
else{
if(l > x){
hi = mid-1;
}
else if(r < x){
lo = mid+1;
}
else{
assert(0);
}
}
}
}
// last seg that contains x
ll last = -1;
{
ll lo = 0, hi = siz-1;
while(lo <= hi){
ll mid = (lo+hi) >> 1;
auto [l,r] = b2[mid];
if(l <= x and x <= r){
last = mid;
lo = mid+1;
}
else{
if(l > x){
hi = mid-1;
}
else if(r < x){
lo = mid+1;
}
else{
assert(0);
}
}
}
}
if(first != -1){
st.rupd(first,last,-1);
}
cakes[x] = 0;
}
ll mx = st.query(0,siz-1).a;
ll val = (mx+untouched)*(n-i+1);
amax(ans,val);
}
cout << ans << endl;
}
int main()
{
fastio;
int t = 1;
cin >> t;
rep1(i, t) {
solve(i);
}
return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
struct segtree{
struct node{
int x = 0;
int lz = 0;
void apply(int l, int r, int y){
x += y;
lz += y;
}
};
int n;
vector <node> seg;
node unite(node a, node b){
node res;
res.x = max(a.x, b.x);
return res;
}
void push(int l, int r, int pos){
if (l != r){
int mid = (l + r) / 2;
seg[pos * 2].apply(l, mid, seg[pos].lz);
seg[pos * 2 + 1].apply(mid + 1, r, seg[pos].lz);
}
seg[pos].lz = 0;
}
void pull(int pos){
seg[pos] = unite(seg[pos * 2], seg[pos * 2 + 1]);
}
void build(int l, int r, int pos){
if (l == r){
return;
}
int mid = (l + r) / 2;
build(l, mid, pos * 2);
build(mid + 1, r, pos * 2 + 1);
pull(pos);
}
template<typename M>
void build(int l, int r, int pos, vector<M> &v){
if (l == r){
seg[pos].apply(l, r, v[l]);
return;
}
int mid = (l + r) / 2;
build(l, mid, pos * 2, v);
build(mid + 1, r, pos * 2 + 1, v);
pull(pos);
}
node query(int l, int r, int pos, int ql, int qr){
push(l, r, pos);
if (l >= ql && r <= qr){
return seg[pos];
}
int mid = (l + r) / 2;
node res{};
if (qr <= mid) res = query(l, mid, pos * 2, ql, qr);
else if (ql > mid) res = query(mid + 1, r, pos * 2 + 1, ql, qr);
else res = unite(query(l, mid, pos * 2, ql, qr), query(mid + 1, r, pos * 2 + 1, ql, qr));
pull(pos);
return res;
}
template <typename... M>
void modify(int l, int r, int pos, int ql, int qr, M&... v){
push(l, r, pos);
if (l >= ql && r <= qr){
seg[pos].apply(l, r, v...);
return;
}
int mid = (l + r) / 2;
if (ql <= mid) modify(l, mid, pos * 2, ql, qr, v...);
if (qr > mid) modify(mid + 1, r, pos * 2 + 1, ql, qr, v...);
pull(pos);
}
segtree (int _n){
n = _n;
seg.resize(4 * n + 1);
build(1, n, 1);
}
template <typename M>
segtree (int _n, vector<M> &v){
n = _n;
seg.resize(4 * n + 1);
if (v.size() == n){
v.insert(v.begin(), M());
}
build(1, n, 1, v);
}
node query(int l, int r){
return query(1, n, 1, l, r);
}
node query(int x){
return query(1, n, 1, x, x);
}
template <typename... M>
void modify(int ql, int qr, M&...v){
modify(1, n, 1, ql, qr, v...);
}
};
void Solve()
{
int n, m, k; cin >> n >> m >> k;
vector <int> a(n + 1);
vector <int> l(k + 1, 0);
for (int i = 1; i <= n; i++){
cin >> a[i];
l[a[i]] = i;
}
int fre = 0;
for (int i = 1; i <= k; i++){
fre += l[i] == 0;
}
vector <pair<int, int>> b;
for (int i = 1; i <= m; i++){
int l, r; cin >> l >> r;
b.push_back({l, r});
}
sort(b.begin(), b.end(), [](pair <int, int> x, pair <int, int> y){
if (x.first != y.first) return x.first < y.first;
return x.second > y.second;
});
vector <pair<int, int>> c;
int mxr = -1;
for (auto [l, r] : b){
if (r > mxr){
mxr = r;
c.push_back({l, r});
}
}
b = c;
vector<vector<int>> at(n + 1);
for (int i = 1; i <= k; i++) if (l[i]){
at[l[i]].push_back(i);
}
m = b.size();
segtree seg(m);
int ans = 0;
vector <int> ls, rs;
for (auto [l, r] : b){
ls.push_back(l);
rs.push_back(r);
// cout << l << " " << r << "\n";
}
// cout << fre << "\n";
for (int i = 1; i <= n; i++){
// calculate answer first
int sv = fre + seg.query(1, m).x;
ans = max(ans, sv * (n + 1 - i));
for (int j : at[i]){
// cout << "HERE " << i << " " << j << "\n";
// j is now saved for intervals containing j
// binary search
// first interval containing
auto id1 = lower_bound(rs.begin(), rs.end(), j) - rs.begin();
// first interval not containing
auto id2 = upper_bound(ls.begin(), ls.end(), j) - ls.begin();
id2--;
if (id1 <= id2){
id1++;
id2++;
int pp = 1;
seg.modify(id1, id2, pp);
}
}
}
cout << ans << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}