MAXMINLEN - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Sorting, prefix sums

PROBLEM:

Given an array A containing distinct elements, count the number of its good subsequences, i.e, subsequences B for which \max(B) - \min(B) = |B|.

EXPLANATION:

Since A contains distinct elements, so will any subsequence of A.
Further, whether a subsequence is good or not depends only on its maximum, minimum, and length - so the order of elements within the subsequence is irrelevant.
This means the answer doesn’t change if A is sorted, so we do that first. We now only need to reason about sorted arrays.

Now, consider a sorted array B = [B_1, B_2, \ldots, B_K]. Let’s ascertain what it means for it to be good.
First off, the maximum and minimum are B_K and B_1 respectively, while the length is K.
So, we want B_K - B_1 = K.

Now, let’s use the fact that the elements are distinct - meaning B_i \gt B_{i-1} for every i.
That is, B_i \geq B_{i-1} + 1, or B_i - B_{i-1} \geq 1.
Now,

B_K - B_1 = (B_K - B_{K-1}) + (B_{K-1} - B_{K-2}) + \ldots + (B_3 - B_2) + (B_2 - B_1)

Each individual term on the right side is \geq 1, and there are K-1 of them.
So, we obtain B_K - B_1 \geq K-1.

Of course, we want it to be equal to K.
It’s easy to see that this only happens when B_i - B_{i-1} = 2 for some i, while all the other terms are 1.


We now have a nice criterion for when a sorted array of distinct elements is good: all the elements should be consecutive, except exactly one adjacent pair which should differ by 2.

With this in hand, counting valid subsequences becomes fairly simple.

  • Let’s fix i, and say that A_i is the element that differs by 2 from its next element.
  • If A_i + 2 doesn’t exist in A, of course no valid subsequence exists.
  • Otherwise, we can choose any segment of contiguous values ending at A_i, and any segment of contiguous values starting at A_i + 2.

To perform the last calculation, let’s define P_i to be the longest possible segment of contiguous values ending at A_i.
Similarly, let S_i be the longest possible segment of contiguous values starting from A_i.

It’s easy to see that if A_{i-1} + 1 = A_i we have P_i = P_{i-1} + 1, otherwise P_i = 1.
S_i can be similarly computed from S_{i+1}.

Now, if A_i is the element we fixed to have a difference of 2 with its neighbor, and j is the index of A_i + 2 in A, the number of subsequences we can choose is simply P_i \times S_j.

Add this up across all i to obtain the final answer.

TIME COMPLEXITY:

\mathcal{O}(N\log N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

void Solve() 
{
    int n; cin >> n;

    vector <int> a(n);
    for (auto &x : a) cin >> x;

    sort(a.begin(), a.end());
    
    vector <int> b(n);
    for (int i = 0; i < n; i++){
        b[i] = a[i] - i;
    }

    map <int, int> f;
    for (auto x : b){
        f[x]++;
    }

    int ans = 0;

    for (auto x : b){
        ans += f[x + 1];
    }

    for (auto [x, y] : f){
        for (int i = 1; i <= y; i++){
            ans += (i - 1) * (y - i);
        }
    }

    cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++)
#include<bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int32_t main() {
  ios_base::sync_with_stdio(0);
  cin.tie(0);

  input_checker inp;
  int T = inp.readInt(1, (int)1e4), NN = 0; inp.readEoln();
  while(T-- > 0) {
    int N = inp.readInt(1, (int)2e5); inp.readEoln();
    NN += N;
    vector<int> A = inp.readInts(N, 1, (int)1e9); inp.readEoln();
    vector<int> ord(N);
    iota(ord.begin(), ord.end(), 0);
    sort(ord.begin(), ord.end(), [&](int i, int j) {
      return A[i] < A[j];
    });

    int p1 = 0, p2 = 0;
    int64_t res = 0;
    auto get = [&](int64_t x) {
        return x * (x - 1) / 2;
    };
    for(int i = 0 ; i < N ; ++i) {
      while(A[ord[i]] - i > A[ord[p1]] - p1)
        ++p1;
      while(A[ord[i]] - i - 1 > A[ord[p2]] - p2)
        ++p2;
      res += p1 - p2 + get(i - p1);
    }
    cout << res << '\n';
  }
  
  return 0;
}

Editorialist's code (Python)
from collections import defaultdict
for _ in range(int(input())):
    n = int(input())
    a = sorted(map(int, input().split()))
    pref, suf = defaultdict(int), defaultdict(int)
    for x in a: pref[x] = pref[x-1] + 1
    for x in reversed(a): suf[x] = suf[x+1] + 1
    
    ans = 0
    for x in a: ans += pref[x] * suf[x+2]
    print(ans)
2 Likes

I thought of maintaining a set to store all elements and then iterate through all starting points, for each starting point remove it from the set and then check if (ele+2 … n) or (ele-2…0) exist in the set.

My Submission
Where did I go wrong?

What could be the solution if elements in A can repeat? Is there a solution faster than O(N^2)? :thinking:

Ah, very good. I see the trick now.

1 Like

This really is a good one!

I don’t think there is any solution faster than O(n^2) for repititive elements. This solution itself highly relies on a fluke observation. Pretty bad problem in my opinion.

1 Like

I take it back this problem is beautiful.

in the author code why this code is written?
for (int i = 0; i < n; i++){ b[i] = a[i] - i; }

Scanner sc=new Scanner(System.in);
	int t=sc.nextInt();
	while(t-->0){
	    int n=sc.nextInt();
	    long a[]=new long[n];
	    for(int i=0;i<n;i++){
	        a[i]=sc.nextLong();
	    }
	    Arrays.sort(a);
	    long ans=0;
	    long max=Integer.MIN_VALUE;
	    long min=Integer.MAX_VALUE;
	    for(int i=n-1;i>0;i--){
	        max=a[i];
	        for(int j=i-1;j>=0;j--){
	            min=a[j];
	            long diff=max-min;
	            long inBwt=i-j-1;
	            long m=diff-2;
	            long nn=inBwt;
	            if(m==nn){
	                ans++;
	            }else if(nn>m){
                            // ways to select m element from nn element  
	                   long way=1;
	                    m=Math.min(m,(m-nn));
                        for(int r=0;r<m;r++){
                            way*=(nn-r);
                            way/=(r+1);
                        }
                        ans+=way;
	            } 
	            }
	        }
	        System.out.println(ans);
	    }

What is wrong in this doesn’t able to figure it out.

Let’s boil down the intuition of sorting the elements (since it is independent of ordering)

After sorting:
If we choose an index i and iterate over the indices j(> i) to check if any subset of this subarray is good.
ai(min) ai+1…aj(max)
In the above subarray, we choose ai, aj, and some other elements. Now, whatever other
elements you choose max, min don’t change
So the diff = aj-ai if we choose a subset of len(>=2) then len = aj-ai, where len <=j-i+1
then aj-ai<=j-i+1
=> There is a possible good subset iff aj-ai <= j-i+1
Tip: In tackling any solution, try to find invariants
=> aj-j <= ai-i+1
=> bj <= bi-1
Now, you may get the reason why the author has chosen array b

can anybody tell where i am doing wrong
https://www.codechef.com/viewsolution/1073407746