SAVJEW - Editorial

PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Setter: Rami

Tester: Roman Bilyi

Editorialist: Taranpreet Singh

DIFFICULTY:

Medium

PREREQUISITES:

Segment Tree, Queues, Observations.

PROBLEM:

There are N jewels, each with value P_i such that each jewel has a distinct value. For the next M days, a thief selects an interval [L_i, R_i] and steals the most expensive jewel within this interval, if there’s any such jewel.

You are allowed to restrict the thief stealing for one day. Assuming we choose the day optimally, what is the maximum sum of values of jewels not stolen.

QUICK EXPLANATION

  • We can save at most one jewel from being stolen, so we shall process the jewels in decreasing order of value.
  • While processing a jewel, there can be three cases.
    • If Jewel is not covered in any interval, it shall never be stolen, so it gets added to the final answer.
    • If the jewel is covered in exactly one interval and we haven’t yet removed any interval, we can remove this interval and save the jewel.
    • If the jewel is covered in more than one interval, we cannot save this jewel.
  • After a jewel is processed, if it is covered in any interval not already removed, we need to remove the minimum index interval which contains this position from our structure before moving to the next jewel.
  • So, we need a data structure, which supports the following operations.
    • Insert an interval [L, R] with index x
    • Remove the interval indexed x
    • For position p, find the number of intervals containing p, and the lowest indexed interval which contains position p
  • Segment Tree with a queue at each node can be used to achieve this, maintaining the number of intervals covering each node.

EXPLANATION

First of all, let us consider jewels one by one, in decreasing order of value. The idea behind this ordering is, that for all unused intervals which contain this position, this jewel always has the maximum value since the jewels with a higher value are already considered and taken care of.

When we consider jewels in decreasing order of values, only one of the following cases can happen.

  • If Jewel is not covered in any interval, it shall never be stolen.
  • If the jewel is covered in exactly one interval, we can choose to save this jewel by restricting this only interval. We can save at most one such jewel.
  • If the jewel is covered in more than one interval, then we cannot save this jewel even if we restrict any interval.

In the first case, this jewel always contributes to the answer. In the third case, the jewel never contributes to the answer. So, for the second case, we need to choose an interval to restrict so as to maximize the value of this jewel.

Why Choosing to restrict any interval in the third case is not optimal

Suppose, it is optimal to restrict an interval which selects the highest jewel with value x not uniquely covered in this interval. So, there’s some other interval, which initially was covering a jewel with a value y y < x. Now, after removing the first interval, the second interval shall take the jewel with value x, so we only manage to save jewel with value y. We could have achieved the same if we left the first interval as it is, and restricted the second interval, which is exactly what we do in case 2.

This proves that if removing any interval under case 3 gives some gain g, we can always achieve at least that much gain only by restricting any interval covering jewel under case 2.

So, all that is needed, is to consider jewels one-by-one and maintain a data structure which can support following operations over range [1, N]

  • Insert an interval [L, R] with index x
  • Remove the interval indexed x
  • For position p, find the number of intervals containing p, and the index of the earliest inserted interval which contains position p

If we have the above data structure, we can insert all intervals in the beginning, and whenever we find an interval corresponding to a jewel, we need to remove it from our data structure.

We can see, that the “the index of the earliest inserted interval which contains position p” is actually the index of the interval, in which current jewel shall be stolen. The second part of the query tells us whether we can save this jewel or not.

By getting the index of the interval, we need to remove it from our data structure so that for subsequent jewels, this interval is not considered, since it’s either used or restricted.

We can achieve all these operations using a segment tree, by keeping a queue storing indices of intervals, and an integer denoting the number of intervals covering the current interval.

Implementation Hints

For insertion,

  • we split interval [L, R] into log(N) disjoint intervals and for each node corresponding to any such interval, we insert the index of this interval and increase the number of intervals covering this node.

For Deletion, we have two things to handle, queue and the number of intervals for each node.

  • For Queue, we maintain a boolean array which determines whether some interval is used or not. Whenever reading the first element of the queue, we can remove the topmost elements of the queue while we have used interval indices on top.
  • For sizes, we repeat the same process as insertion and decrease the number of intervals covering this node.

For query

  • We can consider all nodes on the path of leaf corresponding to position p and find the minimum index among the queues of all nodes.
  • Similarly, the number of intervals covering this position is the sum of the number of intervals covering each node on the path from the root to leaf for position p

This blog is good to refer to for the segment tree, especially this problem.

TIME COMPLEXITY

The time complexity is O(N*log(N)) per test case.

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
#define ll  long long
#define pii pair<int,int>
#define pll pair<ll,ll>
#define sc second
#define fr first
using namespace std;

const int N = 2e5+100;
const int M = 2e5+100;

int n,m;
pii a[N];
int l[M],r[M];
bool take[M];

queue<int>pers[N*4];
int sz[N*4];

void addOrDel(int p,int x){
	if(x > 0)
	    pers[p].push(x);
	sz[p] += x/abs(x);
}

void add(int a,int b,int x,int l =1,int h = n,int p =1){
	if(a == l && b == h){
	    addOrDel(p,x);
	    return ;
	}
	int m = (l+h)/2;
	if(b <= m)
	    add(a,b,x,l,m,2*p);
	else if(a > m)
	    add(a,b,x,m+1,h,2*p+1);
	else{
	    add(m+1,b,x,m+1,h,2*p+1);
	    add(a,m,x,l,m,2*p);
	}
}

void clean(int p){
	while(!pers[p].empty() && take[pers[p].front()])
	    pers[p].pop();
}

pii cal(int in,int l =1,int h = n,int p =1){
	clean(p);
	if(l == h){
	    int x = 1e9;
	    if(pers[p].size())
	        x = pers[p].front();
	    return {sz[p],x};
	}
	int m = (l+h)/2;

	pii res;
	res.fr = sz[p];
	res.sc = 1e9;
	if(pers[p].size())
	   res.sc = pers[p].front();

	if(in <= m){
	    pii x = cal(in,l,m,2*p);
	    return {res.fr+x.fr,min(res.sc,x.sc)};
	}
	pii x= cal(in,m+1,h,2*p+1);
	return {res.fr+x.fr,min(res.sc,x.sc)};
}

void read(){
	scanf("%d%d",&n,&m);
	for(int i=1; i<=n ; i ++){
	    scanf("%d",&a[i].fr);
	    a[i].sc = i;
	}
	for(int i=1 ;i <=m ;i ++)
	    scanf("%d%d",&l[i],&r[i]);
}

void init(){
	for(int i=0 ;i <= 4*n ;i++){
	    while(!pers[i].empty())pers[i].pop();
	    sz[i] = 0;
	}
	for(int i=0 ;i <=m; i++)take[i] = 0;
}


int main()  {
	int t;
	cin>>t;
	while(t--){
	    init();
	    read();

	    sort(a+1,a+n+1);
	    reverse(a+1,a+n+1);

	    for(int i=1; i<=m ;i ++){
	        add(l[i],r[i],i);
	    }
	    bool ch = 0;
	    ll res = 0;
	    for(int i=1 ;i <=n ;i ++){
	        pii x = cal(a[i].sc);
	        if(x.fr == 1 && !ch){
	            ch = 1;
	            take[x.sc] = 1;
	            add(l[x.sc],r[x.sc],-x.sc);
	            res += a[i].fr;
	            continue;
	        }
	        if(!x.fr)
	            res += a[i].fr;
	        if(x.sc <= m){
	            take[x.sc] = 1;
	            add(l[x.sc],r[x.sc],-x.sc);
	        }
	    }
	    printf("%lld\n",res);
	}

	return 0;
}
Tester's Solution
#include "bits/stdc++.h"
#pragma GCC optimize("Ofast")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,avx,avx2")
using namespace std;

#define FOR(i,a,b) for (int i = (a); i < (b); i++)
#define RFOR(i,b,a) for (int i = (b) - 1; i >= (a); i--)
#define ITER(it,a) for (__typeof(a.begin()) it = a.begin(); it != a.end(); it++)
#define FILL(a,value) memset(a, value, sizeof(a))

#define SZ(a) (int)a.size()
#define ALL(a) a.begin(), a.end()
#define PB push_back
#define MP make_pair

typedef long long Int;
typedef vector<int> VI;
typedef pair<int, int> PII;

const double PI = acos(-1.0);
const int INF = 1000 * 1000 * 1000;
const Int LINF = INF * (Int) INF;
const int MAX = 100007;

const int MOD = 1000000007;

const double Pi = acos(-1.0);

PII t[4 * MAX];
int A[MAX];

void build(int v, int l, int r)
{
	if (l == r)
	{
	    t[v] = MP(A[l], l);
	    return;
	}
	int m = (l + r) / 2;
	build(2 * v + 1, l, m);
	build(2 * v + 2, m + 1, r);
	t[v] = max(t[2 * v + 1], t[2 * v + 2]);
}

void Set(int v, int l, int r, int pos, int val)
{
	if (l == r)
	{
	    t[v] = MP(val, l);
	    return;
	}
	int m = (l + r) / 2;
	if (pos <= m)
	    Set(2 * v + 1, l, m, pos, val);
	else
	    Set(2 * v + 2, m + 1 , r , pos, val);
	t[v] = max(t[2 * v + 1], t[2 * v + 2]);
}

PII Get(int v, int l, int r, int L, int R)
{
	if (L > R)
	    return MP(-1, -1);
	if (l == L && r == R)
	    return t[v];
	int m = (l + r) / 2;
	return max(Get(2 * v + 1, l, m, L, min(R, m)), Get(2 * v + 2, m + 1, r, max(L, m + 1), R));
}

int L[MAX];
int R[MAX];
int C[MAX];

int main(int argc, char* argv[])
{
	// freopen("in.txt", "r", stdin);
	//ios::sync_with_stdio(false); cin.tie(0);

	int t;
	cin >> t;
	FOR(tt,0,t)
	{
	    int n, m;
	    cin >> n >> m;

	    Int res = 0;
	    FOR(i,0,n)
	    {
	        cin >> A[i];
	        res += A[i];
	    }
	    build(0, 0, n - 1);
	    FOR(i,0,m)
	    {
	        cin >> L[i] >> R[i];
	        --L[i]; --R[i];
	        PII p = Get(0, 0, n - 1, L[i], R[i]);
	        C[i] = p.first;
	        if (p.first >= 0)
	        {
	            res -= p.first;
	            Set(0, 0, n - 1, p.second, -1);
	        }   
	    }
	    int add = 0;

	    build(0, 0, n - 1);
	    set<int> B;
	    int block_id = m;

	    RFOR(i, m, 0)
	    {
	        if (B.find(C[i]) == B.end()) {
	            if (C[i] > add)
	                block_id = i;
	            add = max(add, C[i]);
	        }
	        while(true) {
	            PII p = Get(0, 0, n - 1, L[i], R[i]);
	            if (p.first <= C[i]) break;
	            B.insert(p.first);
	            Set(0, 0, n - 1, p.second, -1);
	        }
	    }
	    res += add;
	    cout << res << endl;
	    cerr << block_id << ' ' << m << endl;
	}

	cerr << 1.0 * clock() / CLOCKS_PER_SEC << endl;

	
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class SAVJEW{
	//SOLUTION BEGIN
	void pre() throws Exception{}
	void solve(int TC) throws Exception{
	    int n = ni(), m = ni();
	    Jewel[] a = new Jewel[n];
	    for(int i = 0; i< n; i++)a[i] = new Jewel(i, nl());
	    Arrays.sort(a, (Jewel j1, Jewel j2) -> Long.compare(j2.val, j1.val));
	    int[] le = new int[1+m], ri = new int[1+m];
	    for(int i = 1; i<= m; i++){
	        le[i] = ni()-1;
	        ri[i] = ni()-1;
	    }
	    
	    sz = 1;
	    while(sz < n)sz<<=1;
	    //q holds queue for each node of segtree, storing indices of segments which cover the node range in increasing order of indices
	    q = new ArrayDeque[sz<<1];
	    queueSize = new int[sz<<1];
	    for(int i = 0; i< sz+sz; i++)q[i] = new ArrayDeque<>();
	    for(int i = 1; i<= m; i++)add(le[i], ri[i], i, 0, sz-1, 1);
	    
	    used = new boolean[1+m];
	    long ans = 0;
	    boolean taken = false;//Have we restricted some segment yet?
	    for(int i = 0; i< n; i++){
	        int[] p = getPair(a[i].ind, 0, sz-1, 1);
	        //p[0] -> number of segments taking this jewel
	        //p[1] -> index of first segment taking this jewel
	        if(p[0] == 0){
	            //No segment cover this jewel
	            ans += a[i].val;
	        }else if(p[0] == 1 && !taken){
	            //Saving this jewel by restricting the only segment
	            ans += a[i].val;
	            used[p[1]] = true;
	            add(le[p[1]], ri[p[1]], -p[1], 0, sz-1, 1);
	            taken = true;
	            continue;
	        }else{
	            //More than one segment cover this jewel, we can't do anything about it
	            used[p[1]] = true;
	            add(le[p[1]], ri[p[1]], -p[1], 0, sz-1, 1);
	        }
	    }
	    pn(ans);
	}
	int sz, INF = (int)1e9;
	Queue<Integer>[] q;
	int[] queueSize;
	boolean[] used;
	//Removing already used segments from queue front end
	void update(int i){
	    while(!q[i].isEmpty() && used[q[i].peek()])q[i].poll();
	}
	//Returns pair(count, firstInd) - count -> number of segments covering jewel at position pos, firstInd -> index of first segment covering jewel at position pos
	int[] getPair(int pos, int ll, int rr, int i){
	    int[] pair = new int[]{queueSize[i], INF};
	    update(i);
	    if(!q[i].isEmpty())pair[1] = q[i].peek();
	    
	    if(ll == rr)return pair;
	    int mid = (ll+rr)/2;
	    int[] childPair;
	    if(pos <= mid)childPair = getPair(pos, ll, mid, i<<1);
	    else childPair = getPair(pos, mid+1, rr, i<<1|1);
	    return new int[]{pair[0]+childPair[0], Math.min(pair[1], childPair[1])};
	}
	void add(int l, int r, int x, int ll, int rr, int i){
	    if(l == ll && r == rr){
	        if(x > 0){queueSize[i]++;q[i].add(x);}
	        else queueSize[i]--;
	        return;
	    }
	    int mid = (ll+rr)/2;
	    if(r <= mid)add(l, r, x, ll, mid, i<<1);
	    else if(l > mid)add(l, r, x, mid+1, rr, i<<1|1);
	    else{
	        add(l, mid, x, ll, mid, i<<1);
	        add(mid+1, r, x, mid+1, rr, i<<1|1);
	    }
	}
	class Jewel{
	    int ind;
	    long val;
	    public Jewel(int i, long v){
	        ind = i;val = v;
	    }
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	DecimalFormat df = new DecimalFormat("0.00000000000");
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new SAVJEW().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to share your approach, if you want to. (even if its same :stuck_out_tongue: ) . Suggestions are welcomed as always had been. :slight_smile:

1 Like