CNTFIN - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

Dynamic programming, combinatorics

PROBLEM:

For an array A of length N, containing integers in [1, K], define f(A) to be the maximum integer L such that there exists an array B of length L satisfying:

  • B_i \in [1, K]
  • \text{LCS}(A, B) \leq 1
    Here, \text{LCS} refers to the length of the longest common subsequence.

If L can be made arbitrarily large, f(A) = 0 instead.

Given N and K, compute the sum of f(A) across all arrays A of length N with integers in [1 ,K].

EXPLANATION:

We already know from the infinite version, that for f(A) \gt 0 to be true, every element in [1, K] must appear at least twice in it.

So, let A be such an array. What exactly will the value of f(A) be?
To answer this, let’s look at some array B that satisfies \text{LCS}(A, B) \leq 1.
First, since A contains every element (at least) twice, B cannot contain any repeated elements at all - otherwise the LCS would be of length 2 or more.
This limits the length of B to K.

On the other hand, it’s not true that all K elements can always be chosen: for example, if A = [1, 2, 1, 2], it’s not possible to choose either B = [1, 2] or B = [2, 1], since in either case the LCS would be 2.
This example can be generalized a bit: if there are two elements x and y such that both [x, y] and [y, x] are present as subsequences in A, then any B satisfying \text{LCS}(A, B) \leq 1 can contain at most one of x and y - not both.

Now, to check for subsequences of the form [x, y] and [y, x], only the leftmost/rightmost occurrences of x and y matter (since if we can’t pair the leftmost occurrence of x with the rightmost occurrence of y, [x, y] won’t exist as a subsequence).
With this in mind, let’s define l_x and r_x to be the indices of the leftmost/rightmost occurrences of x in A.

Consider two elements x and y such that l_x \lt l_y. Then,

  • If l_y \gt r_x, so the intervals [l_x, r_x] and [l_y, r_y] are disjoint, it’s possible for B to contain both x and y without any issue.
  • If l_y \lt r_x instead, so the intervals [l_x, r_x] and [l_y, r_y] do intersect, then we run into trouble: l_x \lt r_y so [x, y] exists as a subsequence, and l_y \lt r_x so [y, x] exists as a subsequence.
    So, in this case, at most one of x and y can be chosen.

More generally, observe that for any array B that satisfies \text{LCS}(A, B) \leq 1, we must have [l_{B_i}, r_{B_i}] and [l_{B_j}, r_{B_j}] be disjoint intervals, for any pair of distinct indices i, j.

So, if we create the intervals [l_x, r_x] from the indices of A, f(A) simply equals the maximum size of a subset of disjoint intervals among them!


Our task now is to sum up this value across all valid arrays A.
Just as in the previous problem, note that if 2K \gt N there are no arrays that contain every element at least twice, so the answer is trivially 0.
We work with 2K \leq N now.

First, let’s recall how to compute the maximum number of mutually disjoint intervals, if we’re given a set of intervals.
This can be solved using a simple greedy algorithm, as follows:

  1. Among all intervals, choose the interval that ends the earliest.
  2. Then, discard all intervals that intersect the chosen interval, and repeat till the set becomes empty.

Optimality of this can be proved via an exchange argument: if the earliest-ending interval is not chosen, either it can be included for free (improving the solution), or it intersects at most one chosen interval so that interval can be replaced by this one to obtain a not-worse solution.

Keeping this in mind, let’s try to build the array A from left to right.
When placing the next element, there’s some information we need to know: is the element we’re placing extending an existing interval, ending an existing interval, or starting a new one?

This, along with the low constraints on N, lends itself to a solution using dynamic programming.
Define dp(i, x, y, z) to be the sum of answers across all ways of placing the first i elements such that:

  1. There are x “open” intervals (i.e. their last element hasn’t appeared yet).
  2. There are y “closed” intervals, which cannot be extended any further.
  3. There are z (z \leq x) “useless” intervals, i.e. open intervals that intersect a previously closed interval (and hence will not contribute to the answer any more - in terms of the greedy algorithm, these are the intervals that get discarded after each choice).

Also let ct(i, x, y, z) denote the number of arrays satisfying the above definitions of (i, x, y, z). We’ll need this information for transitions.

Suppose we’re at state (i, x, y, z). Then,

  1. We can close a useful open interval.
    There are x-z “useful” open intervals so any of them can be closed, and all the rest become useless in the future.
    We thus transition to the state (i+1, x-1, y+1, x-1).
    As for the transition itself, closing an open interval adds 1 to the answer, so for each existing state with answer s, we want to add s+1 to the new state.
    The sum of all s is exactly what’s stored in dp(i, x, y, z), and the sum of all the 1's is just the number of states which is stored in ct(i, x, y, z) - so we add
    (x-z)\cdot (dp(i, x, y, z) + ct(i, x, y, z)) to dp(i+1, x-1, y+1, x-1), and also add
    (x-z)\cdot ct(i, x, y, z) to ct(i+1, x-1, y+1, x-1) to keep the counts updated.
  2. We can close a useless open interval.
    There are z choices here, and we transition to (i+1, x-1, y+1, z-1).
    The transition is simple: add z\cdot dp(i, x, y, z) to dp(i+1, x-1, y+1, z-1) and
    z\cdot ct(i, x, y, z) to ct(i+1, x-1, y+1, z-1).
  3. We can extend an existing open interval, without closing it.
    There are x choices here, and we transition to (i+1, x, y, z).
  4. We can start a new interval.
    There are (K-x-y) choices for which element to use, and we transition to (i+1, x+1, y, z).

There are \mathcal{O}(NK^3) states here, and \mathcal{O}(1) transitions from each.
For N \leq 100 and K \leq \frac N 2, this is easily fast enough.

It’s possible to use \mathcal{O}(K^3) memory since dp(i, \ldots) updates only dp(i+1, \ldots), so all previous tables can be forgotten - this doesn’t change the time complexity but will help improve runtime a bit nonetheless.

TIME COMPLEXITY:

\mathcal{O}(NK^3) per testcase when K \leq \frac N 2, and \mathcal{O}(1) otherwise.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

const int mod = 998244353;

struct mint{
    int x;

    mint (){ x = 0;}
    mint (int32_t xx){ x = xx % mod; if (x < 0) x += mod;}
    mint (long long xx){ x = xx % mod; if (x < 0) x += mod;}

    int val(){
        return x;
    }
    mint &operator++(){
        x++;
        if (x == mod) x = 0;
        return *this;
    }
    mint &operator--(){
        if (x == 0) x = mod;
        x--;
        return *this;
    }
    mint operator++(int32_t){
        mint result = *this;
        ++*this;
        return result;
    }
    
    mint operator--(int32_t){
        mint result = *this;
        --*this;
        return result;
    }
    mint& operator+=(const mint &b){
        x += b.x;
        if (x >= mod) x -= mod;
        return *this;
    }
    mint& operator-=(const mint &b){
        x -= b.x;
        if (x < 0) x += mod;
        return *this;
    }
    mint& operator*=(const mint &b){
        long long z = x;
        z *= b.x;
        z %= mod;
        x = (int)z;
        return *this;
    }
    mint operator+() const {
        return *this;
    }
    mint operator-() const {
        return mint() - *this;
    }
    mint operator/=(const mint &b){
        return *this = *this * b.inv();
    }
    mint power(long long n) const {
        mint ok = *this, r = 1;
        while (n){
            if (n & 1){
                r *= ok;
            }
            ok *= ok;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        return power(mod - 2);
    }
    friend mint operator+(const mint& a, const mint& b){ return mint(a) += b;}
    friend mint operator-(const mint& a, const mint& b){ return mint(a) -= b;}
    friend mint operator*(const mint& a, const mint& b){ return mint(a) *= b;}
    friend mint operator/(const mint& a, const mint& b){ return mint(a) /= b;}
    friend bool operator==(const mint& a, const mint& b){ return a.x == b.x;}
    friend bool operator!=(const mint& a, const mint& b){ return a.x != b.x;}
    mint power(mint a, long long n){
        return a.power(n);
    }
    friend ostream &operator<<(ostream &os, const mint &m) {
        os << m.x;
        return os;
    }
    explicit operator bool() const {
        return x != 0;
    }
};

// Remember to check MOD

struct factorials{
    int n;
    vector <mint> ff, iff;
    
    factorials(int nn){
        n = nn;
        ff.resize(n + 1);
        iff.resize(n + 1);
        
        ff[0] = 1;
        for (int i = 1; i <= n; i++){
            ff[i] = ff[i - 1] * i;
        }
        
        iff[n] = ff[n].inv();
        for (int i = n - 1; i >= 0; i--){
            iff[i] = iff[i + 1] * (i + 1);
        }
    }
    
    mint C(int n, int r){
        if (n == r) return mint(1);
        if (n < 0 || r < 0 || r > n) return mint(0);
        return ff[n] * iff[r] * iff[n - r];
    }
    
    mint P(int n, int r){
        if (n < 0 || r < 0 || r > n) return mint(0);
        return ff[n] * iff[n - r];
    }
    
    mint solutions(int n, int r){
        // Solutions to x1 + x2 + ... + xn = r, xi >= 0 
        return C(n + r - 1, n - 1);
    }
    
    mint catalan(int n){
        return ff[2 * n] * iff[n] * iff[n + 1];
    }
};

const int PRECOMP = 3e6 + 69;
factorials F(PRECOMP);

// REMEMBER To check MOD and PRECOMP

void Solve() 
{
    int n, k; cin >> n >> k;
    
    if (n < 2 * k){
        cout << 0 << "\n";
        return;
    }
    
    // dp[i][number of current open][number of already closed][number of possible uses]
    // (i, x, y, z) satisfying x + y <= k, z <= x 
    
    // dp = sum of answers over all ways 
    // f = count of ways 
    
    vector<vector<vector<mint>>> dp(k + 1, vector<vector<mint>>(k + 1, vector<mint>(k + 1)));
    auto f = dp;
    f[0][0][0] = 1;
    
    for (int i = 1; i <= n; i++){
        vector<vector<vector<mint>>> ndp(k + 1, vector<vector<mint>>(k + 1, vector<mint>(k + 1)));
        auto nf = ndp;
        
        for (int x = 0; x <= k; x++){
            for (int y = 0; y <= k; y++){
                for (int z = 0; z <= k; z++){
                    if (f[x][y][z]){
                        mint ways;
                        
                        // start a new thing 
                        if (x + 1 <= k && z + 1 <= k){
                            ways = (k - x - y);
                            nf[x + 1][y][z + 1] += f[x][y][z] * ways;
                            ndp[x + 1][y][z + 1] += dp[x][y][z] * ways;
                        }
                        
                        // continue an ongoing thing
                        {
                            ways = x;
                            nf[x][y][z] += f[x][y][z] * ways;
                            ndp[x][y][z] += dp[x][y][z] * ways;
                        }
                        
                        // end an ongoing thing, useful 
                        if (x > 0){
                            ways = z;
                            nf[x - 1][y + 1][0] += f[x][y][z] * ways;
                            ndp[x - 1][y + 1][0] += dp[x][y][z] * ways + f[x][y][z] * ways;
                        }
                        
                        // end an ongoing thing, not useful 
                        if (x > 0){
                            ways = x - z;
                            nf[x - 1][y + 1][z] += f[x][y][z] * ways;
                            ndp[x - 1][y + 1][z] += dp[x][y][z] * ways;
                        }
                    }
                }
            }
        }
        
        f = nf;
        dp = ndp;
    }
    
    cout << dp[0][k][0] << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}