PERMDEL-Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3

Author: Danny Boy
Tester: Takuki Kurokawa
Editorialist: Daanish Mahajan

DIFFICULTY:

Medium

PREREQUISITES:

Segment tree, Prefix Sums, Combinatorics

PROBLEM:

You are given an array A of length N, which initially is a permutation of integers from 1 to N.
Find the number of ways to perform N operations wherein each operation we choose any index i and remove A_i from A (which decrements the size of A by 1) ensuring that the array doesn’t contain any consecutive monotone triples. Print answer modulo 10^9+7.

EXPLANATION:

Hint 1

At any point, we can choose to remove one amongst at max 4 points, 2 from the start, and 2 from the end.

Hint 2

Try coming up with an expression involving the last 3 elements left in terms of the removed N - 3 elements.

Hint 3

Suppose the indexes of the last 3 remaining elements are represented as L, M and R, the number of ways to get:
L as the left of the 3 remaining elements is independent of indexes belonging to the set \{M + 1, M + 2, ..., N\} - \{R\} and
R as the right of the 3 remaining elements is independent of indexes belonging to the set \{1, 2, ..., M - 1\} - \{L\}.

Solution

Suppose after the removal of x elements our array is represented as A_x and the size of A_x = sz_x. We can choose to remove one amongst at max 4 points to reach A_{x + 1}:

  1. First (,i.e, A_{x_1}) (Easy to see)
  2. Second (,i.e, A_{x_2})
  3. Last (,i.e, A_{x_{sz_x}}) (Easy to see)
  4. Second last (,i.e, A_{x_{sz_x - 1}})
Proof

Let the present state of the array be represented as B.
Also, note that B must not contain any monotone triples. Consider the first 5 elements of B which must be related as:

  1. B_1\gt B_2 \lt B_3 \gt B_4 \lt B_5 or,
  2. B_1\lt B_2 \gt B_3 \lt B_4 \gt B_5.
Why other elements can't be removed

Suppose the first relation holds (proof for the second will be similar).
After removing B_3 one of the following relations will hold:

  1. \color{red}{B_1\gt B_2 \gt B_4} \lt B_5 or,
  2. B_1\gt \color{red}{B_2 \lt B_4 \lt B_5}.
    both of which have a 3 monotonic triple. Same will hold for other elements as well.
    So we have a contradiction, so we can’t remove anything except for the mentioned 4 elements.
When second and second last elements can be removed

Let’s talk about the second element, the same will follow for the second last element as well.
So this holds when for condition:

  1. B_1 \lt B_3
  2. B_1 \gt B_3

Now B_3 becomes our second element which can be removed when for condition:

  1. B_1 \gt B_4
  2. B_1 \lt B_4

\ldots

Terminating Condition

This removal is possible until some index x where for condition:

  1. B_1 \lt B_{x + 2}
  2. B_1 \gt B_{x + 2}

when x is even, and

  1. B_1 \gt B_{x + 2}
  2. B_1 \lt B_{x + 2}

when x is odd.
,i.e, we can’t remove index x + 1 without removing index 1.

Suppose the indexes of the last 3 remaining elements after removing the remaining N - 3 elements are represented as L, M and R, the number of ways to get:
L as the left of the 3 remaining elements is independent of indexes belonging to the set \{M + 1, M + 2, ..., N\} - \{R\} and
R as the right of the 3 remaining elements is independent of indexes belonging to the set \{1, 2, ..., M - 1\} - \{L\}.
which means contribution from the removal of a prefix is independent of the contribution from the removal of the suffix.

So we can fix our M and find the number of ways to get (L, M, R) as the last 3 elements as follows:
Iterate from index(i) 2 to N - 1 and consider it as our M.
Find the number of ways in which we can get an index j \lt i as L and
let the sum of the number of ways over all valid j:

  1. Where A_j \lt A_i be mulL[i][0].
  2. Where A_j \gt A_i be mulL[i][1].

Find the number of ways in which we can get an index k \gt i as R and
let the sum of the number of ways over all valid k:

  1. Where A_k \lt A_i be mulR[i][0].
  2. Where A_k \gt A_i be mulR[i][1].

Finally the triplet (j, i, k) is valid if either both A_j and A_k are smaller than A_i or vice-versa.

Our answer = (\sum_{i = 2}^{N - 1} ((mulL[i][0] \cdot mulR[i][0] + mulL[i][1] \cdot mulR[i][1]) \cdot \binom{N - 3}{i - 2})) \cdot 3!.

where \binom{N - 3}{i - 2} represents the number of ways to choose i - 2 positions where we remove an index i \lt M out of the total N - 3 removals and 3! represents the arrangement of last 3 remaining elements since they can be removed in any order.

We explain the computation part only for prefixes, computation for suffixes can be done similarly by reversing the array.

Calculating mulL array

Suppose we represent removing the first element as op1 and removing the second element as op2.
Define ways[i] as the number of ways to remove the prefix of length i.
Define range[i] as maximum index j (\gt i) such that if index i is the first element, elements in the range [i + 1, i + 2, \ldots, j], can be the second element if we only choose op2 and not op1.
So:
ways[0] = 1
ways[i] = ways[i - 1] + \sum_{j = 1}^{i - 1} ways[j - 1] (where j \lt i and range[j] \ge i). (for i \ge 1)

  • First term - After removing i - 1 length prefix, perform op1 operation.
  • Second term - After removing j - 1 length prefix, perform i - j op2 operations and then perform op1 operation.
Calculating range array

Define:
forward[i][0][0] as the smallest even index j \gt i such that A_j \lt A_i.
forward[i][1][0] as the smallest odd index j \gt i such that A_j \lt A_i.
forward[i][0][1] as the smallest even index j \gt i such that A_j \gt A_i.
forward[i][1][1] as the smallest odd index j \gt i such that A_j \gt A_i.

if no such index exists, value = INF.

range[i] = min(min(forward[i][i & 1][1], forward[i][1 - (i & 1)][0]) - 2, N)

The expression comes from our previous discussion that how long we can keep removing the second element.
We can calculate the forward array using the segment tree in \mathcal{O}(N\log_{2}N) and the range array in \mathcal{O}(N) after that.

Calculating ways array

ways array can be calculated in \mathcal{O}(N) using prefix sums and range update query.

Finally,
mulL[i][0] = \sum\limits_{\substack{j = 1 \\ A_j \lt A_i}}^{i - 1} ways[j - 1]

mulL[i][1] = \sum\limits_{\substack{j = 1 \\ A_j \gt A_i}}^{i - 1} ways[j - 1]

which can be calculated in \mathcal {O}(N \log_{2}N) using segment tree after calculating the ways array and in \mathcal {O}(N) using prefix sums and range update query.

ALTERNATE IMPLEMENTATION:

Check the tester’s code for reference.

COMPLEXITY ANALYSIS:

So final complexity is \mathcal {O}(N \log_{2}N).

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
#define ll long long
#define fi first
#define se second
#define mp make_pair
using namespace std;
void db() {cout << endl;}
template <typename T, typename ...U> void db(T a, U ...b) {cout << a << ' ', db(b...);}
#ifdef Cloud
#define file freopen("input.txt", "r", stdin), freopen("output.txt", "w", stdout)
#else
#define file ios::sync_with_stdio(false); cin.tie(0)
#endif
#define pii pair<int, int>
const int inf = 1e9 + 1, N = 3e5 + 1, mod = 1e9 + 7;
ll fac[N]{1};
ll modpow(ll x, int p){
    ll ans = 1;
    for (int i = p; i; i >>= 1, x = x * x % mod) if (i & 1) ans = ans * x % mod;
    return ans;
}
ll C(int n, int m){
    if (n < m) return 0;
    return fac[n] * modpow(fac[m] * fac[n - m] % mod, mod - 2) % mod;
}
struct seg{
    int l, r, mid, sum = 0, lz = 0;
    seg *ch[2]{};
    void push(){
        if (!lz) return;
        ch[0]->sum = ch[1]->sum = 0;
        ch[0]->lz = ch[1]->lz = 1;
        lz = 0;
    }
    void add(int pos, int val){
        if (l == r) return void(sum = (sum + val) % mod);
        push();
        if (pos <= mid) ch[0]->add(pos, val);
        else ch[1]->add(pos, val);
        sum = (ch[0]->sum + ch[1]->sum) % mod;
    }
    void reset(int _l, int _r){
        if (_l <= l and _r >= r){
            sum = 0;
            lz = 1;
            return;
        }
        if (_l > r or _r < l) return;
        push();
        ch[0]->reset(_l, _r);
        ch[1]->reset(_l, _r);
        sum = (ch[0]->sum + ch[1]->sum) % mod;
    }
    seg (int _l, int _r) : l(_l), r(_r), mid(l + r >> 1){
        if (l != r) ch[0] = new seg(l, mid), ch[1] = new seg(mid + 1, r);
    }
};
void solve(){
    int n;
    cin >> n;
    vector<int> a(n), b(n);
    for (int &i : a) cin >> i;
    for (int i = 1; i < n; i++) b[i] = a[i] > a[i - 1];
    b[0] = b[1] ^ 1;
    ll pre[n], suf[n], ans = 0;
    seg *rt = new seg(1, n);
    rt->add(a[0], 1);
    for (int i = 1; i < n - 1; i++){
        pre[i] = rt->sum;
        rt->add(a[i], rt->sum);
        if (b[i + 1] == 1) rt->reset(a[i + 1], n);
        else rt->reset(1, a[i + 1]);
    }
    rt->reset(1, n);
    rt->add(a[n - 1], 1);
    for (int i = n - 2; i > 0; i--){
        suf[i] = rt->sum;
        rt->add(a[i], rt->sum);
        if (b[i - 1] == 1) rt->reset(a[i - 1], n);
        else rt->reset(1, a[i - 1]);
    }
    for (int i = 1; i < n - 1; i++){
        ans += (pre[i] * suf[i] % mod) * (C(n - 3, i - 1) * 6 % mod);
        ans %= mod;
    }
    cout << ans << '\n';
}
int main(){
    for (int i = 1; i < N; i++) fac[i] = fac[i - 1] * i % mod;
    file;
    int t;
    cin >> t;
    while (t--) solve();
}
Tester's Solution
#include <bits/stdc++.h>
using namespace std;

template <long long mod>
struct modular {
    long long value;
    modular(long long x = 0) {
        value = x % mod;
        if (value < 0) value += mod;
    }
    modular& operator+=(const modular& other) {
        if ((value += other.value) >= mod) value -= mod;
        return *this;
    }
    modular& operator-=(const modular& other) {
        if ((value -= other.value) < 0) value += mod;
        return *this;
    }
    modular& operator*=(const modular& other) {
        value = value * other.value % mod;
        return *this;
    }
    modular& operator/=(const modular& other) {
        long long a = 0, b = 1, c = other.value, m = mod;
        while (c != 0) {
            long long t = m / c;
            m -= t * c;
            swap(c, m);
            a -= t * b;
            swap(a, b);
        }
        a %= mod;
        if (a < 0) a += mod;
        value = value * a % mod;
        return *this;
    }
    friend modular operator+(const modular& lhs, const modular& rhs) { return modular(lhs) += rhs; }
    friend modular operator-(const modular& lhs, const modular& rhs) { return modular(lhs) -= rhs; }
    friend modular operator*(const modular& lhs, const modular& rhs) { return modular(lhs) *= rhs; }
    friend modular operator/(const modular& lhs, const modular& rhs) { return modular(lhs) /= rhs; }
    modular& operator++() { return *this += 1; }
    modular& operator--() { return *this -= 1; }
    modular operator++(int) {
        modular res(*this);
        *this += 1;
        return res;
    }
    modular operator--(int) {
        modular res(*this);
        *this -= 1;
        return res;
    }
    modular operator-() const { return modular(-value); }
    bool operator==(const modular& rhs) const { return value == rhs.value; }
    bool operator!=(const modular& rhs) const { return value != rhs.value; }
    bool operator<(const modular& rhs) const { return value < rhs.value; }
};
template <long long mod>
string to_string(const modular<mod>& x) {
    return to_string(x.value);
}
template <long long mod>
ostream& operator<<(ostream& stream, const modular<mod>& x) {
    return stream << x.value;
}
template <long long mod>
istream& operator>>(istream& stream, modular<mod>& x) {
    stream >> x.value;
    x.value %= mod;
    if (x.value < 0) x.value += mod;
    return stream;
}

constexpr long long mod = (long long) 1e9 + 7;
using mint = modular<mod>;

mint power(mint a, long long n) {
    mint res = 1;
    while (n > 0) {
        if (n & 1) {
            res *= a;
        }
        a *= a;
        n >>= 1;
    }
    return res;
}

vector<mint> fact(1, 1);
vector<mint> finv(1, 1);

mint C(int n, int k) {
    if (n < k || k < 0) {
        return mint(0);
    }
    while ((int) fact.size() < n + 1) {
        fact.emplace_back(fact.back() * (int) fact.size());
        finv.emplace_back(mint(1) / fact.back());
    }
    return fact[n] * finv[k] * finv[n - k];
}

template <typename T>
struct fenwick {
    int n;
    vector<T> node;

    fenwick(int _n) : n(_n) {
        node.resize(n);
    }

    void add(int x, T v) {
        while (x < n) {
            node[x] += v;
            x |= (x + 1);
        }
    }

    T get(int x) {  // [0, x]
        T v = 0;
        while (x >= 0) {
            v += node[x];
            x = (x & (x + 1)) - 1;
        }
        return v;
    }

    T get(int x, int y) {  // [x, y]
        return (get(y) - (x ? get(x - 1) : 0));
    }

    int lower_bound(T v) {
        int x = 0;
        int h = 1;
        while (n >= (h << 1)) {
            h <<= 1;
        }
        for (int k = h; k > 0; k >>= 1) {
            if (x + k <= n && node[x + k - 1] < v) {
                v -= node[x + k - 1];
                x += k;
            }
        }
        return x;
    }
};

int main() {
    int tt;
    cin >> tt;
    while (tt--) {
        int n;
        cin >> n;
        vector<int> a(n);
        for (int i = 0; i < n; i++) {
            cin >> a[i];
        }
        vector<vector<mint>> f(2, vector<mint>(n + 1));
        for (int z = 0; z < 2; z++) {
            vector<mint> dp(n + 1);
            fenwick<mint> ft(n + 1);
            set<int> st;
            dp[a[0]] = 1;
            ft.add(a[0], 1);
            st.emplace(a[0]);
            for (int i = 1; i < n - 1; i++) {
                if (a[i] > a[i + 1]) {
                    while (!st.empty() && *st.rbegin() > a[i]) {
                        int j = *st.rbegin();
                        ft.add(j, -dp[j]);
                        st.erase(j);
                    }
                } else {
                    while (!st.empty() && *st.begin() < a[i]) {
                        int j = *st.begin();
                        ft.add(j, -dp[j]);
                        st.erase(j);
                    }
                }
                dp[a[i]] += ft.get(n);
                ft.add(a[i], dp[a[i]]);
                st.emplace(a[i]);
            }
            for (int i = 0; i < n; i++) {
                f[z][i] = dp[a[i]];
            }
            reverse(a.begin(), a.end());
        }
        mint ans = 0;
        for (int k = 1; k <= n - 2; k++) {
            ans += C(n - 3, k - 1) * f[0][k] * f[1][n - k - 1];
        }
        ans *= 6;
        cout << ans << '\n';
    }
    return 0;
}
Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define pii pair<int, int>
#define pb push_back
#define mp make_pair
#define F first
#define S second
const int maxN = 200020, mod = 1e9 + 7;
int n;
int a[maxN];
int forw[maxN][2][2], back[maxN][2][2]; // index, parity, small/ big
int seg[4 * maxN][2]; // index, parity
ll prefix[maxN];
ll waysL[maxN], waysR[maxN];
ll mulL[maxN][2][2], mulR[maxN][2][2]; // index, parity, small/ big
ll ifac[maxN], fac[maxN];

void build(int n, int s, int e, int val){
	for(int i = 0; i < 2; i++){
		seg[n][i] = val;
	}
	if(s == e){
		return;
	}
	int m = (s + e) >> 1;
	build(2 * n + 1, s, m, val);
	build(2 * n + 2, m + 1, e, val);
}
void upd(int t, int n, int s, int e, int id, int val){
	if(s == e){
		seg[n][t] = val;
		return;
	}
	int m = (s + e) >> 1;
	if(id <= m)upd(t, 2 * n + 1, s, m, id, val);
	else upd(t, 2 * n + 2, m + 1, e, id, val);
	seg[n][t] = min(seg[2 * n + 1][t], seg[2 * n + 2][t]);
}
int query(int t, int n, int s, int e, int l, int r, int end){
	if(l > e || r < s || l > r)return end;
	if(l <= s && e <= r){
		return seg[n][t];
	}
	int m = (s + e) >> 1;
	return min(query(t, 2 * n + 1, s, m, l, r, end), query(t, 2 * n + 2, m + 1, e, l, r, end));
}
ll rpe(ll a, int b){
	ll ans = 1;
	while(b != 0){
		if(b & 1)ans = ans * a % mod;
		a = a * a % mod; b >>= 1;
	}
	return ans;
}
ll nCr(int n, int r){
	if(r > n || n < 0)return 0;
	return fac[n] * ifac[r] % mod * ifac[n - r] % mod;
}
void cal(int (*forw)[2][2], ll *waysL, ll (*mulL)[2][2]){
	build(0, 1, n, n + 10);
	for(int i = n; i >= 1; i--){
		for(int j = 0; j < 2; j++){
			forw[i][j][0] = query(j, 0, 1, n, 1, a[i] - 1, n + 10);
			forw[i][j][1] = query(j, 0, 1, n, a[i] + 1, n, n + 10);
			// cout << forw[i][j][0] << " " << forw[i][j][1] << endl;
		}
		upd(i & 1, 0, 1, n, a[i], i);
	}

	for(int i = 0; i <= n + 1; i++){
		prefix[i] = 0; 
		for(int j = 0; j < 2; j++){
			for(int k = 0; k < 2; k++){
				mulL[i][j][k] = 0;
			}
		}
	}
	waysL[0] = 1;
	for(int i = 1; i < n; i++){
		prefix[i] = (prefix[i] + prefix[i - 1]) % mod;
		waysL[i] = (prefix[i] + waysL[i - 1]) % mod;
		for(int j = 0; j < 2; j++){
			for(int k = 0; k < 2; k++){
				mulL[i][j][k] = (mulL[i][j][k] + mulL[i - 1][j][k]) % mod;
			}
		}
		int order = a[i] < a[i + 1], par = i & 1;
		int idx = min(min(forw[i][par][order], forw[i][1 - par][1 - order]) - 2, n);
		if(idx >= i + 1){
			prefix[i + 1] = (prefix[i + 1] + waysL[i - 1]) % mod;
			prefix[idx + 1] = (prefix[idx + 1] - waysL[i - 1] + mod) % mod;

			mulL[i + 2][par][order] = (mulL[i + 2][par][order] + waysL[i - 1]) % mod;
			mulL[idx + 2][par][order] = (mulL[idx + 2][par][order] - waysL[i - 1] + mod) % mod;
		}
		if(idx >= i){
			mulL[i + 1][1 - par][1 - order] = (mulL[i + 1][1 - par][1 - order] + waysL[i - 1]) % mod;
			mulL[idx + 2][1 - par][1 - order] = (mulL[idx + 2][1 - par][1 - order] - waysL[i - 1] + mod) % mod;
		}
	}
}

int main() {
	// freopen("0.in", "r", stdin);
	// freopen("0.txt", "w", stdout);
	ios_base::sync_with_stdio(false); cin.tie(NULL); 
	fac[0] = ifac[0] = 1;
	for(int i = 1; i < maxN; i++){
		fac[i] = fac[i - 1] * i % mod;
	}
	ifac[maxN - 1] = rpe(fac[maxN - 1], mod - 2);
	for(int i = maxN - 2; i >= 1; i--){
		ifac[i] = ifac[i + 1] * (i + 1) % mod;
	}

	int t; cin >> t;
	while(t--){
		cin >> n;
		for(int i = 1; i <= n; i++)cin >> a[i];

		cal(forw, waysL, mulL);

		for(int i = 1; i <= n / 2; i++)swap(a[i], a[n - i + 1]);

		cal(back, waysR, mulR);

		// comp
		ll ans = 0;
		for(int i = 2; i < n; i++){
			ll comb = nCr(n - 3, i - 2);
			ll small = mulL[i][i & 1][0] * mulR[n - i + 1][(n - i + 1) & 1][0] % mod;
			ll large = mulL[i][i & 1][1] * mulR[n - i + 1][(n - i + 1) & 1][1] % mod;
			ll ways = (small + large) % mod;
			ans = (ans + comb * ways % mod) % mod;
		}
		ans = ans * 6 % mod;

		cout << ans << endl;
	}
	return 0;
}