INDEXCOMP - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: thescrasse
Preparation: raysh07
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Combinatorics, dynamic programming or inclusion-exclusion

PROBLEM:

The score of an array A is defined as follows:

  1. First, coordinate compress the array A to obtain the array B.
  2. The score of A is then \sum_{i=1}^N B_i^M.

Given N, M, and K, compute the sum of scores of all arrays of length N with elements between 1 and K.

EXPLANATION:

Consider an array A with d distinct elements. Its elements will be compressed to [1, d].

Let’s try to fix d, and compute the sum of scores across all arrays with d distinct elements.
First, there are \binom{K}{d} ways to choose which d elements the array will contain; then we need to arrange them.

Let f(N, d) denote the number of arrays of length N containing exactly d distinct elements.

How to compute this?

One way to use the inclusion-exclusion principle.

There are d choices for every index, leading to d^N arrays initially.
However, they’re not all valid - some of them might have some elements not appear at all, since we didn’t constrain that in any way.

For a fixed element, there are (d-1)^N arrays with it missing (and maybe missing other elements too).
There are d ways of choosing the missing element, so we subtract out d\cdot (d-1)^N from the total.

However, arrays with two missing elements have been subtracted out twice, so we’d need to add them back in.
There are \binom{d}{2} to fix which two elements are missing, and then (d-2)^N arrays with them missing; which we add back in.

But then arrays with three missing elements are now counted, so we need to subtract them out; and so on and so forth.

This is a classical case of inclusion-exclusion, and we end up with

f(N, d) = \sum_{i=0}^d (-1)^i \binom{d}{i} (d-i)^N

Since we want a sum of sums, we can look at contributions of elements at each index separately.
So, let’s look at just the first index, which when compressed takes values between 1 and d.

Let x_i denote the number of arrays in which B_1 = x.
If we can compute all the x_i values, then the contribution of this index is simply

x_1 \cdot 1^M + x_2 \cdot 2^M + \ldots + x_d \cdot d^M

This computation can then be repeated for each index from 1 to N, and the answers can all be added up.

Here’s the neat part: it turns out that x_1 = x_2 = \ldots = x_d, meaning each of them will equal
\frac{1}{d} of the total.

Why?

Consider the mapping y \to (y\bmod d) + 1 which cyclically shifts all values modulo d.
It’s easy to see that this is a bijection on the set of arrays we have: after all, the inverse operation is to simply shift all values in the other direction.

This bijection maps all arrays with B_1 = y to arrays with B_1 = (y\bmod d) + 1, and only these arrays - meaning their counts must be equal.
This means x_1 = x_2, x_2 = x_3, \ldots, x_{d-1} = x_d, x_d = x_1.
Putting all the equalities together gives us the original claim of all x_i being equal.

So, for index 1, the overall contribution is

\boxed{\frac{1}{d} \cdot \binom{K}{d} \cdot f(N, d) \cdot \left(1^M + 2^M + \ldots + d^M \right)}

Further, note that this computation didn’t depend on the fact that we were dealing with index 1 at all - if we fix any index, the result will be the same.
So, the sum of scores across all arrays with exactly d distinct elements, is simply the above value multiplied by N.


Now, note that d can be anything between 1 and \min(N, K).
For a fixed d, we want to know:

  1. The sum (1^M + 2^M + \ldots + d^M), which is just a prefix sum.
  2. f(N, d), which can be computed in \mathcal{O}(d) time with the inclusion-exclusion summation we derived for it.
  3. \binom{K}{d}, which can be computed in \mathcal{O}(d) time even though K is very large, using the formula
\binom{K}{d} = \frac{K\cdot (K-1) \cdot (K-2) \cdot\ldots\cdot (K-d+1)}{d!}

So, processing a single d takes \mathcal{O}(d) time; trying all d \leq N and adding up the answers gives us a quadratic solution.

TIME COMPLEXITY:

\mathcal{O}(N^2) per testcase.

CODE:

Preparer's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

const int facN = 1e6 + 5;
const int mod = 998244353;
int ff[facN], iff[facN];
bool facinit = false;

int power(int x, int y){
    if (y == 0) return 1;

    int v = power(x, y / 2);
    v = 1LL * v * v % mod;

    if (y & 1) return 1LL * v * x % mod;
    else return v;
}

void factorialinit(){
    facinit = true;
    ff[0] = iff[0] = 1;

    for (int i = 1; i < facN; i++){
        ff[i] = 1LL * ff[i - 1] * i % mod;
    }

    iff[facN - 1] = power(ff[facN - 1], mod - 2);
    for (int i = facN - 2; i >= 1; i--){
        iff[i] = 1LL * iff[i + 1] * (i + 1) % mod;
    }
}

int C(int n, int r){
    if (!facinit) factorialinit();

    if (n == r) return 1;

    if (r < 0 || r > n) return 0;
    return 1LL * ff[n] * iff[r] % mod * iff[n - r] % mod;
}

int P(int n, int r){
    if (!facinit) factorialinit();

    assert(0 <= r && r <= n);
    return 1LL * ff[n] * iff[n - r] % mod;
}

int Solutions(int n, int r){
    //solutions to x1 + ... + xn = r 
    //xi >= 0

    return C(n + r - 1, n - 1);
}

void Solve() 
{
    int n, m, k; cin >> n >> m >> k;
    
    vector <int> dp(n + 1, 0);
    dp[0] = 1;
    
    for (int i = 1; i <= n; i++){
        vector <int> ndp(n + 1, 0);
        for (int j = 0; j < i; j++){
            ndp[j] += dp[j] * j;
            ndp[j + 1] += dp[j] * (k - j);
        }
        
        dp = ndp;
        for (auto &x : dp) x %= mod;
    }
    
    int ans = 0;
    vector <int> p(n + 1);
    for (int i = 1; i <= n; i++){
        p[i] = power(i, m);
        p[i] += p[i - 1];
        p[i] %= mod;
    }
    
    for (int i = 1; i <= n; i++){
        p[i] *= power(i, mod - 2);
        p[i] %= mod;
        p[i] *= n;
        p[i] %= mod;
        
        ans += dp[i] * p[i];
        ans %= mod;
    }
    
    cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;

template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;

#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "YES" << endl
#define no cout << "NO" << endl

#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)

template<typename T>
void amin(T &a, T b) {
    a = min(a,b);
}

template<typename T>
void amax(T &a, T b) {
    a = max(a,b);
}

#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif

/*



*/

const int MOD = 998244353;
const int N = 5e3 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;

ll fact[N], ifact[N];

ll bexp(ll a, ll b) {
    a %= MOD;
    if (a == 0) return 0;

    ll res = 1;

    while (b) {
        if (b & 1) res = res * a % MOD;
        a = a * a % MOD;
        b >>= 1;
    }

    return res;
}

ll invmod(ll a) {
    return bexp(a, MOD - 2);
}

ll ncr(ll n, ll r) {
    if (n < 0 or r < 0 or n < r) return 0;
    return fact[n] * ifact[r] % MOD * ifact[n - r] % MOD;
}

ll npr(ll n, ll r) {
    if (n < 0 or r < 0 or n < r) return 0;
    return fact[n] * ifact[n - r] % MOD;
}

void precalc(ll n) {
    fact[0] = 1;
    rep1(i, n) fact[i] = fact[i - 1] * i % MOD;

    ifact[n] = invmod(fact[n]);
    rev(i, n - 1, 0) ifact[i] = ifact[i + 1] * (i + 1) % MOD;
}

void solve(int test_case){
    ll n,m,k; cin >> n >> m >> k;

    ll dp[n+5][n+5];
    memset(dp,0,sizeof dp);
    dp[0][0] = 1;
    rep(i,n){
        rep(j,n){
            dp[i+1][j] += dp[i][j]*j;
            dp[i+1][j+1] += dp[i][j]*(j+1);
            dp[i+1][j] %= MOD;
            dp[i+1][j+1] %= MOD;
        }
    }

    vector<ll> choose(n+5);
    rep(i,n+1){
        // choose[i] = ncr(k,i)
        ll res = 1;
        for(int j = k-i+1; j <= k; ++j){
            res = res*j%MOD;
        }
        res = res*ifact[i]%MOD;
        choose[i] = res;
    }

    ll ans = 0;

    rep1(x,n){
        ll ways = 0;
        ways += dp[n-1][x-1]*choose[x];
        ways %= MOD;

        for(int y = x; y <= n; ++y){
            ways += dp[n-1][y]*(choose[y]+choose[y+1]);
            ways %= MOD;
        }
        
        ans += ways*bexp(x,m);
        ans %= MOD;
    }

    ans = ans*n%MOD;
    cout << ans << endl;
}

int main()
{
    precalc(N-1);

    fastio;

    int t = 1;
    cin >> t;

    rep1(i, t) {
        solve(i);
    }

    return 0;
}
Editorialist's code (PyPy3)
mod = 998244353

mxN = 5005
fac = [1]
for n in range(1, mxN):
    fac.append(fac[-1] * n % mod)
invf = fac[:]
for i in range(mxN): invf[i] = pow(invf[i], mod-2, mod)

for _ in range(int(input())):
    n, m, k = map(int, input().split())

    pref = [pow(i, m, mod) for i in range(n+1)]
    for i in range(1, n+1): pref[i] = (pref[i] + pref[i-1]) % mod

    pw = [pow(i, n, mod) for i in range(n+1)]

    choices, ans = 1, 0
    for x in range(1, min(n, k) + 1):
        # exactly x distinct elements
        # C(k, x) ways to choose the elements
        # inc-exc for the arrays
        
        choices = choices * (k+1-x) * pow(x, mod-2, mod) % mod
        arrays = 0
        for i in range(x+1):
            arrays += (-1)**(i%2) * pw[x-i] % mod * fac[x] % mod * invf[i] % mod * invf[x-i] % mod
        arrays = choices * arrays % mod

        # arrays/k of a[1] being 1, 2, 3, ..., k
        # applies to every index, so multiply by n
        ans = (ans + arrays * pow(x, mod-2, mod) * n * pref[x]) % mod
    print(ans)

Hi @iceknight1093 . I would like to know, your opinion on this line of thinking.

Can we eliminate inclusion-exclusion principle by doing this ?

Let f(N,D) denote the number of arrays of length N containing exactly D distinct elements .

Step 1 :
For simplycity we can assume all the numbers are ranging from 1,2,3…D.
Lets define freq_i as frequency of the i’th number. So,

freq_1 + freq_2 + … freq_D = N

( Since sum of all the frequencies is is supposed to be N ).

Step 2 :
Since we know, that size of the set has to be D, means frequency of each number must be greater than 0.

freq_i >= 1 , for 1 <= i <= D.

There for, we can define freq'_i = freq_i - 1 ( because, frequence must be greater than or equal to 1 ).

Step 3 :
We are taking 1 frequency of each number for sure.

=> (freq'_1 + freq'_2 + … + freq'_D) + D = N
=> (freq'_1 + freq'_2 + … + freq'_D) = N - D

So, now, again, we have same problem of stars and bars where we have D-1 pipes to place among (N-D) + (D-1) places.

Would this work ? Is there any mistake in this logic ? Do we need to use inclusion/exclusion principle to calculate f(N,d) now ?

Just want to learn, please help.

I see the problem here. This way, we are counting, what all different frequencies are possible.

For example, for some N = 6,

We are assuming that if freq_1 = 3 , freq_2 = 2, freq_3 = 1 then this will be counted as one only. [ 1 1 1 2 2 3 ] . But ideally, we want to count all the permutations of the array [ 1 1 1 2 2 3 ].

Exactly, the difference is between counting ordered and unordered objects.

When dealing with unordered objects, the only thing that matters is the number of times each object appears - i.e its frequency.
Ordered objects don’t have that nice reduction - though sometimes you can count ordered objects from unordered ones.

This problem is actually an example of that, which leads to another solution using dynamic programming.
To find f(N, d), one way is to look at it as partitioning the set \{1, 2, \ldots, N\} into d ordered non-empty subsets, then assigning the same value to each index within a subset - everything in the first subset gets the value 1, everything in the second gets 2, …
For example if N = 4 and you partition into \{\{1\}, \{2, 4\}, \{3\}\}, the corresponding array is [1, 2, 3, 2], whereas if you partition into \{\{1\}, \{3\}, \{2, 4\}\} the array is [1, 3, 2, 3].

Instead of counting ordered partitions, we can count unordered partitions, and then multiply their count by d!, since for any partition there are d! ways to arrange its subsets.
The number of unordered partitions of \{1, 2, \ldots, N\} into d non-empty subsets is given by S(N, d), which has a pretty simple recurrence relation:

S(N, d) = S(N-1, d-1) + d\cdot S(N-1, d)

because either \{N\} is its own subset in the partition (and the remaining N-1 elements form d-1 subsets), or the remaining N elements form d subsets and N gets added to any one of them; which now gives an easy quadratic DP where you can precompute all values.

S(N, d) is also known as the Stirling numbers of the second kind.

I didn’t present this in the editorial because as far as I’m aware, the DP solution doesn’t really generalize to the harder version - you need to use inclusion-exclusion to get a convolution-like expression to optimize (the Stirling numbers can also be computed using inclusion-exclusion, in fact they’re exactly what the hard version ends up computing as I’ve mentioned at the bottom of that editorial).

1 Like

Thanks a lot for such a detailed explanation.