CNTINF - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

The inclusion-exclusion principle, combinatorics

PROBLEM:

For an array A of length N, containing integers in [1, K], define f(A) to be the maximum integer L such that there exists an array B of length L satisfying:

  • B_i \in [1, K]
  • \text{LCS}(A, B) \leq 1
    Here, \text{LCS} refers to the length of the longest common subsequence.

If L can be made arbitrarily large, f(A) = 0 instead.

Given N and K, count the number of arrays A of length N with integers in [1 ,K] for which f(A) = 0.

EXPLANATION:

Clearly, the first thing to do is understand which types of arrays can have f(A) = 0.

One obvious case is if some integer x \in [1, K] doesn’t appear in A at all.
If this happens, we can always take B = [x, x, x, \ldots, x] to form an array of any length that satisfies \text{LCS}(A, B) = 0.
So, every array that doesn’t contain some element of [1, K] trivially has a value of 0.

Since the condition we care about is \text{LCS}(A, B) \leq 1, the above condition can be extended a bit further too - if some integer x \in [1, K] appears exactly once in A, we still have f(A) = 0.
The exact same construction of [x, x, x, \ldots, x] works here, just that \text{LCS}(A, B) = 1 this time which is still fine.


We’ve now identified a couple of classes of arrays that satisfy f(A) = 0, namely all those arrays that either have some element of [1, K] missing, or that have some element of [1, K] appearing only once.

That leaves arrays where every element of [1, K] appears at least two times.
However, all such arrays will definitely have f(A) \gt 0, i.e. a finite maximum length.
It’s easy to see why: B can only contain elements in [1, K] itself, and if it contains any element twice, its \text{LCS} with A will surely be at least 2 (since A is known to contain this element twice).
So, any valid B can contain each element at most once, which limits its length to be no more than K.

We now know exactly what needs to be counted: arrays of length N that contain some element of [1, K] at most once.
Alternately, we can count the number of arrays that contain every element at least twice, and subtract this from the total count of arrays which is K^N.


First off, observe that if 2K \gt N then trivially any array will have f(A) = 0, because it’s impossible for an array of length N to contain each element of [1, K] at least twice.
So, if 2K \gt N, the answer is just K^N, the total number of arrays.

This leaves us with K \leq \frac N 2, and since N \leq 10^4, complexities like \mathcal{O}(NK) will now work.

There are now a few different ways of doing the requisite counting.
The intended solution, running in \mathcal{O}(K^2), is to use the inclusion-exclusion principle.

Inclusion-Exclusion

Let’s try to directly count the number of arrays that have some element appearing \leq 1 time.

For a subset S of \{1, 2, \ldots, K\}, let f(S) denote the number of arrays in which every element of S appears at most once, while every element not in S appears at least twice.
For example, if K = 4, f(\{1, 3\}) will denote the number of arrays that contain 1 and 3 at most once each, while 2 and 4 appear at least twice each.

Note that with this definition, the value we’re interested in is f(\emptyset) (where \emptyset denotes the empty set).
Specifically, f(\emptyset) will denote the number of arrays in which every element appears at least twice, so our answer is K^N - f(\emptyset).

It’s not immediately obvious how to compute the function f for a given set, so we use a common trick: we relax the conditions a bit.
Define g(S) to be the number of arrays in which every element of S appears at most once, but we don’t care about what happens about elements outside of S (so some of them may occur \leq 1 time too).

Note that with this definition, g(S) equals the sum of f(T) across all subsets such that S\subseteq T.

Computing g(S) for a given subset S is not too hard.
Let M = |S| denote the size of the subset.
Each of the M elements in S has two choices: it either does not appear in A, or it appears exactly once in A.
Suppose we fix i (0 \leq i \leq M) to be the number of elements that don’t appear in A at all.
Then,

  1. There are \binom{M}{i} ways to choose which i elements among these M don’t appear.
  2. The remaining M-i elements appear once each in A.
    There are \binom{N}{M-i} ways of choosing their positions, and then (M-i)! ways of arranging them in these positions.
  3. For the remaining N - (M-i) positions of the array, anything will work - as long as it’s not one of the M elements of S.
    So, there are K - M choices for each index, for a total of (K-M)^{(N - M + i)} ways.

So, we quite simply have \displaystyle g(S) = \sum_{i=0}^M \binom{M}{i}\binom{N}{M-i}\cdot (M-i)! \cdot (K-M)^{(N - M + i)}
where M = |S|.

Now, because \displaystyle g(S) = \sum_{S\subseteq T} f(T), the inclusion-exclusion principle tells us that we have
\displaystyle f(S) = \sum_{S\subseteq T} (-1)^{|T| - |S|} g(T)

We’re only interested in S = \emptyset, for which |S| = 0.
So, what we want to compute is

\sum_{T} (-1)^{|T|} g(T)

This doesn’t seem immediately useful, because while each g(T) is easy to compute, there are 2^K choices of T to go through in total.

To optimize this, note that the value of g(T) depended only on the size |T| of the subset.
So, if we fix a size M and compute the quantity
\displaystyle \sum_{i=0}^M \binom{M}{i}\binom{N}{M-i}\cdot (M-i)! \cdot (K-M)^{(N - M + i)}
once (in \mathcal{O}(M) time), this value will be the exact same for every subset of size M.

There are \binom{K}{M} subsets of size M, so we can simply multiply the above value by this and either add or subtract this to the answer, depending on parity.

This allows us to only iterate through all values 0 \leq M \leq K, with the final summation being

\sum_{M=0}^K (-1)^M \binom{K}{M} \left(\sum_{i=0}^M \binom{M}{i}\binom{N}{M-i}\cdot (M-i)! \cdot (K-M)^{N - M + i}\right)

This is computed in \mathcal{O}(K^2) time, which is more than fast enough for us since 2K \leq N means K \leq 5000.

Alternately, there exists a solution in \mathcal{O}(N\log N\log K) time using generating functions.

Genfuncs

Let x_i denote the number of times element i appears in the array.
Suppose we fix the values of all of x_1, x_2, \ldots, x_K.
There are then

\binom{N}{x_1, x_2, \cdots , x_K} = \frac{N!}{x_1! x_2! \cdots x_K!}

ways to arrange them into a valid array.

So, the value we’re looking for is the sum of this quantity across all valid choices of the x_i.
The only constraint on the x_i values is that they should all be \geq 2, and their sum should be N.
So, we’re looking for

\sum_{\substack{x_1 + \ldots + x_K \geq N \\ x_i \geq 2}} \frac{N!}{x_1! x_2! \cdots x_K!}

N! is a constant multiplier here so it can be factored out.
Once that’s done, note that we’re really left with something that looks like a sum of products over several indices; with the overall sum of indices being a constant.
Such a summation can be modeled by polynomial multiplication.

Specifically, consider the generating function \displaystyle p(x) = \sum_{i=2}^{\infty} \frac{x^i}{i!}
Then, what we’re looking for is exactly the coefficient of x^N in the function p^K(x).

This can be computed in \mathcal{O}(N\log N\log K) time as follows:

  • First, restrict the generating function to just be a polynomial with degree N, since terms \gt N don’t matter to us.
  • Two polynomials can be multiplied in \mathcal{O}(N\log N) time using NTT.
  • To find the K-th power of a polynomial quickly, combine NTT with binary exponentiation, so that only \mathcal{O}(\log K) multiplications need to be done.
  • Note that after each multiplication, the degree of the product will be 2N.
    This must be reduced to N by discarding the last N terms, to avoid blow-up of the degree (and larger terms don’t affect the answer anyway).

TIME COMPLEXITY:

\mathcal{O}(K^2) or \mathcal{O}(N\log N\log K) per testcase when K \leq \frac N 2, and \mathcal{O}(\log N) otherwise.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

const int mod = 998244353;

struct mint{
    int x;

    mint (){ x = 0;}
    mint (int32_t xx){ x = xx % mod; if (x < 0) x += mod;}
    mint (long long xx){ x = xx % mod; if (x < 0) x += mod;}

    int val(){
        return x;
    }
    mint &operator++(){
        x++;
        if (x == mod) x = 0;
        return *this;
    }
    mint &operator--(){
        if (x == 0) x = mod;
        x--;
        return *this;
    }
    mint operator++(int32_t){
        mint result = *this;
        ++*this;
        return result;
    }
    
    mint operator--(int32_t){
        mint result = *this;
        --*this;
        return result;
    }
    mint& operator+=(const mint &b){
        x += b.x;
        if (x >= mod) x -= mod;
        return *this;
    }
    mint& operator-=(const mint &b){
        x -= b.x;
        if (x < 0) x += mod;
        return *this;
    }
    mint& operator*=(const mint &b){
        long long z = x;
        z *= b.x;
        z %= mod;
        x = (int)z;
        return *this;
    }
    mint operator+() const {
        return *this;
    }
    mint operator-() const {
        return mint() - *this;
    }
    mint operator/=(const mint &b){
        return *this = *this * b.inv();
    }
    mint power(long long n) const {
        mint ok = *this, r = 1;
        while (n){
            if (n & 1){
                r *= ok;
            }
            ok *= ok;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        return power(mod - 2);
    }
    friend mint operator+(const mint& a, const mint& b){ return mint(a) += b;}
    friend mint operator-(const mint& a, const mint& b){ return mint(a) -= b;}
    friend mint operator*(const mint& a, const mint& b){ return mint(a) *= b;}
    friend mint operator/(const mint& a, const mint& b){ return mint(a) /= b;}
    friend bool operator==(const mint& a, const mint& b){ return a.x == b.x;}
    friend bool operator!=(const mint& a, const mint& b){ return a.x != b.x;}
    mint power(mint a, long long n){
        return a.power(n);
    }
    friend ostream &operator<<(ostream &os, const mint &m) {
        os << m.x;
        return os;
    }
    explicit operator bool() const {
        return x != 0;
    }
};

// Remember to check MOD

struct factorials{
    int n;
    vector <mint> ff, iff;
    
    factorials(int nn){
        n = nn;
        ff.resize(n + 1);
        iff.resize(n + 1);
        
        ff[0] = 1;
        for (int i = 1; i <= n; i++){
            ff[i] = ff[i - 1] * i;
        }
        
        iff[n] = ff[n].inv();
        for (int i = n - 1; i >= 0; i--){
            iff[i] = iff[i + 1] * (i + 1);
        }
    }
    
    mint C(int n, int r){
        if (n == r) return mint(1);
        if (n < 0 || r < 0 || r > n) return mint(0);
        return ff[n] * iff[r] * iff[n - r];
    }
    
    mint P(int n, int r){
        if (n < 0 || r < 0 || r > n) return mint(0);
        return ff[n] * iff[n - r];
    }
    
    mint solutions(int n, int r){
        // Solutions to x1 + x2 + ... + xn = r, xi >= 0 
        return C(n + r - 1, n - 1);
    }
    
    mint catalan(int n){
        return ff[2 * n] * iff[n] * iff[n + 1];
    }
};

const int PRECOMP = 3e6 + 69;
factorials F(PRECOMP);

// REMEMBER To check MOD and PRECOMP

void Solve() 
{
    int n, k; cin >> n >> k;
    
    if (n < 2 * k){
        cout << mint(k).power(n) << "\n";
        return;
    }
    
    mint ans = 0;
    
    vector<vector<mint>> got(k + 1, vector<mint>(k + 1));
    for (int s = 0; s < k; s++){
        mint w = mint(k - s).power(n);
        mint ik = mint(1) / (k - s);
        for (int i = 0; i <= min(n, s); i++){
            if (i > 0) w *= ik;
            got[s][i] = w;
        }
    }
    
    for (int i = 0; i <= min(k, n); i++){
        for (int j = 0; i + j <= k; j++){
            mint w = F.C(k, i) * F.C(k - i, j) * F.C(n, i) * F.ff[i] * got[i + j][i];
            if ((i + j) % 2 == 1){
                ans -= w;
            } else {
                ans += w;
            }
        }
    }
    
    ans = mint(k).power(n) - ans;
    
    cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}