PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: raysh07
Tester: sushil2006
Editorialist: iceknight1093
DIFFICULTY:
Simple
PREREQUISITES:
Sorting
PROBLEM:
You’re given an array A.
Count the number of integers X such that (A_i, A_j, X) can be the side lengths of a non-degenerate triangle, for some distinct indices i, j.
EXPLANATION:
Recall the triangle inequality: (A_i, A_j, X) can be the sides of a non-degenerate triangle if and only if the sum of any two of them is larger than the third.
More precisely, the sum of the smaller two values should be larger than the maximum one (the other two inequalities are trivially true.)
Let’s fix a value of X, and see which A_i and A_j should be chosen.
For convenience, let’s sort the array A in ascending order; and we assume A_i \lt A_j (equivalently i \lt j)
There are three possibilities:
- A_i \lt A_j \leq X, i.e, X is the maximum.
Here, we want A_i + A_j to be as large as possible.
For that, it’s clearly best to choose A_i and A_j to be the two closest elements to X (though still smaller than it).
Note that this means j = i+1 will hold. - A_i \lt X \lt A_j
Here, we want A_i+X \gt A_j
It’s best for A_j to be as small as possible, and A_i to be as large as possible.
So, we choose A_j to be the smallest element larger than X, and A_i to be the largest element smaller than X.
Again, note that j = i+1. - X \leq A_i \lt A_j
This case is a bit trickier: it’s not necessary that we choose A_i and A_j as close to X as possible.
Nonetheless, note that if i is fixed, choosing j = i+1 is optimal.
Observe the common thread across all three cases: if it is possible to make X the side of a triangle, it will definitely be possible to do so using some consecutive pair of elements (A_i, A_{i+1}).
So, let’s switch our perspective.
Let’s fix the pair (A_i, A_{i+1}), and see which values of X can be triangle sides.
The three inequalities that need to be satisfied by X are:
Among them, the third is trivially true since A_{i+1} \gt A_i, so let’s focus only on the first two.
Combined, they tell us that
That is, we obtain a range of valid X.
By considering each adjacent pair of elements (in sorted order), we obtain N-1 ranges of X.
Any X that lies in any of these ranges is valid.
So, the answer we’re looking for is simply the length of the union of all these ranges.
Computing this is a fairly standard task, and is easily done with the help of sorting.
For instance, sort all the ranges by their left endpoints, then process them in this sorted order.
If the next segment extends the current one, do so; otherwise start a new segment.
An alternate algorithm can also be found here.
TIME COMPLEXITY:
\mathcal{O}(N\log N) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
void Solve()
{
int n; cin >> n;
vector <int> a(n);
for (auto &x : a) cin >> x;
set <int> st;
int ans = 0;
const int A = 2.1e9;
st.insert(a[0]);
set <pair<int, int>> b;
auto add = [&](int x, int y){
if (x > y) swap(x, y);
int lo = y - x + 1, hi = y + x - 1;
// need to add this interval
// check last interval in b
auto id = b.upper_bound({lo, 0});
if (id != b.begin()){
--id;
if ((*id).second >= lo){
ans -= (*id).second - (*id).first + 1;
lo = (*id).first;
hi = max(hi, (*id).second);
b.erase(id);
}
}
while (true){
auto id = b.lower_bound({lo, 0});
if (id == b.end()) break;
if ((*id).first > hi){
break;
}
hi = max(hi, (*id).second);
ans -= (*id).second - (*id).first + 1;
b.erase(id);
}
ans += hi - lo + 1;
b.insert({lo, hi});
};
for (int i = 1; i < n; i++){
auto id = st.upper_bound(a[i]);
if (id != st.end()){
add(*id, a[i]);
}
if (id != st.begin()){
--id;
add(*id, a[i]);
}
st.insert(a[i]);
}
cout << ans << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl
#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)
template<typename T>
void amin(T &a, T b) {
a = min(a,b);
}
template<typename T>
void amax(T &a, T b) {
a = max(a,b);
}
#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif
/*
*/
const int MOD = 1e9 + 7;
const int N = 1e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;
void solve(int test_case)
{
ll n; cin >> n;
vector<ll> a(n+5);
rep1(i,n) cin >> a[i];
sort(a.begin()+1,a.begin()+n+1);
vector<pll> segs;
rep1(i,n-1){
ll x = a[i], y = a[i+1];
segs.pb({y,x+y-1});
segs.pb({max(x,y-x+1),y});
segs.pb({max(y-x+1,1ll),x});
}
sort(all(segs));
ll ans = 0;
ll mxr = -1;
for(auto [l,r] : segs){
if(l > r) conts;
if(r > mxr){
ans += r-max(l,mxr+1)+1;
mxr = r;
}
}
cout << ans << endl;
}
int main()
{
fastio;
int t = 1;
cin >> t;
rep1(i, t) {
solve(i);
}
return 0;
}
Editorialist's code (PyPy3)
for _ in range(int(input())):
n = int(input())
a = sorted(list(map(int, input().split())))
intervals = []
def calc(l, r):
d = r-l
intervals.append((d+1, r))
intervals.append((r, l+r-1))
for i in range(n-1):
calc(a[i], a[i+1])
intervals.sort()
ans = 0
L, R = -1, -1
for l, r in intervals:
if l > R:
ans += max(0, R-max(L, 1)+1)
L, R = l, r
else: R = max(R, r)
ans += max(0, R-max(L, 1)+1)
print(ans)
# a[i] + x > a[i+1]
# x > a[i+1] - a[i]
# a[i] + a[i+1] > x