PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: hellolad
Tester: kingmessi
Editorialist: iceknight1093
DIFFICULTY:
Medium
PREREQUISITES:
Dynamic programming
PROBLEM:
You’re given a random permutation of the integers 1 to N.
In one move, you can select an index i (1 \leq i \lt N) and set P_i and P_{i+1} both to \min(P_i, P_{i+1}).
Count the number of distinct final arrays.
EXPLANATION:
The fact that the permutation is generated randomly is of course important (why else would it be mentioned?), but it’s not immediately obvious why.
Let’s just try to solve the problem normally, and see where we get stuck.
It’s always useful to make some observations about the process.
Two relatively straightforward ones here are:
- For any integer x, the set of positions containing x will always be a contiguous segment.
- For two indices i \lt j, if both P_i and P_j are present in the final array, then the segment containing P_i will be to the left of the one containing P_j.
Essentially, when one element starts to the left of another, it’s not possible for them to “cross over” without having to delete an element entirely.
These are both fairly easy to prove, but nonetheless good to keep in mind.
We’ll say that index i can reach index j if it’s possible for index j to receive the value P_i via some sequence of operations.
Observe that i can reach j if and only if there’s no element smaller than P_i on the segment from i to j.
So, if L_i \lt i and R_i \gt i are the closest indices to i containing elements smaller than i, the set of indices reachable by i is exactly the range [L_i + 1, R_i - 1].
Let’s also define \text{reach}_i to be the set of indices that can reach i.
A common idea when dealing with counting problems on an array is to use dynamic programming (usually on prefixes), so let’s give it a shot.
Define dp(i, j) to be the number of distinct arrangements of the prefix of length i, such that index i contains the element P_j.
Note that we don’t restrict ourselves to working with only elements that are initially among the first i elements - it’s possible for a later element to reach us too, so we keep the possibility open.
For example, if P = [2, 1, 3], then dp(1, 1) = 1 (the prefix [2] is possible - just don’t perform an operation) and dp(1, 2) = 1 ([1] is possible by operating on i = 1), but dp(1, 3) = 0 since it’s not possible to obtain [3].
How can dp(i, j) be computed?
Well, first things first, if j cannot reach i (i.e. i \not\in [L_j+1, R_j-1]) then dp(i, j) is definitely 0, so it’s enough to only look at those j which can indeed reach i.
So, let the indices which can reach i be, in ascending order, j_1, j_2, \ldots, j_m.
One of them is definitely i; suppose j_k = i.
Observe that since these are indices that can reach i, they have a rather special structure:
- Index j_{k-1} must contain the element just before P_i that’s smaller than it, i.e. j_{k-1} = L_{i}.
- Index j_{k-2} must contain the element just before P_{j_{k-1}} that’s smaller than it, i.e. j_{k-2} = L_{j_{k-1}}.
- In general, for each 1 \lt s \leq k, we must have j_{s-1} = L_{j_s}.
- Similarly, for the other side it can be seen that for each s \leq k \lt m, we have R_{j_s} = j_{s+1}.
That is, for each valid index starting from i onwards, the next valid index is simply the one containing the next smaller element.
This is a useful property to keep in mind for everything that follows.
Let’s now move to computing the values of dp.
Suppose we want to compute dp(i, j_s) for some s.
Let’s look at j_s \leq i first, i.e. an element that’s initially to the left of i.
To transport this value to index i, we need to go through indices j_s+1, j_s+2, \ldots, i, so all of them must certainly be set to P_{j_s}.
However, after doing so, we can still modify these indices further - but only using elements that are initially from the left of j_s.
This follows from the observation we made about the elements forming segments, right at the start: the segment containing P_{j_s} contains index i, so surely the left of index i can only be made up of elements that were originally before j_s as well.
Now that index i has been fixed, we need to deal with the remaining i-1 elements.
It’s easy to see that basically any configuration of them is valid, as long as the element finally at index i-1 comes from an index that’s \leq j_s.
So, we obtain
That is, we can just sum up dp(i-1, y) across all “valid” indices y, where y is valid if it can reach i-1 and is not larger than j_s.
Next, let’s look at j_s \gt i.
As it turns out, the transition for this case is exactly the same!
For index i to contain the element P_{j_s}, we must go through the indices [i, j_s].
Now,
- Any element of \text{reach}_{i-1} that’s \leq i-1 can certainly appear at i-1 now, since nothing about the prefix before i was changed.
- Indices \gt j_s can never have their values appear at i-1, since values cannot “cross over”.
- That leaves indices in [i, j_s] that are in \text{reach}_{i-1}.
These can indeed appear at i-1 (even though they get overwritten later).
This is because we can simply perform operations with these values first, doing whatever we want with them to the left of i, and only then move P_{j_s} to index i.
So, it turns out that for any j_s, we obtain
This dynamic programming solution, while correct, appears quite slow: we have up to \mathcal{O}(N^2) states, and \mathcal{O}(N) transitions from each of them for \mathcal{O}(N^3) overall.
The transitions are easy enough to optimize to constant time: note that we only want some prefix sum of dp(i-1, \cdot) so just computing these prefix sums in advance (or on-the-fly with two pointers) will optimize that.
However, this still leaves the issue of too many states in the DP.
This is where we use the fact that the permutation is randomly generated: it turns out that the expected number of states is quite small!
In particular, it can be proved that the expected number of elements that can visit a given index is \mathcal{O}(\log N), so the expected number of states is \mathcal{O}(N\log N).
To see why, note that for each index we’re functionally working with prefix (suffix) minimums after restricting to the suffix (prefix) involving that element, and it’s well-known that the number of prefix minimums in a random permutation is \mathcal{O}(\log N) - see Blogewoosh #6 (with this comment containing an actual proof).
More empirically, you can simply generate a random permutation yourself and see that this is fast enough: for N = 2\cdot 10^5 there are about 4.5\cdot 10^6 states on average.
So, simply running the DP on just reachable states will result in a solution that’s fast enough.
Finally, note that the number of states being small also simplifies parts of the implementation:
- Rather than have to use a stack to compute previous/next smaller elements, you can just bruteforce this and it’ll be fast enough (not that the stack approach is complicated, of course).
- Similarly, rather than needing prefix sums to optimize DP transitions, they can be performed in a brute-force manner - after all, we’ll only iterate through \mathcal{O}(\log N) values on average anyway.
Doing this will make the complexity \mathcal{O}(N\log^2 N), which is still quite fine for the limits of N \leq 2\cdot 10^5.
TIME COMPLEXITY:
Expected \mathcal{O}(N\log N) or \mathcal{O}(N\log^2 N) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define IOS ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define ll long long
const int mod=998244353;
int main(){
IOS
int t;
cin>>t;
while(t--){
int n;
cin>>n;
vector<int> p(n+2);
for(int i=1;i<=n;++i){
cin>>p[i];
}
vector<int> l(n+1);
for(int i=1;i<=n;++i){
int j=i-1;
while(p[i]<p[j]){
j=l[j];
}
l[i]=j;
}
vector<int> r(n+1);
for(int i=n;i>=1;--i){
int j=i+1;
while(p[i]<p[j]){
j=r[j];
}
r[i]=j;
}
vector<vector<int>> a(n+1);
a[0].push_back(0);
for(int i=1;i<=n;++i){
for(int j=l[i]+1;j<r[i];++j){
a[j].push_back(i);
}
}
vector<int> dp;
dp.push_back(1);
for(int i=1;i<=n;++i){
a[i-1].push_back(n+1);
int p=0;
int cur=0;
vector<int> ndp;
for(auto &x:a[i]){
while(a[i-1][p]<=x){
cur=(cur+dp[p])%mod;
++p;
}
ndp.push_back(cur);
}
dp=ndp;
}
int ans=0;
for(auto &x:dp){
ans=(ans+x)%mod;
}
cout<<ans<<'\n';
}
return 0;
}
Tester's code (C++)
#include<bits/stdc++.h>
#define ll long long
#define rep(i,a,b) for(int i=a;i<b;i++)
#define rrep(i,a,b) for(int i=a;i>=b;i--)
#define repin rep(i,0,n)
#define vi vector<int>
using namespace std;
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}
int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && !isspace(buffer[now])) {
now++;
}
return now;
}
string readOne() {
assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
return res;
}
string readString(int minl, int maxl, const string &pattern = "") {
assert(minl <= maxl);
string res = readOne();
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int minv, int maxv) {
assert(minv <= maxv);
int res = stoi(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
long long res = stoll(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
auto readInts(int n, int minv, int maxv) {
assert(n >= 0);
vector<int> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readInt(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
auto readLongs(int n, long long minv, long long maxv) {
assert(n >= 0);
vector<long long> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readLong(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
}inp;
const long long MM=998244353;
int smn = 0;
void solve()
{
int n;
// cin >> n;
n = inp.readInt(1,200'000);
inp.readEoln();
smn += n;
vi a(n);
for(int i = 0;i < n;i++){
a[i] = inp.readInt(1,n);
if(i == n-1)inp.readEoln();
else inp.readSpace();
}
set<int> uq(a.begin(),a.end());
assert(uq.size() == n);
vector<int> ns(n,n);
vector<int> ls(n,-1);
vector<int> v;
v.reserve(n);
repin{
while(v.size() && a[v.back()] > a[i]){
ns[v.back()] = i;
v.pop_back();
}
v.push_back(i);
}
v.clear();
rrep(i,n-1,0){
while(v.size() && a[v.back()] > a[i]){
ls[v.back()] = i;
v.pop_back();
}
v.push_back(i);
}
vector<vector<int>> pos(n);//what positions can come at each index
repin{
rep(j,ls[i]+1,ns[i])pos[j].push_back(i);
}
repin{
reverse(pos[i].begin(),pos[i].end());
}
//make positions increasing;
vector<array<int,2>> dp;
for(int i = 0;i < pos[0].size();i++)dp.push_back({pos[0][i],1});
for(int i = 1;i < n;i++){
vector<array<int,2>> dp1;
int sm = 0;
while(pos[i].size()){
while(dp.size() && dp.back()[0] <= pos[i].back()){
sm += dp.back()[1];
sm %= MM;
dp.pop_back();
}
dp1.push_back({pos[i].back(),sm});
pos[i].pop_back();
}
dp = dp1;
reverse(dp.begin(),dp.end());
}
int ans = 0;
for(auto &x : dp)ans += x[1],ans %= MM;
ans += MM;ans %= MM;
cout << ans << "\n";
}
signed main(){
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
// int t; cin >> t;
int t = inp.readInt(1,100'000);
inp.readEoln();
while(t--)
solve();
inp.readEof();
assert(smn <= 200'000);
return 0;
}
Editorialist's code (PyPy3)
mod = 998244353
for _ in range(int(input())):
n = int(input())
p = list(map(int, input().split()))
reach = [ [] for _ in range(n) ]
for i in range(n):
for j in range(i, n):
if p[j] < p[i]: break
reach[j].append(i)
for j in reversed(range(i)):
if p[j] < p[i]: break
reach[j].append(i)
dp = [1]*len(reach[0])
for i in range(1, n):
ndp = []
for j in reach[i]:
s = 0
for ii in range(len(reach[i-1])):
if reach[i-1][ii] <= j: s += dp[ii]
ndp.append(s % mod)
dp = ndp[:]
print(sum(dp) % mod)