CONSTMEX - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4

Setter: Utkarsh Gupta
Tester: Harris Leung
Editorialist: Trung Dang

DIFFICULTY:

Easy-Medium

PREREQUISITES:

None

PROBLEM:

Utkarsh has a permutation P of \{0, 1, 2, \dots, N-1\}.

He wants to find the number of pairs (L, R) (1 \leq L \lt R \leq N) such that swapping the elements P_L and P_R does not change the MEX of any subarray.

EXPLANATION:

Instead of keeping all the subarrays’ mexes unchanged, let’s find another simpler property of a permutation that is equivalent to all subarrays’ mexes (i.e. mex of some subarray will change if and only if this property changes).

Let [l_i, r_i] be the smallest possible range such that A_{[l_i, r_i]} contains all values from 0 to i (denote this as the i-th charactestic range). I claim that the ranges [l_0, r_0], [l_1, r_1], \dots, [l_{N-1}, r_{N-1}] is an equivalent property to all subarrays’ mexes:

  • From the list of such ranges, we can construct the mex of any subarray (the mex of any subarray A_{L, R} is the index i of the first charactestic range [l_i, r_i] such that the subarray does not cover all of the charactestic range). Therefore, similar list of characteristic ranges is equivalent to similar mexes on all subarrays.
  • Additionally, changing the value of at least one characteristic range will change the mex of the subarray revelant to that characteristic range. Therefore, different list of characteristic ranges is equivalent to some subarrays having different mexes.

The problem now becomes: find the number of pairs (i, j) such that A_i < A_j, and swapping A_i and A_j does not change any characteristic range.

For any element A_i = u, we have the following observations:

  • If [l_{u - 1}, r_{u - 1}] \neq [l_u, r_u], we cannot swap A_i with any A_j > A_i, because that will change [l_u, r_u].
  • If [l_{u - 1}, r_{u - 1}] = [l_u, r_u], we can swap A_i with any other A_j > A_i that is also within the range [l_u, r_u]. There are r_u - l_u - u such values.

This directly leads to the solution.

TIME COMPLEXITY:

Time complexity is O(N) per test case.

SOLUTION:

Setter's Solution
//Utkarsh.25dec
#include "bits/stdc++.h"
#define ll long long int
#define pb push_back
#define mp make_pair
#define mod 1000000007
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
const int N=500023;
bool vis[N];
vector <int> adj[N];
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }

            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
int sumN=0;
void solve()
{
    int n=readInt(2,100000,'\n');
    sumN+=n;
    assert(sumN<=200000);
    int indx[n+1]={0};
    int arr[n+1]={0};
    for(int i=1;i<=n;i++)
    {
        if(i!=n)
            arr[i]=readInt(0,n-1,' ');
        else
            arr[i]=readInt(0,n-1,'\n');
        assert(indx[arr[i]]==0);
        indx[arr[i]]=i;
    }
    ll largebef=0;
    int l=indx[0],r=indx[0];
    ll ans=0;
    for(int i=1;i<n;i++)
    {
        int x=indx[i];
        if(x>=l && x<=r)
        {
            largebef--;
            continue;
        }
        else
        {
            l=min(l,x);
            r=max(r,x);
            ll largenow=(r-l+1)-(i+1);
            ll newlar=largenow-largebef;
            ans+=((newlar*(newlar-1))/2+(newlar*largebef));
            largebef=largenow;
        }
    }
    cout<<ans<<'\n';
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL),cout.tie(NULL);
    int T=readInt(1,10000,'\n');
    while(T--)
        solve();
    assert(getchar()==-1);
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Tester's Solution
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define fi first
#define se second
const ll mod=998244353;
const int N=1e5+1;
int n;
int a[N],b[N];
void solve(){
	cin >> n;
	for(int i=1; i<=n ;i++){
		cin >> a[i];
		b[a[i]]=i;
	}
	int l=b[0],r=b[0];
	ll ans=0;
	for(int i=1; i<n ;i++){
		if(l<=b[i] && b[i]<=r){
			ans+=(r-l-i);
		}
		else{
			l=min(l,b[i]);
			r=max(r,b[i]);
		}
	}
	cout << ans << '\n';
}
int main(){
	ios::sync_with_stdio(false);
	int t;cin >> t;while(t--) solve();
}
Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector<int> pos(n);
        for (int i = 0; i < n; i++) {
            int u; cin >> u;
            pos[u] = i;
        }
        int l = n, r = -1;
        long long ans = 0;
        for (int i = 0; i < n; i++) {
            if (pos[i] < l || pos[i] > r) {
                l = min(pos[i], l);
                r = max(pos[i], r);
            } else {
                ans += r - l - i;
            }
        }
        cout << ans << '\n';
    }
}
1 Like

O NlogN approach:

  • let p[i] s[i] denotes prefix MEX and suffix MEX for given i. (Can be calculated in O(n))
  • those i,j are valid such that
  • a[j] > s[i+1] & s[i+1] != a[i] and a[i]>p[j-1] & p[j-1] != a[j]

So iterate back wards and for give i do these

  • if s[i+1] == a[i] → violates condition skip this i
  • else keep popping {a[j], p[j-1]} from some priority queue to get rid of those j such that a[j] <= s[i+1]
  • Simultaneosly pop p[j-1] from some set say S too.
  • Now three conditions are met → s[i+1] != a[i], and S has those p[j-1] for which a[j] > s[i+1] and we would put only thos p[j-1] in S for which p[j-1] != a[j] .
  • For fourth condition → a[i] > p[j-1]. , we basically want to find index/rank of a[i] in S which can be done in logn using order statistics or simple BIT as element of S are [0…n-1]
  • Now put this p[i-1] in S and {a[i], p[j-1]} in priority queue IFF this i has potential to become j i.e p[i-1] != s[i]

Submission: CodeChef: Practical coding for everyone

Can someone give me a counter to my code:

t=int(input())
for _ in range(t):
n=int(input())
a=list(map(int,input().split()))
ans=0
left=[0 for i in range(n)]
right=[0 for i in range(n)]
mini=a[0]
for i in range(1,n):
if mini<a[i]:
left[i]=1
else:
mini=a[i]
mini=a[-1]
for i in range(n-2,-1,-1):
if mini<a[i]:
right[i]=1
else:
mini=a[i]
count=0
for i in range(n):
if left[i]==1 and right[i]==1:
count+=1
ans=count*(count-1)
ans=ans//2
print(ans)

What I have done is for each p[i] in given permutation p, i’m seeing if there is a smaller element than p[i] on both the left and right side of its position i in the given array.
If true then I’m increasing count by 1
The count is the number of elements thats swappable.
Then answer = count * (count-1) /2 . That is each pair among the swappable elements.
I can tell you my thought process if you want but can someone please give me a counter example.

Hey @abhay_x007 :smiling_face: ,
Your code is failing for the Test Case
1
5
0 2 1 4 3
Your output : 1
Correct Output : 0

pls anyone explain this approach.