MEXSEG - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: piyush_2007
Tester: yash_daga
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Familiarity with mex, the inclusion-exclusion principle, basic combinatorics

PROBLEM:

You’re given a permutation P of \{0, 1, 2, \ldots, N-1\}.
Answer Q queries on it, each one as follows:

  • You’re given L_1, L_2, M_1, M_2
  • Find the number of subarrays of P whose lengths lie in [L_1, L_2] and whose mex lies in [M_1, M_2]

EXPLANATION:

First, let’s look at what a subarray having a mex of M actually entails.
M is the smallest integer not present in the subarray, which in particular means the subarray should contain [0, 1, 2, \ldots, M-1]; and should not contain M.

Given that P is a permutation, each value appears exactly once in it. Let \text{pos}_i denote the position of element i in P.
Notice that there’s a unique ‘smallest’ subarray that contains all the values \{0, 1, 2, \ldots, M-1\}; namely the subarray whose left endpoint is \min(\text{pos}_0, \text{pos}_1, \ldots, \text{pos}_{M-1}) and right endpoint is \max(\text{pos}_0, \text{pos}_1, \ldots, \text{pos}_{M-1}).
Let this subarray be [l, r].

Notice that [l, r] will be contained in every subarray whose mex is \geq M; and conversely, any subarray containing [l, r] will have a mex that’s \geq M.


Now let’s move on to answering queries.
Having upper and lower bound restrictions on both length and mex value is a bit too restrictive, so let’s first try to solve a simpler version of this problem, keeping only the lower bounds.

That is, given M and L, we’ll attempt to count the number of subarrays whose length is \geq L and whose mex is \geq M. Let’s denote this value by f(M, L).

First, let’s get a couple of edge cases out of the way.

  • If L \gt N, then obviously f(M, L) = 0, since no subarray can have length \gt N.
  • Similarly, if M \gt N once again f(M, L) = 0 since no subarray can have mex \gt N.

This leaves us with L, M \leq N.

Recall from earlier that we in fact categorized all subarrays whose mex is \geq M: it’s all subarrays containing [l, r], where l = \min(\text{pos}_0, \text{pos}_1, \ldots, \text{pos}_{M-1}) and r = \max(\text{pos}_0, \text{pos}_1, \ldots, \text{pos}_{M-1}).
l and r can be computed in \mathcal{O}(1), since they’re just prefix minimums/maximums of the \text{pos} array.

Now that we know l and r, notice that any valid subarray [x, y] must satisfy 1 \leq x \leq l and r \leq y \leq N. Our task is to count the number of pairs (x, y) that satisfy this condition, and also y-x+1 \geq L (since we want length \geq L).
This is now a combinatorics problem, and can be solved in \mathcal{O}(1).

How?

Suppose we fix 1 \leq x \leq l. Let’s count the number of valid y.

Note that y must satisfy:

  • y-x+1 \leq L, i.e, y \geq x+L-1
  • r \leq y \leq N

The smallest valid y is thus y_0 = \max(r, x+L-1), and the number of valid y is N-y_0+1 (assuming y_0 \leq N, of course; otherwise the number of valid y is zero).

This gives us a solution in \mathcal{O}(N) by iterating x, but we need to do a bit better to answer queries.
So, let’s deal with the \max(r, x+L-1) cases separately; i.e, treat the case when r is the maximum separately from when x+L-1 is the maximum.

r is the maximum

When r = \max(r, x+L-1), this means the number of valid y for this x is simply N-r+1, a constant (since we can pick y=r, r+1, r+2, \ldots, N).
So, we only need to find the number of x that satisfy this.

That’s not hard. We have two inequalities:

  • 1 \leq x \leq l
  • r \geq x+L-1, or x \leq r-L+1

So, x_0 = \min(l, r-L+1) is the maximum x for which this holds, and there are (x_0 + 1) valid positions (of course, if x_0 \leq 0 there are 0 valid positions).

This adds (N-r+1) \cdot (\min(l, r-L+1) + 1) to f(M, L).

x+L-1 is the maximum

Let’s find x_0 as in the previous case.
Now, we need to deal with x_0+1, x_0+2, \ldots, l

Notice that in this case, if x_0+1 has k valid y-positions, then x_0+2 will have k-1 valid positions, x_0+3 will have k-2 valid positions, and so on till l.

So, we’d like to compute the sum of some consecutive range of integers, which is easy to do in \mathcal{O}(1).
Finding the left and right ends of this range can be done by processing x_0+1 and l, and finding their respective position counts.
The exact details here are left as an exercise to the reader :slightly_smiling_face:

You may also see the code linked below.


Now that we know how to compute f(M, L) in \mathcal{O}(1), how do we solve the original problem?

That’s simple, apply inclusion-exclusion!
For the query L_1, L_2, M_1, M_2, the answer is

f(M_1, L_1) - f(M_1, L_2+1) - f(M_2+1, L_1) + f(M_2+1, L_2+1)

each of which are computed in \mathcal{O}(1), so we’re done.

TIME COMPLEXITY

\mathcal{O}(N + Q) per test case.

CODE:

Setter's code (C++)
                                 //  ॐ
#include <bits/stdc++.h>
using namespace std;
#define PI 3.14159265358979323846
#define ll long long int

const int N=1e6+5;
int pos[N];
int l_m[N],r_m[N];

inline ll f(ll m,ll len,ll n){

   if(m>n || len<=0)
    return 0;

   if(m==0)
      return (1LL*len*(2*n-len+1))/2;

     int r=n-r_m[m];
     int l=l_m[m]-1;
     int sz=r_m[m]-l_m[m]+1;

     if(sz>len){
        return 0;
     }

     int left=len-sz;
     l=min(l,left);
     r=min(r,left);

     int z=min(l,r)+1;
     ll ret=1LL*z*(z+1);
     ret/=2;

     ret+=max(0LL,1LL*z*(min(max(l,r),left)-z+1));
     z=max(l,r);
     int num=min(l+r,left)-z;
     z=min(l,r);
     ret+=(1LL*(num)*(z-num+1+z))/2;

     return ret;
}

int main(){
   
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
        

   int test = 1;
   cin>>test;

  
   while(test--){
                     
                      
                     int n,q;
                     cin>>n>>q;
                     int p[n];

                     for(int i=0;i<n;i++){
                         cin>>p[i];
                         pos[p[i]]=i+1;
                     } 
                     
                     l_m[0]=1e9;
                     r_m[0]=-1;

                     for(int i=1;i<=n;i++){
                        r_m[i]=max(r_m[i-1],pos[i-1]);
                        l_m[i]=min(l_m[i-1],pos[i-1]);
                     }

                     while(q--){
                           int l1,l2,m1,m2;
                           cin>>l1>>l2>>m1>>m2;
                           cout<<f(m1,l2,n)-f(m1,l1-1,n)-(f(m2+1,l2,n)-f(m2+1,l1-1,n))<<'\n';
                     }

                     // cout<<'\n';  
                
   }
        return 0;
}
Tester's code (C++)
//clear adj and visited vector declared globally after each test case
//check for long long overflow   
//Mod wale question mein last mein if dalo ie. Ans<0 then ans+=mod;
//Incase of close mle change language to c++17 or c++14  
//Check ans for n=1 
// #pragma GCC target ("avx2")    
// #pragma GCC optimize ("O3")  
// #pragma GCC optimize ("unroll-loops")
#include <bits/stdc++.h>                   
#include <ext/pb_ds/assoc_container.hpp>  
#define int long long     
#define IOS std::ios::sync_with_stdio(false); cin.tie(NULL);cout.tie(NULL);cout.precision(dbl::max_digits10);
#define pb push_back 
#define mod 1000000007ll //998244353ll
#define lld long double
#define mii map<int, int> 
#define pii pair<int, int>
#define ll long long 
#define ff first
#define ss second 
#define all(x) (x).begin(), (x).end()
#define rep(i,x,y) for(int i=x; i<y; i++)    
#define fill(a,b) memset(a, b, sizeof(a))
#define vi vector<int>
#define setbits(x) __builtin_popcountll(x)
#define print2d(dp,n,m) for(int i=0;i<=n;i++){for(int j=0;j<=m;j++)cout<<dp[i][j]<<" ";cout<<"\n";}
typedef std::numeric_limits< double > dbl;
using namespace __gnu_pbds;
using namespace std;
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> indexed_set;
//member functions :
//1. order_of_key(k) : number of elements strictly lesser than k
//2. find_by_order(k) : k-th element in the set
const long long N=200005, INF=2000000000000000000;
const int inf=2e9 + 5;
lld pi=3.1415926535897932;
int lcm(int a, int b)
{
    int g=__gcd(a, b);
    return a/g*b;
}
int power(int a, int b, int p)
    {
        if(a==0)
        return 0;
        int res=1;
        a%=p;
        while(b>0)
        {
            if(b&1)
            res=(1ll*res*a)%p;
            b>>=1;
            a=(1ll*a*a)%p;
        }
        return res;
    }
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

int getRand(int l, int r)
{
    uniform_int_distribution<int> uid(l, r);
    return uid(rng);
}
int sum(int n)
{
    return (n*(n+1))/2;
}

int32_t main()
{
    // IOS;
    int t;
    cin>>t;
    while(t--)
    {
        int n, q;
        cin>>n>>q;
        int a[n], pos[n];
        rep(i,0,n)
        {
            cin>>a[i];
            pos[a[i]]=i;
        }
        int l[n+1], r[n+1];
        l[1]=r[1]=pos[0];
        int x=pos[0], y=pos[0];
        rep(i,2,n+1)
        {
            x=min(x, pos[i-1]);
            y=max(y, pos[i-1]);
            l[i]=x;
            r[i]=y;
        }
        auto cal = [&](int mx, int len)
        {
            if(mx>n || len<=0)
                return 0ll;
            if(mx==0)
                return sum(n)-sum(n-len);

            int l1=l[mx], r1=r[mx];
            int base=r1-l1+1;
            int ad1=min(r1+1, n-l1);
            int ad2=max(r1+1, n-l1);

            if(len<base)
                return 0ll;
            if(len<ad1)
                return sum(len-base+1);
            if(len<ad2)
                return sum(ad1-base+1) + ((len-ad1)*(ad1-base+1));
            return sum(ad1-base+1)+sum(ad1-base) + ((ad2-ad1)*(ad1-base+1)) - sum(n-len);
        };
        while(q--)
        {
            int l1, l2, m1, m2;
            cin>>l1>>l2>>m1>>m2;
            cout<<cal(m1, l2)-cal(m1, l1-1)-(cal(m2+1, l2)-cal(m2+1, l1-1))<<"\n";
        }
    }
}
Editorialist's code (Python)
import sys
input = sys.stdin.readline

for _ in range(int(input())):
	n, q = map(int, input().split())
	p = list(map(int, input().split()))
	
	mnpos, mxpos = [0]*n, [0]*n
	for i in range(n):
		mnpos[p[i]] = mxpos[p[i]] = i
	for i in range(1, n):
		mnpos[i] = min(mnpos[i], mnpos[i-1])
		mxpos[i] = max(mxpos[i], mxpos[i-1])
	
	def calc(M, L): # number of subarrays with mex >= M, length >= L
		if M > n or L > n: return 0
		if M == 0: # 1 + 2 + ... + n-L+1
			return (n-L+1)*(n-L+2)//2
		
		lo, hi = mnpos[M-1], mxpos[M-1]
		ret = max(0, min(lo+1, hi-L+2)) * (n - hi)
		if hi+1 < n and hi-L+1 < lo:
			mx = min(n - L + 1, n - hi - 1)
			mn = 0
			if lo+L < n: mn = n-lo-L+1
			
			# mn + mn+1 + ... + mx
			ret += mx*(mx+1)//2 - mn*(mn-1)//2
		return ret
	
	for i in range(q):
		l, r, x, y = map(int, input().split())
		ans = calc(x, l) - calc(x, r+1) - calc(y+1, l) + calc(y+1, r+1)
		print(ans)
1 Like

just saw the video editorial that guy Madhav didn’t explain the solution well. I dont know whether he is time bounded or not , but the solution could have been explained better .
@admin please look into it.

1 Like

@iceknight1093 the code is missing i think

Oh you’re right, I’ll add them in a couple of minutes.
Edit: done!

Can this problem be solved with 2D BIT?