MEXPERMHD - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: satyam_343
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

Sorting, ordered set/segment tree/fenwick tree

PROBLEM:

You’re given an array A of length N.
Count the number of permutations P of [0, 1, \ldots, N-1] such that, for each i,
A_i = \text{MEX}(P_1, P_2, \ldots, P_i) or A_i = \text{MEX}(P_i, P_{i+1}, \ldots, P_N).

EXPLANATION:

Let’s recall the solution to the easy version, where we only needed to find any valid permutation.
Barring a couple of edge cases, the main idea was:

  • Fix the split between prefix and suffix MEX-es for the non-zero elements of A.
    There are only two options for this split.
  • Once the split is fixed, each element will have a certain interval of positions it can appear in.
  • These intervals are special, in that every pair of them is either disjoint, or will have one contained inside the other.
  • This allowed us to greedy fill in elements in ascending order of interval sizes.

This solution ends up being extremely easy to extend to the counting version!

First, observe that the two different prefix/suffix splits will always give us different sets of permutations. This allows us to just try both of them, count the number of permutations, and add them up to obtain the answer.

Now, let’s look at a single split.
Suppose we’ve computed the intervals [l_i, r_i] for the elements. Without loss of generality, let’s assume they’re sorted in ascending order of length, i.e. r_i - l_i \leq r_{i+1} - l_{i+1}.
Then, to construct an answer we iterate through these intervals, and each time try to pick some empty position within them.
However, it really doesn’t matter which empty position is chosen.

So,

  • For [l_0, r_0], we have r_0 - l_0 + 1 choices of where to place 0.
  • For [l_1, r_1], we have r_1 - l_1 + 1 choices of where to place 1. However, if 0 was already placed at one of these positions, that’s one less option.
  • In general, for [l_i, r_i], we have r_i - l_i + 1 options; minus the number of elements placed inside it already.

To put it simply, the number of options we have for each interval equals the length of the interval, minus the number of previously processed intervals that lie inside of it.
Computing this is a classical problem (see CSES: Nested Ranges Count), and can be done in \mathcal{O}(N\log N).

Once all these values are known, simply multiplying them gives us the number of permutations.
Each split is thus processed in \mathcal{O}(N\log N) time, and there are only two splits so that’s \mathcal{O}(N\log N) overall, with only minor modifications to the solution from the construction version!

There are a couple of details to think about here: mostly, the position of 0 itself.
We obtain a range where 0 can be present, but it can also only be present at a non-zero value of A.
So, for 0 alone, its count is not the length of the interval, but the number of non-zero values inside the interval (which will be not more than 2, so it’s feasible to try every one if you wish to).


Note that depending on your implementation, you might have some edge cases to deal with.
One of them is when the minimum non-zero element is N, in which case you might have to be careful about overcounting.
However, this is not too hard handle separately.

TIME COMPLEXITY:

\mathcal{O}(N\log N) per testcase.

CODE:

Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

#include <bits/extc++.h>
using namespace __gnu_pbds;
template<class T>
using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;

/**
 * Integers modulo p, where p is a prime
 * Source: Aeren (modified from tourist?)
 *         Modmul for 64-bit mod from kactl:ModMulLL
 * Works with p < 7.2e18 with x87 80-bit long double, and p < 2^52 ~ 4.5e12 with 64-bit
 */
template<typename T>
struct Z_p{
	using Type = typename decay<decltype(T::value)>::type;
	static vector<Type> MOD_INV;
	constexpr Z_p(): value(){ }
	template<typename U> Z_p(const U &x){ value = normalize(x); }
	template<typename U> static Type normalize(const U &x){
		Type v;
		if(-mod() <= x && x < mod()) v = static_cast<Type>(x);
		else v = static_cast<Type>(x % mod());
		if(v < 0) v += mod();
		return v;
	}
	const Type& operator()() const{ return value; }
	template<typename U> explicit operator U() const{ return static_cast<U>(value); }
	constexpr static Type mod(){ return T::value; }
	Z_p &operator+=(const Z_p &otr){ if((value += otr.value) >= mod()) value -= mod(); return *this; }
	Z_p &operator-=(const Z_p &otr){ if((value -= otr.value) < 0) value += mod(); return *this; }
	template<typename U> Z_p &operator+=(const U &otr){ return *this += Z_p(otr); }
	template<typename U> Z_p &operator-=(const U &otr){ return *this -= Z_p(otr); }
	Z_p &operator++(){ return *this += 1; }
	Z_p &operator--(){ return *this -= 1; }
	Z_p operator++(int){ Z_p result(*this); *this += 1; return result; }
	Z_p operator--(int){ Z_p result(*this); *this -= 1; return result; }
	Z_p operator-() const{ return Z_p(-value); }
	template<typename U = T>
	typename enable_if<is_same<typename Z_p<U>::Type, int>::value, Z_p>::type &operator*=(const Z_p& rhs){
		#ifdef _WIN32
		uint64_t x = static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value);
		uint32_t xh = static_cast<uint32_t>(x >> 32), xl = static_cast<uint32_t>(x), d, m;
		asm(
			"divl %4; \n\t"
			: "=a" (d), "=d" (m)
			: "d" (xh), "a" (xl), "r" (mod())
		);
		value = m;
		#else
		value = normalize(static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value));
		#endif
		return *this;
	}
	template<typename U = T>
	typename enable_if<is_same<typename Z_p<U>::Type, int64_t>::value, Z_p>::type &operator*=(const Z_p &rhs){
		uint64_t ret = static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value) - static_cast<uint64_t>(mod()) * static_cast<uint64_t>(1.L / static_cast<uint64_t>(mod()) * static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value));
		value = normalize(static_cast<int64_t>(ret + static_cast<uint64_t>(mod()) * (ret < 0) - static_cast<uint64_t>(mod()) * (ret >= static_cast<uint64_t>(mod()))));
		return *this;
	}
	template<typename U = T>
	typename enable_if<!is_integral<typename Z_p<U>::Type>::value, Z_p>::type &operator*=(const Z_p &rhs){
		value = normalize(value * rhs.value);
		return *this;
	}
	template<typename U>
	Z_p &operator^=(U e){
		if(e < 0) *this = 1 / *this, e = -e;
		Z_p res = 1;
		for(; e; *this *= *this, e >>= 1) if(e & 1) res *= *this;
		return *this = res;
	}
	template<typename U>
	Z_p operator^(U e) const{
		return Z_p(*this) ^= e;
	}
	Z_p &operator/=(const Z_p &otr){
		Type a = otr.value, m = mod(), u = 0, v = 1;
		if(a < (int)MOD_INV.size()) return *this *= MOD_INV[a];
		while(a){
			Type t = m / a;
			m -= t * a; swap(a, m);
			u -= t * v; swap(u, v);
		}
		assert(m == 1);
		return *this *= u;
	}
	template<typename U> friend const Z_p<U> &abs(const Z_p<U> &v){ return v; }
	Type value;
};
template<typename T> bool operator==(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value == rhs.value; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(const Z_p<T>& lhs, U rhs){ return lhs == Z_p<T>(rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) == rhs; }
template<typename T> bool operator!=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(const Z_p<T> &lhs, U rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(U lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T> bool operator<(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value < rhs.value; }
template<typename T> bool operator>(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value > rhs.value; }
template<typename T> bool operator<=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value <= rhs.value; }
template<typename T> bool operator>=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value >= rhs.value; }
template<typename T> Z_p<T> operator+(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(const Z_p<T> &lhs, U rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T> Z_p<T> operator-(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T> Z_p<T> operator*(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T> Z_p<T> operator/(const Z_p<T> &lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(const Z_p<T>& lhs, U rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(U lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T> istream &operator>>(istream &in, Z_p<T> &number){
	typename common_type<typename Z_p<T>::Type, int64_t>::type x;
	in >> x;
	number.value = Z_p<T>::normalize(x);
	return in;
}
template<typename T> ostream &operator<<(ostream &out, const Z_p<T> &number){ return out << number(); }

/*
using ModType = int;
struct VarMod{ static ModType value; };
ModType VarMod::value;
ModType &mod = VarMod::value;
using Zp = Z_p<VarMod>;
*/

// constexpr int mod = 1e9 + 7; // 1000000007
constexpr int mod = (119 << 23) + 1; // 998244353
// constexpr int mod = 1e9 + 9; // 1000000009
using Zp = Z_p<integral_constant<decay<decltype(mod)>::type, mod>>;

template<typename T> vector<typename Z_p<T>::Type> Z_p<T>::MOD_INV;
template<typename T = integral_constant<decay<decltype(mod)>::type, mod>>
void precalc_inverse(int SZ){
	auto &inv = Z_p<T>::MOD_INV;
	if(inv.empty()) inv.assign(2, 1);
	for(; inv.size() <= SZ; ) inv.push_back((mod - 1LL * mod / (int)inv.size() * inv[mod % (int)inv.size()]) % mod);
}

template<typename T>
vector<T> precalc_power(T base, int SZ){
	vector<T> res(SZ + 1, 1);
	for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * base;
	return res;
}

template<typename T>
vector<T> precalc_factorial(int SZ){
	vector<T> res(SZ + 1, 1); res[0] = 1;
	for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * i;
	return res;
}

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    auto fac = precalc_factorial<Zp>(1000010);

    auto solve = [&] (vector<int> a) -> Zp {
        int n = size(a);
        Zp ans = 0;
        if (ranges::max(a) == 0) return ans;

        int mn = n;
        for (int x : a) {
            if (x > 0) mn = min(mn, x);
        }

        for (int i = 1; i+1 < n; ++i)
            if (a[i] == n) return ans;
        if (a[0] > 1 and a[0] < n) return ans;
        if (a[n-1] > 1 and a[n-1] < n) return ans;
        
        int L = 0;
        while (a[L] != mn) ++L;
        vector<array<int, 2>> where(n+1, array{0, n-1});

        where[mn][0] = L+1;
        // Prefix
        int p = mn;
        for (int i = L+1; i < n; ++i) {
            if (a[i] == 0) continue;
            if (a[i] < p) return ans;
            if (a[i] == n) continue;

            // a[i] should occur > i
            // everything in [p, a[i] - 1] should occur <= i
            while (p < a[i]) {
                where[p][1] = min(where[p][1], i);
                ++p;
            }
            where[a[i]][0] = max(where[a[i]][0], i+1);
        }

        // Suffix
        p = mn;
        for (int i = L-1; i >= 0; --i) {
            if (a[i] == 0) continue;
            if (a[i] < p) return ans;
            if (a[i] == n) continue;

            while (p < a[i]) {
                where[p][0] = max(where[p][0], i);
                ++p;
            }
            where[a[i]][1] = min(where[a[i]][1], i-1);
        }

        // [0, mn-1] only to the left
        for (int i = L; ; --i) {
            if (i == 0 or (a[i] < n and a[i] > mn)) {
                for (int x = 0; x < mn; ++x)
                    where[x] = {i, L};
                break;
            }
        }

        // 0 can either be at L, or at first non-zero index before L
        // Try both
        auto calc = [&] () {
            vector<array<int, 3>> segs;
            for (int i = 0; i < n; ++i) {
                if (where[i][1] < 0 or where[i][0] >= n or where[i][0] > where[i][1]) return Zp(0);
                segs.push_back({where[i][0], where[i][1], i});
            }
            ranges::sort(segs, [] (auto x, auto y) {return x[1] - x[0] < y[1] - y[0];});

            Tree<array<int, 2>> inside;
            Zp ways = 1;
            for (auto [l, r, i] : segs) {
                Zp cur = r - l + 1;
                cur -= inside.order_of_key({r, n});
                cur += inside.order_of_key({l, -1});
                ways *= cur;
                inside.insert({r, i});
            }

            return ways;
        };
        
        where[0] = {L, L};
        ans += calc();
        for (int i = L-1; i >= 0; --i) {
            if (a[i] > 0) {
                where[0] = {i, i};
                ans += calc();
                break;
            }
        }
        return ans;
    };
    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector a(n, 0);
        for (int &x : a) cin >> x;

        int mn = n;
        for (int x : a) {
            if (x > 0) mn = min(mn, x);
        }
        if (mn == n) {
            // 1 shouldn't be at a 0, everything else is ok
            int zeros = ranges::count(a, 0);
            Zp ans = fac[n-1] * (n - zeros);
            for (int i = 1; i+1 < n; ++i)
                if (a[i] == n) ans = 0;
            cout << ans << '\n';
            continue;
        }

        auto ans = solve(a);
        ranges::reverse(a);
        ans += solve(a);
        cout << ans << '\n';
    }
}