TRICOUNT2 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Easy

PREREQUISITES:

Binary search (on sets)

PROBLEM:

You’re given an array A.
For each prefix of A, count the number of integers X such that (A_i, A_j, X) can be the side lengths of a non-degenerate triangle, for some distinct indices i, j.

EXPLANATION:

From the easy version, we already know how to solve the problem for a fixed array: each pair of adjacent (in sorted order) elements define a range of valid X, and we want to know the union of these ranges.

Since we now have to solve the problem for each prefix of A, let’s see how the answer changes when moving from one prefix to the next, i.e, how the answer for [A_1, A_2, \ldots, A_i] changes when A_{i+1} is added.

We care only about intervals defined by adjacent elements; so inserting A_{i+1} will create at most two pairs of elements we care about.
Specifically, let L be the largest element smaller than A_{i+1}, and R be the smallest element larger than A_{i+1}. We only care about the ranges defined by (L, A_{i+1}) and (A_{i+1}, R).

So, we need to be able to quickly add these two intervals in to our set, and then recompute the total length of all added intervals.
In detail,

  • Maintain a sorted set S that stores the union of all intervals added so far.
    In particular, S will store pairs of elements corresponding to disjoint maximal intervals.
  • When adding a new interval, say [x, y], to S:
    1. First, repeatedly find the first interval that starts \geq x in S.
      If this interval intersects [x, y], delete it from S (and extend [x, y] to include it).
    2. Then, there will be at most one interval that starts before x but intersects [x, y]; if such an interval exists, delete it and extend [x, y] again.
    3. In the end, insert [x, y] into S.
      Note that x and y are not necessarily the same values they started out as: they now represent a union of several intervals.

Note that to process N intervals, this algorithm inserts each of them into S once, and deletes each one at most once.
So, in total we perform at most N insertions and N deletions.
We also need to, several times, find the first interval that starts \geq x.
The appropriate data structure for this is a sorted set (std::set in C++, TreeSet in Java), which allows for quick insertion/deletion/binary search.

The overall complexity is \mathcal{O}(N\log N).

TIME COMPLEXITY:

\mathcal{O}(N\log N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

void Solve() 
{
    int n; cin >> n;
    
    vector <int> a(n);
    for (auto &x : a) cin >> x;
    
    set <int> st;
    int ans = 0;
    const int A = 2.1e9;
    st.insert(a[0]);
    
    set <pair<int, int>> b;
    
    auto add = [&](int x, int y){
        if (x > y) swap(x, y);
        
        int lo = y - x + 1, hi = y + x - 1;
        // need to add this interval 
        
        // check last interval in b 
        auto id = b.upper_bound({lo, 0});
        if (id != b.begin()){
            --id;
            if ((*id).second >= lo){
                ans -= (*id).second - (*id).first + 1;
                lo = (*id).first;
                hi = max(hi, (*id).second);
                b.erase(id);
            }
        }
        
        while (true){
            auto id = b.lower_bound({lo, 0});
            
            if (id == b.end()) break;
            if ((*id).first > hi){
                break;
            }
            
            hi = max(hi, (*id).second);
            ans -= (*id).second - (*id).first + 1;
            
            b.erase(id);
        }
        
        ans += hi - lo + 1;
        b.insert({lo, hi});
    };
    
    for (int i = 1; i < n; i++){
        auto id = st.upper_bound(a[i]);
        if (id != st.end()){
            add(*id, a[i]);
        }
        if (id != st.begin()){
            --id;
            add(*id, a[i]);
        }
        
        st.insert(a[i]);
        
        cout << ans << " \n"[i + 1 == n];
    } 
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;

template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;

#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl

#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)

template<typename T>
void amin(T &a, T b) {
    a = min(a,b);
}

template<typename T>
void amax(T &a, T b) {
    a = max(a,b);
}

#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif

/*



*/

const int MOD = 1e9 + 7;
const int N = 1e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;

void solve(int test_case)
{
    ll n; cin >> n;
    vector<ll> a(n+5);
    rep1(i,n) cin >> a[i];
    
    set<pll> segs;
    set<ll> st;
    ll ans = 0;

    auto ins = [&](ll l, ll r){
        if(l > r) return;
        auto it = segs.upper_bound({l+1,-1});
        vector<pll> del;

        if(it != segs.begin()){
            it--;
            auto [lx,rx] = *it;
            if(lx <= l and rx >= l){
                amin(l,lx);
                amax(r,rx);
                del.pb(*it);
            }
            it++;
        }

        for(; it != segs.end(); ++it){
            auto [lx,rx] = *it;
            if(lx > r) break;
            amin(l,lx);
            amax(r,rx);
            del.pb(*it);
        }

        for(auto [lx,rx] : del){
            segs.erase({lx,rx});
            ans -= rx-lx+1;
        }

        segs.insert({l,r});
        ans += r-l+1;
    };

    auto go = [&](ll x, ll y){
        ins(y,x+y-1);
        ins(max(x,y-x+1),y);
        ins(max(y-x+1,1ll),x);
    };

    rep1(i,n){
        st.insert(a[i]);
        auto it = st.find(a[i]);
        if(it != st.begin()){
            go(*prev(it),a[i]);
        }
        if(next(it) != st.end()){
            go(a[i],*next(it));
        }
        if(i >= 2){
            cout << ans << " ";
        }
    }

    cout << endl;
}

int main()
{
    fastio;

    int t = 1;
    cin >> t;

    rep1(i, t) {
        solve(i);
    }

    return 0;
}
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);
    
    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        set<array<int, 2>> intervals;
        int ans = 0;
        auto ins = [&] (int L, int R) {
            if (L >= R) return;
            auto it = intervals.lower_bound({L, R});
            while (it != intervals.end()) {
                auto [l, r] = *it;
                if (l > R) break;
                R = max(R, r);
                ans -= r-l;
                it = intervals.erase(it);
            }
            if (it != intervals.begin() && (--it)->at(1) >= L) {
                auto [l, r] = *it;
                L = min(L, l);
                R = max(R, r);
                ans -= r-l;
                intervals.erase(it);
            }
            ans += R-L;
            intervals.insert({L, R});
        };
        
        set<int> pts;
        auto process = [&] (int l, int r) {
            ins(r-l+1, r+l);
        };
        for (int i = 0; i < n; ++i) {
            int x; cin >> x;
            pts.insert(x);
            auto it = pts.find(x);

            if (it != begin(pts)) process(*prev(it), x);
            if (next(it) != end(pts)) process(x, *next(it));
            if (i) cout << ans << ' ';
        }
        cout << '\n';
    }
}
1 Like