CTF - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: q_ed
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

Functional graphs, counting labeled trees, NTT

PROBLEM:

You’re given N and K.
Count the number of functions f on \{1, 2, \ldots, N+K\} to itself such that:

  • f(x) = x for x \gt N.
  • f^N(x) \gt N for all x.
  • f^N(x) \leq f^N(x+1) for 1 \leq x \lt N.

f^i(x) refers to iterating the function f, i times.

Answer modulo 998244353.

EXPLANATION:

A function that maps a set into itself can be modeled as a directed graph, by taking the vertex set as the domain and then adding a directed edge x \to f(x) for each x.
Such a graph is called a functional graph.

One important property of functional graphs, which arises from the fact that each vertex has an outdegree of 1, is that every (weakly) connected component has exactly one cycle.
Here, we count a self-loop as a cycle.

Or, more generally, each connected component of a functional graph looks like a single cycle, with a tree hanging off of each vertex of the cycle.
The tree edges are directed towards the cycle.


Let’s look at what the functional graph must look like, if the function f satisfies the properties we want.

The first condition, f(x) = x for all x \gt N, tells us that there are self-loops at each of these vertices.
In particular, it means all of these vertices will lie in different connected components, i.e. we have at least K connected components.

Next, we have f^N(x) \gt N for all x.
This is trivially satisfied for x \gt N, so only smaller x need to be considered.

Iterated functions play quite nicely with functional graphs: it’s not hard to see that the value of f^k(x) can be obtained by simply starting at the vertex x and following its outgoing edge k times.
So, f^N(x) \gt x tells us that, if we start at any vertex x \leq N and follow its outgoing edge N times, we end up at a value \gt N.

Note that f(x) = x for x \gt N means that as soon as we end up at a value \gt N, we’ll stay there forever.
So, we only need to make sure each value \leq N is able to reach some value \gt N at all (within N steps).

In particular, observe that this means that there cannot be any cycles involving vertices in [1, N].
If such a cycle existed, vertices in the cycle would be stuck in it and unable to reach anything \gt N.

This gives us a moire concrete idea of the structure of the graph: each connected component consists of some vertices \leq N along with a single vertex \gt N, and there’s a self-loop on the vertex \gt N.
Another way of looking at this, is that each connected component is essentially just a directed rooted tree, with edges directed towards the root (and a self-loop on the root).


Finally, we have the condition f^N(x) \leq f^N(x+1) for x \lt N.
This tells us that:

  • The connected component of N+1 will include the vertices 1, 2, \ldots, x_1.
  • The connected component of N+2 will include the vertices x_1+1, x_1+2, \ldots, x_2.
  • The connected component of N+3 will include the vertices x_2+1, x_2+2, \ldots, x_3.
    \vdots
  • The connected component of N+K will include the vertices x_{K-1}+1, x_{K-1}+2, \ldots, N.

Or, more generally, we must break the range [1, N] into K subarrays (which may be empty), and then the connected component of N+i will include all the elements of the i-th subarray.

The question now is, how do we count this across all ways of splitting?


Let’s solve a simple version first: what if K = 1?
In this case, what we want to count is really just the number of rooted labeled trees on N+1 vertices, with the root being vertex N+1.
This is because any such tree corresponds to a valid function: direct all edges towards the root and add a self-loop at the root. The converse works too, obviously.

The number of rooted labeled trees on N+1 vertices with a fixed root, is simply the number of labeled trees on N+1 vertices.
This is well-known: Cayley’s formula tells us the count is (N+1)^{N-1}.
K = 1 is hence trivially solved.

Now, we must extend this to larger K.
To do that, let’s look at what’s actually being computed.

Define C_i = (i+1)^{i-1}.
This tells us the number of ways of forming a connected component using i vertices from [1, N].

Now, for each partition of [1, N] into K subarrays, say with sizes s_1, s_2, \ldots, s_K, the number of configurations will equal

\prod_{i=1}^K C_{s_i}

since the different components are independent.

So, the actual value we’re looking for is

\sum_{\substack{s_1 + \ldots + s_K = N \\ s_i \geq 0}} \left(\prod_{i=1}^K C_{s_i} \right)

The fact that we’re summing over s_1 + \ldots + s_K = N should hint at the fact that this is a convolution.

Indeed, let’s define the generating function p(x) = \sum_{i\geq 0} C_ix^i
Then, the value we’re looking for is exactly the coefficient of x^N of the function p^K(x).

To compute this quickly, observe that we don’t really care about what happens at powers larger than N.
So, let’s truncate the generating function to become a polynomial, p(x) = \sum_{i=0}^N C_ix^i.
Then, use binary exponentiation to compute its K-th power, which needs only \mathcal{O}(\log K) polynomial multiplications.
After each multiplication, truncate the resulting polynomial to degree N to ensure it doesn’t get too large. This allows for each multiplication to be done in \mathcal{O}(N\log N).

The overall complexity is \mathcal{O}(N\log N\log K), which is fast enough since N \leq 10^5 and K \leq 10^9.

TIME COMPLEXITY:

\mathcal{O}(N\log N\log K) per testcase.

CODE:

Author's code (C++)


#include <bits/stdc++.h>

using namespace std;

#define rep(i, a, b) for(int i = a; i < (b); ++i)
#define trav(a, x) for(auto& a : x)
#define all(x) x.begin(), x.end()
#define sz(x) (int)(x).size()
typedef long long ll;
typedef pair<int, int> pii;
typedef vector<long long> vl;
typedef vector<int> vi;
template<typename T> std::ostream& operator<<(std::ostream& os, const std::vector<T>& vec) {
    os << "[ ";
    for(const auto& elem : vec) {
        os << elem << " ";
    }
    os << "]";
    return os;
}
const ll mod = (119 << 23) + 1, root = 62; // = 998244353
ll modpow(ll b, ll e) {
	ll ans = 1;
	for (; e; b = b * b % mod, e /= 2)
		if (e & 1) ans = ans * b % mod;
	return ans;
}


// For p < 2^30 there is also e.g. 5 << 25, 7 << 26, 479 << 21
// and 483 << 21 (same root). The last two are > 10^9.

void ntt(vl &a) {
	int n = sz(a), L = 31 - __builtin_clz(n);
	static vl rt(2, 1);
	for (static int k = 2, s = 2; k < n; k *= 2, s++) {
		rt.resize(n);
		ll z[] = {1, modpow(root, mod >> s)};
		rep(i,k,2*k) rt[i] = rt[i / 2] * z[i & 1] % mod;
	}
	vi rev(n);
	rep(i,0,n) rev[i] = (rev[i / 2] | (i & 1) << L) / 2;
	rep(i,0,n) if (i < rev[i]) swap(a[i], a[rev[i]]);
	for (int k = 1; k < n; k *= 2)
		for (int i = 0; i < n; i += 2 * k) rep(j,0,k) {
			ll z = rt[j + k] * a[i + j + k] % mod, &ai = a[i + j];
			a[i + j + k] = ai - z + (z > ai ? mod : 0);
			ai += (ai + z >= mod ? z - mod : z);
		}
}
vl conv(const vl &a, const vl &b) {
	if (a.empty() || b.empty()) return {};
	int s = sz(a) + sz(b) - 1, B = 32 - __builtin_clz(s),
	    n = 1 << B;
	int inv = modpow(n, mod - 2);
	vl L(a), R(b), out(n);
	L.resize(n), R.resize(n);
	ntt(L), ntt(R);
	rep(i,0,n)
		out[-i & (n - 1)] = (ll)L[i] * R[i] % mod * inv % mod;
	ntt(out);
	return {out.begin(), out.begin() + s};
}
int n,k;
vl binExp(vector<ll> &base, ll expo) {
    // cout << expo << endl;
    vl ans = {1};
    while(expo) {
        if(expo%2) {
            ans = conv(ans,base);
            ans.resize(n+1);
        }
        expo/=2;
        base=conv(base,base);
        base.resize(n+1);
    }
    return ans;
}
vl cayley() {
    vl ret(n+1, 1);
    for(int i = 1; i < n+1; i++) {
        ret[i]=modpow(i+1, i-1);
    }
    return ret;
}
int main() {
    
    int t;
    cin >> t;
    // auto start = std::chrono::system_clock::now();
    while(t--) {
        cin >> n >> k;
        auto genfunc = cayley();
        auto genpow = binExp(genfunc, k);
        cout << genpow[n] << endl;
    }
    // auto end = std::chrono::system_clock::now();
    // std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms" << std::endl;
}