COLOSSEUM - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iceknight1093
Tester: raysh07
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

Dynamic programming, elementary combinatorics

PROBLEM:

For a permutation P, define \text{win}(P, K) as follows:

  • Initialize x = 1.
  • For i from 2 to N in order, if P_i + K \geq P_x, do x \gets i.
    \text{win}(P, K) is the final value of x.

Given N and K, for each i count the number of permutations of length N for which \text{win}(P, K) = i.

EXPLANATION:

Let’s fix an index i and try to count for it.

Observe that if \text{win}(P, K) = i, then i has to be a suffix maximum of P, i.e. there can’t be any index \gt i that contains a value \gt P_i.
In fact, the condition is a bit more strict: the values at indices \gt i must be \lt P_i - K.
This is because, if there’s a value that’s \geq P_i - K, then the leftmost such value will defeat i and so i cannot win.

Let’s now fix the value at index i, say x, and try to count the number of permutations that i wins with P_i = x.
As noted earlier, everything after i should be \lt x - K.
There are \binom{x-K}{N-i} ways to choose these elements, and (N-i)! ways to arrange them once chosen. That leaves only the elements before index i.

Let’s look at whichever suffix maximum occurs just before index i. Suppose this is value m, at index j.
Then,

  • Elements at indices j+1, j+2, \ldots, i-1 must all be \lt x; otherwise one of them would be the next suffix maximum before x instead.
  • m \leq x+K must hold, because x must be able to “beat” m.
    Note that it doesn’t really matter if something between m and x beats m: because all these values are \lt x, x will be able to beat it anyway.

Now, observe that the same conditions hold for m: it should be able to beat the next suffix maximum, and elements between them should be \lt m. This then repeats for the next suffix maximum, and so on.

So, let’s try to build the permutation from right to left, starting at i.
When we’re at index j \lt i, we have two options: we can either place a new suffix maximum here, or we can place a “small” element and not change the suffix maximum.
Observe that in the second case, it doesn’t really matter which small element is placed. Further, once an element becomes small, it continues to be small in the future, i.e. even if the suffix maximum is increased.

This tells us what information must be maintained as we go.
Let f(j, m, y) denote the number of ways to fill in indices [j\ldots i], such that the current suffix maximum is m and we have y small elements available.
Then,

  1. Suppose we set P_j = m, so this is the new suffix maximum.
    Then the previous suffix maximum should be something \geq m - K. Let it be m'.
    Once m becomes the new maximum, everything \lt m counts as “small” for the future.
    However, everything \lt m' already counted as small - so we really just get (m - m' - 1) new small elements.
    So, we add f(j+1, m', y - (m - m' - 1)) for all valid m'.
  2. Suppose we don’t set P_j = m. Then we need to place a small element here.
    There are y small elements remaining, so there must’ve been y+1 earlier and we chose one of them.
    For this case, we add f(j+1, m, y+1) \cdot (y+1) to the value.

The base case is, of course, f(i, x, x - 1 - (N-i)) = \binom{x-K}{N-i} \cdot (N-i)! because of the elements initially placed to the right of i.
The final answer is obtained by looking at f(1, N, 0), since N will always be the largest suffix maximum, and there should be no small elements left to place once we reach index 1.


This DP has \mathcal{O}(N^3) states, with \mathcal{O}(K) transitions from each one.
It’s run once for each pair of (i, x) (recall that we set P_i = x), so the overall complexity is currently \mathcal{O}(N^6) which is too slow.

To optimize it, we can do two things:

  1. Note that the DP doesn’t need to be run separately for each (i, x) pair.
    Instead, fix i, and then initialize the DP with the values for all x simultaneously before running it once.
    This saves a factor of N.
  2. For the DP transitions, note that the slow part of computing f(j, m, y) is when we add f(j+1, m', y - (m - m' - 1)) for some range of m'.
    Here, the difference of the second and third arguments is a constant (and equals y - m + 1), so we’re essentially summing the values along a diagonal. This can thus be optimized to constant time by building prefix sums along each such diagonal.
    This saves another factor of N.

Applying both optimizations leads to a complexity of \mathcal{O}(N^4), which is fast enough because we have N \leq 100.

TIME COMPLEXITY:

\mathcal{O}(N^4) per testcase.

CODE:

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

const int mod = 998244353;

struct mint{
    int x;

    mint (){ x = 0;}
    mint (int32_t xx){ x = xx % mod; if (x < 0) x += mod;}
    mint (long long xx){ x = xx % mod; if (x < 0) x += mod;}

    int val(){
        return x;
    }
    mint &operator++(){
        x++;
        if (x == mod) x = 0;
        return *this;
    }
    mint &operator--(){
        if (x == 0) x = mod;
        x--;
        return *this;
    }
    mint operator++(int32_t){
        mint result = *this;
        ++*this;
        return result;
    }
    
    mint operator--(int32_t){
        mint result = *this;
        --*this;
        return result;
    }
    mint& operator+=(const mint &b){
        x += b.x;
        if (x >= mod) x -= mod;
        return *this;
    }
    mint& operator-=(const mint &b){
        x -= b.x;
        if (x < 0) x += mod;
        return *this;
    }
    mint& operator*=(const mint &b){
        long long z = x;
        z *= b.x;
        z %= mod;
        x = (int)z;
        return *this;
    }
    mint operator+() const {
        return *this;
    }
    mint operator-() const {
        return mint() - *this;
    }
    mint operator/=(const mint &b){
        return *this = *this * b.inv();
    }
    mint power(long long n) const {
        mint ok = *this, r = 1;
        while (n){
            if (n & 1){
                r *= ok;
            }
            ok *= ok;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        return power(mod - 2);
    }
    friend mint operator+(const mint& a, const mint& b){ return mint(a) += b;}
    friend mint operator-(const mint& a, const mint& b){ return mint(a) -= b;}
    friend mint operator*(const mint& a, const mint& b){ return mint(a) *= b;}
    friend mint operator/(const mint& a, const mint& b){ return mint(a) /= b;}
    friend bool operator==(const mint& a, const mint& b){ return a.x == b.x;}
    friend bool operator!=(const mint& a, const mint& b){ return a.x != b.x;}
    mint power(mint a, long long n){
        return a.power(n);
    }
    friend ostream &operator<<(ostream &os, const mint &m) {
        os << m.x;
        return os;
    }
    explicit operator bool() const {
        return x != 0;
    }
};

// Remember to check MOD

struct factorials{
    int n;
    vector <mint> ff, iff;
    
    factorials(int nn){
        n = nn;
        ff.resize(n + 1);
        iff.resize(n + 1);
        
        ff[0] = 1;
        for (int i = 1; i <= n; i++){
            ff[i] = ff[i - 1] * i;
        }
        
        iff[n] = ff[n].inv();
        for (int i = n - 1; i >= 0; i--){
            iff[i] = iff[i + 1] * (i + 1);
        }
    }
    
    mint C(int n, int r){
        if (n == r) return mint(1);
        if (n < 0 || r < 0 || r > n) return mint(0);
        return ff[n] * iff[r] * iff[n - r];
    }
    
    mint P(int n, int r){
        if (n < 0 || r < 0 || r > n) return mint(0);
        return ff[n] * iff[n - r];
    }
    
    mint solutions(int n, int r){
        // Solutions to x1 + x2 + ... + xn = r, xi >= 0 
        return C(n + r - 1, n - 1);
    }
    
    mint catalan(int n){
        return ff[2 * n] * iff[n] * iff[n + 1];
    }
};

const int PRECOMP = 3e6 + 69;
factorials F(PRECOMP);

// REMEMBER To check MOD and PRECOMP

void Solve() 
{
    int n, k; cin >> n >> k;
    
    // 3 state mint dp 
    // fix value that wins 
    // then we can do a dp 
    // compute for all lengths 
    // this is O(n^3) maybe? 
    // dp[length][current maximum][number of < x - k used] 
    // options at each step are :
    // i) improve current maxima, should be <= mx + k
    // ii) use element in [x - k, mx)
    // iii) use element < x - k 
    // improve needs prefix sums 
    
    vector dp(n + 1, vector(n + 1, vector<mint>(n + 1, 0)));
    vector <mint> ps(n + 1, 0);
    
    vector <mint> ans(n + 1, 0);
    
    for (int x = 1; x <= n; x++){
        // x wins 
        
        for (int i = 0; i <= n; i++){
            for (int j = 0; j <= n; j++){
                for (int k = 0; k <= n; k++){
                    dp[i][j][k] = 0;
                }
            }
        }
        
        dp[1][x][0] = 1;
        int have = max(0LL, x - k - 1);
        
        for (int len = 1; len < n; len++){
            for (int mx = 1; mx <= n; mx++){
                for (int used = 0; used <= n; used++) if (dp[len][mx][used]){
                    // place < x - k 
                    mint ways = have - used;
                    dp[len + 1][mx][used + 1] += ways * dp[len][mx][used];
                    
                    // place >= x - k and upto mx 
                    int bound = max(1LL, x - k);
                    mint region = mx - bound + 1;
                    region -= (len - used);
                    
                    dp[len + 1][mx][used] += region * dp[len][mx][used];
                    
                    // for (int place = mx + 1; place <= min(n, mx + k); place++){
                    //     dp[len + 1][place][used] += dp[len][mx][used];
                    // }
                }
            }
            
            // prefix sums for new maximum 
            
            for (int used = 0; used <= n; used++){
                for (int i = 0; i <= n; i++){
                    ps[i] = 0;
                }
                
                for (int mx = 1; mx <= n; mx++){
                    ps[mx] += dp[len][mx][used];
                    ps[mx] += ps[mx - 1];
                }
                
                for (int mx = 1; mx <= n; mx++){
                    int bound = max(1LL, mx - k);
                    dp[len + 1][mx][used] += ps[mx - 1] - ps[bound - 1];
                }
            }
        }
        
        for (int len = 1; len <= n; len++){
            // maximum must be n 
            int left = n - len;
            int had = max(0LL, x - k - 1);
            int used = had - left;
            
            if (used < 0) continue;
            
            mint ways = dp[len][n][used];
            // combinations of rest? 
            ways *= F.ff[left];
            
            ans[len] += ways;
        }
    }
    
    for (int i = 1; i <= n; i++){
        cout << ans[i] << " \n"[i == n];
    }
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}