PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: ro27
Tester: jay_1048576
Editorialist: iceknight1093
DIFFICULTY:
1545
PREREQUISITES:
Knowledge of maps/dictionaries OR sorting
PROBLEM:
You’re given an array A. Define the function f as
with f(i, i) = 0.
Count the number of unstable subarrays, i.e, subarrays (L, R) such that f(L, R) \neq A_R - A_L.
EXPLANATION:
Notice that if you expand it out, f(L, R) is a telescoping sum. That is,
So, f(L, R) = A_L - A_R, which means a subarray is unstable if A_L - A_R \neq A_R - A_L.
Let’s instead count the number of subarrays that aren’t unstable, i.e, which satisfy A_L - A_R = A_R - A_L. We can then subtract this from the total number of subarrays, which equals \frac{N\cdot (N+1)}{2}
Notice that this equality simply translates to A_L = A_R.
So, our objective is count the number of subarrays whose endpoints are equal.
This is the same as counting the number of pairs of equal elements in the array!
We do need an algorithm to do this quickly, since \mathcal{O}(N^2) won’t cut it.
There are several ways of doing this faster, here’s one:
- Let \text{freq} be a
map
/dict
, such that \text{freq}[x] denotes the number of times we’ve seen x so far.
Initially, \text{freq} is empty. - Iterate i from 1 to N. When processing index i:
- Increase \text{freq}[A_i] by 1.
- Then, add \text{freq}[A_i] to the answer, since that’s the number of indices to the left of i whose value equals A_i.
Depending on the kind of map used, this is \mathcal{O}(N\log N) or (expected) \mathcal{O}(N) time, which is fast enough.
There are other implementations that don’t need maps; for example using sorting.
TIME COMPLEXITY
\mathcal{O}(N) or \mathcal{O}(N\log N) per test case.
CODE:
Editorialist's code (Python, dict)
for _ in range(int(input())):
n = int(input())
a = list(map(int, input().split()))
freq = {}
ans = n*(n+1)//2
for x in a:
if x not in freq: freq[x] = 0
freq[x] += 1
ans -= freq[x]
print(ans)
Editorialist's code (Python, sorting)
for _ in range(int(input())):
n = int(input())
a = list(map(int, input().split()))
a.sort()
ans = n*(n+1)//2
cur = 0
for i in range(n):
if i == 0 or a[i] != a[i-1]: cur = 1
else: cur += 1
ans -= cur
print(ans)