# MEXSEG - Editorial

Author: piyush_2007
Tester: yash_daga
Editorialist: iceknight1093

# PREREQUISITES:

Familiarity with mex, the inclusion-exclusion principle, basic combinatorics

# PROBLEM:

You’re given a permutation P of \{0, 1, 2, \ldots, N-1\}.
Answer Q queries on it, each one as follows:

• You’re given L_1, L_2, M_1, M_2
• Find the number of subarrays of P whose lengths lie in [L_1, L_2] and whose mex lies in [M_1, M_2]

# EXPLANATION:

First, let’s look at what a subarray having a mex of M actually entails.
M is the smallest integer not present in the subarray, which in particular means the subarray should contain [0, 1, 2, \ldots, M-1]; and should not contain M.

Given that P is a permutation, each value appears exactly once in it. Let \text{pos}_i denote the position of element i in P.
Notice that there’s a unique ‘smallest’ subarray that contains all the values \{0, 1, 2, \ldots, M-1\}; namely the subarray whose left endpoint is \min(\text{pos}_0, \text{pos}_1, \ldots, \text{pos}_{M-1}) and right endpoint is \max(\text{pos}_0, \text{pos}_1, \ldots, \text{pos}_{M-1}).
Let this subarray be [l, r].

Notice that [l, r] will be contained in every subarray whose mex is \geq M; and conversely, any subarray containing [l, r] will have a mex that’s \geq M.

Now let’s move on to answering queries.
Having upper and lower bound restrictions on both length and mex value is a bit too restrictive, so let’s first try to solve a simpler version of this problem, keeping only the lower bounds.

That is, given M and L, we’ll attempt to count the number of subarrays whose length is \geq L and whose mex is \geq M. Let’s denote this value by f(M, L).

First, let’s get a couple of edge cases out of the way.

• If L \gt N, then obviously f(M, L) = 0, since no subarray can have length \gt N.
• Similarly, if M \gt N once again f(M, L) = 0 since no subarray can have mex \gt N.

This leaves us with L, M \leq N.

Recall from earlier that we in fact categorized all subarrays whose mex is \geq M: it’s all subarrays containing [l, r], where l = \min(\text{pos}_0, \text{pos}_1, \ldots, \text{pos}_{M-1}) and r = \max(\text{pos}_0, \text{pos}_1, \ldots, \text{pos}_{M-1}).
l and r can be computed in \mathcal{O}(1), since they’re just prefix minimums/maximums of the \text{pos} array.

Now that we know l and r, notice that any valid subarray [x, y] must satisfy 1 \leq x \leq l and r \leq y \leq N. Our task is to count the number of pairs (x, y) that satisfy this condition, and also y-x+1 \geq L (since we want length \geq L).
This is now a combinatorics problem, and can be solved in \mathcal{O}(1).

How?

Suppose we fix 1 \leq x \leq l. Let’s count the number of valid y.

Note that y must satisfy:

• y-x+1 \leq L, i.e, y \geq x+L-1
• r \leq y \leq N

The smallest valid y is thus y_0 = \max(r, x+L-1), and the number of valid y is N-y_0+1 (assuming y_0 \leq N, of course; otherwise the number of valid y is zero).

This gives us a solution in \mathcal{O}(N) by iterating x, but we need to do a bit better to answer queries.
So, let’s deal with the \max(r, x+L-1) cases separately; i.e, treat the case when r is the maximum separately from when x+L-1 is the maximum.

r is the maximum

When r = \max(r, x+L-1), this means the number of valid y for this x is simply N-r+1, a constant (since we can pick y=r, r+1, r+2, \ldots, N).
So, we only need to find the number of x that satisfy this.

That’s not hard. We have two inequalities:

• 1 \leq x \leq l
• r \geq x+L-1, or x \leq r-L+1

So, x_0 = \min(l, r-L+1) is the maximum x for which this holds, and there are (x_0 + 1) valid positions (of course, if x_0 \leq 0 there are 0 valid positions).

This adds (N-r+1) \cdot (\min(l, r-L+1) + 1) to f(M, L).

x+L-1 is the maximum

Let’s find x_0 as in the previous case.
Now, we need to deal with x_0+1, x_0+2, \ldots, l

Notice that in this case, if x_0+1 has k valid y-positions, then x_0+2 will have k-1 valid positions, x_0+3 will have k-2 valid positions, and so on till l.

So, we’d like to compute the sum of some consecutive range of integers, which is easy to do in \mathcal{O}(1).
Finding the left and right ends of this range can be done by processing x_0+1 and l, and finding their respective position counts.
The exact details here are left as an exercise to the reader

You may also see the code linked below.

Now that we know how to compute f(M, L) in \mathcal{O}(1), how do we solve the original problem?

That’s simple, apply inclusion-exclusion!
For the query L_1, L_2, M_1, M_2, the answer is

f(M_1, L_1) - f(M_1, L_2+1) - f(M_2+1, L_1) + f(M_2+1, L_2+1)

each of which are computed in \mathcal{O}(1), so we’re done.

# TIME COMPLEXITY

\mathcal{O}(N + Q) per test case.

# CODE:

Setter's code (C++)
                                 //  ॐ
#include <bits/stdc++.h>
using namespace std;
#define PI 3.14159265358979323846
#define ll long long int

const int N=1e6+5;
int pos[N];
int l_m[N],r_m[N];

inline ll f(ll m,ll len,ll n){

if(m>n || len<=0)
return 0;

if(m==0)
return (1LL*len*(2*n-len+1))/2;

int r=n-r_m[m];
int l=l_m[m]-1;
int sz=r_m[m]-l_m[m]+1;

if(sz>len){
return 0;
}

int left=len-sz;
l=min(l,left);
r=min(r,left);

int z=min(l,r)+1;
ll ret=1LL*z*(z+1);
ret/=2;

ret+=max(0LL,1LL*z*(min(max(l,r),left)-z+1));
z=max(l,r);
int num=min(l+r,left)-z;
z=min(l,r);
ret+=(1LL*(num)*(z-num+1+z))/2;

return ret;
}

int main(){

ios_base::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);

int test = 1;
cin>>test;

while(test--){

int n,q;
cin>>n>>q;
int p[n];

for(int i=0;i<n;i++){
cin>>p[i];
pos[p[i]]=i+1;
}

l_m[0]=1e9;
r_m[0]=-1;

for(int i=1;i<=n;i++){
r_m[i]=max(r_m[i-1],pos[i-1]);
l_m[i]=min(l_m[i-1],pos[i-1]);
}

while(q--){
int l1,l2,m1,m2;
cin>>l1>>l2>>m1>>m2;
cout<<f(m1,l2,n)-f(m1,l1-1,n)-(f(m2+1,l2,n)-f(m2+1,l1-1,n))<<'\n';
}

// cout<<'\n';

}
return 0;
}

Tester's code (C++)
//clear adj and visited vector declared globally after each test case
//check for long long overflow
//Mod wale question mein last mein if dalo ie. Ans<0 then ans+=mod;
//Incase of close mle change language to c++17 or c++14
//Check ans for n=1
// #pragma GCC target ("avx2")
// #pragma GCC optimize ("O3")
// #pragma GCC optimize ("unroll-loops")
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#define int long long
#define IOS std::ios::sync_with_stdio(false); cin.tie(NULL);cout.tie(NULL);cout.precision(dbl::max_digits10);
#define pb push_back
#define mod 1000000007ll //998244353ll
#define lld long double
#define mii map<int, int>
#define pii pair<int, int>
#define ll long long
#define ff first
#define ss second
#define all(x) (x).begin(), (x).end()
#define rep(i,x,y) for(int i=x; i<y; i++)
#define fill(a,b) memset(a, b, sizeof(a))
#define vi vector<int>
#define setbits(x) __builtin_popcountll(x)
#define print2d(dp,n,m) for(int i=0;i<=n;i++){for(int j=0;j<=m;j++)cout<<dp[i][j]<<" ";cout<<"\n";}
typedef std::numeric_limits< double > dbl;
using namespace __gnu_pbds;
using namespace std;
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> indexed_set;
//member functions :
//1. order_of_key(k) : number of elements strictly lesser than k
//2. find_by_order(k) : k-th element in the set
const long long N=200005, INF=2000000000000000000;
const int inf=2e9 + 5;
lld pi=3.1415926535897932;
int lcm(int a, int b)
{
int g=__gcd(a, b);
return a/g*b;
}
int power(int a, int b, int p)
{
if(a==0)
return 0;
int res=1;
a%=p;
while(b>0)
{
if(b&1)
res=(1ll*res*a)%p;
b>>=1;
a=(1ll*a*a)%p;
}
return res;
}

int getRand(int l, int r)
{
uniform_int_distribution<int> uid(l, r);
return uid(rng);
}
int sum(int n)
{
return (n*(n+1))/2;
}

int32_t main()
{
// IOS;
int t;
cin>>t;
while(t--)
{
int n, q;
cin>>n>>q;
int a[n], pos[n];
rep(i,0,n)
{
cin>>a[i];
pos[a[i]]=i;
}
int l[n+1], r[n+1];
l[1]=r[1]=pos[0];
int x=pos[0], y=pos[0];
rep(i,2,n+1)
{
x=min(x, pos[i-1]);
y=max(y, pos[i-1]);
l[i]=x;
r[i]=y;
}
auto cal = [&](int mx, int len)
{
if(mx>n || len<=0)
return 0ll;
if(mx==0)
return sum(n)-sum(n-len);

int l1=l[mx], r1=r[mx];
int base=r1-l1+1;

if(len<base)
return 0ll;
return sum(len-base+1);
};
while(q--)
{
int l1, l2, m1, m2;
cin>>l1>>l2>>m1>>m2;
cout<<cal(m1, l2)-cal(m1, l1-1)-(cal(m2+1, l2)-cal(m2+1, l1-1))<<"\n";
}
}
}

Editorialist's code (Python)
import sys

for _ in range(int(input())):
n, q = map(int, input().split())
p = list(map(int, input().split()))

mnpos, mxpos = [0]*n, [0]*n
for i in range(n):
mnpos[p[i]] = mxpos[p[i]] = i
for i in range(1, n):
mnpos[i] = min(mnpos[i], mnpos[i-1])
mxpos[i] = max(mxpos[i], mxpos[i-1])

def calc(M, L): # number of subarrays with mex >= M, length >= L
if M > n or L > n: return 0
if M == 0: # 1 + 2 + ... + n-L+1
return (n-L+1)*(n-L+2)//2

lo, hi = mnpos[M-1], mxpos[M-1]
ret = max(0, min(lo+1, hi-L+2)) * (n - hi)
if hi+1 < n and hi-L+1 < lo:
mx = min(n - L + 1, n - hi - 1)
mn = 0
if lo+L < n: mn = n-lo-L+1

# mn + mn+1 + ... + mx
ret += mx*(mx+1)//2 - mn*(mn-1)//2
return ret

for i in range(q):
l, r, x, y = map(int, input().split())
ans = calc(x, l) - calc(x, r+1) - calc(y+1, l) + calc(y+1, r+1)
print(ans)

