LEXMATCH - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: q_ed
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

Dynamic Programming

PROBLEM:

You’re given N, K, and array B of length N.
Count the number of arrays A of length N with elements in [1, K] such that, for each i, the lexicographically largest subsequence of the first i elements of A, has length B_i.

EXPLANATION:

Let’s first understand how to find the lexicographically largest subsequence of a given array A - or more generally, how to compute it for every prefix of A, since that’s what we care about.

Suppose we’ve found the subsequence for the prefix of length i, say this subsequence is (s_1, s_2, \ldots, s_k) with length k. Note that s_i here are the elements themselves, not the indices.
Observe that s must be a non-increasing sequence, because if s_j \lt s_{j+1} then simply discarding s_j will give a (lexicographically) larger subsequence.

Now, when considering A_{i+1}, let j be the largest index such that s_j \geq A_{i+1}.
The lexicographically largest subsequence of the length i+1 prefix is then in fact just
(s_1, s_2, \ldots, s_j, A_{i+1}), that is, we discard all elements smaller than A_{i+1} and then append A_{i+1}.

Proof

This is fairly easy to prove - one way is to note that the lexicographically largest subsequence can be found using the following greedy algorithm: choose the leftmost maximum of the array, discard all elements to its left, and then repeat.

Consider applying this algorithm to the prefixes A[1, i] and A[1, i+1].
It will behave exactly the same way in both, till the value s_j is picked.
Once that is done, the remaining elements in A[1, i] will all be \lt A_{i+1}.
So, for A[1, i+1] the algorithm will pick A_{i+1} itself and then end because nothing remains; while for A[1, i] it will continue on and end up picking s_{j+1}, \ldots, s_k.

In particular, if B_i denotes the length of the lexicographically largest subsequence of the first i elements, what we did above tells us that 1 \le B_{i+1} \le B_i + 1 must hold, since we append at most one new element but can discard any number of them.
Also, B_1 = 1 must be true, obviously.

If the given array B doesn’t satisfy these two conditions, no valid array can exist so just output 0.
Otherwise, we need to count the number of valid A.


Let i_1 denote the index of the leftmost maximum element in A.
Observe that we definitely must have B_{i_1} = 1, and for all j \gt i_1, B_j \gt 1 must hold (since B_j = 1 can happen if and only if A_j is strictly larger than all previous elements).

So, the index i_1 is in fact fixed: it’s the rightmost occurrence of 1 in B.
Let’s find this index i_1.
Now, suppose we decide that A_{i_1} = x (where 1 \leq x \leq K).
Then,

  • All elements at indices \lt i_1 must take values \lt x, since i_1 must be the leftmost maximum in A.
  • All elements at indices \gt i must take values \leq x (specifically, they are allowed to equal x).
  • The parts [1, i_1-1] and [i_1+1, N] are completely independent of each other, because once we reach index i_1 we’ll end up discarding all previous elements anyway.

So, once A_{i_1} = x is fixed, we can simply multiply the number of ways of filling in [1, i_1-1] with elements in [1, x-1], and the number of ways of filling in [i_1+1, N] with elements in [1, x].

Now, the number of ways of filling in [1, i_1-1] is basically the exact same process: find the rightmost occurrence of 1 in it, choose the value there, and recursively break it into smaller parts.
However, [i_1+1, N] no longer has any occurrences of 1, so what do we do?
Choose the rightmost occurrence of 2, of course! After all, the same logic that showed that i_1 should be at the rightmost occurrence of 1, will show that that the next ‘fixed’ element (i.e. the leftmost maximum in the range) should be at the rightmost occurrence of 2 in this range.

More generally, if we’re considering the range [L, R], the index we’re interested in is the rightmost occurrence of \min(B_L, B_{L+1}, \ldots, B_R) in this range - this is where the leftmost maximum of the range will be.

This allows us to write a recursion as follows:
Let f(L, R, K) be the number of ways to assign values to A_L, \ldots, A_R such that all these values are \leq K, while respecting the constraint on B.
Then,

  • Let m be the index of the rightmost minimum of [B_L, \ldots, B_R].
  • Then, trying all choices for A_m, we have:
f(L, R, K) = \sum_{x=1}^{K} f(L, m-1, x-1) \cdot f(m+1, R, x)
  • A simpler way to write this is to note that it equals f(L, R, K-1) + f(L, m-1, K-1)\cdot f(m+1, R, K).
    The first term corresponds to having A_m \lt K, while the second is for A_m = K.

The recursion can be optimized by storing already computed values, but what exactly is its complexity?

At a glance, there are \mathcal{O}(N^2 K) states, because there are \mathcal{O}(N^2) choices for the range [L, R] and then the upper bound on values is K.
However, we don’t actually visit most of these states.
In fact, it can be seen that the number of visited states is only \mathcal{O}(NK), a whole factor of N smaller!

Why?

We start with the range [1, N], and each time, split it into two smaller disjoint ranges.

Looking at it in reverse, there are several (\leq N) segments of length 1, and in each move we merge two of them into a larger segment.
After at most N merges, we’ll have only [1, N] remaining.
The total number of segments equals the number of initial length-1 segments plus the number of merges, and so cannot exceed 2N.

\mathcal{O}(NK) is indeed fast enough for N, K \leq 3000 so all we have to do is output f(1, N, K) and we’re done!

One thing to be careful about here is the caching of previously computed states.
N\cdot K is large enough that using a map or something like that will end up being too slow.
Instead, there are a couple of ways to implement this without using any heavy structures:

  1. When computing f(L, R, K), let m be the ‘pivot’ element in the range.
    Assign the result of this calculation to dp[m][K].
    Two different ranges cannot have the same pivot index, so we only need a single N\times K 2D array which is perfectly fine (and fast).
  2. Alternately, make the function f(L, R) return a vector of length K, storing the answers for all 1 \leq k \leq K.
    This allows us to process each segment exactly once, since knowing the results for its two children allows for merging in \mathcal{O}(K).

TIME COMPLEXITY:

\mathcal{O}(NK) per testcase.

CODE:

Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n, k; cin >> n >> k;
        vector b(n, 0);
        for (int &x : b) cin >> x;

        vector mn(n, vector(n, n));
        for (int i = 0; i < n; ++i) {
            mn[i][i] = i;
            for (int j = i+1; j < n; ++j) {
                if (b[j] <= b[mn[i][j-1]]) mn[i][j] = j;
                else mn[i][j] = mn[i][j-1];
            }
        }

        {
            bool good = b[0] == 1;
            for (int i = 1; i < n; ++i)
                good &= b[i] <= b[i-1] + 1;
            if (!good) {
                cout << 0 << '\n';
                continue;
            }
        }

        const int mod = 998244353;
        vector dp(n, vector(k+1, 0));
        vector mark(n, vector(k+1, 0));
        auto solve = [&] (const auto &self, int L, int R, int K) -> int {
            if (L > R) return 1;
            if (K <= 0) return 0;
            if (L == R) return K;

            int pivot = mn[L][R];
            if (mark[pivot][K]) return dp[pivot][K];
            ll cur = (1ll * self(self, L, pivot-1, K-1) * self(self, pivot+1, R, K) + self(self, L, R, K-1) ) % mod;
            mark[pivot][K] = 1;
            return dp[pivot][K] = cur;
        };
        cout << solve(solve, 0, n-1, k) << '\n';
    }
}
Author's code (C++)
#include <bits/stdc++.h>
template<typename T1, typename T2>
std::ostream& operator<<(std::ostream& os, const std::pair<T1, T2>& p) {
    os << "(" << p.first << ", " << p.second << ")";
    return os;
}
template <typename T, std::size_t N>
std::ostream& operator<<(std::ostream& os, const std::array<T, N>& arr) {
    os << "[";
    for (std::size_t i = 0; i < N; ++i) {
        os << arr[i];
        if (i < N - 1) {
            os << ", ";
        }
    }
    os << "]";
    return os;
}
template<typename T> std::ostream& operator<<(std::ostream& os, const std::set<T>& s) {
    os << "{ ";
    for(const auto& elem : s) {
        os << elem << " ";
    }
    os << "}";
    return os;
}
template<typename T> std::ostream& operator<<(std::ostream& os, const std::multiset<T>& s) {
    os << "{ ";
    for(const auto& elem : s) {
        os << elem << " ";
    }
    os << "}";
    return os;
}

template<typename T> std::ostream& operator<<(std::ostream& os, std::queue<T> q) {
    // Print each element in the queue
    os << "{ ";
    while (!q.empty()) {
        os << q.front() << " ";
        q.pop();
    }
    os << "}";
    // Print a newline at the end
    return os;
}
template<typename T> std::ostream& operator<<(std::ostream& os, std::deque<T> q) {
    // Print each element in the queue
    os << "{ ";
    while (!q.empty()) {
        os << q.front() << " ";
        q.pop();
    }
    os << "}";
    // Print a newline at the end
    return os;
}
template<typename T> std::ostream& operator<<(std::ostream& os, std::stack<T> q) {
    // Print each element in the queue
    os << "{ ";
    while (!q.empty()) {
        os << q.top() << " ";
        q.pop();
    }
    os << "}";
    // Print a newline at the end
    return os;
}
template<typename T> std::ostream& operator<<(std::ostream& os, std::priority_queue<T> q) {
    // Print each element in the queue
    os << "{ ";
    while (!q.empty()) {
        os << q.top() << " ";
        q.pop();
    }
    os << "}";
    // Print a newline at the end
    return os;
}

template<typename T> std::ostream& operator<<(std::ostream& os, const std::vector<T>& vec) {
    os << "[ ";
    for(const auto& elem : vec) {
        os << elem << " ";
    }
    os << "]";
    return os;
}
template<typename K, typename V> std::ostream& operator<<(std::ostream& os, const std::map<K, V>& m) {
    os << "{ ";
    for(const auto& pair : m) {
        os << pair.first << " : " << pair.second << ", ";
    }
    os << "}";
    return os;
}

template<typename T>
using min_pq = std::priority_queue<T, std::vector<T>, std::greater<T>>;
template<typename T> std::ostream& operator<<(std::ostream& os, min_pq<T> q) {
    // Print each element in the queue
    os << "{ ";
    while (!q.empty()) {
        os << q.top() << " ";
        q.pop();
    }
    os << "}";
    // Print a newline at the end
    return os;
}
using namespace std;
using ll = long long;
#define add push_back 
#define FOR(i,a,b) for (int i = (a); i < (b); ++i)
#define F0R(i,a) FOR(i,0,a)
#define ROF(i,a,b) for (int i = (b)-1; i >= (a); --i)
#define R0F(i,a) ROF(i,0,a)
#define f first
#define s second
#define trav(a,x) for (auto& a: x)
#define int long long
#define vt vector
#define endl "\n"
#define enld "\n"
#define double long double
const ll mod = 998244353;
ll inf = 1e18;
mt19937_64 rnd(chrono::steady_clock::now().time_since_epoch().count());
int n,k;
vt<int> b;
vt<int> solve(int start, int end) {
    vt<int> ans(k+1, 1);
    if(start>end) return ans;
    vt<int> dp(k+1);
    int best = start;
    FOR(i, start, end+1) {
        if(b[i]<=b[best]) best=i;
    }
    auto left = solve(start, best-1), right = solve(best+1, end);
    FOR(i, 1, k+1) {
        dp[i]=left[i-1]*right[i]%mod;
    }
    ans[0]=0;
    FOR(i, 1, k+1) ans[i]=(ans[i-1]+dp[i])%mod;
    return ans;
}
signed main() {
    ios_base::sync_with_stdio(false); 
    cin.tie(0);
    // freopen("input.txt" , "r" , stdin);
    // freopen("output.txt" , "w", stdout);
    int t;cin >> t;
    while(t--) {
        cin >> n >> k;
        b.resize(n);
        F0R(i, n) cin >> b[i];
        bool good = true;
        FOR(i, 1, n) if(b[i]>b[i-1]+1) good=false;
        if(!good) {
            cout << 0 << endl;
            continue;
        }
        cout << solve(0,n-1).back() << endl;
    }
    return 0;
}
2 Likes

Excellent editorial!

Another way to look at it maybe is to make a weighted tree where each node is comprised of indices, and each index has a value, and edge (u, v, w) means val[u]>=val[v] if w==1 and val[u]>val[v] if w==0. After making the tree, the problem converges to finding the number of ways to construct val (val[i]<=k) which satisfies the tree, which can be done easily by dp+dfs.

1 Like