PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Setter: Vinit
Tester: Radoslav Dimitrov
Editorialist: Taranpreet Singh
DIFFICULTY
Medium
PREREQUISITES
Stack and Fenwick Tree or Segment Tree
PROBLEM
Given a sequence A of length N, Chef defines a contiguous subsequence A_{l \ldots r} as fruitful if |A_l-A_r| = max(A_{l \ldots r})-min(A_{l \ldots r})
Find the number of fruitful contiguous subsequences.
Note: A_{l \ldots r} denotes contiguous subsequence A_l, A_{l+1} \ldots A_{r-1}, A_r
QUICK EXPLANATION
- A contiguous subsequence A_{l \ldots r} is fruitful if and only if A_l and A_r are minimum and maximum on the subsequence. This implies either A_l = min(A_{l \ldots r}) and A_r = max(A_{l \ldots r}), or A_l = max(A_{l \ldots r}) and A_r = min(A_{l \ldots r})
- WLOG A_l \leq A_r, we need to count the number of contiguous subsequences with A_l being minimum and A_r being maximum.
- If prev_i = max(j : j < i \land A_j > A_i) and nxt_i = min(j: j > i \land A_j < A_l) , we need number of pairs (l, r) such that prev_r \leq l \leq r \leq nxt_l. prev and nxt can be calculated using stack.
- To find the number of valid pairs (l, r), we can iterate over l and update BIT only with r such that prev_r \leq r and query for number of r in range [l, nxt_l]
- For A_l > A_r case, we can reverse array and repeat above. We need to exclude fruitful subsequences where A_l = A_r
EXPLANATION
The first subtask is trivial, so jumping directly to the final subtask.
First off, let’s assume l \leq r and A_l \leq A_r. We can handle case where A_l > A_r by reversing the sequence and repeating the following process again and excluding double-counted subsequences.
We have min(A_{l \ldots r}) \leq A_l \leq A_r \leq max(A_{l \ldots r}), Hence, max(A_{l \ldots r})-min(A_{l \ldots r}) = (max(A_{l \ldots r}) - A_r) + (A_r-A_l) + (A_l - min(A_{l \ldots r}))
For A_r-A_l = max(A_{l \ldots r})-min(A_{l \ldots r}), we need max(A_{l \ldots r})-A_r = 0 and A_l -min(A_{l \ldots r}) = 0 which implies that A_l = min(A_{l \ldots r}) and A_r = max(A_{l \ldots r})
Hence, we need to compute the number of pairs (l, r) with A_l \leq A_r which have l \leq r and A_l = min(A_{l \ldots r}) and A_r = max(A_{l \ldots r})
Let’s count the number of subsequences starting at l for each l. Suppose nxt > l denote the first position after l which have A_{nxt} < A_l. Hence, Beyond position nxt, A_l won’t be minimum, so we have constraint r < nxt. Using same analogy, Suppose prev < r denote the last position before r such that A_{prev} > A_r. We have constraint prev < l
Computing prev and nxt for each position (this can be done using stack in O(N) time, as explained here. Google Next Greater Element/Previous Greater Element for details).
We have prev_i = max(j : j < i \land A_j > A_i) and nxt_i = min(j: j > i \land A_j < A_l)
For a fruitful contiguous subsequence (l, r), we have prev_r < l \leq r < nxt_l.
Here, we can simply use Merge Sort Tree to solve the problem in O(N*log^2(N)) time by building Merge Sort on nxt and querying (prev_r, r] for number of elements greater than r
But we have an efficient approach. Let’s fix l and consider all such r which have prev_r < l. We need the number of such r which have l \leq r < nxt_l. We need to use a range structure (Fenwick Tree would be best here).
Let’s sort all pairs (prev_r, r) by prev_r and considering l from left to right, we can have a pointer moving towards the right, considering all such r which have prev_r < l. We can update all such r in Fenwick Tree and query for the number of values in range [l, nxt_l)
Excluding double-counted subsequences
Consider pairs (l, r) which have min(A_{l \ldots r}) = A_l = A_r = max(A_{l \ldots r}). These subsequences shall be counted both times, so we need to subtract these. It is easy to see these are just the number of subarrays with equal values and can be computed easily.
Refer to implementations below in case anything is unclear.
TIME COMPLEXITY
The time complexity is O(N*log(N)) per test case.
SOLUTIONS
Setter's Solution
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define LG 20
#define N 300005
ll mod = 1000;
vector<ll> a;
ll n;
ll dp[LG][N];
ll solve_full()
{
//initializing dp
for(ll j=0;j<LG;j++)
{
for(ll i=0;i<n;i++)
{
dp[j][i]=-1;
}
}
vector<ll> nxt(n);
stack<ll> s;
s.push(0);
for(ll i=1;i<n;i++)
{
if (s.empty())
{
s.push(i);
continue;
}
while(!s.empty() and a[s.top()]>=a[i])
{
nxt[s.top()] = i;
s.pop();
}
s.push(i);
}
while(!s.empty())
{
nxt[s.top()]=-1;
s.pop();
}
for(ll i=0;i<n;i++)
{
dp[0][i] = nxt[i];
}
for(ll j=1;j<LG;j++)
{
for(ll i=0;i<n;i++)
{
if(dp[j-1][i]==-1)
dp[j][i]=-1;
else
dp[j][i] = dp[j-1][dp[j-1][i]];
}
}
nxt.clear();
s.push(0);
for(ll i=1;i<n;i++)
{
if (s.empty())
{
s.push(i);
continue;
}
while(!s.empty() and a[s.top()]<a[i])
{
nxt[s.top()] = i;
s.pop();
}
s.push(i);
}
while(!s.empty())
{
nxt[s.top()]=n;
s.pop();
}
long long ans = 0;
for(ll i=0;i<n;i++)
{
ll val = nxt[i];
ll t = i;
for(ll j=LG-1;j>=0;j--)
{
if (dp[j][t]!=-1)
{
if (dp[j][t]<val)
{
ans+=(1<<j);
t=dp[j][t];
}
}
}
}
return ans;
}
ll efficientmethod()
{
ll p[n];
long long ans = 0;
for(ll i=0;i<n;i++)
{
p[i]=1;
if (i>0 and a[i]==a[i-1])
p[i]=p[i-1]+1;
ans-=p[i]-1;
}
//in this method , we will counting all subarray that have only one element 2 times.
// eg . [12,12,12] = so, [1,2],[1,3],[2,3] will be counted twice.
// so, subtract them once.
//all subarray of size 1 are included in answer.
ans+=n;
ans+=solve_full();
reverse(a.begin(),a.end());
ans+=solve_full();
return ans;
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
ll t;
cin>>t;
while(t--)
{
cin>>n;
a.clear();
for(ll i=0;i<n;i++)
{
ll v;
cin>>v;
a.push_back(v);
}
ll ans = efficientmethod();
cout<<ans<<"\n";
}
}
Tester's Solution
#include <bits/stdc++.h>
#define endl '\n'
#define SZ(x) ((int)x.size())
#define ALL(V) V.begin(), V.end()
#define L_B lower_bound
#define U_B upper_bound
#define pb push_back
using namespace std;
template<class T, class T1> int chkmin(T &x, const T1 &y) { return x > y ? x = y, 1 : 0; }
template<class T, class T1> int chkmax(T &x, const T1 &y) { return x < y ? x = y, 1 : 0; }
const int MAXN = (1 << 20);
// Main observation is that {minimum(a[l..r]), maximum(a[l..r])} = {a[l], a[r]}
// i.e. the minimum and maximum will be at the corners of the array
template <class T>
struct fenwick {
int sz;
T tr[MAXN];
void init(int n) {
sz = n + 2;
memset(tr, 0, sizeof(tr));
}
T query(int idx) {
idx += 1;
T ans = 0;
for(; idx >= 1; idx -= (idx & -idx))
ans += tr[idx];
return ans;
}
void update(int idx, T val) {
idx += 1;
if(idx <= 0) return;
for(; idx <= sz; idx += (idx & -idx))
tr[idx] += val;
}
T query(int l, int r) { return query(r) - query(l - 1); }
};
int n;
int a[MAXN];
void read() {
cin >> n;
for(int i = 0; i < n; i++) {
cin >> a[i];
}
}
int nxt[MAXN]; // nxt[i] = min j such that i < j and a[i] > a[j]
fenwick<int> t;
vector<int> li[MAXN];
int64_t solve_() {
for(int i = 0; i < n; i++) {
li[i].clear();
}
t.init(n + 2);
vector<int> st;
for(int i = 0; i < n; i++) {
while(!st.empty() && a[st.back()] <= a[i]) {
st.pop_back();
}
if(st.empty()) {
t.update(i, 1);
} else {
li[st.back()].pb(i);
}
st.pb(i);
}
st.clear();
for(int i = n - 1; i >= 0; i--) {
while(!st.empty() && a[st.back()] >= a[i]) {
st.pop_back();
}
if(st.empty()) {
nxt[i] = n;
} else {
nxt[i] = st.back();
}
st.pb(i);
}
int64_t ret = 0;
for(int i = 0; i < n; i++) {
for(int j: li[i]) {
t.update(j, 1);
}
ret += t.query(i, nxt[i] - 1);
t.update(i, -1);
}
return ret;
}
void solve() {
// A neat way to implement it is to count the number of segments, where a[l] <= a[r]. Then reverse the array and find this count again.
// However, there will be a problem that we count the subarrays with a[l] = a[r] twice, but those should have all of their values same, so
// we can find this count easily and then subtract it from the answer.
int64_t answer = 0;
answer += solve_();
reverse(a, a + n);
answer += solve_();
// Subtrack overcounted subarrays.
int cnt = 1;
for(int i = 1; i < n; i++) {
if(a[i] != a[i - 1]) {
answer -= cnt * 1ll * (cnt + 1) / 2ll;
cnt = 1;
} else cnt++;
}
answer -= cnt * 1ll * (cnt + 1) / 2ll;
cout << answer << endl;
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int T;
cin >> T;
while(T--) {
read();
solve();
}
return 0;
}
Editorialist's Solution (using Merge Sort Tree)
import java.util.*;
import java.io.*;
import java.text.*;
class ABSNX{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
int N = ni();
long[] A = new long[N];
for(int i = 0; i< N; i++)A[i] = nl();
long ans = 0;
{
//Fixing A[l] as minimum, A[r] as maximum
int[] nxt = new int[N];
Stack<Integer> s = new Stack<>();
for(int i = N-1; i>= 0; i--){
while(!s.isEmpty() && A[s.peek()] >= A[i])s.pop();
nxt[i] = s.isEmpty()?N:s.peek();
s.push(i);
}
s.clear();
int[] prev = new int[N];
for(int i = 0; i< N; i++){
while(!s.isEmpty() && A[s.peek()] <= A[i])s.pop();
prev[i] = s.isEmpty()?-1:s.peek();
s.push(i);
}
s.clear();
MergeSortTree MST = new MergeSortTree(nxt);
for(int i = 0; i< N; i++)
ans += MST.countGreater(prev[i]+1, i, i);
}
{
//Fixing A[l] as maximum, A[r] as minimum
int[] prev = new int[N], nxt = new int[N];
Stack<Integer> s = new Stack<>();
for(int i = N-1; i>= 0; i--){
while(!s.isEmpty() && A[s.peek()] <= A[i])s.pop();
nxt[i] = s.isEmpty()?N:s.peek();
s.push(i);
}
s.clear();
for(int i = 0; i< N; i++){
while(!s.isEmpty() && A[s.peek()] >= A[i])s.pop();
prev[i] = s.isEmpty()?-1:s.peek();
s.push(i);
}
s.clear();
MergeSortTree MST = new MergeSortTree(nxt);
for(int i = 0; i< N; i++)
ans += MST.countGreater(prev[i]+1, i, i);
}
//Removing double counted subarrays, (subarrays with max(A[l..r]) == min(A[l..r])
for(int i = 0, j = 0; i< A.length; i = j){
while(j< A.length && A[i] == A[j])j++;
long len = j-i;
ans -= (len*len+len)/2;
}
pn(ans);
}
class MergeSortTree{
int m = 1;
int[][] t;
public MergeSortTree(int[] A){
while(m<A.length)m<<=1;
t = new int[m<<1][];
for(int i = 0; i< A.length; i++)
t[i+m] = new int[]{A[i]};
for(int i = A.length; i< m; i++)t[i+m] = new int[0];
for(int i = m-1; i> 0; i--){
int p1 = 0, p2 = 0, p = 0;
t[i] = new int[t[i<<1].length+t[i<<1|1].length];
while(p1 < t[i<<1].length && p2 < t[i<<1|1].length){
if(t[i<<1][p1] <= t[i<<1|1][p2])t[i][p++] = t[i<<1][p1++];
else t[i][p++] = t[i<<1|1][p2++];
}
while(p1 < t[i<<1].length)t[i][p++] = t[i<<1][p1++];
while(p2 < t[i<<1|1].length)t[i][p++] = t[i<<1|1][p2++];
}
}
int countSmaller(int l, int r, int x){
return countSmaller(l, r, 0, m-1, 1, x);
}
int countGreater(int l, int r, int x){
return countGreater(l, r, 0, m-1, 1, x);
}
int countGreater(int l, int r, int ll, int rr, int i, int x){
if(l == ll && r == rr){
int lo = 0, hi = t[i].length-1;
if(t[i][hi] <= x)return 0;
while(lo+1 < hi){
int mid = (lo+hi)/2;
if(t[i][mid] > x)hi = mid;
else lo = mid;
}
if(t[i][lo] > x)hi = lo;
return t[i].length-hi;
}
int mid = (ll+rr)/2;
if(r <= mid)return countGreater(l, r, ll, mid, i<<1, x);
else if(l > mid)return countGreater(l, r, mid+1, rr, i<<1|1, x);
else return countGreater(l, mid, ll, mid, i<<1, x)+countGreater(mid+1, r, mid+1, rr, i<<1|1, x);
}
int countSmaller(int l, int r, int ll, int rr, int i, int x){
if(l == ll && r == rr){
int lo = 0, hi = t[i].length-1;
if(t[i][lo] >= x)return 0;
while(lo+1 < hi){
int mid = (lo+hi)/2;
if(t[i][mid] < x)lo = mid;
else hi = mid;
}
if(t[i][hi] < x)lo = hi;
return lo+1;
}
int mid = (ll+rr)/2;
if(r <= mid)return countSmaller(l, r, ll, mid, i<<1, x);
else if(l > mid)return countSmaller(l, r, mid+1, rr, i<<1|1, x);
else return countSmaller(l, mid, ll, mid, i<<1, x)+countSmaller(mid+1, r, mid+1, rr, i<<1|1, x);
}
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
DecimalFormat df = new DecimalFormat("0.00000000000");
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new ABSNX().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Feel free to share your approach. Suggestions are welcomed as always.