OPTSORT - Editorial


Contest Division 1
Contest Division 2
Contest Division 3

Setter: Akshit Monga
Tester: Nishank Suresh
Editorialist: Taranpreet Singh






Given a sequence A containing N integers, you can perform the following operation any number of times.

  • Select any continuous segment A_{l \ldots r} and sort the segment in non-decreasing order. The cost to perform this operation is max(A_{l \ldots r})-min(A_{l \ldots r}).

Determine the minimum cost to sort the sequence A in non-decreasing order.


  • The simplest solution would involve only one operation covering the whole array, leading to cost max(A)-min(A). But we can do better,
  • If there are two operations performed where a single element was covered by more than one operation, then it is always better to perform a single operation applied on the union of the segment of two operations.
  • So, the only operations needed are when the first x elements of A contain the same element multiset as the first x elements of sorted A.


The most basic solution would be a single operation where the whole array is covered. It’d cost max(A)-min(A) and sorts the array immediately. So we always have a solution to sort the array.

Secondly, we can observe that if there are two operations A_{l_1 \ldots r_1} and A_{l_2 \ldots r_2}, where intervals [l_1, r_1] and [l2, r2] even share an endpoint, we can prove that applying one operation on range [min(l_1, l_2), max(r_1, r_2)] is sufficient to sort the both segments, with lower cost.

Hence, no position might be covered in more than one operation.

Why overlapping segments aren't useful

Let’s assume two intervals in question are [l_1, r_1] and [l_2, r_2]. WLOG assume l_1 \leq l_2 and r_1 \geq l_2, since we are talking about overlapping intervals.

  • Case 1: r_2 \leq r_1, which implies second interval is completely inside first interval. Let’s apply operation on segment [l_1, r_1], which sorts the whole range. Since the range [l_2, r_2] is already sorted, it isn’t optimal to incur cost to sort an already sorted segment.
  • Case 2: r_2 \gt r_1, The cost of applying these two operations is max(A_{l_1 \ldots r_1}) - min(A_{l_1 \ldots r_1}) + max(A_{l_2 \ldots r_2}) - min(A_{l_2 \ldots r_2}). The cost to apply one operation on segment [l_1, r_2] is max(A_{l_1 \ldots r_2}) - min(A_{l_1 \ldots r_2}).
    We can see that max(A_{l_1 \ldots r_2}) is maximum of max of two segments, and min(A_{l_1 \ldots r_2}) is minimum of min of two segments. Hence, The additional cost we are incurring by two operations is D = min(max(A_{l_1 \ldots r_1}), max(A_{l_2 \ldots r_2})) - max(min(A_{l_1 \ldots r_1}), min(A_{l_2 \ldots r_2}).
    Since the intervals overlaps, the D is guaranteed to be non-negative. Hence, by performing two operations instead of a single operation on union of intervals, we are incurring a non-negative additional cost, which is not optimal.

Now, the task becomes to partition A into continuous partitions, where the operation would be applied on each partition once. The partitions would be chosen so as to maximize the number of partitions while ensuring that once all partitions are sorted, A becomes sorted.

For example, for A = [1,2,5,6,4,3,7,9,8], we can partition it like \{[1],[2],[5,6,4,3],[7],[9,8]\} and apply operation on each partition, leading to \{[1],[2],[3,4,5,6],[7],[8,9]\} in cost 6-3+9-8 = 4.

In order to find these partitions, we need to find for each x, whether the first x elements in A form the same multiset as the smallest x elements of A. In the above example, positions satisfying that condition are 1,2,6,7,9, each of which forms a partition ending at that position. Let’s call these good positions.


All we need to do is to maintain two multisets, A-B, and B-A. After considering the first i positions, A-B would store elements present in A not present in B, while B-A denote elements present in B not present in A. At each good position, both A-B and B-A become empty set by definition.

To add (i+1)-th element, just add A_{i+1} to A-B and (i+1)-th smallest element in B-A and remove intersection of A-B and B-A from both sets.


The time complexity is O(N*log(N)) per test case due to sorting.


Setter's Solution
'''Author- Akshit Monga'''
from sys import stdin, stdout
input = stdin.readline
class SortedList:
    def __init__(self, iterable=[], _load=200):
        """Initialize sorted list instance."""
        values = sorted(iterable)
        self._len = _len = len(values)
        self._load = _load
        self._lists = _lists = [values[i:i + _load] for i in range(0, _len, _load)]
        self._list_lens = [len(_list) for _list in _lists]
        self._mins = [_list[0] for _list in _lists]
        self._fen_tree = []
        self._rebuild = True
    def _fen_build(self):
        """Build a fenwick tree instance."""
        self._fen_tree[:] = self._list_lens
        _fen_tree = self._fen_tree
        for i in range(len(_fen_tree)):
            if i | i + 1 < len(_fen_tree):
                _fen_tree[i | i + 1] += _fen_tree[i]
        self._rebuild = False
    def _fen_update(self, index, value):
        """Update `fen_tree[index] += value`."""
        if not self._rebuild:
            _fen_tree = self._fen_tree
            while index < len(_fen_tree):
                _fen_tree[index] += value
                index |= index + 1
    def _fen_query(self, end):
        """Return `sum(_fen_tree[:end])`."""
        if self._rebuild:
        _fen_tree = self._fen_tree
        x = 0
        while end:
            x += _fen_tree[end - 1]
            end &= end - 1
        return x
    def _fen_findkth(self, k):
        """Return a pair of (the largest `idx` such that `sum(_fen_tree[:idx]) <= k`, `k - sum(_fen_tree[:idx])`)."""
        _list_lens = self._list_lens
        if k < _list_lens[0]:
            return 0, k
        if k >= self._len - _list_lens[-1]:
            return len(_list_lens) - 1, k + _list_lens[-1] - self._len
        if self._rebuild:
        _fen_tree = self._fen_tree
        idx = -1
        for d in reversed(range(len(_fen_tree).bit_length())):
            right_idx = idx + (1 << d)
            if right_idx < len(_fen_tree) and k >= _fen_tree[right_idx]:
                idx = right_idx
                k -= _fen_tree[idx]
        return idx + 1, k
    def _delete(self, pos, idx):
        """Delete value at the given `(pos, idx)`."""
        _lists = self._lists
        _mins = self._mins
        _list_lens = self._list_lens
        self._len -= 1
        self._fen_update(pos, -1)
        del _lists[pos][idx]
        _list_lens[pos] -= 1
        if _list_lens[pos]:
            _mins[pos] = _lists[pos][0]
            del _lists[pos]
            del _list_lens[pos]
            del _mins[pos]
            self._rebuild = True
    def _loc_left(self, value):
        """Return an index pair that corresponds to the first position of `value` in the sorted list."""
        if not self._len:
            return 0, 0
        _lists = self._lists
        _mins = self._mins
        lo, pos = -1, len(_lists) - 1
        while lo + 1 < pos:
            mi = (lo + pos) >> 1
            if value <= _mins[mi]:
                pos = mi
                lo = mi
        if pos and value <= _lists[pos - 1][-1]:
            pos -= 1
        _list = _lists[pos]
        lo, idx = -1, len(_list)
        while lo + 1 < idx:
            mi = (lo + idx) >> 1
            if value <= _list[mi]:
                idx = mi
                lo = mi
        return pos, idx
    def _loc_right(self, value):
        """Return an index pair that corresponds to the last position of `value` in the sorted list."""
        if not self._len:
            return 0, 0
        _lists = self._lists
        _mins = self._mins
        pos, hi = 0, len(_lists)
        while pos + 1 < hi:
            mi = (pos + hi) >> 1
            if value < _mins[mi]:
                hi = mi
                pos = mi
        _list = _lists[pos]
        lo, idx = -1, len(_list)
        while lo + 1 < idx:
            mi = (lo + idx) >> 1
            if value < _list[mi]:
                idx = mi
                lo = mi
        return pos, idx
    def add(self, value):
        """Add `value` to sorted list."""
        _load = self._load
        _lists = self._lists
        _mins = self._mins
        _list_lens = self._list_lens
        self._len += 1
        if _lists:
            pos, idx = self._loc_right(value)
            self._fen_update(pos, 1)
            _list = _lists[pos]
            _list.insert(idx, value)
            _list_lens[pos] += 1
            _mins[pos] = _list[0]
            if _load + _load < len(_list):
                _lists.insert(pos + 1, _list[_load:])
                _list_lens.insert(pos + 1, len(_list) - _load)
                _mins.insert(pos + 1, _list[_load])
                _list_lens[pos] = _load
                del _list[_load:]
                self._rebuild = True
            self._rebuild = True
    def discard(self, value):
        """Remove `value` from sorted list if it is a member."""
        _lists = self._lists
        if _lists:
            pos, idx = self._loc_right(value)
            if idx and _lists[pos][idx - 1] == value:
                self._delete(pos, idx - 1)
    def remove(self, value):
        """Remove `value` from sorted list; `value` must be a member."""
        _len = self._len
        if _len == self._len:
            raise ValueError('{0!r} not in list'.format(value))
    def pop(self, index=-1):
        """Remove and return value at `index` in sorted list."""
        pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
        value = self._lists[pos][idx]
        self._delete(pos, idx)
        return value
    def bisect_left(self, value):
        """Return the first index to insert `value` in the sorted list."""
        pos, idx = self._loc_left(value)
        return self._fen_query(pos) + idx
    def bisect_right(self, value):
        """Return the last index to insert `value` in the sorted list."""
        pos, idx = self._loc_right(value)
        return self._fen_query(pos) + idx
    def count(self, value):
        """Return number of occurrences of `value` in the sorted list."""
        return self.bisect_right(value) - self.bisect_left(value)
    def __len__(self):
        """Return the size of the sorted list."""
        return self._len
    def __getitem__(self, index):
        """Lookup value at `index` in sorted list."""
        pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
        return self._lists[pos][idx]
    def __delitem__(self, index):
        """Remove value at `index` from sorted list."""
        pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
        self._delete(pos, idx)
    def __contains__(self, value):
        """Return true if `value` is an element of the sorted list."""
        _lists = self._lists
        if _lists:
            pos, idx = self._loc_left(value)
            return idx < len(_lists[pos]) and _lists[pos][idx] == value
        return False
    def __iter__(self):
        """Return an iterator over the sorted list."""
        return (value for _list in self._lists for value in _list)
    def __reversed__(self):
        """Return a reverse iterator over the sorted list."""
        return (value for _list in reversed(self._lists) for value in reversed(_list))
    def __repr__(self):
        """Return string representation of sorted list."""
        return 'SortedList({0})'.format(list(self))
t = int(input())
for _ in range(t):
    arr=[int(x) for x in input().split()]
    p=[-1 for i in range(n)]
    d=sorted([(arr[i],i) for i in range(n)])
    for i in d:
    eles=SortedList([i for i in range(1,n+1)])
    for i in p:
        if not len(eles) or eles[0]==maxi+1:
Tester's Solution
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());
int main()
    ios::sync_with_stdio(0); cin.tie(0);
    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector<array<int, 2>> a(n);
        for (int i = 0; i < n; ++i) {
            cin >> a[i][0];
            a[i][1] = i;
        sort(begin(a), end(a));
        vector<int> p(n);
        for (int i = 0; i < n; ++i)
            p[a[i][1]] = i;
        ll ans = 0;
        set<int> active;
        for (int i = 0; i < n; ++i)
        int mn = INT_MAX, mx = INT_MIN;
        for (auto i : p) {
            mn = min(mn, i);
            mx = max(mx, i);
            if (active.empty() or *active.begin() > mx) {
                ans += a[mx][0] - a[mn][0];
                mx = INT_MIN;
                mn = INT_MAX;
        cout << ans << '\n';
Editorialist's Solution
import java.util.*;
import java.io.*;
class Main{
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        int N = ni();
        int[] A = new int[N], B = new int[N];
        for(int i = 0; i< N; i++)A[i] = B[i] = ni();
        MyTreeSet<Integer> AB = new MyTreeSet<>(), BA = new MyTreeSet<>();
        //AB -> multiset of elements present in A not present in B
        //BA -> multiset of elements present in B not present in A
        int st = 0;
        int ans = 0;
        for(int i = 0; i< N; i++){
            add(AB, BA, A[i]);
            add(BA, AB, B[i]);
            if(same(AB, BA)){
                ans += B[i]-B[st];
                st = i+1;
    boolean same(MyTreeSet<Integer> s1, MyTreeSet<Integer> s2){
        return s1.isEmpty() && s2.isEmpty();
    //Remove x from s2 if present, otherwise add to s1
    void add(MyTreeSet<Integer> s1, MyTreeSet<Integer> s2, int x){
        else s1.add(x);
    //Multiset in java, equivalent of c++ multiset
    class MyTreeSet<T> implements Iterable<T>{
        private int size;
        private TreeMap<T, Integer> map;
        public MyTreeSet(){
            size = 0;
            map = new TreeMap<>();
        public int size(){return size;}
        public int dsize(){return map.size();}
        public boolean isEmpty(){return size==0;}
        public void add(T t){
            map.put(t, map.getOrDefault(t, 0)+1);
        public boolean remove(T t){
            if(!map.containsKey(t))return false;
            int c = map.get(t);
            else map.put(t, c-1);
            return true;
        public int freq(T t){return map.getOrDefault(t, 0);}
        public boolean contains(T t){return map.getOrDefault(t,0)>0;}
        public T ceiling(T t){return map.ceilingKey(t);}
        public T floor(T t){return map.floorKey(t);}
        public T first(){return map.firstKey();}
        public T last(){return map.lastKey();}
        public Iterator<T> iterator() { 
            return new MyTreeSetIterator<>(this); 
        class MyTreeSetIterator<T> implements Iterator<T>{
            TreeMap<T, Integer> mp;
            T element = null;
            int cur = 0, freq = 0;
            MyTreeSetIterator(MyTreeSet<T> obj){
                this.mp = obj.map;
                    element = mp.firstKey();
                    freq = mp.firstEntry().getValue();
            public boolean hasNext(){
                return element != null;
            public T next(){
                T ret = element;
                if(cur == freq){
                    Map.Entry<T, Integer> e = mp.higherEntry(element);
                    if(e == null){
                        element = null;
                        freq = 0;
                        element = e.getKey();
                        freq = e.getValue();
                        cur = 0;
                return ret;
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = true;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
    public static void main(String[] args) throws Exception{
        new Main().run();
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
            return st.nextToken();

        String nextLine() throws Exception{
            String str = "";
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            return str;

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:


OMG i cant even comprehend how the Tester’s solution works … i solved a self made test case by this code and magically it just came to the right solution , and i don’t even understand where did that happen… someone explain the logic behind it


Cool problem. I struggled a lot to implement the idea :/. Anyways, it can also be done in O(n) - Since the partitions are disjoint this should hold: max of ith partition < min of (i + 1)th partition. So, we iterate from n - 1 to 0 and check if the current suffix min is strictly greater than the prefix max. If it is then we add max - min of current suffix to the answer. Here is my submission


Trust me you wouldn’t wanna know…

But I really wanna know … I even ran it by pen and paper …seems like they are first storing the initial positions before sorting … and then I think they are using vector p to put them in ascending order … but how does this come towards the solution is … kinda weird

I have seen the formal proof and it’s pretty long. I am trying to grasp an easier way of visualizing it instead (given by the same person who gave the formal proof).

My biggest problem was why taking the earliest matching prefix always is optimal… it may have been intuitive for some but it’s not for me.

Editorial has omitted a ton of details.

  1. No proof as to why taking overlapping segments is not optimal.
  2. No proof as to why breaking two matching segments / prefixes is better i.e taking the earliest matching segment.

I like the problem and indeed I thought in terms of breaking into earliest matching prefixes but it was too huge and absurd an assumption for me which I couldn’t prove even visually therefore I never went after it.


I want to know how the N*LogN Solution can work , according to time constraints number of testcases can go up to 10e5 and also N can go up to 10e5. Is not time complexity should be less then order O(n) in the worst case

1 Like

I am not able to figure out the problem in my solution. Someone helps me

ll n; read(n);
vector a(n), b(n);
rep(i, 0, n)
ll x; read(x); a[i] = x; b[i] = x;
ll l = -1, r = -1;
vector<pair<ll, ll>> p;
rep(i, 0, n)
if (a[i] == b[i]) {
if (l != -1) {r = i; p.pb({l, r}); l = -1; r = -1;}
else {
if (l == -1)
l = i;
ll total = 0;
if (l != -1) {r = n; p.pb({l, r});}
for (auto I : p)
// cout << I.ff << " " << I.ss << ed;
ll start = I.ff;
ll end = I.ss;
ll mi = 1e10;
ll mx = -1;
if (end - start <= 1) {continue;}
for (ll i = start; i < end; i++)
mi = min(a[start] , mi);
mx = max(a[start] , mx);
total += (mx - mi);
ll v = b[n - 1] - b[0];
if (p.size() >= 1 and total == 0) {cout << v << ed; return;}
else {
cout << min(total, v) << ed;


Sum of all the N over all test cases will not exceed 2*10^5

Is it mean NT would not be more than 210e5?

lol even though I didn’t participate I was also struggling while implementing the logic as it was a bit uneasy yet simple. I had done in NlogN finally. Btw can you elaborate the idea of yours a bit more? I would love to know in detail.

I feel the same way ,it was really unintuitive for me . In a normal codechef round I can reach solution to a problem with 100+ submissions . But the submissions for this one didn’t even made sense. 1000+ lol :joy:. Either I have gone stupid in last 2 days since the infinity round or the codechef audience has become super smart.


The Only Observation, You Need To Solve The Problem
Let’s suppose you would Sort A[L…R],
Now, this Sort is only suitable if L to R contains every value that it should contain in the Sorted version and Obviously, A[L] and A[R] must not contain their actual sorted values because if they are in their correct sorted position, then messing up with them is just gonna increase your cost.
Why ? Lets Discuss →

say you had decided to sort a segment from A[i…j] and calculated the minimum cost. And let’s say iterating further from j and now you are at “k” and you found a number A[k] that actually belongs to index “p” which belongs to the segment of A[i…j] according to the sorted order.
→ Now you see Does it make sense now to sort the section from A[p…k] just to move A[k] in the right position? Think about it :innocent:

Gave it a thought? Exactly! Now you can see if you perform the above thing then you would have gained extra cost ( How? you can figure it out yourself because one can understand this he/she tries)

Yep, spacially my college has cheater who writes in their bio something like
“My motto is to run O(n!) codes in O(1)”, I am like seriously bro :rofl:
I even have posted 2-3 times in the discussion section about cheating but deleted them because of unhealthy blaming and allegations publicly.

1 Like

And I seriously don’t know what the point of giving 3 hours in the first place. People will only do cheating in the last hour.


I mean cheating is very common these days in almost every contest , the only way to beat them is to get better. You can post blogs about them all you want but in the end nothing will happen.

For me , I want to reach 5 :star: only through my hardwork , If I cheat and do that , I know I won’t be satisfied because of the shame.


Yes, brother, We will reach 5 star one day eventually!!


never mind got the bug : )

if everyone thinks like you then there will be so much fun. cheating is like very very common in 3 hour contest. I even talked to um_nik regarding this topic to decrease the time from 3 hour to 2 hour. i once reached rating 1927 very close to 5 star but due to massive cheating i am now 3 star as i was not able to solve 3 question. i also want to reach 5 star but only from my hardwork.
reaching 5 star with cheating <<< reaching 4 star with hardwork.

1 Like

yep, I left coding for many days. I even once posted a blog about what are the suggestions for reaching 4 stars. Now when I actually become 4 star I was so proud but when I checked my college’s rank list there are even cheaters with 5 stars, I was so frustrated that I decided not to compare myself to anyone so that I don’t get demotivated by cheaters anymore.