PERMRED - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iceknight1093
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Dynamic programming

PROBLEM:

Given an array A of length M, perform the following move on it:

  • Choose an index i such that 1 \lt i \lt M, and A_i \gt \max(A_{i-1}, A_{i+1}).
    Then, delete \max(A_{i-1}, A_{i+1}) from the array. M reduces by 1.

f(A) denotes the minimum length of the final array if you choose the order of operations optimally.
Given N, find for each K from 1 to N the number of permutations P such that f(P) = K.

EXPLANATION:

First, let’s try and find f(P) for a single permutation P.

Note that no matter what, 1 cannot be deleted from P, since it’ll never be the maximum of two elements.
Further, if P_i = 1, we can never perform an operation on index i (since 1 isn’t going to be larger than its neighbors).
So, the problem essentially splits into two independent parts: the prefix ending at 1, and the suffix starting from 1.
Let’s look at what happens to the suffix, the prefix can be dealt with similarly.

Note that we essentially have a permutation of [M] now (M being the length of the suffix), whose first element is 1.
Observe that for any integer x, if x is ever deleted by an operation, one of the following conditions must be true:

  • The element to the immediate left of x is greater than it, and the one two steps to the left is less than it; or
  • The element immediately to the right of x is greater than it, and the one two steps to the right of x is less than it.

Now, since order is preserved, the first case is impossible if x is initially greater than every element to its left, and the second case is impossible if every element \lt x occurs before every element \gt x.
In other words, if x is both a prefix maximum, and is such that every element in \{1, 2, \ldots, x-1\} occurs before everything in \{x+1, x+2, \ldots, M\}, it’s impossible to delete x.
This can be written in a more succinct form: the first x elements of the permutation should themselves form a permutation of \{1, 2, \ldots, x\}.

Every x that satisfies this condition certainly cannot be deleted, and it turns out that every x that doesn’t satisfy it can be deleted simultaneously.
That is, for a permutation P whose first element is 1, f(P) equals the number of prefixes of P that are themselves permutations!

Proof

Let S be an empty stack.
Consider the following algorithm.
For each i = 1, 2, 3, \ldots, K,

  • If P_i \gt top(S) (or S is empty), push P_i onto S.
  • Otherwise, while the second element of S is \gt P_i, delete it from S.
    In this case, P_i is not pushed onto S.

The elements remaining in S at the end of this process are exactly the ones that remain.
This is because:

  • Clearly, S contains only some elements that are prefix maximums of P.
    That is, any non-prefix maximum is certainly deleted.
  • Suppose there are indices i \lt j \lt k such that P_k \lt P_i \lt P_j.
    Then, when processing P_k, the maximum of the stack will be at least P_j (and hence P_i, if it exists in the stack, will be at best the second element).
    This means when P_k is processed, all elements \gt P_k will be deleted from S, which includes P_i if it’s still in S.

So, every ‘deletable’ element will be deleted at the end of the process, as we required.


Let’s use this characterization to compute the answer.
For a fixed N and K (1 \leq K \leq N), a rather simple algorithm is as follows:

  • Fix the position of 1 in the permutation, which gives us N choices.
    Suppose we fix P_i = 1.
  • Choose which elements go into the prefix, \binom{N-1}{i-1} choices.
  • Also fix x, the “answer” for the prefix. Note that this fixes y = K-x-1 to be the “answer” for the suffix.
  • Now, we can choose any permutation of length i whose answer is x (and first element is 1), and any permutation of length N-i whose answer is (K-x-1) (and first element is 1), and combine them to get a valid permutation of length N with answer K.
    So, all we really need to know is the number of permutations of a fixed length and answer, but with first element 1.

Let dp[n][i] denote the number of permutations of [n] such that exactly i prefixes of the permutation are themselves subpermutations (note that we aren’t fixing the first element to be 1 here).
Then, we have the following definition:

dp[n][i] = \begin{cases} n! - \displaystyle\sum_{x \lt n} dp[x][1] \cdot (n-x)!, & \text{if } i = 1, \\ \\ \displaystyle\sum_{x\lt n} dp[x][i-1] \cdot dp[n-x][1], & \text{if } i \gt 1. \end{cases}

The base cases here are dp[0][0] = dp[1][1] = 1.

The reasoning is:

  • If i = 1, we want to count the number of permutations such that no proper prefix is itself a permutation.
    To do this, we subtract out the number of ‘bad’ permutations from the total number (n!), which can be obtained by fixing the first proper prefix that’s a permutation (dp[x][1] ways, for prefix length x), and then arranging the other elements however we like.
  • If i \gt 1, the idea is similar: fix the point where the (i-1)-th prefix ends (say x), then choose any permutation of length x with answer i-1 and combine it with any permutation of length (n-x) with answer 1 for the last part.

This can be computed for all N \leq 500 in \mathcal{O}(N^3) time, which is fast enough.
It’s possible to optimize the transitions using NTT to bring the complexity down to \mathcal{O}(N^2 \log N), though this isn’t needed to get AC.

Now that we have all the dp[n][i] values, simply plug them into the original counting part to get the answer.
Note that we want to deal with permutations whose first element is 1, but the number of permutations of [n] with i prefix subpermutations and first element 1, simply equals the number of permutations of [n-1] with i-1 prefix subpermutations, i.e, dp[n-1][i-1].
So, for (N, K), the answer is

\sum_{i=1}^N \sum_{x=0}^{i-1} \binom{N-1}{i-1} dp[i-1][x] \cdot dp[N-i][K-x-1]

Computing this in \mathcal{O}(N^2) per K gives a solution in \mathcal{O}(N^3) overall, which is fast enough for AC.
Again, though unnecessary, it’s possible to optimize this to \mathcal{O}(N^2 \log N) per test using NTT.

TIME COMPLEXITY:

\mathcal{O}(N^3) per testcase.

CODE:

Tester's code (C++)
#include<bits/stdc++.h>

using namespace std;

#define all(x) begin(x), end(x)
#define sz(x) static_cast<int>((x).size())
#define int long long
typedef long long ll;


const int mod = 998244353;
int dp[501][501];
int fact[501];
int C[501][501];


signed main() {

        ios::sync_with_stdio(0);
        cin.tie(0);

        fact[0] = 1;
        for (int i = 1; i <= 500; i++) fact[i] = fact[i - 1] * i % mod;
        dp[0][0] = 1;
        for (int i = 0; i <= 500; i++) {
                C[i][0] = 1; C[i][i] = 1;
                for (int j = 1; j < i; j++) C[i][j] = (C[i - 1][j - 1] + C[i - 1][j]) % mod;
        }

        for (int i = 1; i <= 500; i++) {
                for (int j = 1; j <= 500; j++) {
                        for (int k = 0; k < i; k++) dp[i][j] += dp[k][j - 1] * fact[i - k] % mod;
                        dp[i][j] %= mod;
                }
                for (int j = 1; j < 500; j++) {
                        dp[i][j] -= dp[i][j + 1];
                        if (dp[i][j] < 0) dp[i][j] += mod;
                }
        }
        
        int t;
        cin >> t;

        while (t--) {

                int n;
                cin >> n;
                int ans[n + 1] = {0};
                for (int i = 1; i <= n; i++) {
                        for (int l1 = 0; l1 < n; l1++) {
                                int l2 = n - l1 - 1;
                                for (int j1 = 0; j1 < i; j1++) {
                                        int j2 = i - j1 - 1;
                                        ans[i] += C[n - 1][l1] * dp[l1][j1] % mod * dp[l2][j2] % mod;
                                }
                        }
                        ans[i] %= mod;
                }

                for (int i = 1; i <= n; i++) cout << ans[i] << " "; 
                cout << "\n";


        }
        
}
Editorialist's code (Python)
mod = 998244353
N = 505
C = [ [0 for _ in range(N)] for _ in range(N)]
for i in range(N):
    C[i][0] = 1
    for j in range(1, i+1):
        C[i][j] = (C[i-1][j] + C[i-1][j-1]) % mod

dp = [ [0 for _ in range(N)] for _ in range(N)]
dp[0][0] = dp[1][1] = 1
fac = [1]*N
for i in range(2, N): fac[i] = fac[i-1] * i % mod

for n in range(2, N):
    dp[n][1] = fac[n]
    for i in range(1, n):
        dp[n][1] = (dp[n][1] - dp[i][1] * fac[n-i]) % mod
    for i in range(2, n+1):
        for j in range(1, n):
            dp[n][i] = (dp[n][i] + dp[j][i-1] * dp[n-j][1]) % mod

for _ in range(int(input())):
    n = int(input())
    ans = []
    for k in range(1, n+1):
        cur = 0
        for i in range(1, n+1):
            for x in range(i):
                if k-x-1 < 0: break
                cur = (cur + C[n-1][i-1] * dp[i-1][x] % mod * dp[n-i][k-x-1]) % mod
        ans.append(cur)
    print(*ans)