MEX_SEQ - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: notsoloud
Tester: abhidot
Editorialist: iceknight1093

DIFFICULTY:

2838

PREREQUISITES:

Combinatorics, specifically stars and bars

PROBLEM:

Given N and M, count the number of sequences A of length N such that:

  • 0 \leq A_i \leq M
  • |A_i - \text{MEX}(A_1, A_2, \ldots, A_i)| \leq 1

EXPLANATION:

Let’s first analyze the structure of an absolute mex sequence.

Suppose we’ve placed A_1, A_2, \ldots, A_{i-1}, and want to decide on A_i.
Let m_{i-1} = \text{MEX}(A_1, \ldots, A_{i-1}).
There are now two possibilities: we set A_i = m_{i-1}, or we don’t.

  • If we don’t set A_i = m_{i-1}, then \text{MEX}(A_1, \ldots, A_{i-1}, A_i) = m_{i-1}.
    So, we want |A_i - m_{i-1}| \leq 1 and A_i \neq m_{i-1}, which (for the most part) gives us two choices; namely m_{i-1}-1 and m_{i-1}+1.
  • If we set A_i = m_{i-1}, then \text{MEX}(A_1, \ldots, A_{i-1}, A_i) \gt m_{i-1}. In particular,
    • If m_{i-1} + 1 has occurred in the sequence before this, the new MEX is m_{i-1}+2, which is bad for us since |A_i - (m_{i+1}-2)| = 2.
    • If m_{i-1} + 1 hasn’t occurred yet, the new MEX is m_{i-1}+1, which is fine.

Together, these observations give us a somewhat rigid structure on what an absolute mex sequence looks like.

  • Let’s say the sequence goes ‘bad’ at index i if we place m_{i-1}+1 at it.
  • Notice that as soon as a sequence goes bad, every following element must be either m_{i-1}+1 or m_{i-1}-1; since as we noted above, placing m_{i-1} would cause the difference to be 2 which is unacceptable.

So, it’s enough to count the following:

  • The number of sequences that never go bad; let’s call them good sequences
  • For each i, the number of sequences that go bad for the first time at index i.

These can be done separately.

Counting good sequences

It’s not hard to observe that a good sequence will be of the form

[0, 0, 0, \ldots, 0, 1, 1, \ldots, 1, 2, 2, \ldots, 2, 3, 3, \ldots, k]

for some 0 \leq k \leq M.
That is, some zeros, followed by some ones, followed by some twos, and so on.

Counting the number of such sequences is fairly easy.
Let’s fix k, the largest element of the good sequence.
Then, we only need to decide how many times each of 0, 1, 2, \ldots, k occur in the sequence, since it’ll be sorted. Also, each of them needs to occur at least once.

This is the classical stars-and-bars problem in combinatorics, and simply equals \binom{N-1}{k}.

So, the number of good sequences is simply

\sum_{k=0}^M \binom{N-1}{k}

which can be computed in \mathcal{O}(\min(N, M)).

Counting bad sequences

Now, let’s count bad sequences.
If the sequence goes bad at i = 1, the only possibility is A = [1, 1, 1, 1, \ldots, 1].

For i \geq 2, if the sequence goes bad at index i,

  • We’re placing m_{i-1}+1 at this index; so m_{i-1}+1 \leq M must hold, i.e, m_{i-1} \leq M.
  • In particular, indices 1, 2, \ldots, i-1 must form a good sequence whose mex is m_{i-1}. From above, we know there are \binom{i-2}{m_{i-1} - 1} such sequences.
  • Positions i+1, i+2, \ldots, N have two choices each: they can be m_{i-1}-1 or m_{i-1}+1. This is 2^{N-i} possibilities.

So, for a fixed i, by considering all possible values of m_{i-1}, we see that the number of sequences is

2^{N-i}\times\left (\binom{i-2}{0} + \binom{i-2}{1} + \ldots + \binom{i-2}{M-2} \right )

2^{N-i} is easy to compute, but what about the inner summation?

For that, we can use a small trick.
Suppose we know \left (\binom{i-2}{0} + \binom{i-2}{1} + \ldots + \binom{i-2}{M-2} \right ) for some index i.
When we move to i+1, we’d like to know \left (\binom{i-1}{0} + \binom{i-1}{1} + \ldots + \binom{i-1}{M-2} \right ).

Recall Pascal’s identity: \binom{n}{k} = \binom{n-1}{k-1} + \binom{n-1}{k}

Applying this to our summation, we have

\sum_{j=0}^{M-2} \binom{i-1}{j} = \sum_{j=0}^{M-2} \left (\binom{i-2}{j-1} + \binom{i-2}{j-2}\right ) \\ = \binom{i-2}{-1} + 2\times\sum_{j=0}^{M-3} \binom{i-2}{j} + \binom{i-2}{M-2} \\ = 0 + \left (2\times\sum_{j=0}^{M-3} \binom{i-2}{j} + \binom{i-2}{M-2} + \binom{i-2}{M-2}\right ) - \binom{i-2}{M-2} \\ = 2\times\sum_{j=0}^{M-2} \binom{i-2}{j} - \binom{i-2}{M-2}

Notice that we already have \sum_{j=0}^{M-2} \binom{i-2}{j}, so we just need to multiply it by 2 and compute \binom{i-2}{M-2} to subtract, and we’re done.
This allows to move from i to i+1 in \mathcal{O}(1) time, and so we’re done.

Note that depending on your implementation, M = 0 and M = 1 might be special cases (with answers 1 and N+1 respectively).

TIME COMPLEXITY

\mathcal{O}(N) per test case.

CODE:

Setter's code (C++)
#include<bits/stdc++.h>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/tree_policy.hpp>
#pragma GCC optimize "trapv"
#define F first
#define S second
// #define endl "\n"
#define Endl "\n"
#define fbo find_by_order
#define ook order_of_key
#define ll long long
#define ld long double
#define vl vector<long long>
#define pll pair<long long,long long>
#define sl set<long long>
#define uset unordered_set
#define umap unordered_map
#define prq priority_queue
#define pqll priority_queue<ll> 
#define pb push_back
#define ppb pop_back
#define mp make_pair
#define bpc(x) __builtin_popcount(x)
#define sz(v) (int)(v.size())
#define all(v) (v).begin(),(v).end() 
#define mem(a, val) memset(a, val, sizeof(a))
#define mem0(a) memset(a,0,sizeof(a))
#define mem1(a) memset(a,-1,sizeof(a))
#define N 1000000
#define N2 2000000
 
const long double EPS = 0.0001;
const long double PI = 3.141592653589793238;
const long long hell = 1000000007;
const long long mod = 998244353;
const long long INF = 1e16;
using namespace std;
using namespace __gnu_pbds;
typedef tree<ll, null_type, less<ll>, rb_tree_tag,tree_order_statistics_node_update> ordered_set;
mt19937 rng ((unsigned int) chrono::steady_clock::now().time_since_epoch().count());
 
template<typename T, typename U> static inline void amin(T &x, U y){ if(y < x) x = y; }
template<typename T, typename U> static inline void amax(T &x, U y){ if(x < y) x = y; }
ll power(ll x, ll y, ll p=hell) 
{ 
    ll res = 1;  
    x = x % p;  
    while (y > 0) 
    { 
        if (y & 1) 
            res = (res*x) % p; 
        y = y>>1; // y = y/2 
        x = (x*x) % p; 
    } 
    return res; 
} 
  
// Returns n^(-1) mod p 
ll modInverse(ll n, ll p=hell) 
{ 
    return power(n, p-2, p); 
} 
  
// Returns nCr % p using Fermat's little theorem. 
ll fac[N+1];
ll power2[N+1];
ll mInv[N+1];
ll facInv[N+1];
void pre(ll p=hell){
    fac[0] = 1; 
    power2[0]=1;
    mInv[0]=1;
    mInv[1]=1;
    facInv[0]=1;
    facInv[1]=1;
    for (ll i=1 ; i<=N; i++) {
        fac[i] = (fac[i-1]*i)%p;
        power2[i] = (power2[i-1]*2)%p;
        if(i>1){
            mInv[i]=(mInv[p%i]*(p-p/i))%p;
            facInv[i]=(facInv[i-1]*mInv[i])%p;
        }
    }
}
ll nCrModPFermat(ll n, ll r, ll p=hell) 
{ 
    if(r>n)
      return 0;
    if (r==0) 
      return 1; 
    
    return (fac[n]* facInv[r] % p * facInv[n-r] % p) % p; 
    // return (fac[n]* modInverse(fac[r],p) % p * modInverse(fac[n-r],p) % p) % p; 
}
void solve(ll n,ll m){

    m=min(n,m);
    if(m==0){
        cout<<1<<endl;
        return;
    }
    ll s=1;
    ll ans=0;

    if(m>1){
        ans=s*power2[n-2];
        for(ll i=2;i<=n-1;i++){
            // cout<<ans<<" ";
            ll temp=0;
            if(m>2){
                temp=(2*(s-1-nCrModPFermat(i-2,m-2)+hell))%hell;
                temp=(temp+1+nCrModPFermat(i-1,m-2))%hell;
            }
            else{
                temp=1;
            }
            s=temp;
            ans=(ans+(temp*power2[n-i-1])%hell)%hell;
        }
    }
    // cout<<ans;
    for(ll i=0;i<=m;i++){
        ans=(ans+nCrModPFermat(n-1,i))%hell;
    }
    cout<<(ans+1)%hell<<endl;

}
int main(){
   
    ios_base::sync_with_stdio(0);
    cin.tie(NULL);
    cout.tie(NULL);
    #ifndef ONLINE_JUDGE
        freopen("tests/output_10.in", "r", stdin);
        freopen("tests/output_10.out", "w", stdout);
    #endif

    ll t,n,m;
    cin>>t;
    pre();
    ll sum_n=0;
    ll sum_m=0;
    while(t--){
        cin>>n>>m;
        sum_n+=n;
        sum_m+=m;
        solve(n,m);
    }
    cerr<<sum_n<<" "<<sum_m;
    assert(sum_m<=N2);
    assert(sum_n<=N2);

}
Tester's code (C++)


// Problem: MEX Sequences
// Contest: CodeChef - STR84TST
// Memory Limit: 256 MB
// Time Limit: 3000 ms
// Author: abhidot

// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp> 
#include <ext/pb_ds/tree_policy.hpp>
#define int long long
#define ll long long
#define IOS std::ios::sync_with_stdio(false); cin.tie(NULL);cout.tie(NULL);
#define pb push_back
#define mod 1000000007
#define mod2 998244353
#define lld long double
#define pii pair<int, int>
#define ff first
#define ss second
#define all(x) (x).begin(), (x).end()
#define uniq(v) (v).erase(unique(all(v)),(v).end())
#define rep(i,x,y) for(int i=x; i<y; i++)
#define fill(a,b) memset(a, b, sizeof(a))
#define vi vector<int>
#define V vector
#define setbits(x) __builtin_popcountll(x)
#define w(x)  int x; cin>>x; while(x--)
using namespace std;
using namespace __gnu_pbds; 
template <typename num_t> using ordered_set = tree<num_t, null_type, less<num_t>, rb_tree_tag, tree_order_statistics_node_update>;
const long long N=1000005, INF=2000000000000000000, inf = 2e9+5;
 
int power(int a, int b, int p){
	if(a==0)
	return 0;
	int res=1;
	a%=p;
	while(b>0)
	{
		if(b&1)
		res=(res*a)%p;
		b>>=1;
		a=(a*a)%p;
	}
	return res;
}
 
 
void print(bool n){
    if(n){
        cout<<"YES\n";
    }else{
        cout<<"NO\n";
    }
}


int f[N], in[N]; 

int ncr(int n, int r){
	if(n<0) return 0; 
	if(n<r) return 0;
	return f[n]*in[r]%mod*in[n-r]%mod;
}

int ff(int n, int m){
	int ans=0;
	for(int i=0;i<=m;i++){
		ans = (ans + ncr(n-1, i))%mod;
	}
	return ans;
}
 
int32_t main()
{
    IOS;
    f[0]=1, in[0]=1;
    for(int i=1;i<N;i++){
    	f[i]=i*f[i-1]%mod;
    }
    in[N-1]=power(f[N-1], mod-2, mod);
    for(int i=N-2;i>=0;i--){
    	in[i]=(i+1)*in[i+1]%mod;
    }
    
    int in2 = power(2, mod-2, mod);
		w(T){
			int n, m;
			cin>>n>>m;
			m = min(m, n);
			{
				int typ1 = ff(n, m);
				int typ2 = 0, prv = 0;
				for(int i=1;i<n&&m>=2;i++){
					if(i==1) prv = 1;
					else{
						int here = (2*prv%mod - ncr(i-2, min(i-2, m-2)))%mod;
						if(here<0) here+=mod;
						if(i-1<=m-2) here = (here + ncr(i-1, i-1))%mod;
						prv = here;
					}
					// cout<<prv<<" ";
					typ2 = (typ2 + prv*power(2, n-1-i, mod)%mod)%mod;
				}
				int typ3 = (m>0);
				// cout<<typ1<<" "<<typ2<<" "<<typ3<<"\n";
				int ans = (typ1 + typ2 + typ3)%mod;
				cout<<ans<<"\n";
			}
		}
}

Editorialist's code (Python)
mod = 10**9 + 7
maxn = 2 * 10**6 + 10
fac = [i for i in range(maxn)]
fac[0] = 1
for i in range(1, maxn):
	fac[i] *= fac[i-1]
	fac[i] %= mod

ifac = [0]*maxn
ifac[-1] = pow(fac[-1], mod-2, mod)
for i in reversed(range(maxn-1)):
    ifac[i] = (i+1) * ifac[i+1] % mod

def C(n, r):
	if n < r or r < 0: return 0
	return fac[n] * ifac[r] * ifac[n-r]  % mod

for _ in range(int(input())):
	n, m = map(int, input().split())
	if m == 0:
		print(1)
		continue
	if m == 1:
		print(n+1)
		continue
	
	ans = 1
	for i in range(m+1): ans += C(n-1, i)
	ans %= mod
	val = 1
	for i in range(2, n+1):
		# i is the first position where it goes bad
		# 2^(n-i) * sum(C(i-2, k) for k <= m-2)
		ans += val * pow(2, n-i, mod)
		ans %= mod
		
		val *= 2
		val -= C(i-2, m-2)
		val %= mod
	print(ans)
1 Like

When N = 3 and M = 2 (The 1st sample test case), why is [1, 1, 2] a disallowed array?

A_3 = 2
\text{MEX}(A_1, A_2, A_3) = \text{MEX}(1, 1, 2) = 0.
|2 - 0| = 2 \gt 1

Where is “WA Failed Testcases” feature? I would love to see where my code fails.

Looking at your latest submission, at the very least I can say that if (n-1 <= m) ans = pow(2, n) is wrong.
For N = 3 and M = 5 the answer is 9.

Edit: the other part looks wrong too, fwiw.

1 Like

Oh, understood.
Thanks!

Can someone explain the implementation for counting the bad sequences? I understood the math equation for number of sequences going bad at i.