INDEXCOMPH - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: thescrasse
Preparation: raysh07
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Combinatorics, NTT

PROBLEM:

The score of an array A is defined as follows:

  1. First, coordinate compress the array A to obtain the array B.
  2. The score of A is then \sum_{i=1}^N B_i^M.

Given N, M, and K, compute the sum of scores of all arrays of length N with elements between 1 and K.

EXPLANATION:

If you haven’t, read the solution to the easy version first - this will continue from there.

To recap, our answer is

\sum_{d=1}^N \left(\frac{N}{d} \cdot\binom{K}{d} \cdot f(N, d)\cdot \left(1^M + 2^M + \ldots + d^M \right) \right)

Of these terms, when d is fixed, \frac{N}{d} and \left(1^M + 2^M + \ldots + d^M \right) are easy to compute quickly: the former requires one division, the latter is a prefix sum and can be precomputed.

However, \binom{K}{d} and f(N, d) were found slowly: both in \mathcal{O}(d) time.
Let’s fix both issues in turn.


One is quite simple to speed up: \binom{K}{d}.
Note that

\binom{K}{d} = \binom{K}{d-1} \cdot \frac{K-d+1}{d}

which is fairly easy to prove both algebraically or using a combinatorial argument.
So, \binom{K}{d} can be computed from \binom{K}{d-1} in constant or \mathcal{O}(\log {MOD}) time.


Next, let’s look at f(N, d).

Recall that we had

f(N, d) = \sum_{i=0}^d (-1)^i \binom{d}{i} (d-i)^N

which if expanded is

f(N, d) = \sum_{i=0}^d (-1)^i \cdot\frac{d!}{i! (d-i)!}\cdot (d-i)^N

This looks suspiciously like a convolution. Specifically, if we define two polynomials p and q where:

  • The coefficient of x^i in p is
(-1)^i \cdot \frac{1}{i!}
  • The coefficient of x^i in q is
i^N \cdot \frac{1}{i!}

Then the product polynomial r = p\cdot q has it’s d-th coefficient be

\sum_{i=0}^d p_i \cdot q_{d-i} = \sum_{i=0}^d (-1)^i \frac{1}{i!} \cdot (d-i)^N \cdot \frac{1}{(d-i)!}

which is exactly the value f(N,d) if we multiply d! to it.

The convolution of p and q can be computed in \mathcal{O}(N\log N) time using NTT, so we obtain every f(N, d) value, and the problem is solved.


As an additional note, some participants may be familiar with the Stirling numbers of the second kind.
In fact, the coefficients of r that we computed are exactly these Stirling numbers, and we indeed do have f(N, d) = d! \cdot {N\brace d} (which is quite easy to see if you know what the Stirling numbers count).

TIME COMPLEXITY:

\mathcal{O}(N\log N) per testcase.

CODE:

Preparer's code (C++)
#include <bits/stdc++.h>

template <uint32_t P>
struct Z {
  uint32_t value;
  constexpr Z() : value(0) {}
  template <typename T, typename = std::enable_if_t<std::is_integral<T>::value>>
  constexpr Z(T a) : value(((int64_t(a) % P) + P) % P) {}
  Z& operator+=(Z rhs) {
    value += rhs.value;
    if (value >= P) value -= P;
    return *this;
  }
  Z& operator-=(Z rhs) {
    value += P - rhs.value;
    if (value >= P) value -= P;
    return *this;
  }
  Z& operator*=(Z rhs) {
    value = uint64_t(value) * rhs.value % P;
    return *this;
  }
  Z& operator/=(Z rhs) {
    return *this *= pow(rhs, -1);
  }
  Z operator-() const {
    return Z() - *this;
  }
  bool operator==(Z rhs) const {
    return value == rhs.value;
  }
  bool operator!=(Z rhs) const {
    return value != rhs.value;
  }
  friend Z operator+(Z lhs, Z rhs) {
    return lhs += rhs;
  }
  friend Z operator-(Z lhs, Z rhs) {
    return lhs -= rhs;
  }
  friend Z operator*(Z lhs, Z rhs) {
    return lhs *= rhs;
  }
  friend Z operator/(Z lhs, Z rhs) {
    return lhs /= rhs;
  }
  friend std::ostream& operator<<(std::ostream& out, Z a) {
    return out << a.value;
  }
  friend std::istream& operator>>(std::istream& in, Z& a) {
    int64_t value;
    in >> value;
    a = Z(value);
    return in;
  }
};

template <uint32_t P>
Z<P> pow(Z<P> x, int64_t p) {
  p %= P - 1;
  if (p < 0) p += P - 1;
  Z<P> res = 1;
  while (p) {
    if (p & 1) {
      res *= x;
    }
    x *= x;
    p >>= 1;
  }
  return res;
}

template <typename T>
struct root_of_unity_t {};

template <typename T>
struct root_of_unity_t<std::complex<T>> {
  static constexpr T PI = std::acos(-1);
  static std::complex<T> root_of_unity(int N) {
    return std::polar<T>(1, 2 * PI / N);
  }
};

constexpr int ntt_mod = 998244353;
template <>
struct root_of_unity_t<Z<ntt_mod>> {
  static constexpr Z<ntt_mod> g = Z<ntt_mod>(3);
  static Z<ntt_mod> root_of_unity(int N) {
    return pow(g, int(ntt_mod - 1) / N);
  }
};

template <typename T>
struct fft_t {
  int N;
  std::vector<int> rev;
  std::vector<T> rs;
  fft_t(int N) : N(N), rev(N) {
    for (int i = 0; i < N; ++i) {
      int r = 0;
      for (int b = 1, j = i; b < N; b <<= 1, j >>= 1) {
        r = (r << 1) | j & 1;
      }
      rev[i] = r;
    }
    for (auto sgn : {+1, -1}) {
      for (int b = 1; b < N; b <<= 1) {
        T w = root_of_unity_t<T>::root_of_unity(sgn * 2 * b);
        rs.push_back(1);
        for (int i = 0; i + 1 < b; ++i) {
          rs.push_back(rs.back() * w);
        }
      }
    }
  }
  std::vector<T> operator()(std::vector<T> p, bool inverse) {
    p.resize(N);
    for (int i = 0; i < N; ++i) {
      if (i < rev[i]) {
        std::swap(p[i], p[rev[i]]);
      }
    }
    T* r = rs.data();
    if (inverse) {
      r += rs.size() / 2;
    }
    for (int b = 1; b < N; b <<= 1) {
      for (int s = 0; s < N; s += 2 * b) {
        for (int i = 0; i < b; ++i) {
          int u = s | i, v = u | b;
          T x = p[u], y = r[i] * p[v];
          p[u] = x + y;
          p[v] = x - y;
        }
      }
      r += b;
    }
    if (inverse) {
      T inv = T(1) / T(N);
      for (int i = 0; i < N; ++i) p[i] *= inv;
    }
    return p;
  }
};

constexpr int naive_threshold = 64;

template <typename T>
std::vector<T> operator*(const std::vector<T>& p, const std::vector<T>& q) {
  int N = p.size(), M = q.size();
  if (N == 0 || M == 0) {
    return {};
  } else if (std::min(N, M) <= naive_threshold) {
    std::vector<T> res(N + M - 1);
    for (int i = 0; i < N; ++i) {
      for (int j = 0; j < M; ++j) {
        res[i + j] += p[i] * q[j];
      }
    }
    return res;
  } else {
    int R = N + M - 1, K = 1;
    while (K < R) K <<= 1;
    fft_t<T> fft(K);
    auto phat = fft(p, false), qhat = fft(q, false);
    for (int i = 0; i < K; ++i) {
      phat[i] *= qhat[i];
    }
    auto res = fft(std::move(phat), true);
    res.resize(R);
    return res;
  }
}

template <typename T>
struct Combinatorics {};

template <uint32_t P>
struct Combinatorics<Z<P>> {
  std::vector<Z<P>> fact, rfact, r;
  Combinatorics(int N) : fact(N), rfact(N), r(N) {
    fact[0] = fact[1] = rfact[0] = rfact[1] = r[1] = 1;
    for (int i = 2; i < N; ++i) {
      r[i] = -(P / i * r[P % i]);
      rfact[i] = r[i] * rfact[i - 1];
      fact[i] = i * fact[i - 1];
    }
  }
  Z<P> C(int n, int k) const {
    return k < 0 || n < k ? 0 : fact[n] * rfact[k] * rfact[n - k];
  }
  Z<P> S(int n, int k) const {
    return k == 0 ? n == 0 : C(n + k - 1, k - 1);
  }
};

template <typename T>
const Combinatorics<T> combinatorics(1 << 20);

template <typename T>
struct FormalPowerSeries : public std::vector<T> {
  using F = FormalPowerSeries;
  using std::vector<T>::vector;
  template <typename... Args>
  explicit FormalPowerSeries(Args&&... args) : std::vector<T>(std::forward<Args>(args)...) {}

  F operator+(const F& rhs) const {
    return F(*this) += rhs;
  }
  F& operator+=(const F& rhs) {
    if (this->size() < rhs.size()) {
      this->resize(rhs.size());
    }
    for (int i = 0; i < rhs.size(); ++i) {
      (*this)[i] += rhs[i];
    }
    return *this;
  }
  F operator-(const F& rhs) const {
    return F(*this) -= rhs;
  }
  F& operator-=(const F& rhs) {
    if (this->size() < rhs.size()) {
      this->resize(rhs.size());
    }
    for (int i = 0; i < rhs.size(); ++i) {
      (*this)[i] -= rhs[i];
    }
    return *this;
  }
  F& operator*=(T alpha) {
    for (auto& x : *this) {
      x *= alpha;
    }
    return *this;
  }
  F operator*(T alpha) const {
    return F(*this) *= alpha;
  }
  F operator/(T alpha) const {
    return F(*this) *= 1 / alpha;
  }
  friend F operator*(T alpha, F rhs) {
    return rhs *= alpha;
  }
  F operator-() const {
    return F() -= *this;
  }

  F operator*(const F& rhs) {
    return F(static_cast<std::vector<T>>(*this) * rhs);
  }
  F& operator*=(const F& rhs) {
    return *this = F(static_cast<std::vector<T>>(*this) * rhs);
  }

  void trim_zeros() {
    while (!this->empty() && this->back() == 0) {
      this->pop_back();
    }
  }

  F operator/(F rhs) const {
    int N = this->size(), M = rhs.size();
    if (N < M) {
      return {};
    } else if (M <= naive_threshold) {
      return naive_division(rhs).first;
    } else {
      int K = N - M + 1;
      std::reverse(rhs.begin(), rhs.end());
      rhs.resize(K);
      auto res = F(this->rbegin(), this->rbegin() + K) * inv(rhs);
      res.resize(K);
      std::reverse(res.begin(), res.end());
      res.trim_zeros();
      return res;
    }
  }
  F& operator/=(const F& rhs) {
    return *this = *this / rhs;
  }
  F operator%(const F& rhs) const {
    return divided_by(rhs).second;
  }
  F operator%=(const F& rhs) {
    return *this = divided_by(rhs)->second;
  }
  std::pair<F, F> naive_division(const F& d) const {
    F q, r = *this;
    while (r.size() >= d.size()) {
      T c = r.back() / d.back();
      q.push_back(c);
      for (int i = 0; i < d.size(); ++i) {
        r.rbegin()[i] -= c * d.rbegin()[i];
      }
      r.pop_back();
    }
    std::reverse(q.begin(), q.end());
    q.trim_zeros();
    r.trim_zeros();
    return std::pair(q, r);
  }
  std::pair<F, F> euclidean_division(F d) const {
    if (d.size() <= naive_threshold) {
      return naive_division(d);
    } else {
      auto q = *this / d;
      if (d.size() > 1) {
        d.pop_back();
      }
      auto q0 = F(q.begin(), q.begin() + std::min(q.size(), d.size()));
      auto r = *this - d * q0;
      r.resize(d.size());
      r.trim_zeros();
      return std::pair(q, r);
    }
  }

  T operator()(T x) const {
    T pow = 1, y = 0;
    for (auto& c : *this) {
      y += c * pow;
      pow *= x;
    }
    return y;
  }
  // returns composition modulo x^M
  // O(sqrt(N) * M * log(M))
  F operator()(const F& g) const {
    int N = this->size(), M = g.size();
    int block_size = 1;
    while ((block_size + 1) * (block_size + 1) <= N) ++block_size;
    std::vector<F> pow(block_size);
    pow[0] = {1};
    for (int k = 0; k + 1 < block_size; ++k) {
      pow[k + 1] = pow[k] * g;
      pow[k + 1].resize(M);
    }
    F h = pow.back() * g;
    h.resize(M);
    F offset = {1}, res;
    for (int i = 0; i < N; i += block_size) {
      F p;
      for (int k = 0; k < block_size && i + k < N; ++k) {
        p += (*this)[i + k] * pow[k];
      }
      p.resize(M);
      res += offset * p;
      offset *= h;
      offset.resize(M);
    }
    res.resize(M);
    return res;
  }
};

template <typename T>
FormalPowerSeries<T> product(FormalPowerSeries<T>* p, int N) {
  if (N == 0) {
    return {1};
  } else if (N == 1) {
    return *p;
  } else {
    int h = N / 2;
    return product(p, h) * product(p + h, N - h);
  }
}

template <typename T>
FormalPowerSeries<T> inv(const FormalPowerSeries<T>& P) {
  using F = FormalPowerSeries<T>;
  assert(!P.empty() && P[0] != 0);
  F Q = {1 / P[0]};
  int N = P.size(), K = 1;
  while (K < N) {
    K *= 2;
    fft_t<T> fft(2 * K);
    auto Qhat = fft(Q, false);
    auto Phat = fft(F(P.begin(), P.begin() + std::min(K, N)), false);
    for (int i = 0; i < 2 * K; ++i) {
      Qhat[i] *= 2 - Phat[i] * Qhat[i];
    }
    auto nQ = fft(Qhat, true);
    Q.swap(nQ);
    Q.resize(K);
  }
  Q.resize(N);
  return Q;
}

template <typename T>
FormalPowerSeries<T> D(FormalPowerSeries<T> P) {
  for (int i = 0; i + 1 < P.size(); ++i) {
    P[i] = (i + 1) * P[i + 1];
  }
  P.pop_back();
  return P;
}

template <typename T>
FormalPowerSeries<T> I(FormalPowerSeries<T> P) {
  int N = P.size();
  P.push_back(0);
  for (int i = N - 1; i >= 0; --i) {
    P[i + 1] = P[i] / (i + 1);
  }
  P[0] = 0;
  return P;
}

template <typename T>
FormalPowerSeries<T> log(const FormalPowerSeries<T>& P) {
  assert(!P.empty() && P[0] == 1);
  int N = P.size();
  auto r = D(P) * inv(P);
  r.resize(N - 1);
  return I(std::move(r));
}

template <typename T>
FormalPowerSeries<T> exp(const FormalPowerSeries<T>& P) {
  assert(P.empty() || P[0] == 0);
  FormalPowerSeries<T> Q = {1};
  int N = P.size(), K = 1;
  while (K < N) {
    K *= 2;
    Q.resize(K);
    auto B = -log(Q);
    B[0] += 1;
    for (int i = 0; i < std::min(N, K); ++i) {
      B[i] += P[i];
    }
    Q *= B;
    Q.resize(K);
  }
  Q.resize(N);
  return Q;
}

template <typename T>
FormalPowerSeries<T> pow(FormalPowerSeries<T> P, int64_t k) {
  int N = P.size();
  int t = 0;
  while (t < N && P[t] == 0) ++t;
  if (t == N || (t > 0 && k >= (N + t - 1) / t)) {
    return FormalPowerSeries<T>(N, 0);
  }
  int max = N - k * t;
  P.erase(P.begin(), P.begin() + t);
  P.resize(max);
  T alpha = P[0];
  P *= 1 / alpha;
  P = pow(alpha, k) * exp(k * log(P));
  P.insert(P.begin(), k * t, 0);
  return P;
}

namespace flags {
  bool fps_sqrt_failed;
};

template <typename T>
FormalPowerSeries<T> sqrt(FormalPowerSeries<T> P) {
  int N = P.size();
  int t = 0;
  while (t < N && P[t] == 0) ++t;
  if (t == N) {
    return P;
  }
  auto x = sqrt(P[t]);
  if (t % 2 || x * x != P[t]) {
    flags::fps_sqrt_failed = true;
    return {};
  }
  P.erase(P.begin(), P.begin() + t);
  P.resize(N - t / 2);
  P *= 1 / P[0];
  P = x * exp(log(P) / 2);
  P.insert(P.begin(), t / 2, 0);
  flags::fps_sqrt_failed = false;
  return P;
}

template <typename T>
struct Interpolator {
  using F = FormalPowerSeries<T>;
  struct Node {
    F P;
    T y;
    Node* left = nullptr;
    Node* right = nullptr;
  };
  std::deque<Node> deq;
  template <typename Iterator>
  Interpolator(Iterator first, Iterator last) {
    Node* root = &deq.emplace_back();
    build(root, first, last);
  }
  template <typename Iterator>
  void build(Node* node, Iterator first, Iterator last) {
    int len = last - first;
    if (len == 1) {
      node->P = {-first[0], 1};
    } else {
      node->left = &deq.emplace_back();
      node->right = &deq.emplace_back();
      Iterator middle = first + len / 2;
      build(node->left, first, middle);
      build(node->right, middle, last);
      node->P = node->left->P * node->right->P;
    }
  }
  std::vector<T> res;
  std::vector<T> evaluate(const F& Q) {
    res.clear();
    evaluate(&deq[0], Q % deq[0].P);
    return std::move(res);
  }
  void evaluate(Node* node, F Q) {
    if (node->left) {
      for (auto next : {node->left, node->right}) {
        evaluate(next, Q % next->P);
      }
    } else {
      assert(Q.size() == 1);
      res.push_back(Q[0]);
    }
  }
  bool flag = false;
  template <typename Iterator>
  F interpolate(Iterator first, Iterator last) {
    if (!flag) {
      flag = true;
      auto y = evaluate(D(deq[0].P));
      auto iter = y.begin();
      for (auto& node : deq) {
        if (node.left) continue;
        node.y = *iter;
        ++iter;
      }
    }
    return interpolate(&deq[0], first, last);
  }
  template <typename Iterator>
  F interpolate(Node* node, Iterator first, Iterator last) {
    int len = last - first;
    if (len == 1) {
      return {first[0] / node->y};
    } else {
      Iterator middle = first + len / 2;
      return node->right->P * interpolate(node->left, first, middle) +
        node->left->P * interpolate(node->right, middle, last);
    }
  }
};

// computes P(D)A
template <typename T>
FormalPowerSeries<T> apply_polynomial_of_derivative(
    FormalPowerSeries<T> P, FormalPowerSeries<T> A) {
  int N = A.size();
  if (P.size() > N) {
    P.resize(N);
  }
  std::vector<T> f(N);
  f[0] = 1;
  for (int k = 0; k + 1 < N; ++k) {
    f[k + 1] = (k + 1) * f[k];
  }
  for (int k = 0; k < N; ++k) {
    A[k] *= f[k];
  }
  std::reverse(P.begin(), P.end());
  auto res = P * A;
  res.erase(res.begin(), res.begin() + P.size() - 1);
  for (int k = 0; k < N; ++k) {
    res[k] /= f[k];
  }
  return res;
}

template <typename T>
FormalPowerSeries<T> exp(T alpha, int N) {
  const auto& C = combinatorics<T>;
  FormalPowerSeries<T> exp(N);
  T pow = 1;
  for (int k = 0; k < N; ++k) {
    exp[k] = pow * C.rfact[k];
    pow *= alpha;
  }
  return exp;
}

// finds coefficients of polynomial x -> P(x + c)
template <typename T>
FormalPowerSeries<T> taylor_shift(FormalPowerSeries<T> P, T c) {
  return apply_polynomial_of_derivative(exp(c, P.size()), P);
}

// returns coefficients in the basis of falling factorials of the unique
// polynomial P (of degree < N) with P(i) = y[i] (the coefficient of y)
template <typename T>
FormalPowerSeries<T> interpolate_to_falling_factorials(FormalPowerSeries<T> y) {
  int N = y.size();
  for (int k = 0; k < N; ++k) {
    y[k] *= combinatorics<T>.rfact[k];
  }
  auto res = exp(T(-1), N) * y;
  res.resize(N);
  return res;
}

// p-th index of the result is the sum of k^p for k = 0, ..., N - 1
template <typename T>
FormalPowerSeries<T> sum_of_powers(int N, int P) {
  auto A = exp<T>(N, P + 1), B = exp<T>(1, P + 2);
  A[0] -= 1;
  B.erase(B.begin());
  auto res = A * inv(std::move(B));
  res.erase(res.begin());
  res.resize(P);
  for (int p = 0; p < P; ++p) {
    res[p] *= combinatorics<T>.fact[p];
  }
  return res;
}

template <typename T>
FormalPowerSeries<T> stirling_numbers_1(int N) {
  std::vector<FormalPowerSeries<T>> f(N);
  for (int i = 0; i < N; ++i) {
    f[i] = {-i, 1};
  }
  return product(f.data(), N);
}

template <typename T>
FormalPowerSeries<T> stirling_numbers_2(int N) {
  FormalPowerSeries<T> y(N + 1);
  for (int i = 0; i <= N; ++i) {
    y[i] = pow(T(i), N);
  }
  return interpolate_to_falling_factorials(y);
}

using Zp = Z<998244353>;
using namespace std;
#define int long long

void Solve(){
  int n, m, k; cin >> n >> m >> k;
  auto S = stirling_numbers_2<Zp>(n);
  
  vector <Zp> b(n + 1);
  vector <Zp> p(n + 1);
  for (int i = 1; i <= n; i++){
    p[i] = pow(Zp(i), m);
    p[i] += p[i - 1];
  }
  for (int i = 1; i <= n; i++){
    Zp v = p[i] * pow(Zp(i), 998244351);
    b[i] = v * n;
  }
  
  Zp ans = 0;
  Zp curr = 1;
  Zp mul = k;
  for (int i = 1; i <= n; i++){
    Zp ways = S[i];
    curr *= mul;
    mul -= 1;
    
    ways *= curr;
    ans += ways * b[i];
  }
  
  cout << ans << "\n";
}

int32_t main() {
  std::ios_base::sync_with_stdio(false);
  std::cin.tie(nullptr);
  int t; cin >> t;
  
  while (t--){
    Solve();
  }
}

Tester's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;

template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;

#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "YES" << endl
#define no cout << "NO" << endl

#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)

template<typename T>
void amin(T &a, T b) {
    a = min(a,b);
}

template<typename T>
void amax(T &a, T b) {
    a = max(a,b);
}

#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif

/*



*/

const int MOD = 998244353;
const int N = 2e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;

namespace ntt {
    // https://judge.yosupo.jp/submission/69896
    template <class T, class F = multiplies<T>>
    T power(T a, long long n, F op = multiplies<T>(), T e = {1}) {
        // assert(n >= 0);
        T res = e;
        while (n) {
            if (n & 1) res = op(res, a);
            if (n >>= 1) a = op(a, a);
        }
        return res;
    }

    constexpr int mod = int(1e9) + 7;
    constexpr int nttmod = 998'244'353;

    template <std::uint32_t P>
    struct ModInt32 {
       public:
        using i32 = std::int32_t;
        using u32 = std::uint32_t;
        using i64 = std::int64_t;
        using u64 = std::uint64_t;
        using m32 = ModInt32;
        using internal_value_type = u32;

       private:
        u32 v;
        static constexpr u32 get_r() {
            u32 iv = P;
            for (u32 i = 0; i != 4; ++i) iv *= 2U - P * iv;
            return -iv;
        }
        static constexpr u32 r = get_r(), r2 = -u64(P) % P;
        static_assert((P & 1) == 1);
        static_assert(-r * P == 1);
        static_assert(P < (1 << 30));
        static constexpr u32 pow_mod(u32 x, u64 y) {
            u32 res = 1;
            for (; y != 0; y >>= 1, x = u64(x) * x % P)
                if (y & 1) res = u64(res) * x % P;
            return res;
        }
        static constexpr u32 reduce(u64 x) {
            return (x + u64(u32(x) * r) * P) >> 32;
        }
        static constexpr u32 norm(u32 x) { return x - (P & -(x >= P)); }

       public:
        static constexpr u32 get_pr() {
            u32 tmp[32] = {}, cnt = 0;
            const u64 phi = P - 1;
            u64 m = phi;
            for (u64 i = 2; i * i <= m; ++i)
                if (m % i == 0) {
                    tmp[cnt++] = i;
                    while (m % i == 0) m /= i;
                }
            if (m != 1) tmp[cnt++] = m;
            for (u64 res = 2; res != P; ++res) {
                bool flag = true;
                for (u32 i = 0; i != cnt && flag; ++i)
                    flag &= pow_mod(res, phi / tmp[i]) != 1;
                if (flag) return res;
            }
            return 0;
        }
        constexpr ModInt32() : v(0){};
        ~ModInt32() = default;
        constexpr ModInt32(u32 _v) : v(reduce(u64(_v) * r2)) {}
        constexpr ModInt32(i32 _v) : v(reduce(u64(_v % P + P) * r2)) {}
        constexpr ModInt32(u64 _v) : v(reduce((_v % P) * r2)) {}
        constexpr ModInt32(i64 _v) : v(reduce(u64(_v % P + P) * r2)) {}
        constexpr ModInt32(const m32& rhs) : v(rhs.v) {}
        constexpr u32 get() const { return norm(reduce(v)); }
        explicit constexpr operator u32() const { return get(); }
        explicit constexpr operator i32() const { return i32(get()); }
        constexpr m32& operator=(const m32& rhs) { return v = rhs.v, *this; }
        constexpr m32 operator-() const {
            m32 res;
            return res.v = (P << 1 & -(v != 0)) - v, res;
        }
        constexpr m32 inv() const { return pow(P - 2); }
        constexpr m32& operator+=(const m32& rhs) {
            return v += rhs.v - (P << 1), v += P << 1 & -(v >> 31), *this;
        }
        constexpr m32& operator-=(const m32& rhs) {
            return v -= rhs.v, v += P << 1 & -(v >> 31), *this;
        }
        constexpr m32& operator*=(const m32& rhs) {
            return v = reduce(u64(v) * rhs.v), *this;
        }
        constexpr m32& operator/=(const m32& rhs) {
            return this->operator*=(rhs.inv());
        }
        friend m32 operator+(const m32& lhs, const m32& rhs) {
            return m32(lhs) += rhs;
        }
        friend m32 operator-(const m32& lhs, const m32& rhs) {
            return m32(lhs) -= rhs;
        }
        friend m32 operator*(const m32& lhs, const m32& rhs) {
            return m32(lhs) *= rhs;
        }
        friend m32 operator/(const m32& lhs, const m32& rhs) {
            return m32(lhs) /= rhs;
        }
        friend bool operator==(const m32& lhs, const m32& rhs) {
            return norm(lhs.v) == norm(rhs.v);
        }
        friend bool operator!=(const m32& lhs, const m32& rhs) {
            return norm(lhs.v) != norm(rhs.v);
        }
        friend std::istream& operator>>(std::istream& is, m32& rhs) {
            return is >> rhs.v, rhs.v = reduce(u64(rhs.v) * r2), is;
        }
        friend std::ostream& operator<<(std::ostream& os, const m32& rhs) {
            return os << rhs.get();
        }
        constexpr m32 pow(i64 y) const {
            // assumes P is a prime
            i64 rem = y % (P - 1);
            if (y > 0 && rem == 0)
                y = P - 1;
            else
                y = rem;
            m32 res(1), x(*this);
            for (; y != 0; y >>= 1, x *= x)
                if (y & 1) res *= x;
            return res;
        }
    };

    using mint = ModInt32<nttmod>;

    void ntt(vector<mint>& a, bool inverse) {
        static array<mint, 30> dw{}, idw{};
        if (dw[0] == 0) {
            mint root = 2;
            while (power(root, (nttmod - 1) / 2) == 1) root += 1;
            for (int i = 0; i < 30; ++i)
                dw[i] = -power(root, (nttmod - 1) >> (i + 2)),
                idw[i] = 1 / dw[i];
        }
        int n = (int)a.size();
        assert((n & (n - 1)) == 0);
        if (not inverse) {
            for (int m = n; m >>= 1;) {
                mint w = 1;
                for (int s = 0, k = 0; s < n; s += 2 * m) {
                    for (int i = s, j = s + m; i < s + m; ++i, ++j) {
                        auto x = a[i], y = a[j] * w;
                        a[i] = x + y;
                        a[j] = x - y;
                    }
                    w *= dw[__builtin_ctz(++k)];
                }
            }
        } else {
            for (int m = 1; m < n; m *= 2) {
                mint w = 1;
                for (int s = 0, k = 0; s < n; s += 2 * m) {
                    for (int i = s, j = s + m; i < s + m; ++i, ++j) {
                        auto x = a[i], y = a[j];
                        a[i] = x + y;
                        a[j] = (x - y) * w;
                    }
                    w *= idw[__builtin_ctz(++k)];
                }
            }
            auto inv = 1 / mint(n);
            for (auto&& e : a) e *= inv;
        }
    }
    vector<mint> operator*(vector<mint> l, vector<mint> r) {
        if (l.empty() or r.empty()) return {};
        int n = (int)l.size(), m = (int)r.size(),
            sz = 1 << __lg(2 * (n + m - 1) - 1);
        if (min(n, m) < 30) {
            vector<mint> res(n + m - 1);
            for (int i = 0; i < n; ++i)
                for (int j = 0; j < m; ++j) res[i + j] += (l[i] * r[j]);
            return {begin(res), end(res)};
        }
        bool eq = l == r;
        l.resize(sz), ntt(l, false);
        if (eq)
            r = l;
        else
            r.resize(sz), ntt(r, false);
        for (int i = 0; i < sz; ++i) l[i] *= r[i];
        ntt(l, true), l.resize(n + m - 1);
        return l;
    }
    template <bool check = false>
    vector<mint>& operator*=(vector<mint>& l, vector<mint> r) {
        if (l.empty() or r.empty()) {
            l.clear();
            return l;
        }
        int n = (int)l.size(), m = (int)r.size(),
            sz = 1 << __lg(2 * (n + m - 1) - 1);
        if (min(n, m) < 30) {
            vector<mint> res(n + m - 1);
            for (int i = 0; i < n; ++i)
                for (int j = 0; j < m; ++j) res[i + j] += (l[i] * r[j]);
            l = {begin(res), end(res)};
            return l;
        }
        bool eq = false;
        if constexpr (check) eq = l == r;
        l.resize(sz), ntt(l, false);
        if (eq)
            r = l;
        else
            r.resize(sz), ntt(r, false);
        for (int i = 0; i < sz; ++i) l[i] *= r[i];
        ntt(l, true), l.resize(n + m - 1);
        return l;
    }

}  // namespace ntt

typedef ntt::mint mint;

ll fact[N], ifact[N];

ll bexp(ll a, ll b) {
    ll res = 1;

    while (b) {
        if (b & 1) res = res * a % MOD;
        a = a * a % MOD;
        b >>= 1;
    }

    return res;
}

ll invmod(ll a) {
    return bexp(a, MOD - 2);
}

ll ncr(ll n, ll r) {
    if (n < 0 or r < 0 or n < r) return 0;
    return fact[n] * ifact[r] % MOD * ifact[n - r] % MOD;
}

ll npr(ll n, ll r) {
    if (n < 0 or r < 0 or n < r) return 0;
    return fact[n] * ifact[n - r] % MOD;
}

void precalc(ll n) {
    fact[0] = 1;
    rep1(i, n) fact[i] = fact[i - 1] * i % MOD;

    ifact[n] = invmod(fact[n]);
    rev(i, n - 1, 0) ifact[i] = ifact[i + 1] * (i + 1) % MOD;
}

void solve(int test_case){
    ll n,m,k; cin >> n >> m >> k;

    /*

    vector<ll> dp(n+5);
    rep(i,n){
        rep(j,i+1){
            ll c = 1;
            if(j&1) c = -1;
            dp[i] += ncr(i,j)*bexp(i-j,n-1)*c;
            dp[i] = (dp[i]%MOD+MOD)%MOD;
        }
    }

    */

    vector<mint> p1(n+5), p2(n+5);
    rep(i,n+1){
        ll c = 1;
        if(i&1) c = -1;
        ll val = (ifact[i]*c%MOD+MOD)%MOD;
        p1[i] = mint((int)val);
    }
    rep(j,n+1){
        ll val = ifact[j]*bexp(j,n-1)%MOD;
        p2[j] = mint((int)val);
    }

    auto p3 = p1*p2;
    
    vector<ll> dp(n+5);
    rep(i,n+1){
        ll val = (int)p3[i]*fact[i]%MOD;
        dp[i] = val;
    }

    vector<ll> choose(n+5);
    choose[0] = 1;
    ll prod = 1;
    ll cr = 0;
    rev(j,k,k-n+1){
        cr++;
        prod = prod*j%MOD;
        choose[cr] = prod*ifact[cr]%MOD;
    }

    ll ans = 0;
    ll suff_sum = 0;

    rev(x,n,1){
        suff_sum += dp[x]*(choose[x]+choose[x+1]);
        suff_sum %= MOD;
        ll ways = dp[x-1]*choose[x]+suff_sum;
        ways %= MOD;
        ans += ways*bexp(x,m);
        ans %= MOD;
    }

    ans = ans*n%MOD;
    cout << ans << endl;
}

int main()
{
    precalc(N-1);

    fastio;

    int t = 1;
    cin >> t;

    rep1(i, t) {
        solve(i);
    }

    return 0;
}
Editorialist's code (PyPy3)
# NTT implementation based on https://codeforces.com/blog/entry/117947

# NTT prime
MOD = (119 << 23) + 1
assert MOD & 1

non_quad_res = 2
while pow(non_quad_res, MOD//2, MOD) != MOD - 1:
    non_quad_res += 1
rt = [1]

def ntt(P):
    n = len(P)
    P = list(P)
    assert n and (n - 1) & n == 0
    
    while 2 * len(rt) < n:
        # 4*len(rt)-th root of unity
        root = pow(non_quad_res, MOD // (4*len(rt)), MOD)
        rt.extend([r * root % MOD for r in rt])

    k = n
    while k > 1:
        for i in range(n//k):
            r = rt[i]
            for j1 in range(i*k, i*k + k//2):
                j2 = j1 + k//2
                z = r * P[j2]
                P[j2] = (P[j1] - z) % MOD
                P[j1] = (P[j1] + z) % MOD
        k //= 2
    
    rev = [0] * n
    for i in range(1, n):
        rev[i] = rev[i // 2] // 2 + (i & 1) * n // 2
    return [P[r] for r in rev]

def intt(P):
    n = len(P)
    ninv = pow(n, MOD - 2, MOD)
    return ntt([P[-i] * ninv % MOD for i in range(n)])

def ntt_conv(P, Q):
    m = len(P) + len(Q) - 1
    n = 1 << m.bit_length()

    P = P + [0] * (n - len(P))
    Q = Q + [0] * (n - len(Q))
    P, Q = ntt(P), ntt(Q)

    return intt([p * q % MOD for p,q in zip(P, Q)])[:m]

mxN = 200005
fac = [1]
for n in range(1, mxN):
    fac.append(fac[-1] * n % MOD)

for _ in range(int(input())):
    n, m, k = map(int, input().split())

    pref = [pow(i, m, MOD) for i in range(n+1)]
    for i in range(1, n+1): pref[i] = (pref[i] + pref[i-1]) % MOD

    p = [ pow(i, n, MOD) * pow(fac[i], MOD-2, MOD) % MOD for i in range(n+1) ]
    q = [ pow(fac[i], MOD-2, MOD) * (-1)**(i%2) % MOD for i in range(n+1) ]
    S = ntt_conv(p, q)
    choices = 1
    ans = 0

    for x in range(1, min(n, k) + 1):
        # exactly x distinct elements
        # C(k, x) ways to choose the elements
        # S2(n, x) * x! orderings
        
        choices = choices * (k+1-x) * pow(x, MOD-2, MOD) % MOD
        arrays = choices * S[x] % MOD * fac[x] % MOD
        # print(x, choices, arrays)

        # arrays/k of a[1] being 1, 2, 3, ..., k
        ans = (ans + arrays * pow(x, MOD-2, MOD) * n * pref[x]) % MOD
    print(ans)
3 Likes

CodeChef’s problem editorials have always been the best I have seen in all programming competitions. Thank you all.

1 Like