PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: iceknight1093
Tester: raysh07
Editorialist: iceknight1093
DIFFICULTY:
Medium
PREREQUISITES:
Dynamic programming
PROBLEM:
With a fixed parameter K, a permutation P has a valley if there exist K indices i_1 \lt\ldots\lt i_K such that \min(P_{i_1}, P_{i_K}) \gt \max(P_{i_2}, \ldots, P_{i_{K-1}}).
You’re given an array A of length N with elements in [0, N].
Count the number of ways to fill A into a permutation P such that P has no valleys.
EXPLANATION:
First, let’s assume A contains only 0’s, i.e. there are no “fixed” elements.
Let’s try to fill in the permutation by placing elements one at a time, from large to small, i.e. first we’ll place N, then we’ll place N-1, and so on.
Suppose we’ve placed elements N, N-1, \ldots, x.
Among the placed elements, let L denote the index of the leftmost one and R denote the index of the rightmost one.
Consider the range [L, R] of indices. It contains all the elements \ge x, and then it might contain several zeros representing currently-blank spaces.
However, note that all blank spaces will eventually be filled with values strictly smaller than x.
So, if there are at least K-2 blank spaces in the range [L, R], we cannot avoid creating a valley!
Thus, a necessary condition for the final permutation to not have any valleys is that for every integer x, after placing all elements \ge x, there must be fewer than K-2 blank spaces between the leftmost and rightmost placed elements.
It’s easy to see that this condition is also sufficient for the permutation to not have any valleys.
With the above characterization, let’s try to count the number of valid permutations.
Note that we’re still in the case of all zeros in A.
We care about a few different parameters, as seen earlier:
- The value x, denoting the latest number placed.
- The range [L, R] denoting the bounds of the placed values.
- The number of zeros in [L, R].
Note that the number of zeros can actually be inferred from the first two pieces of information; because [L, R] has a length of R-L+1 and it contains all elements \ge x (of which there are N-x+1) and only zeros otherwise.
So, if x, L, R are fixed, the number of zeros equals (R-L+1) - (N-x+1).
With this in mind, let’s define a function dp(x, L, R) to denote the number of valid ways of placing the values \ge x such that the bounds of placed elements are [L, R].
To compute this value, let’s look at where x is being placed.
First, it’s possible that x is placed neither at index L nor at index R, i.e. it doesn’t expand the active segment.
For this to be the case, the previous active segment must’ve been [L, R], and then x can be placed at any 0 within the segment.
There are dp(x+1, L, R) ways for the former; and the number of zeros (before placing x) is (R-L+1) - (N-x).
Secondly, it’s possible that x is placed at a boundary and hence expands it.
Let’s consider the case of placing x at L (placing it at R is symmetric.)
The previous segment must’ve then been some [L_0, R] where L_0 \gt R.
For each of them there’s only one way to place x.
So, we obtain dp(x+1, L+1, R) + dp(x+1, L+2, R) + \ldots + dp(x+1, R, R) possibilities here.
This can be computed in constant time with the help of prefix sums.
The case of placing x at R can be similarly be computed in constant time with prefix sums.
Note that we need to ensure the “not too many empty spaces” condition, i.e. the number of zeros never reaches K-2 or more.
This means (R-L+1) - (N-x+1) \le K-3 must hold.
If this condition fails, we can simply define dp(x, L, R) = 0 and not bother having to compute transitions at all.
Since there are \mathcal{O}(N^3) states, and each one’s computed in constant time, the overall complexity is \mathcal{O}(N^3) as well, which is fast enough.
Now let’s move to the case where A is partially filled.
It turns out that nothing much changes!
Once again, let’s place elements in descending order.
Now however, when placing x, we’ll treat all elements smaller than x as just not being present in the array at all (i.e. we treat their positions as containing zeros.)
Then, when placing x:
- If x doesn’t belong to the array at all, the transitions are basically exactly the same (just the number of choices for placing x will reduce by the number of filled spots in [L, R] with elements \lt x as well.)
- If x is already present in the array, we only need to confirm that it lies in [L, R] instead (and lose out on the multiplier for where to place it.)
This gives us a solution in \mathcal{O}(N^3) overall, so we’re done.
It’s possible to implement this using \mathcal{O}(N^2) memory since dp(x, \cdot) depends only on dp(x+1, \cdot), which might be needed for speed if your constant factor is bad.
TIME COMPLEXITY:
\mathcal{O}(N^3) per testcase.
CODE:
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
const int mod = 998244353;
struct mint{
int x;
mint (){ x = 0;}
mint (int32_t xx){ x = xx % mod; if (x < 0) x += mod;}
mint (long long xx){ x = xx % mod; if (x < 0) x += mod;}
int val(){
return x;
}
mint &operator++(){
x++;
if (x == mod) x = 0;
return *this;
}
mint &operator--(){
if (x == 0) x = mod;
x--;
return *this;
}
mint operator++(int32_t){
mint result = *this;
++*this;
return result;
}
mint operator--(int32_t){
mint result = *this;
--*this;
return result;
}
mint& operator+=(const mint &b){
x += b.x;
if (x >= mod) x -= mod;
return *this;
}
mint& operator-=(const mint &b){
x -= b.x;
if (x < 0) x += mod;
return *this;
}
mint& operator*=(const mint &b){
long long z = x;
z *= b.x;
z %= mod;
x = (int)z;
return *this;
}
mint operator+() const {
return *this;
}
mint operator-() const {
return mint() - *this;
}
mint operator/=(const mint &b){
return *this = *this * b.inv();
}
mint power(long long n) const {
mint ok = *this, r = 1;
while (n){
if (n & 1){
r *= ok;
}
ok *= ok;
n >>= 1;
}
return r;
}
mint inv() const {
return power(mod - 2);
}
friend mint operator+(const mint& a, const mint& b){ return mint(a) += b;}
friend mint operator-(const mint& a, const mint& b){ return mint(a) -= b;}
friend mint operator*(const mint& a, const mint& b){ return mint(a) *= b;}
friend mint operator/(const mint& a, const mint& b){ return mint(a) /= b;}
friend bool operator==(const mint& a, const mint& b){ return a.x == b.x;}
friend bool operator!=(const mint& a, const mint& b){ return a.x != b.x;}
mint power(mint a, long long n){
return a.power(n);
}
friend ostream &operator<<(ostream &os, const mint &m) {
os << m.x;
return os;
}
explicit operator bool() const {
return x != 0;
}
};
// Remember to check MOD
void Solve()
{
int n, k; cin >> n >> k;
vector <int> pos(n + 1), fr(n + 1, 1);
for (int i = 1; i <= n; i++){
int x; cin >> x;
pos[x] = i;
if (x) fr[i] = 0;
}
vector<vector<mint>> dp(n + 1, vector<mint>(n + 1));
for (int i = 1; i <= n; i++) if ((pos[n] == 0 && fr[i] == 1)|| pos[n] == i){
dp[i][i] = 1;
}
vector <int> pfr(n + 1);
for (int i = 1; i <= n; i++){
pfr[i] = pfr[i - 1] + fr[i];
}
int done = pos[n] == 0;
auto valid = [&](int l, int r, int v){
int left = (r - l + 1) - (n - v + 1);
if (left >= k - 2) return false;
return true;
};
for (int v = n - 1; v >= 1; v--){
vector<vector<mint>> ndp(n + 1, vector<mint>(n + 1));
if (pos[v] != 0){
for (int l = 1; l <= n; l++){
for (int r = l; r <= n; r++){
int nl = min(l, pos[v]);
int nr = max(r, pos[v]);
if (valid(nl, nr, v)){
ndp[nl][nr] += dp[l][r];
}
}
}
dp = ndp;
continue;
}
for (int l = 1; l <= n; l++){
for (int r = l; r <= n; r++){
int choices = pfr[r] - pfr[l - 1] - done;
if (choices > 0){
ndp[l][r] += dp[l][r] * choices;
}
}
}
// for (int l = 1; l <= n; l++){
// for (int r = l; r <= n; r++){
// for (int i = 1; i <= n; i++) if (!(l <= i && i <= r) && fr[i]){
// int nl = min(l, i);
// int nr = max(r, i);
// if (valid(nl, nr, v)){
// ndp[nl][nr] += dp[l][r];
// }
// }
// }
// }
for (int r = 1; r <= n; r++){
mint sum = 0;
for (int i = r - 1; i >= 1; i--){
// [i, r] is new interval
sum += dp[i + 1][r];
if (valid(i, r, v) && fr[i]) ndp[i][r] += sum;
}
}
for (int l = 1; l <= n; l++){
mint sum = 0;
for (int i = l + 1; i <= n; i++){
sum += dp[l][i - 1];
if (valid(l, i, v) && fr[i]) ndp[l][i] += sum;
}
}
dp = ndp;
done++;
}
cout << dp[1][n] << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}