Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: thescrasse
Preparation: raysh07
Tester: sushil2006
Editorialist: iceknight1093
Combinatorics, dynamic programming or inclusion-exclusion
The score of an array A is defined as follows:
- First, coordinate compress the array A to obtain the array B.
- The score of A is then \sum_{i=1}^N B_i^M.
Given N, M, and K, compute the sum of scores of all arrays of length N with elements between 1 and K.
Consider an array A with d distinct elements. Its elements will be compressed to [1, d].
Let’s try to fix d, and compute the sum of scores across all arrays with d distinct elements.
First, there are \binom{K}{d} ways to choose which d elements the array will contain; then we need to arrange them.
Let f(N, d) denote the number of arrays of length N containing exactly d distinct elements.
How to compute this?
One way to use the inclusion-exclusion principle.
There are d choices for every index, leading to d^N arrays initially.
However, they’re not all valid - some of them might have some elements not appear at all, since we didn’t constrain that in any way.For a fixed element, there are (d-1)^N arrays with it missing (and maybe missing other elements too).
There are d ways of choosing the missing element, so we subtract out d\cdot (d-1)^N from the total.However, arrays with two missing elements have been subtracted out twice, so we’d need to add them back in.
There are \binom{d}{2} to fix which two elements are missing, and then (d-2)^N arrays with them missing; which we add back in.But then arrays with three missing elements are now counted, so we need to subtract them out; and so on and so forth.
This is a classical case of inclusion-exclusion, and we end up with
f(N, d) = \sum_{i=0}^d (-1)^i \binom{d}{i} (d-i)^N
Since we want a sum of sums, we can look at contributions of elements at each index separately.
So, let’s look at just the first index, which when compressed takes values between 1 and d.
Let x_i denote the number of arrays in which B_1 = x.
If we can compute all the x_i values, then the contribution of this index is simply
This computation can then be repeated for each index from 1 to N, and the answers can all be added up.
Here’s the neat part: it turns out that x_1 = x_2 = \ldots = x_d, meaning each of them will equal
\frac{1}{d} of the total.
Consider the mapping y \to (y\bmod d) + 1 which cyclically shifts all values modulo d.
It’s easy to see that this is a bijection on the set of arrays we have: after all, the inverse operation is to simply shift all values in the other direction.This bijection maps all arrays with B_1 = y to arrays with B_1 = (y\bmod d) + 1, and only these arrays - meaning their counts must be equal.
This means x_1 = x_2, x_2 = x_3, \ldots, x_{d-1} = x_d, x_d = x_1.
Putting all the equalities together gives us the original claim of all x_i being equal.
So, for index 1, the overall contribution is
Further, note that this computation didn’t depend on the fact that we were dealing with index 1 at all - if we fix any index, the result will be the same.
So, the sum of scores across all arrays with exactly d distinct elements, is simply the above value multiplied by N.
Now, note that d can be anything between 1 and \min(N, K).
For a fixed d, we want to know:
- The sum (1^M + 2^M + \ldots + d^M), which is just a prefix sum.
- f(N, d), which can be computed in \mathcal{O}(d) time with the inclusion-exclusion summation we derived for it.
- \binom{K}{d}, which can be computed in \mathcal{O}(d) time even though K is very large, using the formula
So, processing a single d takes \mathcal{O}(d) time; trying all d \leq N and adding up the answers gives us a quadratic solution.
\mathcal{O}(N^2) per testcase.
Preparer's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
const int facN = 1e6 + 5;
const int mod = 998244353;
int ff[facN], iff[facN];
bool facinit = false;
int power(int x, int y){
if (y == 0) return 1;
int v = power(x, y / 2);
v = 1LL * v * v % mod;
if (y & 1) return 1LL * v * x % mod;
else return v;
void factorialinit(){
facinit = true;
ff[0] = iff[0] = 1;
for (int i = 1; i < facN; i++){
ff[i] = 1LL * ff[i - 1] * i % mod;
iff[facN - 1] = power(ff[facN - 1], mod - 2);
for (int i = facN - 2; i >= 1; i--){
iff[i] = 1LL * iff[i + 1] * (i + 1) % mod;
int C(int n, int r){
if (!facinit) factorialinit();
if (n == r) return 1;
if (r < 0 || r > n) return 0;
return 1LL * ff[n] * iff[r] % mod * iff[n - r] % mod;
int P(int n, int r){
if (!facinit) factorialinit();
assert(0 <= r && r <= n);
return 1LL * ff[n] * iff[n - r] % mod;
int Solutions(int n, int r){
//solutions to x1 + ... + xn = r
//xi >= 0
return C(n + r - 1, n - 1);
void Solve()
int n, m, k; cin >> n >> m >> k;
vector <int> dp(n + 1, 0);
dp[0] = 1;
for (int i = 1; i <= n; i++){
vector <int> ndp(n + 1, 0);
for (int j = 0; j < i; j++){
ndp[j] += dp[j] * j;
ndp[j + 1] += dp[j] * (k - j);
dp = ndp;
for (auto &x : dp) x %= mod;
int ans = 0;
vector <int> p(n + 1);
for (int i = 1; i <= n; i++){
p[i] = power(i, m);
p[i] += p[i - 1];
p[i] %= mod;
for (int i = 1; i <= n; i++){
p[i] *= power(i, mod - 2);
p[i] %= mod;
p[i] *= n;
p[i] %= mod;
ans += dp[i] * p[i];
ans %= mod;
cout << ans << "\n";
int32_t main()
auto begin = std::chrono::high_resolution_clock::now();
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
//cout << "Case #" << i << ": ";
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
Tester's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "YES" << endl
#define no cout << "NO" << endl
#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)
template<typename T>
void amin(T &a, T b) {
a = min(a,b);
template<typename T>
void amax(T &a, T b) {
a = max(a,b);
#ifdef LOCAL
#include "debug.h"
#define debug(...) 42
const int MOD = 998244353;
const int N = 5e3 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;
ll fact[N], ifact[N];
ll bexp(ll a, ll b) {
a %= MOD;
if (a == 0) return 0;
ll res = 1;
while (b) {
if (b & 1) res = res * a % MOD;
a = a * a % MOD;
b >>= 1;
return res;
ll invmod(ll a) {
return bexp(a, MOD - 2);
ll ncr(ll n, ll r) {
if (n < 0 or r < 0 or n < r) return 0;
return fact[n] * ifact[r] % MOD * ifact[n - r] % MOD;
ll npr(ll n, ll r) {
if (n < 0 or r < 0 or n < r) return 0;
return fact[n] * ifact[n - r] % MOD;
void precalc(ll n) {
fact[0] = 1;
rep1(i, n) fact[i] = fact[i - 1] * i % MOD;
ifact[n] = invmod(fact[n]);
rev(i, n - 1, 0) ifact[i] = ifact[i + 1] * (i + 1) % MOD;
void solve(int test_case){
ll n,m,k; cin >> n >> m >> k;
ll dp[n+5][n+5];
memset(dp,0,sizeof dp);
dp[0][0] = 1;
dp[i+1][j] += dp[i][j]*j;
dp[i+1][j+1] += dp[i][j]*(j+1);
dp[i+1][j] %= MOD;
dp[i+1][j+1] %= MOD;
vector<ll> choose(n+5);
// choose[i] = ncr(k,i)
ll res = 1;
for(int j = k-i+1; j <= k; ++j){
res = res*j%MOD;
res = res*ifact[i]%MOD;
choose[i] = res;
ll ans = 0;
ll ways = 0;
ways += dp[n-1][x-1]*choose[x];
ways %= MOD;
for(int y = x; y <= n; ++y){
ways += dp[n-1][y]*(choose[y]+choose[y+1]);
ways %= MOD;
ans += ways*bexp(x,m);
ans %= MOD;
ans = ans*n%MOD;
cout << ans << endl;
int main()
int t = 1;
cin >> t;
rep1(i, t) {
return 0;
Editorialist's code (PyPy3)
mod = 998244353
mxN = 5005
fac = [1]
for n in range(1, mxN):
fac.append(fac[-1] * n % mod)
invf = fac[:]
for i in range(mxN): invf[i] = pow(invf[i], mod-2, mod)
for _ in range(int(input())):
n, m, k = map(int, input().split())
pref = [pow(i, m, mod) for i in range(n+1)]
for i in range(1, n+1): pref[i] = (pref[i] + pref[i-1]) % mod
pw = [pow(i, n, mod) for i in range(n+1)]
choices, ans = 1, 0
for x in range(1, min(n, k) + 1):
# exactly x distinct elements
# C(k, x) ways to choose the elements
# inc-exc for the arrays
choices = choices * (k+1-x) * pow(x, mod-2, mod) % mod
arrays = 0
for i in range(x+1):
arrays += (-1)**(i%2) * pw[x-i] % mod * fac[x] % mod * invf[i] % mod * invf[x-i] % mod
arrays = choices * arrays % mod
# arrays/k of a[1] being 1, 2, 3, ..., k
# applies to every index, so multiply by n
ans = (ans + arrays * pow(x, mod-2, mod) * n * pref[x]) % mod