PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: raysh07
Tester: apoorv_me
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Sorting, prefix sums
PROBLEM:
Given an array A containing distinct elements, count the number of its good subsequences, i.e, subsequences B for which \max(B) - \min(B) = |B|.
EXPLANATION:
Since A contains distinct elements, so will any subsequence of A.
Further, whether a subsequence is good or not depends only on its maximum, minimum, and length - so the order of elements within the subsequence is irrelevant.
This means the answer doesn’t change if A is sorted, so we do that first. We now only need to reason about sorted arrays.
Now, consider a sorted array B = [B_1, B_2, \ldots, B_K]. Let’s ascertain what it means for it to be good.
First off, the maximum and minimum are B_K and B_1 respectively, while the length is K.
So, we want B_K - B_1 = K.
Now, let’s use the fact that the elements are distinct - meaning B_i \gt B_{i-1} for every i.
That is, B_i \geq B_{i-1} + 1, or B_i - B_{i-1} \geq 1.
Now,
Each individual term on the right side is \geq 1, and there are K-1 of them.
So, we obtain B_K - B_1 \geq K-1.
Of course, we want it to be equal to K.
It’s easy to see that this only happens when B_i - B_{i-1} = 2 for some i, while all the other terms are 1.
We now have a nice criterion for when a sorted array of distinct elements is good: all the elements should be consecutive, except exactly one adjacent pair which should differ by 2.
With this in hand, counting valid subsequences becomes fairly simple.
- Let’s fix i, and say that A_i is the element that differs by 2 from its next element.
- If A_i + 2 doesn’t exist in A, of course no valid subsequence exists.
- Otherwise, we can choose any segment of contiguous values ending at A_i, and any segment of contiguous values starting at A_i + 2.
To perform the last calculation, let’s define P_i to be the longest possible segment of contiguous values ending at A_i.
Similarly, let S_i be the longest possible segment of contiguous values starting from A_i.
It’s easy to see that if A_{i-1} + 1 = A_i we have P_i = P_{i-1} + 1, otherwise P_i = 1.
S_i can be similarly computed from S_{i+1}.
Now, if A_i is the element we fixed to have a difference of 2 with its neighbor, and j is the index of A_i + 2 in A, the number of subsequences we can choose is simply P_i \times S_j.
Add this up across all i to obtain the final answer.
TIME COMPLEXITY:
\mathcal{O}(N\log N) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
void Solve()
{
int n; cin >> n;
vector <int> a(n);
for (auto &x : a) cin >> x;
sort(a.begin(), a.end());
vector <int> b(n);
for (int i = 0; i < n; i++){
b[i] = a[i] - i;
}
map <int, int> f;
for (auto x : b){
f[x]++;
}
int ans = 0;
for (auto x : b){
ans += f[x + 1];
}
for (auto [x, y] : f){
for (int i = 1; i <= y; i++){
ans += (i - 1) * (y - i);
}
}
cout << ans << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}
Tester's code (C++)
#include<bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...)
#endif
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}
int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
now++;
}
return now;
}
string readOne() {
assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
return res;
}
string readString(int minl, int maxl, const string &pattern = "") {
assert(minl <= maxl);
string res = readOne();
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int minv, int maxv) {
assert(minv <= maxv);
int res = stoi(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
long long res = stoll(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
auto readInts(int n, int minv, int maxv) {
assert(n >= 0);
vector<int> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readInt(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
auto readLongs(int n, long long minv, long long maxv) {
assert(n >= 0);
vector<long long> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readLong(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
};
int32_t main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
input_checker inp;
int T = inp.readInt(1, (int)1e4), NN = 0; inp.readEoln();
while(T-- > 0) {
int N = inp.readInt(1, (int)2e5); inp.readEoln();
NN += N;
vector<int> A = inp.readInts(N, 1, (int)1e9); inp.readEoln();
vector<int> ord(N);
iota(ord.begin(), ord.end(), 0);
sort(ord.begin(), ord.end(), [&](int i, int j) {
return A[i] < A[j];
});
int p1 = 0, p2 = 0;
int64_t res = 0;
auto get = [&](int64_t x) {
return x * (x - 1) / 2;
};
for(int i = 0 ; i < N ; ++i) {
while(A[ord[i]] - i > A[ord[p1]] - p1)
++p1;
while(A[ord[i]] - i - 1 > A[ord[p2]] - p2)
++p2;
res += p1 - p2 + get(i - p1);
}
cout << res << '\n';
}
return 0;
}
Editorialist's code (Python)
from collections import defaultdict
for _ in range(int(input())):
n = int(input())
a = sorted(map(int, input().split()))
pref, suf = defaultdict(int), defaultdict(int)
for x in a: pref[x] = pref[x-1] + 1
for x in reversed(a): suf[x] = suf[x+1] + 1
ans = 0
for x in a: ans += pref[x] * suf[x+2]
print(ans)