SPLITSORT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Dynamic programming

PROBLEM:

Given N and K, count the number of permutations for which f(P) is sorted.
f(P) is defined as follows:

  1. If |P|=1, f(P)=P.
  2. Otherwise, choose an index i such that 1 \leq i \leq \min(K, |P|-1).
  3. Then, let P_1 = f(P[1\ldots i]) and P_2 = f(P[i+1\ldots|P|]).
  4. f(P) is either P_1 + P_2 or P_2 + P_1, you may decide which one.

EXPLANATION:

First, let’s understand when exactly a permutation P can be sorted via the given process.

P is broken up into some prefix A (say of length m \leq K) and some suffix B (of length N-m), the process is applied recursively to A and B, and the results are joined in some order.
In particular, A and B must themselves be sortable via the process, otherwise the resulting joined array definitely cannot be sorted.

Further, A will form either the first m or the last m elements of the sorted permutation - and so must contain either the elements \{1, 2, \ldots, m\} or \{N, N-1, \ldots, N-m+1\}.
Note that that first case means the prefix is itself a permutation of length m, while the second means the corresponding suffix is a permutation of length N-m.

In fact, we can prove something even stronger: if it’s possible to sort P, then it’s possible to do so by choosing the smallest valid m, that is, choose the smallest prefix that is either a permutation, or whose corresponding suffix is a permutation.

Proof

Let x denote the position of 1 and y denote the position of N in P.
Without loss of generality, let x \lt y.
Note that any valid prefix to cut at should satisfy x \leq i \lt y.

Let L be the index of the smallest valid prefix.
Suppose there exists a way to sort P by splitting at index i (where i \gt L).

Then, for the permutation [P_1, P_2, \ldots, P_i], L is still the index of the smallest valid prefix.
Further, it can be seen that every element among the first L indices is smaller than every element from indices L+1 to i, which are themselves smaller than every element after index i.

Now, inductively it can be seen that when sorting [P_1, P_2, \ldots, P_i], the first move is to break at index L.
Once that’s done, the part [P_{L+1}, P_{L+2}, \ldots, P_i] essentially gets independently sorted within itself.

However, note that we could’ve split at index L on the very first move instead; and then break the part [P_{L+1}, P_{L+2}, \ldots, P_i] (which is now a prefix of the resulting suffix).
We know that all three parts obtained are solvable (since breaking at i gave a solution), and combining them into a sorted array is also possible, so we’re done!


With this in mind, we can formulate a solution using dynamic programming.
Let dp_N denote the number of permutations of length N that are sortable using the given process.

Let’s fix i, the length of the smallest prefix of P that’s itself a permutation.
Then,

  • We first have two choices: the first i elements of P can either be small
    (i.e a permutation of 1 to i), or large (a permutation of N to N-i+1).
  • Once this choice is made, we need to fix the order of elements in the prefix and suffix.
  • It’s tempting to say that there are dp_i ways to choose the prefix and dp_{N-i} for the suffix, but this isn’t quite true (for the prefix) - we also need to ensure that the prefix we choose has no smaller valid prefixes; after all, we’re fixing the smallest valid prefix.

It’s surprisingly easy to resolve this issue, however:

  • If i = 1, the number of valid prefixes is indeed dp_1 (which is just 1).
  • If i \gt 1, the number of valid prefixes is just \frac{dp_i}{2}.
    • Note that this division by 2 in fact cancels out with the multiplication by 2 from the first step.
Why?

Without loss of generality, suppose the first i elements form a permutation of [1\ldots i].

Let x denote the position of 1, and y denote the position of i.
If x \gt y, no smaller prefix can itself be a permutation so there’s no issue - any prefix containing 1 will also contain i, and hence to be a permutation must be of length at least i itself.

On the other hand, if x \lt y, there will exist a smaller prefix that’s a permutation.
This follows from the fact that this prefix must itself be sortable; and hence some prefix needs to be cut off.
However, 1 appears before i, so a valid cut should have 1 on the left, making the prefix a permutation - a contradiction.

So, what we really want to do is count the number of valid permutations for which 1 appears after i.
This is exactly half of them, since the mapping [P_1, P_2, \ldots, P_i] \to [i+1-P_1, i+1-P_2, \ldots, i+1-P_i] is a bijection that pairs up valid permutations with 1 and i on different sides of each other.


Since we’re only allowed to choose i \leq K, we thus obtain

dp_N = 2dp_{N-1} + \sum_{i=2}^K dp_i\cdot dp_{N-i}

This is easily computed in \mathcal{O}(NK) time, which is fast enough here.

TIME COMPLEXITY:

\mathcal{O}(NK) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

void Solve() 
{
    int n, k; cin >> n >> k;

    vector <int> dp(n + 1, 0);
    dp[1] = 1;
    const int mod = 1e9 + 7;

    for (int i = 1; i <= n; i++){
        dp[i] += dp[i - 1] * 2;
        dp[i] %= mod;
        for (int j = 2; j <= min(k, n - 1); j++){
            dp[i] += dp[j] * dp[i - j];
            dp[i] %= mod;
        }
    }

    cout << dp[n] << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++)
#include<bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...)
#endif

namespace mint_ns {
template<auto P>
struct Modular {
    using value_type = decltype(P);
    value_type value;
 
    Modular(long long k = 0) : value(norm(k)) {}
 
    friend Modular<P>& operator += (      Modular<P>& n, const Modular<P>& m) { n.value += m.value; if (n.value >= P) n.value -= P; return n; }
    friend Modular<P>  operator +  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r += m; }
 
    friend Modular<P>& operator -= (      Modular<P>& n, const Modular<P>& m) { n.value -= m.value; if (n.value < 0)  n.value += P; return n; }
    friend Modular<P>  operator -  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r -= m; }
    friend Modular<P>  operator -  (const Modular<P>& n)                      { return Modular<P>(-n.value); }
 
    friend Modular<P>& operator *= (      Modular<P>& n, const Modular<P>& m) { n.value = n.value * 1ll * m.value % P; return n; }
    friend Modular<P>  operator *  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r *= m; }
 
    friend Modular<P>& operator /= (      Modular<P>& n, const Modular<P>& m) { return n *= m.inv(); }
    friend Modular<P>  operator /  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r /= m; }
 
    Modular<P>& operator ++ (   ) { return *this += 1; }
    Modular<P>& operator -- (   ) { return *this -= 1; }
    Modular<P>  operator ++ (int) { Modular<P> r = *this; *this += 1; return r; }
    Modular<P>  operator -- (int) { Modular<P> r = *this; *this -= 1; return r; }
 
    friend bool operator == (const Modular<P>& n, const Modular<P>& m) { return n.value == m.value; }
    friend bool operator != (const Modular<P>& n, const Modular<P>& m) { return n.value != m.value; }
 
    explicit    operator       int() const { return value; }
    explicit    operator      bool() const { return value; }
    explicit    operator long long() const { return value; }
 
    constexpr static value_type mod()      { return     P; }
 
    value_type norm(long long k) {
        if (!(-P <= k && k < P)) k %= P;
        if (k < 0) k += P;
        return k;
    }
 
    Modular<P> inv() const {
        value_type a = value, b = P, x = 0, y = 1;
        while (a != 0) { value_type k = b / a; b -= k * a; x -= k * y; swap(a, b); swap(x, y); }
        return Modular<P>(x);
    }
 
    friend void __print(Modular<P> v) {
        cerr << v.value;
    }
};
template<auto P> Modular<P> pow(Modular<P> m, long long p) {
    Modular<P> r(1);
    while (p) {
        if (p & 1) r *= m;
        m *= m;
        p >>= 1;
    }
    return r;
}
 
template<auto P> ostream& operator << (ostream& o, const Modular<P>& m) { return o << m.value; }
template<auto P> istream& operator >> (istream& i,       Modular<P>& m) { long long k; i >> k; m.value = m.norm(k); return i; }
template<auto P> string   to_string(const Modular<P>& m) { return to_string(m.value); }
 
}
constexpr int mod = (int)1e9 + 7;
using mod_int = mint_ns::Modular<mod>;
 
int32_t main() {
  ios_base::sync_with_stdio(0);
  cin.tie(0);

  int T;     cin >> T;
  while(T--) {
  int N, K;  cin >> N >> K;
  vector<mod_int> dp(N + 1);
  vector<mod_int> dph(N + 1);
  dp[1] = 1, dph[1] = 1;

  for(int x = 2 ; x <= N ; ++x) {
    for(int y = 1 ; y <= min(x - 1, K) ; ++y) {
      dp[x] += dph[y] * dp[x - y] * 2;
      dph[x] += dph[y] * dp[x - y];
    }
  }
  cout << dp[N] << '\n';
  }
  return 0;
}

Editorialist's code (Python)
mod = 10**9 + 7
for _ in range(int(input())):
    n, k = map(int, input().split())
    dp = [0]*(n+1)
    dp[1] = 1
    for i in range(1, n+1):
        for j in range(1, min(k+1, i)):
            dp[i] += dp[j] * dp[i-j]
        dp[i] += dp[i-1]
        dp[i] %= mod
    print(dp[n])