PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: piyush_2007
Tester: yash_daga
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Familiarity with mex, the inclusion-exclusion principle, basic combinatorics
PROBLEM:
You’re given a permutation P of \{0, 1, 2, \ldots, N-1\}.
Answer Q queries on it, each one as follows:
- You’re given L_1, L_2, M_1, M_2
- Find the number of subarrays of P whose lengths lie in [L_1, L_2] and whose mex lies in [M_1, M_2]
EXPLANATION:
First, let’s look at what a subarray having a mex of M actually entails.
M is the smallest integer not present in the subarray, which in particular means the subarray should contain [0, 1, 2, \ldots, M-1]; and should not contain M.
Given that P is a permutation, each value appears exactly once in it. Let \text{pos}_i denote the position of element i in P.
Notice that there’s a unique ‘smallest’ subarray that contains all the values \{0, 1, 2, \ldots, M-1\}; namely the subarray whose left endpoint is \min(\text{pos}_0, \text{pos}_1, \ldots, \text{pos}_{M-1}) and right endpoint is \max(\text{pos}_0, \text{pos}_1, \ldots, \text{pos}_{M-1}).
Let this subarray be [l, r].
Notice that [l, r] will be contained in every subarray whose mex is \geq M; and conversely, any subarray containing [l, r] will have a mex that’s \geq M.
Now let’s move on to answering queries.
Having upper and lower bound restrictions on both length and mex value is a bit too restrictive, so let’s first try to solve a simpler version of this problem, keeping only the lower bounds.
That is, given M and L, we’ll attempt to count the number of subarrays whose length is \geq L and whose mex is \geq M. Let’s denote this value by f(M, L).
First, let’s get a couple of edge cases out of the way.
- If L \gt N, then obviously f(M, L) = 0, since no subarray can have length \gt N.
- Similarly, if M \gt N once again f(M, L) = 0 since no subarray can have mex \gt N.
This leaves us with L, M \leq N.
Recall from earlier that we in fact categorized all subarrays whose mex is \geq M: it’s all subarrays containing [l, r], where l = \min(\text{pos}_0, \text{pos}_1, \ldots, \text{pos}_{M-1}) and r = \max(\text{pos}_0, \text{pos}_1, \ldots, \text{pos}_{M-1}).
l and r can be computed in \mathcal{O}(1), since they’re just prefix minimums/maximums of the \text{pos} array.
Now that we know l and r, notice that any valid subarray [x, y] must satisfy 1 \leq x \leq l and r \leq y \leq N. Our task is to count the number of pairs (x, y) that satisfy this condition, and also y-x+1 \geq L (since we want length \geq L).
This is now a combinatorics problem, and can be solved in \mathcal{O}(1).
How?
Suppose we fix 1 \leq x \leq l. Let’s count the number of valid y.
Note that y must satisfy:
- y-x+1 \leq L, i.e, y \geq x+L-1
- r \leq y \leq N
The smallest valid y is thus y_0 = \max(r, x+L-1), and the number of valid y is N-y_0+1 (assuming y_0 \leq N, of course; otherwise the number of valid y is zero).
This gives us a solution in \mathcal{O}(N) by iterating x, but we need to do a bit better to answer queries.
So, let’s deal with the \max(r, x+L-1) cases separately; i.e, treat the case when r is the maximum separately from when x+L-1 is the maximum.
r is the maximum
When r = \max(r, x+L-1), this means the number of valid y for this x is simply N-r+1, a constant (since we can pick y=r, r+1, r+2, \ldots, N).
So, we only need to find the number of x that satisfy this.
That’s not hard. We have two inequalities:
- 1 \leq x \leq l
- r \geq x+L-1, or x \leq r-L+1
So, x_0 = \min(l, r-L+1) is the maximum x for which this holds, and there are (x_0 + 1) valid positions (of course, if x_0 \leq 0 there are 0 valid positions).
This adds (N-r+1) \cdot (\min(l, r-L+1) + 1) to f(M, L).
x+L-1 is the maximum
Let’s find x_0 as in the previous case.
Now, we need to deal with x_0+1, x_0+2, \ldots, l
Notice that in this case, if x_0+1 has k valid y-positions, then x_0+2 will have k-1 valid positions, x_0+3 will have k-2 valid positions, and so on till l.
So, we’d like to compute the sum of some consecutive range of integers, which is easy to do in \mathcal{O}(1).
Finding the left and right ends of this range can be done by processing x_0+1 and l, and finding their respective position counts.
The exact details here are left as an exercise to the reader ![]()
You may also see the code linked below.
Now that we know how to compute f(M, L) in \mathcal{O}(1), how do we solve the original problem?
That’s simple, apply inclusion-exclusion!
For the query L_1, L_2, M_1, M_2, the answer is
each of which are computed in \mathcal{O}(1), so we’re done.
TIME COMPLEXITY
\mathcal{O}(N + Q) per test case.
CODE:
Setter's code (C++)
// ॐ
#include <bits/stdc++.h>
using namespace std;
#define PI 3.14159265358979323846
#define ll long long int
const int N=1e6+5;
int pos[N];
int l_m[N],r_m[N];
inline ll f(ll m,ll len,ll n){
if(m>n || len<=0)
return 0;
if(m==0)
return (1LL*len*(2*n-len+1))/2;
int r=n-r_m[m];
int l=l_m[m]-1;
int sz=r_m[m]-l_m[m]+1;
if(sz>len){
return 0;
}
int left=len-sz;
l=min(l,left);
r=min(r,left);
int z=min(l,r)+1;
ll ret=1LL*z*(z+1);
ret/=2;
ret+=max(0LL,1LL*z*(min(max(l,r),left)-z+1));
z=max(l,r);
int num=min(l+r,left)-z;
z=min(l,r);
ret+=(1LL*(num)*(z-num+1+z))/2;
return ret;
}
int main(){
ios_base::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int test = 1;
cin>>test;
while(test--){
int n,q;
cin>>n>>q;
int p[n];
for(int i=0;i<n;i++){
cin>>p[i];
pos[p[i]]=i+1;
}
l_m[0]=1e9;
r_m[0]=-1;
for(int i=1;i<=n;i++){
r_m[i]=max(r_m[i-1],pos[i-1]);
l_m[i]=min(l_m[i-1],pos[i-1]);
}
while(q--){
int l1,l2,m1,m2;
cin>>l1>>l2>>m1>>m2;
cout<<f(m1,l2,n)-f(m1,l1-1,n)-(f(m2+1,l2,n)-f(m2+1,l1-1,n))<<'\n';
}
// cout<<'\n';
}
return 0;
}
Tester's code (C++)
//clear adj and visited vector declared globally after each test case
//check for long long overflow
//Mod wale question mein last mein if dalo ie. Ans<0 then ans+=mod;
//Incase of close mle change language to c++17 or c++14
//Check ans for n=1
// #pragma GCC target ("avx2")
// #pragma GCC optimize ("O3")
// #pragma GCC optimize ("unroll-loops")
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#define int long long
#define IOS std::ios::sync_with_stdio(false); cin.tie(NULL);cout.tie(NULL);cout.precision(dbl::max_digits10);
#define pb push_back
#define mod 1000000007ll //998244353ll
#define lld long double
#define mii map<int, int>
#define pii pair<int, int>
#define ll long long
#define ff first
#define ss second
#define all(x) (x).begin(), (x).end()
#define rep(i,x,y) for(int i=x; i<y; i++)
#define fill(a,b) memset(a, b, sizeof(a))
#define vi vector<int>
#define setbits(x) __builtin_popcountll(x)
#define print2d(dp,n,m) for(int i=0;i<=n;i++){for(int j=0;j<=m;j++)cout<<dp[i][j]<<" ";cout<<"\n";}
typedef std::numeric_limits< double > dbl;
using namespace __gnu_pbds;
using namespace std;
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> indexed_set;
//member functions :
//1. order_of_key(k) : number of elements strictly lesser than k
//2. find_by_order(k) : k-th element in the set
const long long N=200005, INF=2000000000000000000;
const int inf=2e9 + 5;
lld pi=3.1415926535897932;
int lcm(int a, int b)
{
int g=__gcd(a, b);
return a/g*b;
}
int power(int a, int b, int p)
{
if(a==0)
return 0;
int res=1;
a%=p;
while(b>0)
{
if(b&1)
res=(1ll*res*a)%p;
b>>=1;
a=(1ll*a*a)%p;
}
return res;
}
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
int getRand(int l, int r)
{
uniform_int_distribution<int> uid(l, r);
return uid(rng);
}
int sum(int n)
{
return (n*(n+1))/2;
}
int32_t main()
{
// IOS;
int t;
cin>>t;
while(t--)
{
int n, q;
cin>>n>>q;
int a[n], pos[n];
rep(i,0,n)
{
cin>>a[i];
pos[a[i]]=i;
}
int l[n+1], r[n+1];
l[1]=r[1]=pos[0];
int x=pos[0], y=pos[0];
rep(i,2,n+1)
{
x=min(x, pos[i-1]);
y=max(y, pos[i-1]);
l[i]=x;
r[i]=y;
}
auto cal = [&](int mx, int len)
{
if(mx>n || len<=0)
return 0ll;
if(mx==0)
return sum(n)-sum(n-len);
int l1=l[mx], r1=r[mx];
int base=r1-l1+1;
int ad1=min(r1+1, n-l1);
int ad2=max(r1+1, n-l1);
if(len<base)
return 0ll;
if(len<ad1)
return sum(len-base+1);
if(len<ad2)
return sum(ad1-base+1) + ((len-ad1)*(ad1-base+1));
return sum(ad1-base+1)+sum(ad1-base) + ((ad2-ad1)*(ad1-base+1)) - sum(n-len);
};
while(q--)
{
int l1, l2, m1, m2;
cin>>l1>>l2>>m1>>m2;
cout<<cal(m1, l2)-cal(m1, l1-1)-(cal(m2+1, l2)-cal(m2+1, l1-1))<<"\n";
}
}
}
Editorialist's code (Python)
import sys
input = sys.stdin.readline
for _ in range(int(input())):
n, q = map(int, input().split())
p = list(map(int, input().split()))
mnpos, mxpos = [0]*n, [0]*n
for i in range(n):
mnpos[p[i]] = mxpos[p[i]] = i
for i in range(1, n):
mnpos[i] = min(mnpos[i], mnpos[i-1])
mxpos[i] = max(mxpos[i], mxpos[i-1])
def calc(M, L): # number of subarrays with mex >= M, length >= L
if M > n or L > n: return 0
if M == 0: # 1 + 2 + ... + n-L+1
return (n-L+1)*(n-L+2)//2
lo, hi = mnpos[M-1], mxpos[M-1]
ret = max(0, min(lo+1, hi-L+2)) * (n - hi)
if hi+1 < n and hi-L+1 < lo:
mx = min(n - L + 1, n - hi - 1)
mn = 0
if lo+L < n: mn = n-lo-L+1
# mn + mn+1 + ... + mx
ret += mx*(mx+1)//2 - mn*(mn-1)//2
return ret
for i in range(q):
l, r, x, y = map(int, input().split())
ans = calc(x, l) - calc(x, r+1) - calc(y+1, l) + calc(y+1, r+1)
print(ans)