ALL3 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iceknight1093
Tester: nskybytskyi
Editorialist: iceknight1093

DIFFICULTY:

Simple

PREREQUISITES:

None

PROBLEM:

You’re given an array A, with elements between 1 and 3.
Find the number of subarrays such that, when each element within it is increased by 1 (modulo 3), the resulting array still contains all of 1,2,3 at least once.

EXPLANATION:

Rather than count “good” subarrays which give us the result we want, let’s instead count “bad” subarrays, on which performing the operation leaves us with strictly less than 3 distinct elements.
This count can then be subtracted from the total number of subarrays, which equals \frac{N\cdot (N+1)}{2}.

One of the numbers must be missing after the operation - let’s fix this missing number to be 1, and see how many different operations cause this to happen.

First, since 1 is missing in the final array, every occurrence of 1 in A must lie within the subarray we choose - all of these occurrences will turn into 2 (and so every 1 will disappear).
Let l_1 and r_1 denote the leftmost and rightmost indices that contain a 1.
This observation tells us that we must surely have L \leq l_1 and R \geq r_1.

On the other hand, the subarray we choose also shouldn’t contain any occurrence of 3: if it did, the 3 would become 1 after the operation, which we don’t want.
So, there should be no 3's at indices between L and R.
In particular, observe that if there is a 3 between l_1 and r_1, no valid pair of (L, R) can ever exist.

Finally, we’re left with the case where there isn’t a 3 between l_1 and r_1.
Here, let x \lt l_1 be the closest occurrence of 3 to the left of l_1, and y\gt r_1 be the same to the right of r_1.
We can then choose any L such that x \lt L \leq l_1, and any R such that r_1 \leq R \lt y.
So, we obtain (l_1 - x) \cdot (y - r_1) subarrays in total.


The above discussion was for subarrays that eliminate 1 after the operation.
Repeat the process to count subarrays that eliminate 2 and 3, and add all three up to get the required number of “bad” subarrays.
Note that it’s impossible for more than one value to be eliminated (do you see why?), so there’s no overcounting going on here.

All necessary indices can be found in linear time by just iterating through the array, and we do three passes of the algorithm so the overall complexity remains linear.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector<int> a(n);
        for (int &x : a) cin >> x;
        ll ans = 0;
        for (int x : {1, 2, 3}) {
            int bad = x-1;
            if (bad == 0) bad = 3;

            int L = n, R = 0;
            for (int i = 0; i < n; ++i) {
                if (a[i] == x) {
                    L = min(i, L);
                    R = i;
                }
            }

            int L2 = L, R2 = R, good = 1;
            for (int i = L; i <= R; ++i) {
                good &= a[i] != bad;
            }
            
            while (L2 >= 0) {
                if (a[L2] == bad) break;
                --L2;
            }
            while (R2 < n) {
                if (a[R2] == bad) break;
                ++R2;
            }

            ans += 1ll * good * (R2 - R) * (L - L2);
        }
        cout << 1ll*n*(n+1)/2 - ans << '\n';
    }
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;

int64_t count_missing(int n, const vector<int>& a) {
    // Let's count how many operations result in all values being {1, 2}:
    // - all values 3 from the initial array must belong to the subarray
    int first_three = -1, last_three = -1;
    for (int i = 0; i < n; ++i) {
        if (a[i] == 3) {
            if (first_three == -1) {
                first_three = i;
            }
            last_three = i;
        }
    }

    // - no values 2 from the initial array can belong to the subarray
    for (int i = first_three; i < last_three; ++i) {
        if (a[i] == 2) {
            return 0;
        }
    }

    // - several consecutive values 1 before the first 3 or after the last 3
    // may belong to the subarray, giving us several choices according to the product rule
    int left = first_three, right = last_three;
    while (left > 0 && a[left - 1] == 1) {
        --left;
    }
    while (right + 1 < n && a[right + 1] == 1) {
        ++right;
    }
    return (first_three - left + 1ll) * (right - last_three + 1ll); 
}

int64_t solve(int n, vector<int> a) {
    // Lemma: the resulting array contains at least two distinct values
    // Proof: if the selected subarray only contained one value,
    // then the untouched part contrains at least two distinct values
    // If the selected subarray contrainted at least two distinct values,
    // then their images under the operation are also distinct

    auto all = ((n + 1ll) * n) / 2;
    // Hence the only possible bad results are {1, 2}, {2, 3}, and {3, 1}
    // These subproblems are equivalent to {1, 2}
    // if we apply the operation to the entire array
    for (int i = 0; i < 3; ++i) {
        all -= count_missing(n, a);
        for (auto& ai : a) {
            if (++ai > 3) {
                ai -= 3;
            }
        }
    }
    return all;
}

int main() {
    cin.tie(0)->sync_with_stdio(0);
    int t; cin >> t; while (t--) {
        int n;
        cin >> n;
        vector<int> a(n);
        for (auto& ai : a) {
            cin >> ai;
        }
        cout << solve(n, a) << '\n';
    }
}
3 Likes

I am trying to calculate the number of subarrays starting from each index. Not taking the ones which are causing problems. It’s not getting AC and i think i am missing something. Please let me know.
include <bits/stdc++.h>
using namespace std;
include <time.h>
include <stdlib.h>

//----------------------------------------------------------------
#ifndef ONLINE_JUDGE
include “template.cpp”
#else
define debug(…)
define debugArr(…)
#endif
//----------------------------------------------------------------

define int long long int
define pb push_back
define mp make_pair
define all(v) v.begin(),v.end()
define endl ‘\n’
define getunique(v) {sort(v.begin(),v.end());v.erase(unique(v.begin(),v.end()),v.end());}
define getunique1(v) {v.erase(unique(v.begin(),v.end()),v.end());}

const int mod = 1000000007;
//const int mod = 998244353;

long double tick(){static clock_t oldt; clock_t newt=clock();
long double diff = 1.0L*(newt-oldt)/CLOCKS_PER_SEC;oldt = newt; return diff;}

void fast_io(){
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
return;
}

signed main() {
int t;
cin>>t;
while (t–) {
int n;
cin>>n;
vector a(n);
vector ind1,ind2,ind3;
for(int i=0;i<n;i++){
cin>>a[i];
if(a[i]==1) ind1.pb(i);
if(a[i]==2) ind2.pb(i);
if(a[i]==3) ind3.pb(i);
}
int ans=0;
for(int i=0;i<a.size();i++){
int r1,r2,r3;
if(ind1[0]<i) r1=a.size()-1;
else{
r1=ind1[ind1.size()-1]-1;
}
if(ind2[0]<i) r2=a.size()-1;
else{
r2=ind2[ind2.size()-1]-1;
}
if(ind3[0]<i) r3=a.size()-1;
else{
r3=ind3[ind3.size()-1]-1;
}
int range = min(r1,min(r2,r3));
int r;
//check if all of 1,2,3 appear atleast once in indexes i to range using upper_bound on ind1,ind2,ind3
auto it1 = lower_bound(all(ind1),i);
auto it2 = lower_bound(all(ind2),i);
auto it3 = lower_bound(all(ind3),i);
if(it1!=ind1.end() && *it1<=range){
//ignore 2s range
r2=a.size()-1;
}
if(it2!=ind2.end() && *it2<=range){
//ignore 3s range
r3=a.size()-1;
}
if(it3!=ind3.end() && *it3<=range){
//ignore 1s range
r1=a.size()-1;
}
range = min(r1,min(r2,r3));
if(it1!=ind1.end() && it2!=ind2.end() && it3!=ind3.end()){
int start = max(*it1,max(*it2,*it3));
if(range>=start) ans+=a.size()-i;
if(range<start) ans+=a.size()-start + range-i+1;
}
else{
ans+=range-i+1;
}
}
cout<<ans<<endl;
}
return 0;
}

1 Like

I haven’t read your code but I ran it through my stress testing (I implemented a similar solution and had a lot of bugs), here a test for which your code fails:

1
4
2 3 1 3 

Expected answer is 6, but you output 5.

Thanks! I get it now.

nice Solution. I totally understand it. Is there any other to solve this q?

I guess you can also use DP

Note that it’s impossible for more than one value to be eliminated (do you see why?), so there’s no overcounting going on here.

This doesn’t seem tremendously obvious, might you elaborate?

1 Like

I included the proof in the comments of my code, here it is:

1 Like

Loved this problem! I thoroughly enjoyed solving it. Kudos to the setter :clap:

My ideas disrupting my problem-solving flow

vision-quest

To add to this: the constraint “A contains each of 1, 2, 3 at least once” was important for exactly this reason.