PROBLEM LINK:
Contest Division 1
Contest Division 2
Contest Division 3
Practice
Setter: Akshit Monga
Tester: Nishank Suresh
Editorialist: Taranpreet Singh
DIFFICULTY
Easy
PREREQUISITES
Sorting
PROBLEM
Given a sequence A containing N integers, you can perform the following operation any number of times.
- Select any continuous segment A_{l \ldots r} and sort the segment in non-decreasing order. The cost to perform this operation is max(A_{l \ldots r})-min(A_{l \ldots r}).
Determine the minimum cost to sort the sequence A in non-decreasing order.
QUICK EXPLANATION
- The simplest solution would involve only one operation covering the whole array, leading to cost max(A)-min(A). But we can do better,
- If there are two operations performed where a single element was covered by more than one operation, then it is always better to perform a single operation applied on the union of the segment of two operations.
- So, the only operations needed are when the first x elements of A contain the same element multiset as the first x elements of sorted A.
EXPLANATION
The most basic solution would be a single operation where the whole array is covered. It’d cost max(A)-min(A) and sorts the array immediately. So we always have a solution to sort the array.
Secondly, we can observe that if there are two operations A_{l_1 \ldots r_1} and A_{l_2 \ldots r_2}, where intervals [l_1, r_1] and [l2, r2] even share an endpoint, we can prove that applying one operation on range [min(l_1, l_2), max(r_1, r_2)] is sufficient to sort the both segments, with lower cost.
Hence, no position might be covered in more than one operation.
Why overlapping segments aren't useful
Let’s assume two intervals in question are [l_1, r_1] and [l_2, r_2]. WLOG assume l_1 \leq l_2 and r_1 \geq l_2, since we are talking about overlapping intervals.
- Case 1: r_2 \leq r_1, which implies second interval is completely inside first interval. Let’s apply operation on segment [l_1, r_1], which sorts the whole range. Since the range [l_2, r_2] is already sorted, it isn’t optimal to incur cost to sort an already sorted segment.
- Case 2: r_2 \gt r_1, The cost of applying these two operations is max(A_{l_1 \ldots r_1}) - min(A_{l_1 \ldots r_1}) + max(A_{l_2 \ldots r_2}) - min(A_{l_2 \ldots r_2}). The cost to apply one operation on segment [l_1, r_2] is max(A_{l_1 \ldots r_2}) - min(A_{l_1 \ldots r_2}).
We can see that max(A_{l_1 \ldots r_2}) is maximum of max of two segments, and min(A_{l_1 \ldots r_2}) is minimum of min of two segments. Hence, The additional cost we are incurring by two operations is D = min(max(A_{l_1 \ldots r_1}), max(A_{l_2 \ldots r_2})) - max(min(A_{l_1 \ldots r_1}), min(A_{l_2 \ldots r_2}).
Since the intervals overlaps, the D is guaranteed to be non-negative. Hence, by performing two operations instead of a single operation on union of intervals, we are incurring a non-negative additional cost, which is not optimal.
Now, the task becomes to partition A into continuous partitions, where the operation would be applied on each partition once. The partitions would be chosen so as to maximize the number of partitions while ensuring that once all partitions are sorted, A becomes sorted.
For example, for A = [1,2,5,6,4,3,7,9,8], we can partition it like \{[1],[2],[5,6,4,3],[7],[9,8]\} and apply operation on each partition, leading to \{[1],[2],[3,4,5,6],[7],[8,9]\} in cost 6-3+9-8 = 4.
In order to find these partitions, we need to find for each x, whether the first x elements in A form the same multiset as the smallest x elements of A. In the above example, positions satisfying that condition are 1,2,6,7,9, each of which forms a partition ending at that position. Let’s call these good positions.
Implementation
All we need to do is to maintain two multisets, A-B, and B-A. After considering the first i positions, A-B would store elements present in A not present in B, while B-A denote elements present in B not present in A. At each good position, both A-B and B-A become empty set by definition.
To add (i+1)-th element, just add A_{i+1} to A-B and (i+1)-th smallest element in B-A and remove intersection of A-B and B-A from both sets.
TIME COMPLEXITY
The time complexity is O(N*log(N)) per test case due to sorting.
SOLUTIONS
Setter's Solution
'''Author- Akshit Monga'''
from sys import stdin, stdout
input = stdin.readline
class SortedList:
def __init__(self, iterable=[], _load=200):
"""Initialize sorted list instance."""
values = sorted(iterable)
self._len = _len = len(values)
self._load = _load
self._lists = _lists = [values[i:i + _load] for i in range(0, _len, _load)]
self._list_lens = [len(_list) for _list in _lists]
self._mins = [_list[0] for _list in _lists]
self._fen_tree = []
self._rebuild = True
def _fen_build(self):
"""Build a fenwick tree instance."""
self._fen_tree[:] = self._list_lens
_fen_tree = self._fen_tree
for i in range(len(_fen_tree)):
if i | i + 1 < len(_fen_tree):
_fen_tree[i | i + 1] += _fen_tree[i]
self._rebuild = False
def _fen_update(self, index, value):
"""Update `fen_tree[index] += value`."""
if not self._rebuild:
_fen_tree = self._fen_tree
while index < len(_fen_tree):
_fen_tree[index] += value
index |= index + 1
def _fen_query(self, end):
"""Return `sum(_fen_tree[:end])`."""
if self._rebuild:
self._fen_build()
_fen_tree = self._fen_tree
x = 0
while end:
x += _fen_tree[end - 1]
end &= end - 1
return x
def _fen_findkth(self, k):
"""Return a pair of (the largest `idx` such that `sum(_fen_tree[:idx]) <= k`, `k - sum(_fen_tree[:idx])`)."""
_list_lens = self._list_lens
if k < _list_lens[0]:
return 0, k
if k >= self._len - _list_lens[-1]:
return len(_list_lens) - 1, k + _list_lens[-1] - self._len
if self._rebuild:
self._fen_build()
_fen_tree = self._fen_tree
idx = -1
for d in reversed(range(len(_fen_tree).bit_length())):
right_idx = idx + (1 << d)
if right_idx < len(_fen_tree) and k >= _fen_tree[right_idx]:
idx = right_idx
k -= _fen_tree[idx]
return idx + 1, k
def _delete(self, pos, idx):
"""Delete value at the given `(pos, idx)`."""
_lists = self._lists
_mins = self._mins
_list_lens = self._list_lens
self._len -= 1
self._fen_update(pos, -1)
del _lists[pos][idx]
_list_lens[pos] -= 1
if _list_lens[pos]:
_mins[pos] = _lists[pos][0]
else:
del _lists[pos]
del _list_lens[pos]
del _mins[pos]
self._rebuild = True
def _loc_left(self, value):
"""Return an index pair that corresponds to the first position of `value` in the sorted list."""
if not self._len:
return 0, 0
_lists = self._lists
_mins = self._mins
lo, pos = -1, len(_lists) - 1
while lo + 1 < pos:
mi = (lo + pos) >> 1
if value <= _mins[mi]:
pos = mi
else:
lo = mi
if pos and value <= _lists[pos - 1][-1]:
pos -= 1
_list = _lists[pos]
lo, idx = -1, len(_list)
while lo + 1 < idx:
mi = (lo + idx) >> 1
if value <= _list[mi]:
idx = mi
else:
lo = mi
return pos, idx
def _loc_right(self, value):
"""Return an index pair that corresponds to the last position of `value` in the sorted list."""
if not self._len:
return 0, 0
_lists = self._lists
_mins = self._mins
pos, hi = 0, len(_lists)
while pos + 1 < hi:
mi = (pos + hi) >> 1
if value < _mins[mi]:
hi = mi
else:
pos = mi
_list = _lists[pos]
lo, idx = -1, len(_list)
while lo + 1 < idx:
mi = (lo + idx) >> 1
if value < _list[mi]:
idx = mi
else:
lo = mi
return pos, idx
def add(self, value):
"""Add `value` to sorted list."""
_load = self._load
_lists = self._lists
_mins = self._mins
_list_lens = self._list_lens
self._len += 1
if _lists:
pos, idx = self._loc_right(value)
self._fen_update(pos, 1)
_list = _lists[pos]
_list.insert(idx, value)
_list_lens[pos] += 1
_mins[pos] = _list[0]
if _load + _load < len(_list):
_lists.insert(pos + 1, _list[_load:])
_list_lens.insert(pos + 1, len(_list) - _load)
_mins.insert(pos + 1, _list[_load])
_list_lens[pos] = _load
del _list[_load:]
self._rebuild = True
else:
_lists.append([value])
_mins.append(value)
_list_lens.append(1)
self._rebuild = True
def discard(self, value):
"""Remove `value` from sorted list if it is a member."""
_lists = self._lists
if _lists:
pos, idx = self._loc_right(value)
if idx and _lists[pos][idx - 1] == value:
self._delete(pos, idx - 1)
def remove(self, value):
"""Remove `value` from sorted list; `value` must be a member."""
_len = self._len
self.discard(value)
if _len == self._len:
raise ValueError('{0!r} not in list'.format(value))
def pop(self, index=-1):
"""Remove and return value at `index` in sorted list."""
pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
value = self._lists[pos][idx]
self._delete(pos, idx)
return value
def bisect_left(self, value):
"""Return the first index to insert `value` in the sorted list."""
pos, idx = self._loc_left(value)
return self._fen_query(pos) + idx
def bisect_right(self, value):
"""Return the last index to insert `value` in the sorted list."""
pos, idx = self._loc_right(value)
return self._fen_query(pos) + idx
def count(self, value):
"""Return number of occurrences of `value` in the sorted list."""
return self.bisect_right(value) - self.bisect_left(value)
def __len__(self):
"""Return the size of the sorted list."""
return self._len
def __getitem__(self, index):
"""Lookup value at `index` in sorted list."""
pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
return self._lists[pos][idx]
def __delitem__(self, index):
"""Remove value at `index` from sorted list."""
pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
self._delete(pos, idx)
def __contains__(self, value):
"""Return true if `value` is an element of the sorted list."""
_lists = self._lists
if _lists:
pos, idx = self._loc_left(value)
return idx < len(_lists[pos]) and _lists[pos][idx] == value
return False
def __iter__(self):
"""Return an iterator over the sorted list."""
return (value for _list in self._lists for value in _list)
def __reversed__(self):
"""Return a reverse iterator over the sorted list."""
return (value for _list in reversed(self._lists) for value in reversed(_list))
def __repr__(self):
"""Return string representation of sorted list."""
return 'SortedList({0})'.format(list(self))
t = int(input())
for _ in range(t):
n=int(input())
arr=[int(x) for x in input().split()]
p=[-1 for i in range(n)]
d=sorted([(arr[i],i) for i in range(n)])
c=1
for i in d:
p[i[1]]=c
c+=1
eles=SortedList([i for i in range(1,n+1)])
mini=float('inf')
maxi=-float('inf')
ans=0
for i in p:
mini=min(mini,i)
maxi=max(maxi,i)
eles.remove(i)
if not len(eles) or eles[0]==maxi+1:
ans+=d[maxi-1][0]-d[mini-1][0]
mini=float('inf')
maxi=-float('inf')
print(ans)
Tester's Solution
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());
int main()
{
ios::sync_with_stdio(0); cin.tie(0);
int t; cin >> t;
while (t--) {
int n; cin >> n;
vector<array<int, 2>> a(n);
for (int i = 0; i < n; ++i) {
cin >> a[i][0];
a[i][1] = i;
}
sort(begin(a), end(a));
vector<int> p(n);
for (int i = 0; i < n; ++i)
p[a[i][1]] = i;
ll ans = 0;
set<int> active;
for (int i = 0; i < n; ++i)
active.insert(i);
int mn = INT_MAX, mx = INT_MIN;
for (auto i : p) {
active.erase(i);
mn = min(mn, i);
mx = max(mx, i);
if (active.empty() or *active.begin() > mx) {
ans += a[mx][0] - a[mn][0];
mx = INT_MIN;
mn = INT_MAX;
}
}
cout << ans << '\n';
}
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class Main{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
int N = ni();
int[] A = new int[N], B = new int[N];
for(int i = 0; i< N; i++)A[i] = B[i] = ni();
Arrays.sort(B);
MyTreeSet<Integer> AB = new MyTreeSet<>(), BA = new MyTreeSet<>();
//AB -> multiset of elements present in A not present in B
//BA -> multiset of elements present in B not present in A
int st = 0;
int ans = 0;
for(int i = 0; i< N; i++){
add(AB, BA, A[i]);
add(BA, AB, B[i]);
if(same(AB, BA)){
ans += B[i]-B[st];
st = i+1;
}
}
pn(ans);
}
boolean same(MyTreeSet<Integer> s1, MyTreeSet<Integer> s2){
return s1.isEmpty() && s2.isEmpty();
}
//Remove x from s2 if present, otherwise add to s1
void add(MyTreeSet<Integer> s1, MyTreeSet<Integer> s2, int x){
if(s2.contains(x))s2.remove(x);
else s1.add(x);
}
//Multiset in java, equivalent of c++ multiset
class MyTreeSet<T> implements Iterable<T>{
private int size;
private TreeMap<T, Integer> map;
public MyTreeSet(){
size = 0;
map = new TreeMap<>();
}
public int size(){return size;}
public int dsize(){return map.size();}
public boolean isEmpty(){return size==0;}
public void add(T t){
size++;
map.put(t, map.getOrDefault(t, 0)+1);
}
public boolean remove(T t){
if(!map.containsKey(t))return false;
size--;
int c = map.get(t);
if(c==1)map.remove(t);
else map.put(t, c-1);
return true;
}
public int freq(T t){return map.getOrDefault(t, 0);}
public boolean contains(T t){return map.getOrDefault(t,0)>0;}
public T ceiling(T t){return map.ceilingKey(t);}
public T floor(T t){return map.floorKey(t);}
public T first(){return map.firstKey();}
public T last(){return map.lastKey();}
public Iterator<T> iterator() {
return new MyTreeSetIterator<>(this);
}
class MyTreeSetIterator<T> implements Iterator<T>{
TreeMap<T, Integer> mp;
T element = null;
int cur = 0, freq = 0;
MyTreeSetIterator(MyTreeSet<T> obj){
this.mp = obj.map;
if(!this.mp.isEmpty()){
element = mp.firstKey();
freq = mp.firstEntry().getValue();
}
}
public boolean hasNext(){
return element != null;
}
public T next(){
T ret = element;
cur++;
if(cur == freq){
Map.Entry<T, Integer> e = mp.higherEntry(element);
if(e == null){
element = null;
freq = 0;
}else{
element = e.getKey();
freq = e.getValue();
cur = 0;
}
}
return ret;
}
}
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new Main().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Feel free to share your approach. Suggestions are welcomed as always.