COUNTSUB - Editorial


Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: ro27
Tester: jay_1048576
Editorialist: iceknight1093




Knowledge of maps/dictionaries OR sorting


You’re given an array A. Define the function f as

f(L, R) = \sum_{i=L}^{R-1} (A_i - A_{i+1})

with f(i, i) = 0.

Count the number of unstable subarrays, i.e, subarrays (L, R) such that f(L, R) \neq A_R - A_L.


Notice that if you expand it out, f(L, R) is a telescoping sum. That is,

\begin{align*} f(L, R) &= \sum_{i=L}^{R-1} (A_i - A_{i+1}) \\ &= (A_L - A_{L+1}) + (A_{L+1} - A_{L+2}) + \ldots + (A_{R-1} - A_R) \\ &= A_L + (-A_{L+1} + A_{L+1}) + (-A_{L+2} + A_{L+2}) + \ldots + (-A_{R-1} + A_{R-1}) - A_R \\ &= A_L - A_R \end{align*}

So, f(L, R) = A_L - A_R, which means a subarray is unstable if A_L - A_R \neq A_R - A_L.

Let’s instead count the number of subarrays that aren’t unstable, i.e, which satisfy A_L - A_R = A_R - A_L. We can then subtract this from the total number of subarrays, which equals \frac{N\cdot (N+1)}{2}
Notice that this equality simply translates to A_L = A_R.

So, our objective is count the number of subarrays whose endpoints are equal.
This is the same as counting the number of pairs of equal elements in the array!

We do need an algorithm to do this quickly, since \mathcal{O}(N^2) won’t cut it.
There are several ways of doing this faster, here’s one:

  • Let \text{freq} be a map/dict, such that \text{freq}[x] denotes the number of times we’ve seen x so far.
    Initially, \text{freq} is empty.
  • Iterate i from 1 to N. When processing index i:
    • Increase \text{freq}[A_i] by 1.
    • Then, add \text{freq}[A_i] to the answer, since that’s the number of indices to the left of i whose value equals A_i.

Depending on the kind of map used, this is \mathcal{O}(N\log N) or (expected) \mathcal{O}(N) time, which is fast enough.
There are other implementations that don’t need maps; for example using sorting.


\mathcal{O}(N) or \mathcal{O}(N\log N) per test case.


Editorialist's code (Python, dict)
for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    freq = {}
    ans = n*(n+1)//2
    for x in a:
        if x not in freq: freq[x] = 0
        freq[x] += 1
        ans -= freq[x]
Editorialist's code (Python, sorting)
for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    ans = n*(n+1)//2
    cur = 0
    for i in range(n):
        if i == 0 or a[i] != a[i-1]: cur = 1
        else: cur += 1
        ans -= cur
1 Like
 import java.util.*;
import java.lang.*;

class Codechef
	public static void main (String[] args) throws java.lang.Exception
		Scanner scn = new Scanner(;
		int t = scn.nextInt();
		while(t-- > 0){
		    int n = scn.nextInt();
		    long arr[] = new long[n];
		    HashMap<Long,Long> fre = new HashMap<Long,Long>();
		    for(int i = 0; i < n; i++)
		        arr[i] = scn.nextInt();
		    long sim = (n*(n-1))/2;
		    for(Map.Entry<Long,Long> i : fre.entrySet())
		            long c = i.getValue();
		            sim -= (c*(c-1))/2;

Why this code is giving wrong answer in last test case ??

this line overflows because n is an int.


Why is this submission of mine failing on the last test case?

Same issue as I pointed out in my above comment, n * (n-1)/2 will overflow if n is an int.


for _ in range(int(input())):
n = int(input())
a = list(map(int,input().split()))
c = 0
for i in range(n):
for j in range(i+1, n):
f = a[i] - a[j]
if f != a[j] - a[i]:
c += 1

Why this code is giving Time Limit Exceeded in 2 test case ?