NORMAL - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iceknight1093
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Easy

PREREQUISITES:

Prefix sums

PROBLEM:

An array is called normal if its mean and median are equal.
Given an array with elements between 1 and 3, count the number of its normal subarrays.

EXPLANATION:

The elements being between 1 and 3 is key to solving this task.

Since the mean must equal the median, and the median is an element of the array, there are really only three possibilities: the mean and median can equal 1, 2, or 3.
Let’s look at each one in turn.

First, suppose mean = median = 1.
Now, since the elements are all \geq 1, the only way a subarray can have a mean of 1 is if all its elements are 1 - the instant a subarray contains an element larger than 1, its mean will be larger than 1 too.
But if all the elements are 1, the median is also surely 1, so every such subarray satisfies the conditions.

The exact same analysis applies to mean = median = 3, in that this is only possible if every element is 3.

Counting the number of subarrays whose elements are all 1 is fairly straightforward: if you have a block of ones of length k, there are \frac{k\cdot (k+1)}{2} valid subarrays within it.

That leaves us with mean = median = 2.
First, let’s analyze what mean = 2 means for a subarray.
Suppose the subarray contains x_1 occurrences of 1, x_2 of 2, and x_3 of 3.
Then, for the mean to equal 2, we want

\begin{aligned} \frac{x_1 + 2x_2 + 3x_3}{x_1 + x_2 + x_3} &= 2 \\ x_1 + 2x_2 + 3x_3 &= 2x_1 + 2x_2 + 2x_3 \\ x_3 &= x_1 \end{aligned}

That is, we want the number of ones and threes to be equal; or in other words the difference between the number of ones and number of threes should be 0.


Let b_i = A_i - 2. Note that this maps [1, 2, 3] to [-1, 0, 1].

Observe that for the subarray [l, r], b_l + b_{l+1} + \ldots + b_r is exactly the difference between the number of ones and number of threes in it.
We’re looking for subarrays where this difference is 0.

Let p_i = b_1 + b_2 + \ldots + b_i, so p is the prefix sum array of b.
Then, we’re looking for pairs (l, r) such that p_r = p_{l-1}, since that’s exactly when the subarray sum between l and r is 0.

However, remember that there is one more condition to take care of here: the median being 2.
As it turns out, this is in fact quite simple: we already have an equal number of ones and threes, so the only condition for the median to be 2, is for 2 to exist in the subarray in the first place!

So, the subarray [l, r] has mean = median = 2 if and only if:

  1. p_{l-1} = p_r, and
  2. There’s at least one occurrence of 2 in the range [l, r] of A.

Let’s now try to count such subarrays.

Suppose we fix r, and try to count all valid l.
Let x \leq r be the largest index such that A_x = 2.
Then, we must have l \leq x, so that [l, r] contains a 2 at all.

But beyond that, any l \leq x that satisfies p_{l-1} = p_r will do.
This is fairly easy to count: for example, you can store a sorted list of indices corresponding to each value of p_i, and then binary search on this list to count the number of them that are \lt x.
Alternately, a two-pointer method will work too, the implementation details are fairly straightforward.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Tester's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;

template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;

#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl

#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)

template<typename T>
void amin(T &a, T b) {
    a = min(a,b);
}

template<typename T>
void amax(T &a, T b) {
    a = max(a,b);
}

#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif

/*



*/

const int MOD = 1e9 + 7;
const int N = 1e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;

void solve(int test_case)
{
    ll n; cin >> n;
    vector<ll> a(n+5);
    rep1(i,n) cin >> a[i];

    // only 1s/only 3s/equal #of 1s and 3s with a 2
    vector<ll> dp(n+5);
    rep1(i,n){
        if(a[i-1] == a[i]) dp[i] = dp[i-1]+1;
        else dp[i] = 1;
    }

    ll ans = 0;

    rep1(i,n){
        if(a[i] != 2){
            ans += dp[i];
        }
    }

    vector<ll> p(n+5);
    rep1(i,n){
        p[i] = p[i-1];
        if(a[i] == 1) p[i]--;
        if(a[i] == 3) p[i]++;
    }

    auto go = [&](ll l, ll r){
        if(l == -1 or l > r) return 0ll;
        map<ll,ll> mp;
        ll res = 0;
        for(int i = l; i <= r; ++i){
            mp[p[i-1]]++;
            res += mp[p[i]];
        }
        return res;
    };

    ans += go(1,n);
    ll sub = 0;
    ll l = -1;

    rep1(i,n){
        if(a[i] == 2){
            sub += go(l,i-1);
            l = -1;
        }
        else{
            if(l == -1){
                l = i;
            }
        }
    }

    sub += go(l,n);
    ans -= sub;

    cout << ans << endl;
}

int main()
{
    fastio;

    int t = 1;
    cin >> t;

    rep1(i, t) {
        solve(i);
    }

    return 0;
}
Editorialist's code (PyPy3)
for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))

    ans = 0

    pref = [0]*(n+1)
    for i in range(n):
        pref[i+1] = pref[i]
        if a[i] == 1: pref[i+1] += 1
        if a[i] == 3: pref[i+1] -= 1
    
    block = 0
    freq = [0]*(2*n + 10)
    prv = 0
    for i in range(n):
        if i > 0 and a[i] != a[i-1]: block = 0
        block += 1

        if a[i] != 2: ans += block
        else:
            while prv <= i:
                freq[pref[prv] + n] += 1
                prv += 1
        ans += freq[pref[i+1] + n]
    print(ans)

Given the amount of analysis required, this doesn’t atleast fall in easy category

3 Likes

This is NOT an easy problem.

1 Like

This problem is 2270 rated according to the c-list. It may be easier for the editorialist since he is a 6-star coder.

But it is easy, as the analysis is pretty straightforward. Its just take some time. Also, if you check the div1 contest submission count, you will understand. Its was solved by 60%+ contest givers.

explain me please why my O(n) solution is giving TLE both in java and c++.

code:

include
include
using namespace std;

int main() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);

 int t;
cin >> t;
while(t--){


int n;
cin >> n;
vector<int> arr(n);
for (int i = 0; i < n; i++) cin >> arr[i];

vector<long> left(n), right(n);
int mx = 3e5;
vector<int> fre(mx + n + 1, 0);
int cou = 0;

for (int i = 0; i < n; i++) {
    if (arr[i] == 2) {
        left[i] = (cou != 0) ? fre[cou + mx] - 1 : fre[cou + mx];
    } else {
        cou += (arr[i] == 1) ? 1 : -1;
        fre[mx + cou]++;
    }
}

cou = 0;
fre.assign(mx + n + 1, 0);
long ans=0;
for (int i = n - 1; i >= 0; i--) {
    if (arr[i] == 2) {
        right[i] = (cou != 0) ? fre[cou + mx] - 1 : fre[cou + mx];
    } else {
        cou += (arr[i] == 1) ? 1 : -1;
        fre[mx + cou]++;
    }
    ans += ((left[i] + 1) * (right[i] + 1) - 1);
}



int last=arr[0];int co=1;

for(int i=1;i<n;i++){
	if(arr[i]==arr[i-1])co++;
	else {
		ans+=((long) co *(co+1)/2);
		last=arr[i];co=1;
	}
}
ans+=((long) co *(co+1)/2);

cout << ans << '\n';

}
return 0;
}