BNBX - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: hellolad
Tester: kingmessi
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

Dynamic programming, combinatorics, the inclusion-exclusion principle

PROBLEM:

You have several types of balls with you, specifically A_i balls of type i.
Find the number of ways to arrange the balls in a line such that the first ball of type N+1 appears only after the first ball of every other type.

EXPLANATION:

To solve this task, we’ll use the inclusion-exclusion principle.

Define c_i to be the number of arrangements of balls such that at least i types of balls among [1, N] start after type N+1.
Suppose we’re able to compute all these values.
Then, the final answer will simply be

c_0 - c_1 + c_2 - c_3 + \ldots = \sum_{k=0}^N (-1)^k c_k

This is standard inclusion-exclusion: start with all configurations (c_0), subtract out bad ones (c_1), add back in anything that was subtracted twice (c_2, since anything with two types after N+1 would’ve been counted twice), subtract out anything that was added back in one too many times (c_3), and so on.

The question now is, how exactly to compute c_k?

To answer that, suppose we fix a subset of k types of balls, say with counts p_1, p_2, \ldots, p_k.
Let’s try to count the number of arrangements in which these k types of balls all appear only after N+1 appears.
Note that we don’t really care about balls out side the subset, they can appear before or after - this is fine since c_k is defined to be the number of ways in which at least k types appear after N+1, and not exactly k types.

Counting this is now a straightforward combinatorics task.
First, let’s focus on just the k types that appear after N+1, along with N+1 itself.
The first ball among them should be N+1, but everything after it can be arranged in any order at all.
Let’s define S = p_1 + p_2 + \ldots + p_k.
There are then

\binom{A_{N+1}-1 + S}{p_1, p_2, \ldots, p_k, (A_{N+1}-1)} = \frac{(A_{N+1}-1+S)!}{p_1! \cdot p_2! \cdot \ldots \cdot p_k! \cdot (A_{N+1}-1)!}

arrangements.

Now, let’s bring in the balls of other types.
Let their counts be q_1, q_2, \ldots, q_{N-k}, and let T = A_1 + A_2 + \ldots + A_N be the total number of balls other than type N+1.
There’s already the existing arrangement of (S + A_{N+1}) balls, which (for the current purpose) are all of the same type.
So, including the remaining N-k types gives us

\binom{T + A_{N+1}}{q_1, q_2, \ldots, q_{N-k}, (S+A_{N+1})} = \frac{(T+A_N)!}{q_1! \cdot q_2! \cdot \ldots\cdot q_{N-k}! \cdot (S+A_{N+1})!}

So, with these k types fixed to appear after N+1, there are a total of

\frac{(A_{N+1}-1+S)!(T+A_N)!}{p_1! \cdot p_2! \cdot \ldots \cdot p_k! \cdot (A_{N+1}-1)!q_1! \cdot q_2! \cdot \ldots\cdot q_{N-k}! \cdot (S+A_{N+1})!}

arrangements.

This quantity looks rather menacing at first glance.
However, it can be simplified a lot!

  1. The (T+A_N)! term in the numerator is a constant.
  2. The (A_{N+1}-1)! term in the denominator is a constant.
  3. The product p_1!\cdot \ldots\cdot p_k! \cdot q_1! \cdot \ldots q_{N-k}! is simply the product A_1!\cdot A_2\cdot\ldots\cdot A_N!, so that’s a constant too!

Taking out the constants, we’re left with just

\frac{(A_{N+1}-1+S)!}{(S+A_{N+1})!} = \frac{1}{S + A_{N+1}}

which is extremely simple!


We now know what we need to do: for each subset of \{A_1, A_2, \ldots, A_N\} with size k and sum S, we need to add \frac{1}{S + A_{N+1}} to c_k.

This turns into a subset sum task, since all that needs to be done is count the number of subsets with a given size and sum.
The classical approach for this is to define dp(i, j, k) to be the number of subsets of the first i elements with size j and sum k.
Transitions are simple: dp(i, j, k) = dp(i-1, j, k) + dp(i-1, j-1, k-A_i) depending on whether A_i is chosen or not.

However, this approach is a bit too slow here: there are upto \mathcal{O}(N^2 \cdot \text{sum}(A)) states here, which in our case can be as large as 500^4.

To optimize this, observe that we don’t actually care about the actual size of the subset itself: only whether it’s even or odd (since that’s the only thing determining whether it gets added or subtracted in the end).
This allows us to kick out an entire state in the DP, and replace it with just parity.
We now have \mathcal{O}(N\cdot \text{sum}(A))) states with constant transitions for each, which is fast enough.

Once the DP is computed, simply add/subtract the final values appropriately depending on parity, and then multiply the final answer by the constants taken out at the beginning to obtain the final answer.

TIME COMPLEXITY:

\mathcal{O}(N\cdot \text{sum}(A)) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define IOS ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define int long long

const int N=2000001;
const int mod=998244353;

int fact[N], ifact[N];

int bpow(int b, int e){
    int ans = 1;
    for (; e; b = b * b % mod, e /= 2)
        if (e & 1)
            ans = ans * b % mod;
    return ans;
}

void pre(){
    fact[0] = 1;
    for (int i = 1; i < N; i++)
        fact[i] = fact[i - 1] * i % mod;
    ifact[N - 1] = bpow(fact[N - 1], mod - 2);
    for (int i = N - 2; i >= 0; i--)
        ifact[i] = ifact[i + 1] * (i + 1) % mod;
}

int ncr(int n, int r){
    if (r > n){
        return 0;
    }
    return fact[n] * ifact[r] % mod * ifact[n-r] % mod;
}

int32_t main(){
    IOS
    pre();
    int t;
    cin>>t;
    while(t--){
        int n;
        cin>>n;
        int s=0;
        vector<int> a(n+1);
        for(int i=0;i<=n;++i){
            cin>>a[i];
            if(i<n){
                s+=a[i];
            }
        }
        int ways=1;
        int tot=s+a[n];
        for(int i=0;i<n;++i){
            ways=ways*ncr(tot,a[i])%mod;
            tot-=a[i];
        }
        vector<int> dp(s+1);
        dp[0]=mod-1;
        for(int i=0;i<n;++i){
            for(int j=s;j>=a[i];--j){
                dp[j]=(dp[j]-dp[j-a[i]]+mod)%mod;
            }
        }
        int ans=0;
        for(int i=1;i<=s;++i){
            ans=(ans+a[n]*bpow(i+a[n],mod-2)%mod*dp[i]%mod)%mod;
        }
        ans=ans*ways%mod;
        ans=(ways-ans+mod)%mod;
        cout<<ans<<'\n';
    }
    return 0;
}
Tester's code (C++)
#include<bits/stdc++.h>

#include <cassert>
#include <numeric>
#include <type_traits>

#ifdef _MSC_VER
#include <intrin.h>
#endif


#include <utility>

#ifdef _MSC_VER
#include <intrin.h>
#endif

namespace atcoder {

namespace internal {

constexpr long long safe_mod(long long x, long long m) {
    x %= m;
    if (x < 0) x += m;
    return x;
}

struct barrett {
    unsigned int _m;
    unsigned long long im;

    explicit barrett(unsigned int m) : _m(m), im((unsigned long long)(-1) / m + 1) {}

    unsigned int umod() const { return _m; }

    unsigned int mul(unsigned int a, unsigned int b) const {

        unsigned long long z = a;
        z *= b;
#ifdef _MSC_VER
        unsigned long long x;
        _umul128(z, im, &x);
#else
        unsigned long long x =
            (unsigned long long)(((unsigned __int128)(z)*im) >> 64);
#endif
        unsigned long long y = x * _m;
        return (unsigned int)(z - y + (z < y ? _m : 0));
    }
};

constexpr long long pow_mod_constexpr(long long x, long long n, int m) {
    if (m == 1) return 0;
    unsigned int _m = (unsigned int)(m);
    unsigned long long r = 1;
    unsigned long long y = safe_mod(x, m);
    while (n) {
        if (n & 1) r = (r * y) % _m;
        y = (y * y) % _m;
        n >>= 1;
    }
    return r;
}

constexpr bool is_prime_constexpr(int n) {
    if (n <= 1) return false;
    if (n == 2 || n == 7 || n == 61) return true;
    if (n % 2 == 0) return false;
    long long d = n - 1;
    while (d % 2 == 0) d /= 2;
    constexpr long long bases[3] = {2, 7, 61};
    for (long long a : bases) {
        long long t = d;
        long long y = pow_mod_constexpr(a, t, n);
        while (t != n - 1 && y != 1 && y != n - 1) {
            y = y * y % n;
            t <<= 1;
        }
        if (y != n - 1 && t % 2 == 0) {
            return false;
        }
    }
    return true;
}
template <int n> constexpr bool is_prime = is_prime_constexpr(n);

constexpr std::pair<long long, long long> inv_gcd(long long a, long long b) {
    a = safe_mod(a, b);
    if (a == 0) return {b, 0};

    long long s = b, t = a;
    long long m0 = 0, m1 = 1;

    while (t) {
        long long u = s / t;
        s -= t * u;
        m0 -= m1 * u;  // |m1 * u| <= |m1| * s <= b


        auto tmp = s;
        s = t;
        t = tmp;
        tmp = m0;
        m0 = m1;
        m1 = tmp;
    }
    if (m0 < 0) m0 += b / s;
    return {s, m0};
}

constexpr int primitive_root_constexpr(int m) {
    if (m == 2) return 1;
    if (m == 167772161) return 3;
    if (m == 469762049) return 3;
    if (m == 754974721) return 11;
    if (m == 998244353) return 3;
    int divs[20] = {};
    divs[0] = 2;
    int cnt = 1;
    int x = (m - 1) / 2;
    while (x % 2 == 0) x /= 2;
    for (int i = 3; (long long)(i)*i <= x; i += 2) {
        if (x % i == 0) {
            divs[cnt++] = i;
            while (x % i == 0) {
                x /= i;
            }
        }
    }
    if (x > 1) {
        divs[cnt++] = x;
    }
    for (int g = 2;; g++) {
        bool ok = true;
        for (int i = 0; i < cnt; i++) {
            if (pow_mod_constexpr(g, (m - 1) / divs[i], m) == 1) {
                ok = false;
                break;
            }
        }
        if (ok) return g;
    }
}
template <int m> constexpr int primitive_root = primitive_root_constexpr(m);

unsigned long long floor_sum_unsigned(unsigned long long n,
                                      unsigned long long m,
                                      unsigned long long a,
                                      unsigned long long b) {
    unsigned long long ans = 0;
    while (true) {
        if (a >= m) {
            ans += n * (n - 1) / 2 * (a / m);
            a %= m;
        }
        if (b >= m) {
            ans += n * (b / m);
            b %= m;
        }

        unsigned long long y_max = a * n + b;
        if (y_max < m) break;
        n = (unsigned long long)(y_max / m);
        b = (unsigned long long)(y_max % m);
        std::swap(m, a);
    }
    return ans;
}

}  // namespace internal

}  // namespace atcoder


#include <cassert>
#include <numeric>
#include <type_traits>

namespace atcoder {

namespace internal {

#ifndef _MSC_VER
template <class T>
using is_signed_int128 =
    typename std::conditional<std::is_same<T, __int128_t>::value ||
                                  std::is_same<T, __int128>::value,
                              std::true_type,
                              std::false_type>::type;

template <class T>
using is_unsigned_int128 =
    typename std::conditional<std::is_same<T, __uint128_t>::value ||
                                  std::is_same<T, unsigned __int128>::value,
                              std::true_type,
                              std::false_type>::type;

template <class T>
using make_unsigned_int128 =
    typename std::conditional<std::is_same<T, __int128_t>::value,
                              __uint128_t,
                              unsigned __int128>;

template <class T>
using is_integral = typename std::conditional<std::is_integral<T>::value ||
                                                  is_signed_int128<T>::value ||
                                                  is_unsigned_int128<T>::value,
                                              std::true_type,
                                              std::false_type>::type;

template <class T>
using is_signed_int = typename std::conditional<(is_integral<T>::value &&
                                                 std::is_signed<T>::value) ||
                                                    is_signed_int128<T>::value,
                                                std::true_type,
                                                std::false_type>::type;

template <class T>
using is_unsigned_int =
    typename std::conditional<(is_integral<T>::value &&
                               std::is_unsigned<T>::value) ||
                                  is_unsigned_int128<T>::value,
                              std::true_type,
                              std::false_type>::type;

template <class T>
using to_unsigned = typename std::conditional<
    is_signed_int128<T>::value,
    make_unsigned_int128<T>,
    typename std::conditional<std::is_signed<T>::value,
                              std::make_unsigned<T>,
                              std::common_type<T>>::type>::type;

#else

template <class T> using is_integral = typename std::is_integral<T>;

template <class T>
using is_signed_int =
    typename std::conditional<is_integral<T>::value && std::is_signed<T>::value,
                              std::true_type,
                              std::false_type>::type;

template <class T>
using is_unsigned_int =
    typename std::conditional<is_integral<T>::value &&
                                  std::is_unsigned<T>::value,
                              std::true_type,
                              std::false_type>::type;

template <class T>
using to_unsigned = typename std::conditional<is_signed_int<T>::value,
                                              std::make_unsigned<T>,
                                              std::common_type<T>>::type;

#endif

template <class T>
using is_signed_int_t = std::enable_if_t<is_signed_int<T>::value>;

template <class T>
using is_unsigned_int_t = std::enable_if_t<is_unsigned_int<T>::value>;

template <class T> using to_unsigned_t = typename to_unsigned<T>::type;

}  // namespace internal

}  // namespace atcoder


namespace atcoder {

namespace internal {

struct modint_base {};
struct static_modint_base : modint_base {};

template <class T> using is_modint = std::is_base_of<modint_base, T>;
template <class T> using is_modint_t = std::enable_if_t<is_modint<T>::value>;

}  // namespace internal

template <int m, std::enable_if_t<(1 <= m)>* = nullptr>
struct static_modint : internal::static_modint_base {
    using mint = static_modint;

  public:
    static constexpr int mod() { return m; }
    static mint raw(int v) {
        mint x;
        x._v = v;
        return x;
    }

    static_modint() : _v(0) {}
    template <class T, internal::is_signed_int_t<T>* = nullptr>
    static_modint(T v) {
        long long x = (long long)(v % (long long)(umod()));
        if (x < 0) x += umod();
        _v = (unsigned int)(x);
    }
    template <class T, internal::is_unsigned_int_t<T>* = nullptr>
    static_modint(T v) {
        _v = (unsigned int)(v % umod());
    }

    unsigned int val() const { return _v; }

    mint& operator++() {
        _v++;
        if (_v == umod()) _v = 0;
        return *this;
    }
    mint& operator--() {
        if (_v == 0) _v = umod();
        _v--;
        return *this;
    }
    mint operator++(int) {
        mint result = *this;
        ++*this;
        return result;
    }
    mint operator--(int) {
        mint result = *this;
        --*this;
        return result;
    }

    mint& operator+=(const mint& rhs) {
        _v += rhs._v;
        if (_v >= umod()) _v -= umod();
        return *this;
    }
    mint& operator-=(const mint& rhs) {
        _v -= rhs._v;
        if (_v >= umod()) _v += umod();
        return *this;
    }
    mint& operator*=(const mint& rhs) {
        unsigned long long z = _v;
        z *= rhs._v;
        _v = (unsigned int)(z % umod());
        return *this;
    }
    mint& operator/=(const mint& rhs) { return *this = *this * rhs.inv(); }

    mint operator+() const { return *this; }
    mint operator-() const { return mint() - *this; }

    mint pow(long long n) const {
        assert(0 <= n);
        mint x = *this, r = 1;
        while (n) {
            if (n & 1) r *= x;
            x *= x;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        if (prime) {
            assert(_v);
            return pow(umod() - 2);
        } else {
            auto eg = internal::inv_gcd(_v, m);
            assert(eg.first == 1);
            return eg.second;
        }
    }

    friend mint operator+(const mint& lhs, const mint& rhs) {
        return mint(lhs) += rhs;
    }
    friend mint operator-(const mint& lhs, const mint& rhs) {
        return mint(lhs) -= rhs;
    }
    friend mint operator*(const mint& lhs, const mint& rhs) {
        return mint(lhs) *= rhs;
    }
    friend mint operator/(const mint& lhs, const mint& rhs) {
        return mint(lhs) /= rhs;
    }
    friend bool operator==(const mint& lhs, const mint& rhs) {
        return lhs._v == rhs._v;
    }
    friend bool operator!=(const mint& lhs, const mint& rhs) {
        return lhs._v != rhs._v;
    }

  private:
    unsigned int _v;
    static constexpr unsigned int umod() { return m; }
    static constexpr bool prime = internal::is_prime<m>;
};

template <int id> struct dynamic_modint : internal::modint_base {
    using mint = dynamic_modint;

  public:
    static int mod() { return (int)(bt.umod()); }
    static void set_mod(int m) {
        assert(1 <= m);
        bt = internal::barrett(m);
    }
    static mint raw(int v) {
        mint x;
        x._v = v;
        return x;
    }

    dynamic_modint() : _v(0) {}
    template <class T, internal::is_signed_int_t<T>* = nullptr>
    dynamic_modint(T v) {
        long long x = (long long)(v % (long long)(mod()));
        if (x < 0) x += mod();
        _v = (unsigned int)(x);
    }
    template <class T, internal::is_unsigned_int_t<T>* = nullptr>
    dynamic_modint(T v) {
        _v = (unsigned int)(v % mod());
    }

    unsigned int val() const { return _v; }

    mint& operator++() {
        _v++;
        if (_v == umod()) _v = 0;
        return *this;
    }
    mint& operator--() {
        if (_v == 0) _v = umod();
        _v--;
        return *this;
    }
    mint operator++(int) {
        mint result = *this;
        ++*this;
        return result;
    }
    mint operator--(int) {
        mint result = *this;
        --*this;
        return result;
    }

    mint& operator+=(const mint& rhs) {
        _v += rhs._v;
        if (_v >= umod()) _v -= umod();
        return *this;
    }
    mint& operator-=(const mint& rhs) {
        _v += mod() - rhs._v;
        if (_v >= umod()) _v -= umod();
        return *this;
    }
    mint& operator*=(const mint& rhs) {
        _v = bt.mul(_v, rhs._v);
        return *this;
    }
    mint& operator/=(const mint& rhs) { return *this = *this * rhs.inv(); }

    mint operator+() const { return *this; }
    mint operator-() const { return mint() - *this; }

    mint pow(long long n) const {
        assert(0 <= n);
        mint x = *this, r = 1;
        while (n) {
            if (n & 1) r *= x;
            x *= x;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        auto eg = internal::inv_gcd(_v, mod());
        assert(eg.first == 1);
        return eg.second;
    }

    friend mint operator+(const mint& lhs, const mint& rhs) {
        return mint(lhs) += rhs;
    }
    friend mint operator-(const mint& lhs, const mint& rhs) {
        return mint(lhs) -= rhs;
    }
    friend mint operator*(const mint& lhs, const mint& rhs) {
        return mint(lhs) *= rhs;
    }
    friend mint operator/(const mint& lhs, const mint& rhs) {
        return mint(lhs) /= rhs;
    }
    friend bool operator==(const mint& lhs, const mint& rhs) {
        return lhs._v == rhs._v;
    }
    friend bool operator!=(const mint& lhs, const mint& rhs) {
        return lhs._v != rhs._v;
    }

  private:
    unsigned int _v;
    static internal::barrett bt;
    static unsigned int umod() { return bt.umod(); }
};
template <int id> internal::barrett dynamic_modint<id>::bt(998244353);

using modint998244353 = static_modint<998244353>;
using modint1000000007 = static_modint<1000000007>;
using modint = dynamic_modint<-1>;

namespace internal {

template <class T>
using is_static_modint = std::is_base_of<internal::static_modint_base, T>;

template <class T>
using is_static_modint_t = std::enable_if_t<is_static_modint<T>::value>;

template <class> struct is_dynamic_modint : public std::false_type {};
template <int id>
struct is_dynamic_modint<dynamic_modint<id>> : public std::true_type {};

template <class T>
using is_dynamic_modint_t = std::enable_if_t<is_dynamic_modint<T>::value>;

}  // namespace internal

}  // namespace atcoder

#define ll long long
// #define int long long
#define rep(i,a,b) for(int i=a;i<b;i++)
#define rrep(i,a,b) for(int i=a;i>=b;i--)
#define vi vector<int>
#define repin rep(i,0,n)
using namespace std;
using namespace atcoder;

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && !isspace(buffer[now])) {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
}inp;



const long long MM=998244353;
using mint = static_modint<MM>;

mint fact[2000005],invfact[2000005];
void init(){
    fact[0]=1;
    int i;
    for(i=1;i<2000005;i++){
        fact[i]=fact[i-1]*i;
    }
    i--;
}

int smn = 0;
 
void solve()
{
    int n;
    // cin >> n;
    n = inp.readInt(1,500);
    smn += n;
    inp.readEoln();
    vi a(n);
    repin{
        // cin >> a[i];
        a[i] = inp.readInt(1,500);
        inp.readSpace();
    }
    int k;
    // cin >> k;
    k = inp.readInt(1,1000'000);
    inp.readEoln();
    int sm = 0;
    for(auto &x : a)sm += x;
    vector<vector<mint>>dp(2,vector<mint>(sm+1,0));
    dp[0][0] = 1;
    repin{
        vector<vector<mint>>tp(2,vector<mint>(sm+1,0));
        rep(k,a[i],sm+1){
            tp[0][k] += dp[1][k-a[i]];
            tp[1][k] += dp[0][k-a[i]];
        }
        rep(k,0,sm+1){
            tp[0][k] += dp[0][k];
            tp[1][k] += dp[1][k];
        }
        dp = tp;
    }


    mint ans = fact[sm+k];
    for(auto &x : a)ans /= fact[x];
    ans /= fact[k];
    mint pr = 0;
    rep(i,0,sm+1){
        pr += dp[0][i]/(i+k);
        pr -= dp[1][i]/(i+k);
    }
    pr *= k;
    ans *= pr;
    cout << ans.val() << "\n";



}

signed main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    init();
        // int t; cin >> t; 
        int t = inp.readInt(1,500);
        inp.readEoln();
        while(t--)
        solve();
        inp.readEof();
        assert(smn <= 500);
    return 0;
}
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

/**
 * Integers modulo p, where p is a prime
 * Source: Aeren (modified from tourist?)
 *         Modmul for 64-bit mod from kactl:ModMulLL
 * Works with p < 7.2e18 with x87 80-bit long double, and p < 2^52 ~ 4.5e12 with 64-bit
 */
template<typename T>
struct Z_p{
	using Type = typename decay<decltype(T::value)>::type;
	static vector<Type> MOD_INV;
	constexpr Z_p(): value(){ }
	template<typename U> Z_p(const U &x){ value = normalize(x); }
	template<typename U> static Type normalize(const U &x){
		Type v;
		if(-mod() <= x && x < mod()) v = static_cast<Type>(x);
		else v = static_cast<Type>(x % mod());
		if(v < 0) v += mod();
		return v;
	}
	const Type& operator()() const{ return value; }
	template<typename U> explicit operator U() const{ return static_cast<U>(value); }
	constexpr static Type mod(){ return T::value; }
	Z_p &operator+=(const Z_p &otr){ if((value += otr.value) >= mod()) value -= mod(); return *this; }
	Z_p &operator-=(const Z_p &otr){ if((value -= otr.value) < 0) value += mod(); return *this; }
	template<typename U> Z_p &operator+=(const U &otr){ return *this += Z_p(otr); }
	template<typename U> Z_p &operator-=(const U &otr){ return *this -= Z_p(otr); }
	Z_p &operator++(){ return *this += 1; }
	Z_p &operator--(){ return *this -= 1; }
	Z_p operator++(int){ Z_p result(*this); *this += 1; return result; }
	Z_p operator--(int){ Z_p result(*this); *this -= 1; return result; }
	Z_p operator-() const{ return Z_p(-value); }
	template<typename U = T>
	typename enable_if<is_same<typename Z_p<U>::Type, int>::value, Z_p>::type &operator*=(const Z_p& rhs){
		#ifdef _WIN32
		uint64_t x = static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value);
		uint32_t xh = static_cast<uint32_t>(x >> 32), xl = static_cast<uint32_t>(x), d, m;
		asm(
			"divl %4; \n\t"
			: "=a" (d), "=d" (m)
			: "d" (xh), "a" (xl), "r" (mod())
		);
		value = m;
		#else
		value = normalize(static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value));
		#endif
		return *this;
	}
	template<typename U = T>
	typename enable_if<is_same<typename Z_p<U>::Type, int64_t>::value, Z_p>::type &operator*=(const Z_p &rhs){
		uint64_t ret = static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value) - static_cast<uint64_t>(mod()) * static_cast<uint64_t>(1.L / static_cast<uint64_t>(mod()) * static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value));
		value = normalize(static_cast<int64_t>(ret + static_cast<uint64_t>(mod()) * (ret < 0) - static_cast<uint64_t>(mod()) * (ret >= static_cast<uint64_t>(mod()))));
		return *this;
	}
	template<typename U = T>
	typename enable_if<!is_integral<typename Z_p<U>::Type>::value, Z_p>::type &operator*=(const Z_p &rhs){
		value = normalize(value * rhs.value);
		return *this;
	}
	template<typename U>
	Z_p &operator^=(U e){
		if(e < 0) *this = 1 / *this, e = -e;
		Z_p res = 1;
		for(; e; *this *= *this, e >>= 1) if(e & 1) res *= *this;
		return *this = res;
	}
	template<typename U>
	Z_p operator^(U e) const{
		return Z_p(*this) ^= e;
	}
	Z_p &operator/=(const Z_p &otr){
		Type a = otr.value, m = mod(), u = 0, v = 1;
		if(a < (int)MOD_INV.size()) return *this *= MOD_INV[a];
		while(a){
			Type t = m / a;
			m -= t * a; swap(a, m);
			u -= t * v; swap(u, v);
		}
		assert(m == 1);
		return *this *= u;
	}
	template<typename U> friend const Z_p<U> &abs(const Z_p<U> &v){ return v; }
	Type value;
};
template<typename T> bool operator==(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value == rhs.value; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(const Z_p<T>& lhs, U rhs){ return lhs == Z_p<T>(rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) == rhs; }
template<typename T> bool operator!=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(const Z_p<T> &lhs, U rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(U lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T> bool operator<(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value < rhs.value; }
template<typename T> bool operator>(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value > rhs.value; }
template<typename T> bool operator<=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value <= rhs.value; }
template<typename T> bool operator>=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value >= rhs.value; }
template<typename T> Z_p<T> operator+(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(const Z_p<T> &lhs, U rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T> Z_p<T> operator-(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T> Z_p<T> operator*(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T> Z_p<T> operator/(const Z_p<T> &lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(const Z_p<T>& lhs, U rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(U lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T> istream &operator>>(istream &in, Z_p<T> &number){
	typename common_type<typename Z_p<T>::Type, int64_t>::type x;
	in >> x;
	number.value = Z_p<T>::normalize(x);
	return in;
}
template<typename T> ostream &operator<<(ostream &out, const Z_p<T> &number){ return out << number(); }

/*
using ModType = int;
struct VarMod{ static ModType value; };
ModType VarMod::value;
ModType &mod = VarMod::value;
using Zp = Z_p<VarMod>;
*/

// constexpr int mod = 1e9 + 7; // 1000000007
constexpr int mod = (119 << 23) + 1; // 998244353
// constexpr int mod = 1e9 + 9; // 1000000009
using Zp = Z_p<integral_constant<decay<decltype(mod)>::type, mod>>;

template<typename T> vector<typename Z_p<T>::Type> Z_p<T>::MOD_INV;
template<typename T = integral_constant<decay<decltype(mod)>::type, mod>>
void precalc_inverse(int SZ){
	auto &inv = Z_p<T>::MOD_INV;
	if(inv.empty()) inv.assign(2, 1);
	for(; inv.size() <= SZ; ) inv.push_back((mod - 1LL * mod / (int)inv.size() * inv[mod % (int)inv.size()]) % mod);
}

template<typename T>
vector<T> precalc_power(T base, int SZ){
	vector<T> res(SZ + 1, 1);
	for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * base;
	return res;
}

template<typename T>
vector<T> precalc_factorial(int SZ){
	vector<T> res(SZ + 1, 1); res[0] = 1;
	for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * i;
	return res;
}

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    const int N = 2e6 + 5;
    auto fac = precalc_factorial<Zp>(N);

    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector a(n+1, 0);
        for (int &x : a) cin >> x;

		int sm = accumulate(begin(a), end(a), 0) - a.back();
        vector dp(2, vector(sm + 1, Zp(0)));
		dp[0][0] = 1;
		for (int i = 0; i < n; ++i) {
			for (int j = sm; j >= a[i]; --j) {
				dp[0][j] += dp[1][j - a[i]];
				dp[1][j] += dp[0][j - a[i]];
			}
		}

		Zp ans = 0;
		for (int i = 0; i <= sm; ++i) {
			ans += dp[0][i] / (i + a[n]);
			ans -= dp[1][i] / (i + a[n]);
		}
		ans *= fac[sm + a[n]];
		ans /= fac[a[n] - 1];
		for (int i = 0; i < n; ++i)
			ans /= fac[a[i]];
		cout << ans << '\n';
    }
}
3 Likes

Compute P(x) = \prod_{i=1}^n P_i(x) in O(k \log k \log n) where P_i(x) = \sum_{j=1}^{a_i} \frac{x^j}{j! (a_i-j)!} and k = \text{sum}(A). Now answer is \sum_{j=1}^{k} [x^j]P(x) j! (k-j)! \binom{m - 1 + k - j}{m - 1} [my submission].

My GF solution: CodeChef: Practical coding for everyone

I was curious if it should pass since it looks O(n^3log^2n) :thinking:

Much fast than mine. But don’t you think 0.08s is too fast for this constraint?

That’s a question for the contest admin @pols_agyi_pols.

The difference is probably sorting, though any divide and conquer should work. Check this cf comment for more details.

How is the inclusion exclusion principle validly resulting in counting exactly 0 types of balls among [1, N] start after type N + 1?
If we see it the following way:
b_i be the number of arrangements such that exactly i types of balls among [1, N] start after type (N + 1),
then the way c_i is defined is :
c_i = (b_0 + b_1 + … + b_i)
Now this way if you are doing c_0 - c_1 + c_2 - c_3 +…
you will end up getting : b_0 + b_2 + b_4 + …
So essentially the expression of final answer given in the expression is not matching the expected (b_0) but rather counting b_2, b_4 etc. as well.
Please correct me if anything wrong is there in this approach

why this code gives TLE?

void solve(){
   ll n;
   cin>>n;
   vector<ll> a(n+1);
   rep(i,0,n+1)cin>>a[i];
   ll sum=accumulate(all(a),0ll);
   // all possible ways to make a row 
   ll total=fact[sum];
   ll den=1;
   for(auto &i:a){
      total=mod_div(total,fact[i],MOD);
   }
   // count of invalid ways
   a[n]--;
   for(auto &i:a){
      den=mul(den,fact[i]);
   }
   vector<vector<vector<ll>>> dp(n+5,vector<vector<ll>>(sum-a[n]+5,vector<ll>(2,-1)));
   auto func=[&](ll i,ll cnt,ll parity,auto &&func)->ll{
      if(i==n){
        if(cnt==0)return 0;
        ll ways=mul(fact[cnt+a[n]],fact[sum-1-(cnt+a[n])]);
        ways=mul(ways,combination(sum,cnt+1+a[n],MOD,fact,ifact));
        return (!parity?sub(0,ways):ways);
      }
      if(dp[i][cnt][parity]!=-1)return dp[i][cnt][parity];
      ll ways=func(i+1,cnt,parity,func);
      ways=add(ways,func(i+1,cnt+a[i],!parity,func));
      return dp[i][cnt][parity]=ways;
   };
   ll invalid=func(0,0,0,func);
   invalid=mod_div(invalid,den,MOD);
   // finalans = total - invalid
   cout<<sub(total,invalid)<<endl;
}

int main() {
// #ifndef ONLINE_JUDGE
// freopen("Error.txt", "w", stderr);
// #endif
// timeval tp;gettimeofday(&tp, 0);C = (int)tp.tv_usec; // (less than modulo)
// assert((ull)(H(1)*2+1-3) == 0);
fastio();
        ll t;
        cin>>t;
        precompute_factorials();
        while(t--){
            solve();
        }

    return 0;    
}