MININV - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: jeevanjyot
Testers: nishant403, satyam_343
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Frequency tables

PROBLEM:

You have an array A. Exactly once, you can choose a subarray [L, R] and increase all its elements by 1.

Suppose the final array is B. Find the maximum value of inv(A) - inv(B).

EXPLANATION:

Let’s analyze what choosing subarray [L, R] does for the number of inversions, and whether this tells us anything about what choices of L and/or R can possibly be optimal.

Analysis

Let’s denote the prefix [1, L-1] by P, the subarray [L, R] by M, and the suffix [R+1, N] by S.

Now, consider some pair (i, j). Let’s do a bit of casework.

  • If i\not\in M and j\not\in M, A_i and A_j both don’t change, so this pair’s contribution to the number of inversions doesn’t change.
  • If i\in M and j\in M, both values increase by 1 so once again, its contribution to the number of inversions doesn’t change.
  • If i\in M and j\in S, then only A_i increases by 1.
    • If (i, j) was already an inversion (i.e, initially A_i \gt A_j), it continues to remain one.
    • If A_i = A_j initially, this pair creates a new inversion
    • If A_i \lt A_j initially, this pair continues to not be an inversion
  • If i\in P and j \in M, then only A_j increases by 1.
    • If A_i \leq A_j, this pair continues to not be an inversion.
    • If A_i \gt A_j+1, this pair continues to be an inversion.
    • If A_i = A_j+1, this pair stops being an inversion.

From this, we see that the only way to reduce inversions is between P and M; while interactions between M and S are bad because they can increase the number of inversions.
In particular, it’d be nice if P and M were as large as possible, while S was as small as possible.

This is easy to achieve: simply choose R = N always, so the suffix S will be empty!
However, we can’t yet say anything about L.

Now that we’ve fixed R to always be N, we need to find which value of L is optimal.
Checking each one in \mathcal{O}(N) (or worse) is obviously too slow.

Instead, let’s be a bit smarter.
Suppose we (somehow) knew the answer for [L, N] (that is, you know how many inversions it reduces).
Can we then compute the answer for [L+1, N]?

Yes we can!

Recall from our previous analysis that the only reductions in inversions come from pairs (i, j) such that A_i = A_j+1.

When moving from L to L+1, we’re essentially moving the element A_{L+1} from the subarray M to the subarray P. So,

  • if P contains x occurrences of A_{L+1}+1, these x positions were originally reduced inversions with position L+1, but they are no longer reduced. So, decrease the current answer by x.
  • On the other hand, if there are y occurrences of A_{L+1}-1 in M, these y positions now are reduced inversions with position L+1, so increase the current answer by y.

We need to be able to quickly compute x and y. Note that they’re both frequencies.

So, maintain two frequency tables: one corresponding to P and one corresponding to M.
When moving from L to L+1, update the frequencies appropriately: this takes one operation in each table, after which both x and y can be obtained in \mathcal{O}(N) by just looking at the appropriate frequency table.

This allows us to move from [L, N] to [L+1, N] in \mathcal{O}(1) time; updating the answer along the way.
So, start from L = 1 (for which computing the answer is trivial), and then increase L till N; each time computing the answer for that suffix.

The final answer is the maximum among everything computed.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Setter's code (C++)
#ifdef WTSH
    #include <wtsh.h>
#else
    #include <bits/stdc++.h>
    using namespace std;
    #define dbg(...)
#endif

#define int long long
#define endl "\n"
#define sz(w) (int)(w.size())
using pii = pair<int, int>;

// -------------------- Input Checker Start --------------------

long long readInt(long long l, long long r, char endd)
{
    long long x = 0;
    int cnt = 0, fi = -1;
    bool is_neg = false;
    while(true)
    {
        char g = getchar();
        if(g == '-')
        {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if('0' <= g && g <= '9')
        {
            x *= 10;
            x += g - '0';
            if(cnt == 0)
                fi = g - '0';
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);
            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if(g == endd)
        {
            if(is_neg)
                x = -x;
            if(!(l <= x && x <= r))
            {
                cerr << "L: " << l << ", R: " << r << ", Value Found: " << x << '\n';
                assert(false);
            }
            return x;
        }
        else
        {
            assert(false);
        }
    }
}

string readString(int l, int r, char endd)
{
    string ret = "";
    int cnt = 0;
    while(true)
    {
        char g = getchar();
        assert(g != -1);
        if(g == endd)
            break;
        cnt++;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}

long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
void readEOF() { assert(getchar() == EOF); }

vector<int> readVectorInt(int n, long long l, long long r)
{
    vector<int> a(n);
    for(int i = 0; i < n - 1; i++)
        a[i] = readIntSp(l, r);
    a[n - 1] = readIntLn(l, r);
    return a;
}

// -------------------- Input Checker End --------------------

int sumN = 0;

void solve()
{
    int n = readIntLn(1, 1e5);
    vector<int> a = readVectorInt(n, 1, n);
    vector<int> pfreq(n + 2), sfreq(n + 2);
    for(int i = 0; i < n; i++)
        pfreq[a[i]]++;
    int ans = 0, cur = 0;
    for(int i = n - 1; i >= 0; i--)
    {
        // changing a[i] to a[i] + 1
        cur -= sfreq[a[i] - 1];
        cur += pfreq[a[i] + 1];
        ans = max(ans, cur);
        sfreq[a[i]]++;
        pfreq[a[i]]--;
    }
    cout << ans << endl;
}

int32_t main()
{
    ios::sync_with_stdio(0); 
    cin.tie(0);
    int T = readIntLn(1, 1e5);
    for(int tc = 1; tc <= T; tc++)
    {
        // cout << "Case #" << tc << ": ";
        solve();
    }
    assert(sumN <= 2e5);
    readEOF();
    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/

#define int long long 
 
const int MAX_T = 1e5;
const int MAX_N = 1e5;
const int MAX_SUM_N = 2e5;

#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)

int sum_n = 0;
int max_n = 0;
int sum_ans = 0;

void solve()
{
    int n;
    n = readIntLn(1, MAX_N);
    
    sum_n += n;
    assert(sum_n <= MAX_SUM_N);
    max_n = max(max_n,n);
    
    int a[n];
    for(int i=0;i<n;i++) {
        if(i != n - 1) {
            a[i] = readIntSp(1 , n);
        } else {
            a[i] = readIntLn(1 , n);
        }
    }
    
   vector<int> before(n+2,0),after(n+2,0);
    
    for(int i=0;i<n;i++) {
        before[a[i]]++;
    }
    
   int result = 0;
   int cur_change = 0;
   int best_ind = -1;
    
    for(int i=n-1;i>=0;i--) {
        cur_change -= after[a[i]];
        
        after[a[i] + 1]++;
        before[a[i]]--;
    
        cur_change += before[a[i] + 1];
    
        result = max(result , cur_change);
        
        if(result == cur_change) {
            best_ind = i;
        }
    }
    
    sum_ans += result;
    
    cerr << "N : " << n << " best ind : " << best_ind << '\n';
    
    cout << result << '\n';
}
 
signed main()
{
    int t = 1;
    t = readIntLn(1,MAX_T);
    
    for(int i=1;i<=t;i++)
    {     
       solve();
    }
    
    assert(getchar() == -1);
 
    cerr<<"SUCCESS\n";
    cerr<<"Tests : " << t << '\n';
    cerr<<"Maximum N : " << max_n << '\n';
    cerr<<"Sum of N : " << sum_n << '\n';
    cerr<<"Sum of answer : " << sum_ans << '\n';
}
Editorialist's code (Python)
for _ in range(int(input())):
	n = int(input())
	a = list(map(int, input().split()))
	pref_freq = [0]*(n+2)
	suf_freq = [0]*(n+2)
	for x in a: pref_freq[x] += 1
	ans = 0
	cur = 0
	for i in reversed(range(1, n)):
		cur -= pref_freq[a[i]] * suf_freq[a[i]-1] + pref_freq[a[i]+1] * suf_freq[a[i]]
		suf_freq[a[i]] += 1
		pref_freq[a[i]] -= 1
		cur += pref_freq[a[i]] * suf_freq[a[i]-1] + pref_freq[a[i]+1] * suf_freq[a[i]]
		ans = max(ans, cur)
	print(ans)
6 Likes

why for loop is from n-1 to 1, rather than 1 to n
We are first finding for [L,R] and then [L+1,R], so why are we not traversing from left to right?

The editorialist explains their own code. Problem setter’s code can differ arbitrarily from what’s in the editorial.

it gives WA in 6’th test !!!

If you think about it a bit, is there really any difference between iterating from 1 to N or N to 1?
If you’ve understood what’s written in the editorial, then the idea is exactly the same: only the contribution of one element changes, and keeping prefix and suffix frequency arrays allows us to process this change in constant time.

Literally the only thing that changes is the formula used to update the answer, and that’s a minor change at best because the idea is exactly the same; I’m sure you can derive it yourself quite quickly.

Mostly true, but there are some cases where I don’t do this; usually if I solve the problem one way and find a simpler (to explain) solution afterwards.
Sometimes I’m just too lazy/don’t have time to rewrite my code to match the editorial’s explanation, so if it’s similar enough (like in this problem) I leave it in. That seems like a reasonable compromise to me.