CHEFSSM - Editorial

PROBLEM LINK:

Practice
Div-1 Contest
Div-2 Contest

Author: Anmol Choudhary
Tester: Istvan Nagy
Editorialist: Anmol Choudhary

DIFFICULTY:

Medium-Hard

PREREQUISITES:

Math, Polynomials, Generating functions, Lagrange Interpolation

PROBLEM:

Given a combination lock with N wheels and i^{th} wheel has integer values from 0 to A_i in ascending order. The lock opens when when there is at least one wheel where either 0 or the largest value on that wheel is selected.
In one operation you can move any wheel cyclically clockwise or counter-clockwise by 1 unit. You have to always open the lock in minimum number of operations.
Initially each wheel shows an arbitrary value uniformly randomly. Find the expected number of operations needs to open the lock.

EXPLANATION:

Let initially i^{th} wheel selects the value X_i and 0 \leq X_i \leq A_i for every 1 \leq i \leq N.
The lock will open when there exist a i such that X_i=0 or X_i=A_i.

Minimum number of operations required to change X_i to 0 or A_i is min(X_i,A_i-X_i). (change X_i to either 0 or A_i).

Hence minimum number of operations required to open the lock =min(min(X_1,A_1-X_1),min(X_2,A_2-X_2),...,min(X_N,A_N-X_N)).

Let’s redefine the problem in mathematical terms-

Let X is a sequence (X_1,X_2,...,X_N) such that 0 \leq X_i \leq A_i for every 1 \leq i \leq N. Let S be the set containing all such sequences X.
F(X)=min(X_1,A_1-X_1,X_2,A_2-X_2,...,X_N,A_N-X_N).
Find E(F(X)).

We know E(F(X))=\sum\limits_{X \in S}{\left(P(X) \cdot F(X)\right)} where P(X) is the probability of getting sequence X initially.

It is given that each wheel shows an arbitrary value uniformly randomly so P(X)=P(Y) for all X,Y \in S.

There are \prod\limits_{i=1}^{N}(A_i+1) sequences in set S.

Hence P(X)= \frac{1}{\prod\limits_{i=1}^{N}(A_i+1)} for every X \in S.

Now we have to deal with \sum\limits_{X \in S}{F(X)}.

Let M=min(A_1,A_2,...,A_N). Clearly minimum value of F(X)=0 and maximum value of F(X)=m=\left\lfloor \frac{M}{2} \right\rfloor.

Let T(x)= number of sequences X such that F(X)=x.

We can write \sum\limits_{X \in S}{F(X)} in terms of T(x) as follows-

\sum\limits_{X \in S}{F(X)}=\sum\limits_{x=0}^{m}{(x\cdot T(x)}).

Now let’s calculate T(x).

Let G(x)= number of sequences X such that F(X) \geq x.
F(X) \geq x holds true only when min(X_i,A_i-X_i) \geq x for every 1 \leq i \leq N.
And min(X_i,A_i-X_i) \geq x holds true only when X_i \in [x,A_i-x].

Therefore G(x)=\prod\limits_{i=1}^{N}(A_i-2 \cdot x+1) and T(x)=G(x)-G(x+1).

\sum\limits_{x=0}^{m}{(x \cdot T(x)})=\sum\limits_{x=0}^{m}{(x \cdot (G(x)-G(x+1))}) =\sum\limits_{x=1}^{m}G(x)-m \cdot G\left(m+1\right).

G\left(m+1\right)=0 because maximum value of F(X) is m

Finally E(F(X))=\frac{\sum\limits_{x=1}^{m}{G(x)}}{\prod\limits_{i=1}^{N}(A_i+1)} where G(x)=\prod\limits_{i=1}^{N}(A_i-2 \cdot x+1)

G(x) is a polynomial of degree N. Let G(x)=\sum\limits_{i=0}^{N}{(G_i \cdot x^i)}. We can compute it’s coefficients (G_0,G_1,..,G_N) using divide and conquer + NTT in O(N log^2N).

\sum\limits_{x=1}^{m}{G(x)}=\sum\limits_{x=1}^{m}{\sum\limits_{i=0}^{N}{(G_i \cdot x^i)}}=\sum\limits_{i=0}^{N} \left(G_i \cdot {\sum\limits_{x=1}^{m}{x^i}}\right)

Let A(k,t)=\sum\limits_{x=1}^{t}{x^k}

\sum\limits_{x=1}^{m}{G(x)}=\sum\limits_{i=0}^{N} G_i \cdot {A(i,m)}

Now the problem is reduced to calculate A(i,m) for every 0 \leq i \leq N.

SUBTASK 1

Brute force in O(m*N).

SUBTASK 2

We know (x+1)^{k+1}=\sum\limits_{i=0}^{k+1}{{k+1 \choose i} \cdot x^i}

\qquad \quad (x+1)^{k+1}-x^{k+1}=\sum\limits_{i=0}^{k}{{k+1 \choose i} \cdot x^i}

put x=1,2,..,m in above equation and add all.

(m+1)^{k+1}-1^{k+1}=\sum\limits_{i=0}^{k}{\left({k+1 \choose i} \cdot \sum\limits_{x=1}^{m}x^i \right)}

\qquad \qquad \qquad \qquad=\sum\limits_{i=0}^{k}{{k+1 \choose i} \cdot A(i,m)}

Therefore A(k,m)=\frac{\left((m+1)^{k+1}-1-\sum\limits_{i=0}^{k-1}{{k+1 \choose i} \cdot A(i,m)}\right)}{k+1}

We can compute A(0,m),A(1,m), \ldots ,A(N,m) in O(N^2).

SUBTASK 3

Consider Exponential Generating function of sequence A(r,m).

Q(x) = \sum\limits_{r = 0}^{\infty} \frac{A(r,m)}{r!} x^r = \sum\limits_{r = 0}^{\infty} \sum\limits_{i = 1}^{m}\frac{i^r}{r!}x^r = \sum\limits_{i = 1}^{m} \sum\limits_{r = 0}^{\infty} \frac{(ix)^r}{r!} = \sum\limits_{i= 1}^{m} e^{ix} = \frac{e^{(m+1)x} - e^x}{e^x - 1}.

Formal power series of e^{(m+1)x}-e^x=\sum\limits_{i=1}^{\infty}((m+1)^i-1) \frac{x^i}{i!}

Similarly e^x-1=\sum\limits_{i=1}^{\infty}\frac{x^i}{i!}

Since we are interested in only first N coefficients we can write
Q(x)=\frac{\sum\limits_{i=0}^{N}((m+1)^{(i+1)}-1) \frac{x^i}{(i+1)!}}{\sum\limits_{i=0}^{N} \frac{x^i}{(i+1)!}}

To evaluate Q(x) just find inverse of denominator and multiply it with numerator.

Q(x)=\sum\limits_{r=0}^{N}Q_r x^r= \sum\limits_{r=0}^{N}{\frac{A(r,m)}{r!}x^r}

A(i,m)=Q_i \cdot i! for every 0 \leq i \leq N

To know more about operations on polynomials refer to the following links
Link
Link

Alternative Solution (Lagrange Interpolation)

Let H(t)=\sum\limits_{x=1}^{t}G(x).

We have to find H(m).

Claim: H(t) is a polynomial of degree N+1.

Proof

H(t)=\sum\limits_{x=1}^{t}G(x)=\sum\limits_{i=0}^{N}{G_i \cdot A(i,t)}
Claim: A(y,z)=\sum\limits_{x=1}^{z}{x^y} is a polynomial of degree y+1 in z. Proof attached here
According to above claim we can say A(i,t) is a polynomial of degree i+1 in t.
Therefore H(t) is a polynomial of degree N+1.

If we have a unknown polynomial f of degree n and n+1 pairs (x_i,y_i) such that f(x_i)=y_i, then we can find exact polynomial using Lagrange Interpolation.

Now to find the polyomial H(t) let’s first find N+2 pairs (x_i,y_i) such that H(x_i)=y_i.

Let x_i=i for all 1 \leq i \leq N+2. Calculate G(1),G(2),...,G(N+2) using multipoint evaluation. y_i is basically prefix sum of these values i.e. y_i=\sum\limits_{i=1}^{x_i}G(i).

Lagrange interpolation tutorial

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
using namespace std;

//https://judge.yosupo.jp/submission/3207
const int N = 100005;
const int mod = 998244353;
struct base
{
    double x, y;
    base() { x = y = 0; }
    base(double x, double y) : x(x), y(y) {}
};
inline base operator+(base a, base b) { return base(a.x + b.x, a.y + b.y); }
inline base operator-(base a, base b) { return base(a.x - b.x, a.y - b.y); }
inline base operator*(base a, base b) { return base(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); }
inline base conj(base a) { return base(a.x, -a.y); }
int lim = 1;
vector<base> roots = {{0, 0}, {1, 0}};
vector<int> rev = {0, 1};
const double PI = acosl(-1.0);
void ensure_base(int p)
{
    if (p <= lim)
        return;
    rev.resize(1 << p);
    for (int i = 0; i < (1 << p); i++)
        rev[i] = (rev[i >> 1] >> 1) + ((i & 1) << (p - 1));
    roots.resize(1 << p);
    while (lim < p)
    {
        double angle = 2 * PI / (1 << (lim + 1));
        for (int i = 1 << (lim - 1); i < (1 << lim); i++)
        {
            roots[i << 1] = roots[i];
            double angle_i = angle * (2 * i + 1 - (1 << lim));
            roots[(i << 1) + 1] = base(cos(angle_i), sin(angle_i));
        }
        lim++;
    }
}
void fft(vector<base> &a, int n = -1)
{
    if (n == -1)
        n = a.size();
    assert((n & (n - 1)) == 0);
    int zeros = __builtin_ctz(n);
    ensure_base(zeros);
    int shift = lim - zeros;
    for (int i = 0; i < n; i++)
        if (i < (rev[i] >> shift))
            swap(a[i], a[rev[i] >> shift]);
    for (int k = 1; k < n; k <<= 1)
    {
        for (int i = 0; i < n; i += 2 * k)
        {
            for (int j = 0; j < k; j++)
            {
                base z = a[i + j + k] * roots[j + k];
                a[i + j + k] = a[i + j] - z;
                a[i + j] = a[i + j] + z;
            }
        }
    }
}
//eq = 0: 4 FFTs in total
//eq = 1: 3 FFTs in total
vector<int> multiply(vector<int> &a, vector<int> &b, int eq = 0)
{
    int need = a.size() + b.size() - 1;
    int p = 0;
    while ((1 << p) < need)
        p++;
    ensure_base(p);
    int sz = 1 << p;
    vector<base> A, B;
    if (sz > (int)A.size())
        A.resize(sz);
    for (int i = 0; i < (int)a.size(); i++)
    {
        int x = (a[i] % mod + mod) % mod;
        A[i] = base(x & ((1 << 15) - 1), x >> 15);
    }
    fill(A.begin() + a.size(), A.begin() + sz, base{0, 0});
    fft(A, sz);
    if (sz > (int)B.size())
        B.resize(sz);
    if (eq)
        copy(A.begin(), A.begin() + sz, B.begin());
    else
    {
        for (int i = 0; i < (int)b.size(); i++)
        {
            int x = (b[i] % mod + mod) % mod;
            B[i] = base(x & ((1 << 15) - 1), x >> 15);
        }
        fill(B.begin() + b.size(), B.begin() + sz, base{0, 0});
        fft(B, sz);
    }
    double ratio = 0.25 / sz;
    base r2(0, -1), r3(ratio, 0), r4(0, -ratio), r5(0, 1);
    for (int i = 0; i <= (sz >> 1); i++)
    {
        int j = (sz - i) & (sz - 1);
        base a1 = (A[i] + conj(A[j])), a2 = (A[i] - conj(A[j])) * r2;
        base b1 = (B[i] + conj(B[j])) * r3, b2 = (B[i] - conj(B[j])) * r4;
        if (i != j)
        {
            base c1 = (A[j] + conj(A[i])), c2 = (A[j] - conj(A[i])) * r2;
            base d1 = (B[j] + conj(B[i])) * r3, d2 = (B[j] - conj(B[i])) * r4;
            A[i] = c1 * d1 + c2 * d2 * r5;
            B[i] = c1 * d2 + c2 * d1;
        }
        A[j] = a1 * b1 + a2 * b2 * r5;
        B[j] = a1 * b2 + a2 * b1;
    }
    fft(A, sz);
    fft(B, sz);
    vector<int> res(need);
    for (int i = 0; i < need; i++)
    {
        long long aa = A[i].x + 0.5;
        long long bb = B[i].x + 0.5;
        long long cc = A[i].y + 0.5;
        res[i] = (aa + ((bb % mod) << 15) + ((cc % mod) << 30)) % mod;
    }
    return res;
}
template <int32_t MOD>
struct modint
{
    int32_t value;
    modint() = default;
    modint(int32_t value_) : value(value_) {}
    inline modint<MOD> operator+(modint<MOD> other) const
    {
        int32_t c = this->value + other.value;
        return modint<MOD>(c >= MOD ? c - MOD : c);
    }
    inline modint<MOD> operator-(modint<MOD> other) const
    {
        int32_t c = this->value - other.value;
        return modint<MOD>(c < 0 ? c + MOD : c);
    }
    inline modint<MOD> operator*(modint<MOD> other) const
    {
        int32_t c = (int64_t)this->value * other.value % MOD;
        return modint<MOD>(c < 0 ? c + MOD : c);
    }
    inline modint<MOD> &operator+=(modint<MOD> other)
    {
        this->value += other.value;
        if (this->value >= MOD)
            this->value -= MOD;
        return *this;
    }
    inline modint<MOD> &operator-=(modint<MOD> other)
    {
        this->value -= other.value;
        if (this->value < 0)
            this->value += MOD;
        return *this;
    }
    inline modint<MOD> &operator*=(modint<MOD> other)
    {
        this->value = (int64_t)this->value * other.value % MOD;
        if (this->value < 0)
            this->value += MOD;
        return *this;
    }
    inline modint<MOD> operator-() const { return modint<MOD>(this->value ? MOD - this->value : 0); }
    modint<MOD> pow(uint64_t k) const
    {
        modint<MOD> x = *this, y = 1;
        for (; k; k >>= 1)
        {
            if (k & 1)
                y *= x;
            x *= x;
        }
        return y;
    }
    modint<MOD> inv() const { return pow(MOD - 2); } // MOD must be a prime
    inline modint<MOD> operator/(modint<MOD> other) const { return *this * other.inv(); }
    inline modint<MOD> operator/=(modint<MOD> other) { return *this *= other.inv(); }
    inline bool operator==(modint<MOD> other) const { return value == other.value; }
    inline bool operator!=(modint<MOD> other) const { return value != other.value; }
    inline bool operator<(modint<MOD> other) const { return value < other.value; }
    inline bool operator>(modint<MOD> other) const { return value > other.value; }
};
template <int32_t MOD>
modint<MOD> operator*(int64_t value, modint<MOD> n) { return modint<MOD>(value) * n; }
template <int32_t MOD>
modint<MOD> operator*(int32_t value, modint<MOD> n) { return modint<MOD>(value % MOD) * n; }
template <int32_t MOD>
ostream &operator<<(ostream &out, modint<MOD> n) { return out << n.value; }

using mint = modint<mod>;
struct poly
{
    vector<mint> a;
    inline void normalize()
    {
        while ((int)a.size() && a.back() == 0)
            a.pop_back();
    }
    template <class... Args>
    poly(Args... args) : a(args...) {}
    poly(const initializer_list<mint> &x) : a(x.begin(), x.end()) {}
    int size() const { return (int)a.size(); }
    inline mint coef(const int i) const { return (i < a.size() && i >= 0) ? a[i] : mint(0); }
    mint operator[](const int i) const { return (i < a.size() && i >= 0) ? a[i] : mint(0); } //Beware!! p[i] = k won't change the value of p.a[i]
    bool is_zero() const
    {
        for (int i = 0; i < size(); i++)
            if (a[i] != 0)
                return 0;
        return 1;
    }
    poly operator+(const poly &x) const
    {
        int n = max(size(), x.size());
        vector<mint> ans(n);
        for (int i = 0; i < n; i++)
            ans[i] = coef(i) + x.coef(i);
        while ((int)ans.size() && ans.back() == 0)
            ans.pop_back();
        return ans;
    }
    poly operator-(const poly &x) const
    {
        int n = max(size(), x.size());
        vector<mint> ans(n);
        for (int i = 0; i < n; i++)
            ans[i] = coef(i) - x.coef(i);
        while ((int)ans.size() && ans.back() == 0)
            ans.pop_back();
        return ans;
    }
    poly operator*(const poly &b) const
    {
        if (is_zero() || b.is_zero())
            return {};
        vector<int> A, B;
        for (auto x : a)
            A.push_back(x.value);
        for (auto x : b.a)
            B.push_back(x.value);
        auto res = multiply(A, B, (A == B));
        vector<mint> ans;
        for (auto x : res)
            ans.push_back(mint(x));
        while ((int)ans.size() && ans.back() == 0)
            ans.pop_back();
        return ans;
    }
    poly operator*(const mint &x) const
    {
        int n = size();
        vector<mint> ans(n);
        for (int i = 0; i < n; i++)
            ans[i] = a[i] * x;
        return ans;
    }
    poly operator/(const mint &x) const { return (*this) * x.inv(); }
    poly &operator+=(const poly &x) { return *this = (*this) + x; }
    poly &operator-=(const poly &x) { return *this = (*this) - x; }
    poly &operator*=(const poly &x) { return *this = (*this) * x; }
    poly &operator*=(const mint &x) { return *this = (*this) * x; }
    poly &operator/=(const mint &x) { return *this = (*this) / x; }
    poly mod_xk(int k) const { return {a.begin(), a.begin() + min(k, size())}; } //modulo by x^k
    poly mul_xk(int k) const
    { // multiply by x^k
        poly ans(*this);
        ans.a.insert(ans.a.begin(), k, 0);
        return ans;
    }
    poly div_xk(int k) const
    { // divide by x^k
        return vector<mint>(a.begin() + min(k, (int)a.size()), a.end());
    }
    poly substr(int l, int r) const
    { // return mod_xk(r).div_xk(l)
        l = min(l, size());
        r = min(r, size());
        return vector<mint>(a.begin() + l, a.begin() + r);
    }
    poly reverse_it(int n, bool rev = 0) const
    { // reverses and leaves only n terms
        poly ans(*this);
        if (rev)
        { // if rev = 1 then tail goes to head
            ans.a.resize(max(n, (int)ans.a.size()));
        }
        reverse(ans.a.begin(), ans.a.end());
        return ans.mod_xk(n);
    }
    poly differentiate() const
    {
        int n = size();
        vector<mint> ans(n);
        for (int i = 1; i < size(); i++)
            ans[i - 1] = coef(i) * i;
        return ans;
    }
    poly inverse(int n) const
    { // 1 / p(x) % x^n, O(nlogn)
        assert(!is_zero());
        assert(a[0] != 0);
        poly ans{mint(1) / a[0]};
        for (int i = 1; i < n; i *= 2)
        {
            ans = (ans * mint(2) - ans * ans * mod_xk(2 * i)).mod_xk(2 * i);
        }
        return ans.mod_xk(n);
    }
    pair<poly, poly> divmod_slow(const poly &b) const
    { // when divisor or quotient is small
        vector<mint> A(a);
        vector<mint> ans;
        while (A.size() >= b.a.size())
        {
            ans.push_back(A.back() / b.a.back());
            if (ans.back() != mint(0))
            {
                for (size_t i = 0; i < b.a.size(); i++)
                {
                    A[A.size() - i - 1] -= ans.back() * b.a[b.a.size() - i - 1];
                }
            }
            A.pop_back();
        }
        reverse(ans.begin(), ans.end());
        return {ans, A};
    }
    pair<poly, poly> divmod(const poly &b) const
    { // returns quotient and remainder of a mod b
        if (size() < b.size())
            return {poly{0}, *this};
        int d = size() - b.size();
        if (min(d, b.size()) < 250)
            return divmod_slow(b);
        poly D = (reverse_it(d + 1) * b.reverse_it(d + 1).inverse(d + 1)).mod_xk(d + 1).reverse_it(d + 1, 1);
        return {D, *this - (D * b)};
    }
    poly operator/(const poly &t) const { return divmod(t).first; }
    poly operator%(const poly &t) const { return divmod(t).second; }
    poly &operator/=(const poly &t) { return *this = divmod(t).first; }
    poly &operator%=(const poly &t) { return *this = divmod(t).second; }
    mint eval(mint x)
    { // evaluates in single point x
        mint ans(0);
        for (int i = (int)size() - 1; i >= 0; i--)
        {
            ans *= x;
            ans += a[i];
        }
        return ans;
    }
    poly build(vector<poly> &ans, int v, int l, int r, vector<mint> &vec)
    { //builds evaluation tree for (x-a1)(x-a2)...(x-an)
        if (l == r)
            return ans[v] = poly({-vec[l], 1});
        int mid = l + r >> 1;
        return ans[v] = build(ans, 2 * v, l, mid, vec) * build(ans, 2 * v + 1, mid + 1, r, vec);
    }
    vector<mint> eval(vector<poly> &tree, int v, int l, int r, vector<mint> &vec)
    { // auxiliary evaluation function
        if (l == r)
            return {eval(vec[l])};
        if (size() < 100)
        {
            vector<mint> ans(r - l + 1, 0);
            for (int i = l; i <= r; i++)
                ans[i - l] = eval(vec[i]);
            return ans;
        }
        int mid = l + r >> 1;
        auto A = (*this % tree[2 * v]).eval(tree, 2 * v, l, mid, vec);
        auto B = (*this % tree[2 * v + 1]).eval(tree, 2 * v + 1, mid + 1, r, vec);
        A.insert(A.end(), B.begin(), B.end());
        return A;
    }
    //O(nlog^2n)
    vector<mint> eval(vector<mint> x)
    { // evaluate polynomial in (x_0, ..., x_n-1)
        int n = x.size();
        if (is_zero())
            return vector<mint>(n, mint(0));
        vector<poly> tree(4 * n);
        build(tree, 1, 0, n - 1, x);
        return eval(tree, 1, 0, n - 1, x);
    }
};

mint a[N];
mint fact[N];

void preprocess()
{
    fact[0] = 1;
    for (int i = 1; i < N; i++)
        fact[i] = (fact[i - 1] * i);
}

poly mult(int l, int r)
{
    if (l == r)
    {
        return poly({a[l] + 1, mod - 2});
    }
    int mid = (l + r) >> 1;
    return mult(l, mid) * mult(mid + 1, r);
}

int main()
{
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    preprocess(); //For calculating factorial values % mod

    int t;
    cin >> t;
    while (t--)
    {
        int n;
        cin >> n;
        mint P = 0, Q = 1; //P = numerator, Q = denominator, Final_ans = P/Q
        int m = INT_MAX;
        for (int i = 1; i <= n; i++)
        {
            int num;
            cin >> num;
            a[i] = num;
            Q *= (a[i] + 1);
            m = min(num, m);
        }
        if (m < 2)
        {
            cout << 0 << '\n';
            continue;
        }
        m /= 2;

        // (a1 - 2x + 1)(a2 - 2x + 1)...(an - 2x + 1)
        poly G = mult(1, n); //Using Divide and conquer, FFT  : O(n.(log^2)(n))

        //Calculation of A(i,m) using exponential generating function :

        vector<mint> f1(n + 1); //     (e^((m+1)x)-e^x)/x
        mint temp = m + 1;
        for (int i = 0; i <= n; i++)
        {
            f1[i] = ((temp - 1) / fact[i + 1]);
            temp *= (m + 1);
        }

        vector<mint> f2(n + 1); //  (e^x-1)/x
        for (int i = 0; i <= n; i++)
            f2[i] = mint(1) / fact[i + 1];

        poly P1(f1);
        poly P2(f2);

        vector<mint> A = (P1 * (P2.inverse(n + 1))).a; //A[i] = A(i,m)

        for (int i = 0; i <= n; i++)
        {
            A[i] *= fact[i];
        }

        // P = Numerator = Σ(G[i].A(i,m) = Σ (G[i] . A[i])
        for (int i = 0; i <= n; i++)
        {
            P += (G[i] * A[i]);
        }
        cout << P / Q << "\n";
    }
    return 0;
}

Tester's Solution
#include <bits/stdc++.h>
#define all(x) (x).begin(), (x).end()
#define forn(i, n) for (int i = 0; i < (int)(n); ++i)
#define fore(i, a, b) for (int i = (int)(a); i <= (int)(b); ++i)

using namespace std;

const uint64_t mod = 998244353;

uint64_t powMod(uint64_t a, uint64_t pw)
{
	uint64_t res(1);
	while (pw)
	{
		if (pw & 1)
		{
			res = (res * a) % mod;
		}
		pw >>= 1;
		a = (a*a) % mod;
	}
	return res;
}

uint64_t inverse(uint64_t a)
{
	return powMod(a, mod - 2);
}

struct NTT
{
	const uint32_t k = 23;
	const uint64_t c = 7 * 17;
	const uint64_t mod = 998244353;// = (2^k)*c+1
	const uint64_t primitiveRoot = 3;
	const uint64_t prc = 15311432;// (primitiveRoot^c)%mod
	vector<uint64_t> wl;
	vector<uint64_t> wlInv;

	NTT()
	{
		wl.resize(k);
		wlInv.resize(k);
		fore(i, 1, k - 1)
		{
			uint64_t pw = 1 << (k - i);
			wl[i] = powMod(prc, pw);
			wlInv[i] = inverse(wl[i]);
		}
	}

	void transform(vector<uint64_t> & a, bool inv)
	{
		size_t n = a.size();

		for (size_t i = 1, j = 0; i < n; ++i)
		{
			size_t bit = n >> 1;
			while (j >= bit)
			{
				j -= bit;
				bit >>= 1;
			}
			j += bit;
			if (i < j)
				swap(a[i], a[j]);
		}

		for (size_t len = 2, pw = 1; len <= n; len <<= 1, ++pw)
		{
			uint64_t wlen = inv ? wlInv[pw] : wl[pw];
			for (size_t i = 0; i < n; i += len)
			{
				uint64_t w = 1;
				for (size_t j = 0; j < len / 2; ++j)
				{
					uint64_t u = a[i + j], v = (a[i + j + len / 2] * w) % mod;
					a[i + j] = u + v < mod ? u + v : u + v - mod;
					a[i + j + len / 2] = u >= v ? u - v : u - v + mod;
					w = (w * wlen) % mod;
				}
			}
		}
		if (inv)
		{
			uint64_t nrev = inverse(static_cast<uint64_t>(n));
			for (int i = 0; i < n; ++i)
				a[i] = (a[i] * nrev) % mod;
		}
	}

	void multiply(vector<uint64_t>& a, const vector<uint64_t>& b)
	{
		size_t n = a.size();
		auto bc = b;
		a.resize(2 * n);
		bc.resize(2 * n);
		transform(a, false);
		transform(bc, false);
		for (size_t i = 0; i < a.size(); ++i)
		{
			a[i] = (a[i] * bc[i]) % mod;
		}
		transform(a, true);
	}

	void  multiply(vector<vector<uint64_t>>& pol)
	{
		while (pol.size() > 1)
		{
			size_t n = pol[0].size();
			size_t d = pol.size() / 2;
			for (size_t i = 0; i < d; ++i)
			{
				size_t opp = pol.size() - 1 - i;
				pol[i].resize(2 * n);
				pol[opp].resize(2 * n);
				transform(pol[i], false);
				transform(pol[opp], false);
				for (size_t j = 0; j < 2 * n; ++j)
				{
					pol[i][j] = (pol[i][j] * pol[opp][j]) % mod;
				}
				transform(pol[i], true);
			}
			pol.resize((pol.size() + 1) / 2);
		}
	}

	vector<uint64_t> inverseTransform(const vector<uint64_t>& a)
	{
		size_t N = a.size();
		vector<uint64_t> b(1);
		b[0] = inverse(a[0]);
		while (2 * b.size() <= N)
		{
			size_t K = 2 * b.size();
			vector<uint64_t> ac(a.begin(), a.begin() + K), bc = b;
			multiply(bc, b);
			multiply(bc, ac);
			bc.resize(K);

			auto newb = vector<uint64_t>(K);
			forn(i, K)
			{
				newb[i] = (mod + (i < b.size() ? 2 * b[i] : 0) - bc[i]) % mod;
			}
			b.swap(newb);
		}
		return b;
	}

	vector<uint64_t> powerSumSeries(int n, int t)
	{
		++t;
		//at least power of n, and sum 1..t
		size_t N = 1;
		while (N <= n)
		{
			N *= 2;
		}
		vector<uint64_t> w(N), u(N);
		uint64_t fact = 1, tp = 1;
		forn(i, N)
		{
			fact = (fact * (i + 1)) % mod;
			tp = (tp * t) % mod;
			uint64_t invFact = inverse(fact);
			w[i] = invFact;
			u[i] = (invFact * tp) % mod;
		}
		auto res = inverseTransform(w);
		multiply(res, u);
		fact = 1;
		forn(i, N)
		{
			res[i] = (res[i] * fact) % mod;
			fact = (fact * (i + 1)) % mod;
		}
		res[0]--;
		return res;
	}

	vector<uint64_t> powerSumSeriesOdd(int n, int t)
	{
		int t2 = t / 2 + 1;
		++t;
		size_t N = 1;
		while (N < n)
		{
			N *= 2;
		}
		vector<uint64_t> w(N), u(N), u2(N);
		uint64_t fact = 1, tp = 1, tp2 = 1;
		forn(i, N)
		{
			fact = (fact * (i + 1)) % mod;
			tp = (tp * t) % mod;
			tp2 = (tp2 * t2) % mod;
			uint64_t invFact = inverse(fact);
			w[i] = invFact;
			u[i] = (invFact * tp) % mod;
			u2[i] = (invFact * tp2) % mod;
		}
		auto res = inverseTransform(w);
		multiply(u, res);
		multiply(u2, res);
		fact = 1;
		forn(i, N)
		{
			u[i] = (u[i] * fact) % mod;
			u2[i] = (((u2[i] * fact) % mod) * powMod(2, i)) % mod;
			fact = (fact * (i + 1)) % mod;
		}
		u[0] = t % mod;
		u2[0] = t2 % mod;
		forn(i, N)
		{
			if (u[i] >= u2[i])
				u[i] -= u2[i];
			else
				u[i] = mod + u[i] - u2[i];
		}
		return u;
	}
};

int main(int argc, char** argv)
{
	NTT ntt;
	ios::sync_with_stdio(false);
	int T;
	cin >> T;

	forn(tc, T)
	{
		int N;
		cin >> N;
		vector<int> A(N);
		vector<vector<uint64_t> > pol, pol2;

		for (auto& ai : A)
		{
			cin >> ai;
			pol.push_back({ static_cast<uint64_t>(ai), mod - 1 });
		}
		sort(all(A));

		uint64_t res = 0;

		ntt.multiply(pol);
		const auto& pr = pol[0];

		vector<uint64_t> pw = ntt.powerSumSeriesOdd(pr.size(), A[0]);

		forn(j, pr.size())
		{
			uint64_t tmp = pr[j];
			tmp = (tmp * pw[j]) % mod;
			res = (res + tmp) % mod;
		}
		forn(i, N)
			res = (inverse(A[i] + 1) * res) % mod;
		cout << res << endl;
	}
	return 0;
}

VIDEO EDITORIAL:

3 Likes

I did brute force it but got a TLE even for 15 points. Any idea why? CodeChef: Practical coding for everyone

A similar question Kick Start - Google’s Coding Competitions

Also, it can be solved using faulhaber’s formula, computing the convulation as explained in the *Computing L part of the editorial of the following CodeChef: Practical coding for everyone

1 Like