EXPDIF-Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4

Setter: Jeevan Jyot Singh
Tester: Satyam, Jatin Garg
Editorialist: Devendra Singh

DIFFICULTY:

2706

PREREQUISITES:

Fast Fourier transform, Dynamic Programming, Divide and Conquer, Combinatorics, Modular multiplication

PROBLEM:

Alice initially had X stones. She divides the X stones into N piles where the i^{th} pile contains A_i stones.

Now, she randomly selects a subset of piles S such that every subset has an equal probability of getting selected.
After selecting the subset, for every 1 \le i \le N,

  • If i \in S, Alice colours all the stones in the i^{th} pile RED.
  • If i \notin S, Alice colours all the stones in the i^{th} pile BLUE.

Let the total number of RED coloured stones be R and the number of BLUE coloured stones be B.
Alice now wonders what the expected value of |R - B| is. Can you help Alice? Output the answer modulo 998244353.

Note that, |x| denotes the absolute value of x. For example, |-4| = 4 and |7|=7.

EXPLANATION:

The problem can be rephrased in simple words as:
Given an array A of length N, we randomly choose a subsequence P and add all the remaining elements to a second subsequence Q. Find expected value of |sum_P - sum_Q| (where sum_X for a subsequence X represents the sum of all the elements of X and |i| represents the absolute value of a number i). Find the answer modulo 998244353.

Let S be the sum of all elements of A (i.e. S = \sum_{i=1}^{N} A_i). Since each subsequence has equal probability of getting selected therefore for each subsequence we calculate |S - 2\cdot sum_P| and in the end divide the sum by 2^N to get the final answer.
The solution in steps:

Method 1: Brute force approach - Iterate over all subsequences P of array A and for each subsequence P formed calculate sum_P and sum_Q = S-sum_P .The time complexity of this approach is O(2^N) which is too slow for such large N.

Method 2: Dynamic Programing

The dp state is defined as follows: dp_i represents the number of ways to get sum equal to i by picking some subset of elements of array A. This dp can be calculated in O(N * S) time using standard knapsack approach.
Example of calculating this dp

Then the final answer can be calculated as follows: ans += dp_i * |S - 2\cdot i| for all 0\leq i\leq S
But since N and S are both large, this is not fast enough to pass the test cases for the given constraints.
Sub-task for Method 2: FFT
We can use FFT to speed up the computation of the dp states. FFT is an algorithm that allows us to multiply two polynomials of length N in O(Nlog⁡N) time,
Let the polynomial P(x) = (1 + x^{A_1}) * (1 + x^{A_2}) * (1 + x^{A_3}) * ... * (1 + x^{A_n}). We can compute P by multiplying the polynomials in a divide and conquer fashion using FFT. The coefficient of x^i in P represents the number of ways to get sum equal to i.
Time complexity of this solution is O(S log(N) log(S)) which is fast enough to pass the test cases.

For details of implementation please refer to the solutions attached.

TIME COMPLEXITY:

O(Xlog(X)log(N)) for each test case.

SOLUTION:

Setter's solution
#ifdef WTSH
    #include <wtsh.h>
#else
    #include <bits/stdc++.h>
    using namespace std;
    #define dbg(...)
#endif

#define endl "\n"
#define sz(w) (int)(w.size())
using pii = pair<int, int>;

const long long INF = 1e18;
const int N = 1e6 + 5; 

// -------------------- Input Checker Start --------------------

long long readInt(long long l, long long r, char endd)
{
    long long x = 0;
    int cnt = 0, fi = -1;
    bool is_neg = false;
    while(true)
    {
        char g = getchar();
        if(g == '-')
        {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if('0' <= g && g <= '9')
        {
            x *= 10;
            x += g - '0';
            if(cnt == 0)
                fi = g - '0';
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);
            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if(g == endd)
        {
            if(is_neg)
                x = -x;
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(false);
            }
            return x;
        }
        else
        {
            assert(false);
        }
    }
}

string readString(int l, int r, char endd)
{
    string ret = "";
    int cnt = 0;
    while(true)
    {
        char g = getchar();
        assert(g != -1);
        if(g == endd)
            break;
        cnt++;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}

vector<int> readVectorInt(int n, long long l, long long r)
{
    vector<int> a(n);
    for(int i = 0; i < n - 1; i++)
        a[i] = readInt(l, r, ' ');
    a[n - 1] = readInt(l, r, '\n');
    return a;
}

void readEOF() { assert(getchar() == EOF); }

// -------------------- Input Checker End --------------------

template<const int& MOD>
struct Mint {
    using T = typename decay<decltype(MOD)>::type; T v;
    Mint(int64_t v = 0) { if(v < 0) v = v % MOD + MOD; if(v >= MOD) v %= MOD; this->v = T(v); }
    Mint(uint64_t v) { if (v >= MOD) v %= MOD; this->v = T(v); }
    Mint(int v): Mint(int64_t(v)) {}
    Mint(unsigned v): Mint(uint64_t(v)) {}
    explicit operator int() const { return v; }
    explicit operator int64_t() const { return v; }
    explicit operator uint64_t() const { return v; }
    friend istream& operator>>(istream& in, Mint& m) { int64_t v_; in >> v_; m = Mint(v_); return in; } 
    friend ostream& operator<<(ostream& os, const Mint& m) { return os << m.v; }

    static T inv(T a, T m) {
        T g = m, x = 0, y = 1;
        while(a != 0) {
            T q = g / a;
            g %= a; swap(g, a);
            x -= q * y; swap(x, y);
        } return x < 0? x + m : x;
    }

    static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
    #if !defined(_WIN32) || defined(_WIN64)
            return unsigned(x % m);
    #endif // x must be less than 2^32 * m
        unsigned x_high = unsigned(x >> 32), x_low = unsigned(x), quot, rem;
        asm("divl %4\n" : "=a" (quot), "=d" (rem) : "d" (x_high), "a" (x_low), "r" (m));
        return rem;
    }

    Mint inv() const { return Mint(inv(v, MOD)); }
    Mint operator-() const { return Mint(v? MOD-v : 0); }
    Mint& operator++() { v++; if(v == MOD) v = 0; return *this; }
    Mint& operator--() { if(v == 0) v = MOD; v--; return *this; }
    Mint operator++(int) { Mint a = *this; ++*this; return a; }
    Mint operator--(int) { Mint a = *this; --*this; return a; }
    Mint& operator+=(const Mint& o) { v += o.v; if (v >= MOD) v -= MOD; return *this; }
    Mint& operator-=(const Mint& o) { v -= o.v; if (v < 0) v += MOD; return *this; }
    Mint& operator*=(const Mint& o) { v = fast_mod(uint64_t(v) * o.v); return *this; }
    Mint& operator/=(const Mint& o) { return *this *= o.inv(); }
    friend Mint operator+(const Mint& a, const Mint& b) { return Mint(a) += b; }
    friend Mint operator-(const Mint& a, const Mint& b) { return Mint(a) -= b; }
    friend Mint operator*(const Mint& a, const Mint& b) { return Mint(a) *= b; }
    friend Mint operator/(const Mint& a, const Mint& b) { return Mint(a) /= b; }
    friend bool operator==(const Mint& a, const Mint& b) { return a.v == b.v; }
    friend bool operator!=(const Mint& a, const Mint& b) { return a.v != b.v; }
    friend bool operator<(const Mint& a, const Mint& b) { return a.v < b.v; }
    friend bool operator>(const Mint& a, const Mint& b) { return a.v > b.v; }
    friend bool operator<=(const Mint& a, const Mint& b) { return a.v <= b.v; }
    friend bool operator>=(const Mint& a, const Mint& b) { return a.v >= b.v; }
    Mint operator^(int64_t p) {
        if(p < 0) return inv() ^ -p;
        Mint a = *this, res{1}; while(p > 0) {
            if(p & 1) res *= a;
            p >>= 1; if(p > 0) a *= a;
        } return res;
    }
};

const int MOD = 998244353;
using mint = Mint<MOD>;

namespace NTT {
    const int FFT_CUTOFF = 150;
    vector<mint> roots = {0, 1};
    vector<int> bit_order;
    int max_size = 1;       // MOD = c∙2^k + 1, maxsize = 2^k
    mint root = 2;          // n must not exceed maxsize

    auto find_root = []() {
        int order = MOD-1;
        while(~order & 1) order >>= 1, max_size <<= 1;
        while((root ^ max_size) != 1 or (root ^ max_size/2) == 1)
            root++;
        return 'W';
    }();

    void prepare_roots(int n) {
        if(roots.size() >= n) return;
        int len = __builtin_ctz(roots.size());
        roots.resize(n);
        while(n > 1 << len) {
            mint z = root ^ max_size >> len + 1;
            for(int i = 1 << len-1; i < 1 << len; i++) {
                roots[i << 1] = roots[i];
                roots[i << 1 | 1] = roots[i] * z;
            } len++;
        }
    }

    void reorder_bits(int n, vector<mint>& a) {
        if(bit_order.size() != n) {
            bit_order.assign(n, 0);
            int len = __builtin_ctz(n);
            for(int i = 0; i < n; i++)
                bit_order[i] = (bit_order[i >> 1] >> 1) + ((i & 1) << len-1);
        }
        for(int i = 0; i < n; i++)
            if(i < bit_order[i]) swap(a[i], a[bit_order[i]]);
    }

    void fft(int n, vector<mint>& a) {
        assert(n <= max_size);
        prepare_roots(n); reorder_bits(n, a);
        for(int len = 1; len < n; len <<= 1)
            for(int i = 0; i < n; i += len << 1)
                for(int j = 0; j < len; j++) {
                    const mint& even = a[i + j];
                    mint odd = a[i + len + j] * roots[len + j];
                    a[i + len + j] = even - odd; a[i + j] = even + odd;
                }
    }

    template<typename T>
    vector<mint> multiply(const vector<T>& a, const vector<T>& b, bool trim = false) {
        int n = a.size(), m = b.size();
        if(n == 0 or m == 0)
            return vector<mint>{0};

        if(min(n, m) < FFT_CUTOFF) {
            vector<mint> res(n + m - 1);
            for(int i = 0; i < n; i++)
                for(int j = 0; j < m; j++)
                    res[i + j] += mint(a[i]) * mint(b[j]);
            
            if(trim) {
                while(!res.empty() && res.back() == 0)
                    res.pop_back();
            }

            return res;
        }

        int N = [](int x) { while(x & x-1) x = (x | x-1) + 1; return x; }(n + m - 1);
        vector<mint> fa(a.begin(), a.end()), fb(b.begin(), b.end());
        fa.resize(N); fb.resize(N);

        bool equal = fa == fb;
        fft(N, fa);

        if(equal) fb = fa;
        else fft(N, fb);

        mint inv_N = mint(N) ^ -1;
        for(int i = 0; i < N; i++)
            fa[i] *= fb[i] * inv_N;
        reverse(fa.begin() + 1, fa.end());

        fft(N, fa);
        fa.resize(n + m - 1);

        if(trim) {
            while(!fa.empty() && fa.back() == 0)
                fa.pop_back();
        }

        return fa;
    }

    template<typename T>
    vector<mint> expo(const vector<T>& a, int e, bool trim = false) {
        int n = a.size();
        int N = [](int x) { while(x & x-1) x = (x | x-1) + 1; return x; }((n-1) * e + 1);
        vector<mint> fa(a.begin(), a.end()); fa.resize(N);

        fft(N, fa);

        mint inv_N = mint(N) ^ -1;
        for(int i = 0; i < N; i++)
            fa[i] = (fa[i] ^ e) * inv_N;
        reverse(fa.begin() + 1, fa.end());

        fft(N, fa);
        fa.resize((n-1) * e + 1);

        if(trim) {
            while(!fa.empty() && fa.back() == 0)
                fa.pop_back();
        }

        return fa;
    }

} // namespace NTT 

vector<mint> recur(int l, int r, vector<vector<mint>> &P)
{
    if(l == r)
        return P[l];
    int mid = (l + r) / 2;
    auto res = NTT::multiply(recur(l, mid, P), recur(mid + 1, r, P));
    return res;
}

int sumX = 0;

void solve()
{
    int x, n;
    x = readInt(1, 3e5, ' ');
    n = readInt(1, x, '\n');
    vector<int> a = readVectorInt(n, 1, x);
    int sum = accumulate(a.begin(), a.end(), 0);
    sumX += x;
    assert(sum == x);
    assert(sumX <= 3e5);

    vector<vector<mint>> P(n);
    for(int i = 0; i < n; i++)
    {
        vector<mint> poly(a[i] + 1);
        poly[0] = poly[a[i]] = 1;
        P[i] = poly;
    }
    auto res = recur(0, n - 1, P);
    mint ans = 0;
    for(int i = 0; i < sz(res); i++)
    {
        mint ways = res[i];
        ans += ways * abs(x - 2 * i);
    }
    mint den = mint(2) ^ n;
    ans /= den;
    cout << ans << endl;
}

int32_t main()
{
    ios::sync_with_stdio(0); 
    cin.tie(0);
    int T = readInt(1, 1000, '\n');
    for(int tc = 1; tc <= T; tc++)
    {
        // cout << "Case #" << tc << ": ";
        solve();
    }
    readEOF();
    assert(sumX <= 3e5);
    return 0;
}
Tester-1's Solution


#include <bits/stdc++.h>
 
#include <cassert>
#include <numeric>
#include <type_traits>
 
#ifdef _MSC_VER
#include <intrin.h>
#endif

#define ll long long
 
 
#include <utility>
 
#ifdef _MSC_VER
#include <intrin.h>
#endif
 
namespace atcoder {
 
namespace internal {
 
constexpr long long safe_mod(long long x, long long m) {
    x %= m;
    if (x < 0) x += m;
    return x;
}
 
struct barrett {
    unsigned int _m;
    unsigned long long im;
 
    barrett(unsigned int m) : _m(m), im((unsigned long long)(-1) / m + 1) {}
 
    unsigned int umod() const { return _m; }
 
    unsigned int mul(unsigned int a, unsigned int b) const {
 
        unsigned long long z = a;
        z *= b;
#ifdef _MSC_VER
        unsigned long long x;
        _umul128(z, im, &x);
#else
        unsigned long long x =
            (unsigned long long)(((unsigned __int128)(z)*im) >> 64);
#endif
        unsigned int v = (unsigned int)(z - x * _m);
        if (_m <= v) v += _m;
        return v;
    }
};
 
constexpr long long pow_mod_constexpr(long long x, long long n, int m) {
    if (m == 1) return 0;
    unsigned int _m = (unsigned int)(m);
    unsigned long long r = 1;
    unsigned long long y = safe_mod(x, m);
    while (n) {
        if (n & 1) r = (r * y) % _m;
        y = (y * y) % _m;
        n >>= 1;
    }
    return r;
}
 
constexpr bool is_prime_constexpr(int n) {
    if (n <= 1) return false;
    if (n == 2 || n == 7 || n == 61) return true;
    if (n % 2 == 0) return false;
    long long d = n - 1;
    while (d % 2 == 0) d /= 2;
    constexpr long long bases[3] = {2, 7, 61};
    for (long long a : bases) {
        long long t = d;
        long long y = pow_mod_constexpr(a, t, n);
        while (t != n - 1 && y != 1 && y != n - 1) {
            y = y * y % n;
            t <<= 1;
        }
        if (y != n - 1 && t % 2 == 0) {
            return false;
        }
    }
    return true;
}
template <int n> constexpr bool is_prime = is_prime_constexpr(n);
 
constexpr std::pair<long long, long long> inv_gcd(long long a, long long b) {
    a = safe_mod(a, b);
    if (a == 0) return {b, 0};
 
    long long s = b, t = a;
    long long m0 = 0, m1 = 1;
 
    while (t) {
        long long u = s / t;
        s -= t * u;
        m0 -= m1 * u;  // |m1 * u| <= |m1| * s <= b
 
 
        auto tmp = s;
        s = t;
        t = tmp;
        tmp = m0;
        m0 = m1;
        m1 = tmp;
    }
    if (m0 < 0) m0 += b / s;
    return {s, m0};
}
 
constexpr int primitive_root_constexpr(int m) {
    if (m == 2) return 1;
    if (m == 167772161) return 3;
    if (m == 469762049) return 3;
    if (m == 754974721) return 11;
    if (m == 998244353) return 3;
    int divs[20] = {};
    divs[0] = 2;
    int cnt = 1;
    int x = (m - 1) / 2;
    while (x % 2 == 0) x /= 2;
    for (int i = 3; (long long)(i)*i <= x; i += 2) {
        if (x % i == 0) {
            divs[cnt++] = i;
            while (x % i == 0) {
                x /= i;
            }
        }
    }
    if (x > 1) {
        divs[cnt++] = x;
    }
    for (int g = 2;; g++) {
        bool ok = true;
        for (int i = 0; i < cnt; i++) {
            if (pow_mod_constexpr(g, (m - 1) / divs[i], m) == 1) {
                ok = false;
                break;
            }
        }
        if (ok) return g;
    }
}
template <int m> constexpr int primitive_root = primitive_root_constexpr(m);
 
}  // namespace internal
 
}  // namespace atcoder
 
 
#include <cassert>
#include <numeric>
#include <type_traits>
 
namespace atcoder {
 
namespace internal {
 
#ifndef _MSC_VER
template <class T>
using is_signed_int128 =
    typename std::conditional<std::is_same<T, __int128_t>::value ||
                                  std::is_same<T, __int128>::value,
                              std::true_type,
                              std::false_type>::type;
 
template <class T>
using is_unsigned_int128 =
    typename std::conditional<std::is_same<T, __uint128_t>::value ||
                                  std::is_same<T, unsigned __int128>::value,
                              std::true_type,
                              std::false_type>::type;
 
template <class T>
using make_unsigned_int128 =
    typename std::conditional<std::is_same<T, __int128_t>::value,
                              __uint128_t,
                              unsigned __int128>;
 
template <class T>
using is_integral = typename std::conditional<std::is_integral<T>::value ||
                                                  is_signed_int128<T>::value ||
                                                  is_unsigned_int128<T>::value,
                                              std::true_type,
                                              std::false_type>::type;
 
template <class T>
using is_signed_int = typename std::conditional<(is_integral<T>::value &&
                                                 std::is_signed<T>::value) ||
                                                    is_signed_int128<T>::value,
                                                std::true_type,
                                                std::false_type>::type;
 
template <class T>
using is_unsigned_int =
    typename std::conditional<(is_integral<T>::value &&
                               std::is_unsigned<T>::value) ||
                                  is_unsigned_int128<T>::value,
                              std::true_type,
                              std::false_type>::type;
 
template <class T>
using to_unsigned = typename std::conditional<
    is_signed_int128<T>::value,
    make_unsigned_int128<T>,
    typename std::conditional<std::is_signed<T>::value,
                              std::make_unsigned<T>,
                              std::common_type<T>>::type>::type;
 
#else
 
template <class T> using is_integral = typename std::is_integral<T>;
 
template <class T>
using is_signed_int =
    typename std::conditional<is_integral<T>::value && std::is_signed<T>::value,
                              std::true_type,
                              std::false_type>::type;
 
template <class T>
using is_unsigned_int =
    typename std::conditional<is_integral<T>::value &&
                                  std::is_unsigned<T>::value,
                              std::true_type,
                              std::false_type>::type;
 
template <class T>
using to_unsigned = typename std::conditional<is_signed_int<T>::value,
                                              std::make_unsigned<T>,
                                              std::common_type<T>>::type;
 
#endif
 
template <class T>
using is_signed_int_t = std::enable_if_t<is_signed_int<T>::value>;
 
template <class T>
using is_unsigned_int_t = std::enable_if_t<is_unsigned_int<T>::value>;
 
template <class T> using to_unsigned_t = typename to_unsigned<T>::type;
 
}  // namespace internal
 
}  // namespace atcoder
 
 
namespace atcoder {
 
namespace internal {
 
struct modint_base {};
struct static_modint_base : modint_base {};
 
template <class T> using is_modint = std::is_base_of<modint_base, T>;
template <class T> using is_modint_t = std::enable_if_t<is_modint<T>::value>;
 
}  // namespace internal
 
template <int m, std::enable_if_t<(1 <= m)>* = nullptr>
struct static_modint : internal::static_modint_base {
    using mint = static_modint;
 
  public:
    static constexpr int mod() { return m; }
    static mint raw(int v) {
        mint x;
        x._v = v;
        return x;
    }
 
    static_modint() : _v(0) {}
    template <class T, internal::is_signed_int_t<T>* = nullptr>
    static_modint(T v) {
        long long x = (long long)(v % (long long)(umod()));
        if (x < 0) x += umod();
        _v = (unsigned int)(x);
    }
    template <class T, internal::is_unsigned_int_t<T>* = nullptr>
    static_modint(T v) {
        _v = (unsigned int)(v % umod());
    }
 
    unsigned int val() const { return _v; }
 
    mint& operator++() {
        _v++;
        if (_v == umod()) _v = 0;
        return *this;
    }
    mint& operator--() {
        if (_v == 0) _v = umod();
        _v--;
        return *this;
    }
    mint operator++(int) {
        mint result = *this;
        ++*this;
        return result;
    }
    mint operator--(int) {
        mint result = *this;
        --*this;
        return result;
    }
 
    mint& operator+=(const mint& rhs) {
        _v += rhs._v;
        if (_v >= umod()) _v -= umod();
        return *this;
    }
    mint& operator-=(const mint& rhs) {
        _v -= rhs._v;
        if (_v >= umod()) _v += umod();
        return *this;
    }
    mint& operator*=(const mint& rhs) {
        unsigned long long z = _v;
        z *= rhs._v;
        _v = (unsigned int)(z % umod());
        return *this;
    }
    mint& operator/=(const mint& rhs) { return *this = *this * rhs.inv(); }
 
    mint operator+() const { return *this; }
    mint operator-() const { return mint() - *this; }
 
    mint pow(long long n) const {
        assert(0 <= n);
        mint x = *this, r = 1;
        while (n) {
            if (n & 1) r *= x;
            x *= x;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        if (prime) {
            assert(_v);
            return pow(umod() - 2);
        } else {
            auto eg = internal::inv_gcd(_v, m);
            assert(eg.first == 1);
            return eg.second;
        }
    }
 
    friend mint operator+(const mint& lhs, const mint& rhs) {
        return mint(lhs) += rhs;
    }
    friend mint operator-(const mint& lhs, const mint& rhs) {
        return mint(lhs) -= rhs;
    }
    friend mint operator*(const mint& lhs, const mint& rhs) {
        return mint(lhs) *= rhs;
    }
    friend mint operator/(const mint& lhs, const mint& rhs) {
        return mint(lhs) /= rhs;
    }
    friend bool operator==(const mint& lhs, const mint& rhs) {
        return lhs._v == rhs._v;
    }
    friend bool operator!=(const mint& lhs, const mint& rhs) {
        return lhs._v != rhs._v;
    }
 
  private:
    unsigned int _v;
    static constexpr unsigned int umod() { return m; }
    static constexpr bool prime = internal::is_prime<m>;
};
 
template <int id> struct dynamic_modint : internal::modint_base {
    using mint = dynamic_modint;
 
  public:
    static int mod() { return (int)(bt.umod()); }
    static void set_mod(int m) {
        assert(1 <= m);
        bt = internal::barrett(m);
    }
    static mint raw(int v) {
        mint x;
        x._v = v;
        return x;
    }
 
    dynamic_modint() : _v(0) {}
    template <class T, internal::is_signed_int_t<T>* = nullptr>
    dynamic_modint(T v) {
        long long x = (long long)(v % (long long)(mod()));
        if (x < 0) x += mod();
        _v = (unsigned int)(x);
    }
    template <class T, internal::is_unsigned_int_t<T>* = nullptr>
    dynamic_modint(T v) {
        _v = (unsigned int)(v % mod());
    }
 
    unsigned int val() const { return _v; }
 
    mint& operator++() {
        _v++;
        if (_v == umod()) _v = 0;
        return *this;
    }
    mint& operator--() {
        if (_v == 0) _v = umod();
        _v--;
        return *this;
    }
    mint operator++(int) {
        mint result = *this;
        ++*this;
        return result;
    }
    mint operator--(int) {
        mint result = *this;
        --*this;
        return result;
    }
 
    mint& operator+=(const mint& rhs) {
        _v += rhs._v;
        if (_v >= umod()) _v -= umod();
        return *this;
    }
    mint& operator-=(const mint& rhs) {
        _v += mod() - rhs._v;
        if (_v >= umod()) _v -= umod();
        return *this;
    }
    mint& operator*=(const mint& rhs) {
        _v = bt.mul(_v, rhs._v);
        return *this;
    }
    mint& operator/=(const mint& rhs) { return *this = *this * rhs.inv(); }
 
    mint operator+() const { return *this; }
    mint operator-() const { return mint() - *this; }
 
    mint pow(long long n) const {
        assert(0 <= n);
        mint x = *this, r = 1;
        while (n) {
            if (n & 1) r *= x;
            x *= x;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        auto eg = internal::inv_gcd(_v, mod());
        assert(eg.first == 1);
        return eg.second;
    }
 
    friend mint operator+(const mint& lhs, const mint& rhs) {
        return mint(lhs) += rhs;
    }
    friend mint operator-(const mint& lhs, const mint& rhs) {
        return mint(lhs) -= rhs;
    }
    friend mint operator*(const mint& lhs, const mint& rhs) {
        return mint(lhs) *= rhs;
    }
    friend mint operator/(const mint& lhs, const mint& rhs) {
        return mint(lhs) /= rhs;
    }
    friend bool operator==(const mint& lhs, const mint& rhs) {
        return lhs._v == rhs._v;
    }
    friend bool operator!=(const mint& lhs, const mint& rhs) {
        return lhs._v != rhs._v;
    }
 
  private:
    unsigned int _v;
    static internal::barrett bt;
    static unsigned int umod() { return bt.umod(); }
};
template <int id> internal::barrett dynamic_modint<id>::bt = 998244353;
 
using modint998244353 = static_modint<998244353>;
using modint754974721 = static_modint<754974721>;
using modint1000000007 = static_modint<1000000007>;
using modint = dynamic_modint<-1>;
 
namespace internal {
 
template <class T>
using is_static_modint = std::is_base_of<internal::static_modint_base, T>;
 
template <class T>
using is_static_modint_t = std::enable_if_t<is_static_modint<T>::value>;
 
template <class> struct is_dynamic_modint : public std::false_type {};
template <int id>
struct is_dynamic_modint<dynamic_modint<id>> : public std::true_type {};
 
template <class T>
using is_dynamic_modint_t = std::enable_if_t<is_dynamic_modint<T>::value>;
 
}  // namespace internal
 
}  // namespace atcoder
 
 
#include <algorithm>
#include <array>
#include <cassert>
#include <type_traits>
#include <vector>
 
 
#ifdef _MSC_VER
#include <intrin.h>
#endif
 
namespace atcoder {
 
namespace internal {
 
int ceil_pow2(int n) {
    int x = 0;
    while ((1U << x) < (unsigned int)(n)) x++;
    return x;
}
 
int bsf(unsigned int n) {
#ifdef _MSC_VER
    unsigned long index;
    _BitScanForward(&index, n);
    return index;
#else
    return __builtin_ctz(n);
#endif
}
 
}  // namespace internal
 
}  // namespace atcoder
 
 
namespace atcoder {
 
namespace internal {
 
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
void butterfly(std::vector<mint>& a) {
    static constexpr int g = internal::primitive_root<mint::mod()>;
    int n = int(a.size());
    int h = internal::ceil_pow2(n);
 
    static bool first = true;
    static mint sum_e[30];  // sum_e[i] = ies[0] * ... * ies[i - 1] * es[i]
    if (first) {
        first = false;
        mint es[30], ies[30];  // es[i]^(2^(2+i)) == 1
        int cnt2 = bsf(mint::mod() - 1);
        mint e = mint(g).pow((mint::mod() - 1) >> cnt2), ie = e.inv();
        for (int i = cnt2; i >= 2; i--) {
            es[i - 2] = e;
            ies[i - 2] = ie;
            e *= e;
            ie *= ie;
        }
        mint now = 1;
        for (int i = 0; i <= cnt2 - 2; i++) {
            sum_e[i] = es[i] * now;
            now *= ies[i];
        }
    }
    for (int ph = 1; ph <= h; ph++) {
        int w = 1 << (ph - 1), p = 1 << (h - ph);
        mint now = 1;
        for (int s = 0; s < w; s++) {
            int offset = s << (h - ph + 1);
            for (int i = 0; i < p; i++) {
                auto l = a[i + offset];
                auto r = a[i + offset + p] * now;
                a[i + offset] = l + r;
                a[i + offset + p] = l - r;
            }
            now *= sum_e[bsf(~(unsigned int)(s))];
        }
    }
}
 
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
void butterfly_inv(std::vector<mint>& a) {
    static constexpr int g = internal::primitive_root<mint::mod()>;
    int n = int(a.size());
    int h = internal::ceil_pow2(n);
 
    static bool first = true;
    static mint sum_ie[30];  // sum_ie[i] = es[0] * ... * es[i - 1] * ies[i]
    if (first) {
        first = false;
        mint es[30], ies[30];  // es[i]^(2^(2+i)) == 1
        int cnt2 = bsf(mint::mod() - 1);
        mint e = mint(g).pow((mint::mod() - 1) >> cnt2), ie = e.inv();
        for (int i = cnt2; i >= 2; i--) {
            es[i - 2] = e;
            ies[i - 2] = ie;
            e *= e;
            ie *= ie;
        }
        mint now = 1;
        for (int i = 0; i <= cnt2 - 2; i++) {
            sum_ie[i] = ies[i] * now;
            now *= es[i];
        }
    }
 
    for (int ph = h; ph >= 1; ph--) {
        int w = 1 << (ph - 1), p = 1 << (h - ph);
        mint inow = 1;
        for (int s = 0; s < w; s++) {
            int offset = s << (h - ph + 1);
            for (int i = 0; i < p; i++) {
                auto l = a[i + offset];
                auto r = a[i + offset + p];
                a[i + offset] = l + r;
                a[i + offset + p] =
                    (unsigned long long)(mint::mod() + l.val() - r.val()) *
                    inow.val();
            }
            inow *= sum_ie[bsf(~(unsigned int)(s))];
        }
    }
}
 
}  // namespace internal
 
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
std::vector<mint> convolution(std::vector<mint> a, std::vector<mint> b) {
    int n = int(a.size()), m = int(b.size());
    if (!n || !m) return {};
    if (std::min(n, m) <= 60) {
        if (n < m) {
            std::swap(n, m);
            std::swap(a, b);
        }
        std::vector<mint> ans(n + m - 1);
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                ans[i + j] += a[i] * b[j];
            }
        }
        return ans;
    }
    int z = 1 << internal::ceil_pow2(n + m - 1);
    a.resize(z);
    internal::butterfly(a);
    b.resize(z);
    internal::butterfly(b);
    for (int i = 0; i < z; i++) {
        a[i] *= b[i];
    }
    internal::butterfly_inv(a);
    a.resize(n + m - 1);
    mint iz = mint(z).inv();
    for (int i = 0; i < n + m - 1; i++) a[i] *= iz;
    return a;
}
 
template <unsigned int mod = 998244353,
          class T,
          std::enable_if_t<internal::is_integral<T>::value>* = nullptr>
std::vector<T> convolution(const std::vector<T>& a, const std::vector<T>& b) {
    int n = int(a.size()), m = int(b.size());
    if (!n || !m) return {};
 
    using mint = static_modint<mod>;
    std::vector<mint> a2(n), b2(m);
    for (int i = 0; i < n; i++) {
        a2[i] = mint(a[i]);
    }
    for (int i = 0; i < m; i++) {
        b2[i] = mint(b[i]);
    }
    auto c2 = convolution(move(a2), move(b2));
    std::vector<T> c(n + m - 1);
    for (int i = 0; i < n + m - 1; i++) {
        c[i] = c2[i].val();
    }
    return c;
}
 
std::vector<long long> convolution_ll(const std::vector<long long>& a,
                                      const std::vector<long long>& b) {
    int n = int(a.size()), m = int(b.size());
    if (!n || !m) return {};
 
    static constexpr unsigned long long MOD1 = 754974721;  // 2^24
    static constexpr unsigned long long MOD2 = 167772161;  // 2^25
    static constexpr unsigned long long MOD3 = 469762049;  // 2^26
    static constexpr unsigned long long M2M3 = MOD2 * MOD3;
    static constexpr unsigned long long M1M3 = MOD1 * MOD3;
    static constexpr unsigned long long M1M2 = MOD1 * MOD2;
    static constexpr unsigned long long M1M2M3 = MOD1 * MOD2 * MOD3;
 
    static constexpr unsigned long long i1 =
        internal::inv_gcd(MOD2 * MOD3, MOD1).second;
    static constexpr unsigned long long i2 =
        internal::inv_gcd(MOD1 * MOD3, MOD2).second;
    static constexpr unsigned long long i3 =
        internal::inv_gcd(MOD1 * MOD2, MOD3).second;
 
    auto c1 = convolution<MOD1>(a, b);
    auto c2 = convolution<MOD2>(a, b);
    auto c3 = convolution<MOD3>(a, b);
 
    std::vector<long long> c(n + m - 1);
    for (int i = 0; i < n + m - 1; i++) {
        unsigned long long x = 0;
        x += (c1[i] * i1) % MOD1 * M2M3;
        x += (c2[i] * i2) % MOD2 * M1M3;
        x += (c3[i] * i3) % MOD3 * M1M2;
        long long diff =
            c1[i] - internal::safe_mod((long long)(x), (long long)(MOD1));
        if (diff < 0) diff += MOD1;
        static constexpr unsigned long long offset[5] = {
            0, 0, M1M2M3, 2 * M1M2M3, 3 * M1M2M3};
        x -= offset[diff % 5];
        c[i] = x;
    }
 
    return c;
}
 
}  // namespace atcoder

ll binpow(ll a,ll b,ll MOD){
    ll ans=1;
    a%=MOD;
    while(b){
        if(b&1)
            ans=(ans*a)%MOD;
        b/=2;
        a=(a*a)%MOD;
    }
    return ans;
}
ll inverse(ll a,ll MOD){
    return binpow(a,MOD-2,MOD);
}
 
using namespace std;
using mint = atcoder::modint998244353;
const ll MOD = 998244353;
const ll MAX=100100; 
int main(){
  ios::sync_with_stdio(false);
  cin.tie(0);
  
  int tc=1; cin>>tc;
  while(tc--){
      ll n,x; cin>>x>>n;
      vector<ll> a(n);
      ll sum=0,z=n;
      for(ll i=0;i<n;i++){
        cin>>a[i]; 
        sum+=a[i]; 
      }
      vector<vector<ll>> track; 
      for(ll i=0;i<n;i++){
        vector<ll> now;
        now.push_back(1); 
        for(ll j=1;j<a[i];j++){
            now.push_back(0); 
        }
        now.push_back(1); 
        track.push_back(now); 
      }
      while(1){
        vector<vector<ll>> anot; 
        ll len=track.size();
        if(n==1){
            break; 
        }
        if(n&1){
            n++;
            vector<ll> z={1};
            track.push_back(z); 
        }
        for(ll i=1;i<n;i+=2){
            auto cur=atcoder::convolution(track[i-1],track[i]);
            anot.push_back(cur); 
        }
        track=anot; 
        n=n/2; 
      }
      ll pos=0,ans=0;
      for(auto it:track[0]){
        ans=(ans+abs(sum-2*pos)*it)%MOD; 
        pos++; 
      }
      ll y=binpow(2,z,MOD);
      ans=(ans*inverse(y,MOD))%MOD;
      cout<<ans<<"\n";

  }
} 
Tester-2's Solution
// Jai Shree Ram  
  
#include<bits/stdc++.h>
using namespace std;

#define rep(i,a,n)     for(int i=a;i<n;i++)
#define ll             long long
#define int            long long
#define pb             push_back
#define all(v)         v.begin(),v.end()
#define endl           "\n"
#define x              first
#define y              second
#define gcd(a,b)       __gcd(a,b)
#define mem1(a)        memset(a,-1,sizeof(a))
#define mem0(a)        memset(a,0,sizeof(a))
#define sz(a)          (int)a.size()
#define pii            pair<int,int>
#define hell           1000000007
#define elasped_time   1.0 * clock() / CLOCKS_PER_SEC



template<typename T1,typename T2>istream& operator>>(istream& in,pair<T1,T2> &a){in>>a.x>>a.y;return in;}
template<typename T1,typename T2>ostream& operator<<(ostream& out,pair<T1,T2> a){out<<a.x<<" "<<a.y;return out;}
template<typename T,typename T1>T maxs(T &a,T1 b){if(b>a)a=b;return a;}
template<typename T,typename T1>T mins(T &a,T1 b){if(b<a)a=b;return a;}




const int MOD = 998244353;

 
struct mod_int {
    int val;
 
    mod_int(long long v = 0) {
        if (v < 0)
            v = v % MOD + MOD;
 
        if (v >= MOD)
            v %= MOD;
 
        val = v;
    }
 
    static int mod_inv(int a, int m = MOD) {
        int g = m, r = a, x = 0, y = 1;
 
        while (r != 0) {
            int q = g / r;
            g %= r; swap(g, r);
            x -= q * y; swap(x, y);
        }
 
        return x < 0 ? x + m : x;
    }
 
    explicit operator int() const {
        return val;
    }
 
    mod_int& operator+=(const mod_int &other) {
        val += other.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
 
    mod_int& operator-=(const mod_int &other) {
        val -= other.val;
        if (val < 0) val += MOD;
        return *this;
    }
 
    static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
           #if !defined(_WIN32) || defined(_WIN64)
                return x % m;
           #endif
           unsigned x_high = x >> 32, x_low = (unsigned) x;
           unsigned quot, rem;
           asm("divl %4\n"
            : "=a" (quot), "=d" (rem)
            : "d" (x_high), "a" (x_low), "r" (m));
           return rem;
    }
 
    mod_int& operator*=(const mod_int &other) {
        val = fast_mod((uint64_t) val * other.val);
        return *this;
    }
 
    mod_int& operator/=(const mod_int &other) {
        return *this *= other.inv();
    }
 
    friend mod_int operator+(const mod_int &a, const mod_int &b) { return mod_int(a) += b; }
    friend mod_int operator-(const mod_int &a, const mod_int &b) { return mod_int(a) -= b; }
    friend mod_int operator*(const mod_int &a, const mod_int &b) { return mod_int(a) *= b; }
    friend mod_int operator/(const mod_int &a, const mod_int &b) { return mod_int(a) /= b; }
 
    mod_int& operator++() {
        val = val == MOD - 1 ? 0 : val + 1;
        return *this;
    }
 
    mod_int& operator--() {
        val = val == 0 ? MOD - 1 : val - 1;
        return *this;
    }
 
    mod_int operator++(int32_t) { mod_int before = *this; ++*this; return before; }
    mod_int operator--(int32_t) { mod_int before = *this; --*this; return before; }
 	
    mod_int operator-() const {
        return val == 0 ? 0 : MOD - val;
    }
    bool operator < (const mod_int &other) const {return this->val < other.val;}
 
    bool operator==(const mod_int &other) const { return val == other.val; }
    bool operator!=(const mod_int &other) const { return val != other.val; }
 
    mod_int inv() const {
        return mod_inv(val);
    }
 
    mod_int pow(long long p) const {
        assert(p >= 0);
        mod_int a = *this, result = 1;
 
        while (p > 0) {
            if (p & 1)
                result *= a;
 
            a *= a;
            p >>= 1;
        }
 
        return result;
    }
 
    friend ostream& operator<<(ostream &stream, const mod_int &m) {
        return stream << m.val;
    }
    friend istream& operator >> (istream &stream, mod_int &m) {
        return stream>>m.val;   
    }
};


typedef mod_int base;


namespace algebra {
	const int inf = 1e9;
	const int magic = 200; // threshold for sizes to run the naive algo
	
	namespace NTT {
	vector<mod_int> roots = {0, 1};
		vector<int> bit_reverse;
		int max_size = -1;
		mod_int root;
	 
		bool is_power_of_two(int n) {
			return (n & (n - 1)) == 0;
		}
	 
		int round_up_power_two(int n) {
			while (n & (n - 1))
				n = (n | (n - 1)) + 1;
	 
			return max(n, 1LL);
		}
	 
		// Given n (a power of two), finds k such that n == 1 << k.
		int get_length(int n) {
			assert(is_power_of_two(n));
			return __builtin_ctz(n);
		}
	 
		// Rearranges the indices to be sorted by lowest bit first, then second lowest, etc., rather than highest bit first.
		// This makes even-odd div-conquer much easier.
		void bit_reorder(int n, vector<mod_int> &values) {
			if ((int) bit_reverse.size() != n) {
				bit_reverse.assign(n, 0);
				int length = get_length(n);
	 
				for (int i = 0; i < n; i++)
					bit_reverse[i] = (bit_reverse[i >> 1] >> 1) + ((i & 1) << (length - 1));
			}
	 
			for (int i = 0; i < n; i++)
				if (i < bit_reverse[i])
					swap(values[i], values[bit_reverse[i]]);
		}
	 
		void find_root() {
			max_size = 1 << __builtin_ctz(MOD - 1);
			root = 2;
	 
			// Find a max_size-th primitive root of MOD.
			while (!(root.pow(max_size) == 1 && root.pow(max_size / 2) != 1))
				root++;
		}
	 
		void prepare_roots(int n) {
			if (max_size < 0)
				find_root();
	 
			assert(n <= max_size);
	 
			if ((int) roots.size() >= n)
				return;
	 
			int length = get_length(roots.size());
			roots.resize(n);
	 
			// The roots array is set up such that for a given power of two n >= 2, roots[n / 2] through roots[n - 1] are
			// the first half of the n-th primitive roots of MOD.
			while (1 << length < n) {
				// z is a 2^(length + 1)-th primitive root of MOD.
				mod_int z = root.pow(max_size >> (length + 1));
	 
				for (int i = 1 << (length - 1); i < 1 << length; i++) {
					roots[2 * i] = roots[i];
					roots[2 * i + 1] = roots[i] * z;
				}
	 
				length++;
			}
		}
	 
		void fft_iterative(int N, vector<mod_int> &values) {
			assert(is_power_of_two(N));
			prepare_roots(N);
			bit_reorder(N, values);
	 
			for (int n = 1; n < N; n *= 2)
				for (int start = 0; start < N; start += 2 * n)
					for (int i = 0; i < n; i++) {
						mod_int even = values[start + i];
						mod_int odd = values[start + n + i] * roots[n + i];
						values[start + n + i] = even - odd;
						values[start + i] = even + odd;
					}
		}
	 
		const int FFT_CUTOFF = magic;
	 
		vector<mod_int> mod_multiply(vector<mod_int> left, vector<mod_int> right) {
			int n = left.size();
			int m = right.size();
	 
			// Brute force when either n or m is small enough.
			if (min(n, m) < FFT_CUTOFF) {
				const uint64_t ULL_BOUND = numeric_limits<uint64_t>::max() - (uint64_t) MOD * MOD;
				vector<uint64_t> result(n + m - 1, 0);
	 
				for (int i = 0; i < n; i++)
					for (int j = 0; j < m; j++) {
						result[i + j] += (uint64_t) ((int) left[i]) * ((int) right[j]);
	 
						if (result[i + j] > ULL_BOUND)
							result[i + j] %= MOD;
					}
	 
				for (uint64_t &x : result)
					if (x >= MOD)
						x %= MOD;
	 
				return vector<mod_int>(result.begin(), result.end());
			}
	 
			int N = round_up_power_two(n + m - 1);
			left.resize(N);
			right.resize(N);
	 
			bool equal = left == right;
			fft_iterative(N, left);
	 
			if (equal)
				right = left;
			else
				fft_iterative(N, right);
	 
			mod_int inv_N = mod_int(N).inv();
	 
			for (int i = 0; i < N; i++)
				left[i] *= right[i] * inv_N;
	 
			reverse(left.begin() + 1, left.end());
			fft_iterative(N, left);
			left.resize(n + m - 1);
			return left;
		}
		template<typename T>
		void mul(vector<T> &a,const vector<T> &b){
			a = mod_multiply(a,b);
		}
	}
	
	
	template<typename T>
	struct poly {
		vector<T> a;
		
		void normalize() { // get rid of leading zeroes
			while(!a.empty() && a.back() == T(0)) {
				a.pop_back();
			}
		}
		
		poly(){}
		poly(T a0) : a{a0}{normalize();}
		poly(vector<T> t) : a(t){normalize();}
		
		poly operator += (const poly &t) {
			a.resize(max(a.size(), t.a.size()));
			for(size_t i = 0; i < t.a.size(); i++) {
				a[i] += t.a[i];
			}
			normalize();
			return *this;
		}
		poly operator -= (const poly &t) {
			a.resize(max(a.size(), t.a.size()));
			for(size_t i = 0; i < t.a.size(); i++) {
				a[i] -= t.a[i];
			}
			normalize();
			return *this;
		}

		bool operator < (const poly &rhs)const {return this->a.size() < rhs.a.size();}
		poly operator + (const poly &t) const {return poly(*this) += t;}
		poly operator - (const poly &t) const {return poly(*this) -= t;}
		
		poly mod_xk(size_t k) const { // get same polynomial mod x^k
			k = min(k, a.size());
			return vector<T>(begin(a), begin(a) + k);
		}
		poly mul_xk(size_t k) const { // multiply by x^k
			poly res(*this);
			res.a.insert(begin(res.a), k, 0);
			return res;
		}
		poly div_xk(size_t k) const { // divide by x^k, dropping coefficients
			k = min(k, a.size());
			return vector<T>(begin(a) + k, end(a));
		}
		poly substr(size_t l, size_t r) const { // return mod_xk(r).div_xk(l)
			l = min(l, a.size());
			r = min(r, a.size());
			return vector<T>(begin(a) + l, begin(a) + r);
		}
		poly inv(size_t n) const { // get inverse series mod x^n
			assert(!is_zero());
			poly ans = a[0].inv();
			size_t a = 1;
			while(a < n) {
				poly C = (ans * mod_xk(2 * a)).substr(a, 2 * a);
				ans -= (ans * C).mod_xk(a).mul_xk(a);
				a *= 2;
			}
			ans.a.resize(n);
			return ans;
		}
		poly sqrt(size_t n) const { 
			int m = a.size();
			poly b ={1};
			size_t s = 1;
			while(s<n){
				vector<T> x(a.begin(),a.begin()+min(a.size(),b.a.size()<<1));
				b.a.resize(b.a.size()<<1);
				poly c(x);
				c *=b.inv(b.a.size());
				T inv2 = static_cast<T>(1)/static_cast<T>(2);
				for(int i = b.a.size()>>1;i<min(c.a.size(),b.a.size());i++){
					b.a[i] = c.a[i]*inv2;
				}
				s *= 2;
			}
			b.a.resize(n);
			return b;
		}
		
		poly operator *= (const poly &t) {NTT::mul(a, t.a); normalize(); return *this;}
		poly operator * (const poly &t) const {return poly(*this) *= t;}
		
		poly reverse(size_t n, bool rev = 0) const { // reverses and leaves only n terms
			poly res(*this);
			if(rev) { // If rev = 1 then tail goes to head
				res.a.resize(max(n, res.a.size()));
			}
			std::reverse(res.a.begin(), res.a.end());
			return res.mod_xk(n);
		}
		
		pair<poly, poly> divmod_slow(const poly &b) const { // when divisor or quotient is small
			vector<T> A(a);
			vector<T> res;
			while(A.size() >= b.a.size()) {
				res.push_back(A.back() / b.a.back());
				if(res.back() != T(0)) {
					for(size_t i = 0; i < b.a.size(); i++) {
						A[A.size() - i - 1] -= res.back() * b.a[b.a.size() - i - 1];
					}
				}
				A.pop_back();
			}
			std::reverse(begin(res), end(res));
			return {res, A};
		}
		
		pair<poly, poly> divmod(const poly &b) const { // returns quotiend and remainder of a mod b
			if(deg() < b.deg()) {
				return {poly{0}, *this};
			}
			int d = deg() - b.deg();
			if(min(d, b.deg()) < magic) {
				return divmod_slow(b);
			}
			poly D = (reverse(d + 1) * b.reverse(d + 1).inv(d + 1)).mod_xk(d + 1).reverse(d + 1, 1);
			return {D, *this - D * b};
		}
		
		poly operator / (const poly &t) const {return divmod(t).first;}
		poly operator % (const poly &t) const {return divmod(t).second;}
		poly operator /= (const poly &t) {return *this = divmod(t).first;}
		poly operator %= (const poly &t) {return *this = divmod(t).second;}
		poly operator *= (const T &x) {
			for(auto &it: a) {
				it *= x;
			}
			normalize();
			return *this;
		}
		poly operator /= (const T &x) {
			for(auto &it: a) {
				it /= x;
			}
			normalize();
			return *this;
		}
		poly operator * (const T &x) const {return poly(*this) *= x;}
		poly operator / (const T &x) const {return poly(*this) /= x;}
		
		void print() const {
			for(auto it: a) {
				cout << it << ' ';
			}
			cout << endl;
		}
		T eval(T x) const { // evaluates in single point x
			T res(0);
			for(int i = sz(a) - 1; i >= 0; i--) {
				res *= x;
				res += a[i];
			}
			return res;
		}
		
		T& lead() { // leading coefficient
			return a.back();
		}
		int deg() const { // degree
			return a.empty() ? -inf : a.size() - 1;
		}
		bool is_zero() const { // is polynomial zero
			return a.empty();
		}
		T operator [](int idx) const {
			return idx >= (int)a.size() || idx < 0 ? T(0) : a[idx];
		}
		
		T& coef(size_t idx) { // mutable reference at coefficient
			return a[idx];
		}
		bool operator == (const poly &t) const {return a == t.a;}
		bool operator != (const poly &t) const {return a != t.a;}
		
		poly deriv() { // calculate derivative
			vector<T> res;
			for(int i = 1; i <= deg(); i++) {
				res.push_back(T(i) * a[i]);
			}
			return res;
		}
		poly integr() { // calculate integral with C = 0
			vector<T> res = {0};
			for(int i = 0; i <= deg(); i++) {
				res.push_back(a[i] / T(i + 1));
			}
			return res;
		}
		size_t leading_xk() const { // Let p(x) = x^k * t(x), return k
			if(is_zero()) {
				return inf;
			}
			int res = 0;
			while(a[res] == T(0)) {
				res++;
			}
			return res;
		}
		poly log(size_t n) { // calculate log p(x) mod x^n
			assert(a[0] == T(1));
			return (deriv().mod_xk(n) * inv(n)).integr().mod_xk(n);
		}
		poly exp(size_t n) { // calculate exp p(x) mod x^n
			if(is_zero()) {
				return T(1);
			}
			assert(a[0] == T(0));
			poly ans = T(1);
			size_t a = 1;
			while(a < n) {
				poly C = ans.log(2 * a).div_xk(a) - substr(a, 2 * a);
				ans -= (ans * C).mod_xk(a).mul_xk(a);
				a *= 2;
			}
			return ans.mod_xk(n);
			
		}
		poly pow_slow(size_t k, size_t n) { // if k is small
			return k ? k % 2 ? (*this * pow_slow(k - 1, n)).mod_xk(n) : (*this * *this).mod_xk(n).pow_slow(k / 2, n) : T(1);
		}
		poly pow(size_t k, size_t n) { // calculate p^k(n) mod x^n
			if(is_zero()) {
				return *this;
			}
			if(k < magic) {
				return pow_slow(k, n);
			}
			int i = leading_xk();
			T j = a[i];
			poly t = div_xk(i) / j;
			return bpow(j, k) * (t.log(n) * T(k)).exp(n).mul_xk(i * k).mod_xk(n);
		}
		poly mulx(T x) { // component-wise multiplication with x^k
			T cur = 1;
			poly res(*this);
			for(int i = 0; i <= deg(); i++) {
				res.coef(i) *= cur;
				cur *= x;
			}
			return res;
		}
		poly mulx_sq(T x) { // component-wise multiplication with x^{k^2}
			T cur = x;
			T total = 1;
			T xx = x * x;
			poly res(*this);
			for(int i = 0; i <= deg(); i++) {
				res.coef(i) *= total;
				total *= cur;
				cur *= xx;
			}
			return res;
		}
		vector<T> chirpz_even(T z, int n) { // P(1), P(z^2), P(z^4), ..., P(z^2(n-1))
			int m = deg();
			if(is_zero()) {
				return vector<T>(n, 0);
			}
			vector<T> vv(m + n);
			T zi = z.inv();
			T zz = zi * zi;
			T cur = zi;
			T total = 1;
			for(int i = 0; i <= max(n - 1, m); i++) {
				if(i <= m) {vv[m - i] = total;}
				if(i < n) {vv[m + i] = total;}
				total *= cur;
				cur *= zz;
			}
			poly w = (mulx_sq(z) * vv).substr(m, m + n).mulx_sq(z);
			vector<T> res(n);
			for(int i = 0; i < n; i++) {
				res[i] = w[i];
			}
			return res;
		}
		vector<T> chirpz(T z, int n) { // P(1), P(z), P(z^2), ..., P(z^(n-1))
			auto even = chirpz_even(z, (n + 1) / 2);
			auto odd = mulx(z).chirpz_even(z, n / 2);
			vector<T> ans(n);
			for(int i = 0; i < n / 2; i++) {
				ans[2 * i] = even[i];
				ans[2 * i + 1] = odd[i];
			}
			if(n % 2 == 1) {
				ans[n - 1] = even.back();
			}
			return ans;
		}
		template<typename iter>
		vector<T> eval(vector<poly> &tree, int v, iter l, iter r) { // auxiliary evaluation function
			if(r - l == 1) {
				return {eval(*l)};
			} else {
				auto m = l + (r - l) / 2;
				auto A = (*this % tree[2 * v]).eval(tree, 2 * v, l, m);
				auto B = (*this % tree[2 * v + 1]).eval(tree, 2 * v + 1, m, r);
				A.insert(end(A), begin(B), end(B));
				return A;
			}
		}
		vector<T> eval(vector<T> x) { // evaluate polynomial in (x1, ..., xn)
			int n = x.size();
			if(is_zero()) {
				return vector<T>(n, T(0));
			}
			vector<poly> tree(4 * n);
			build(tree, 1, begin(x), end(x));
			return eval(tree, 1, begin(x), end(x));
		}
		template<typename iter>
		poly inter(vector<poly> &tree, int v, iter l, iter r, iter ly, iter ry) { // auxiliary interpolation function
			if(r - l == 1) {
				return {*ly / a[0]};
			} else {
				auto m = l + (r - l) / 2;
				auto my = ly + (ry - ly) / 2;
				auto A = (*this % tree[2 * v]).inter(tree, 2 * v, l, m, ly, my);
				auto B = (*this % tree[2 * v + 1]).inter(tree, 2 * v + 1, m, r, my, ry);
				return A * tree[2 * v + 1] + B * tree[2 * v];
			}
		}
	};
	template<typename T>
	poly<T> operator * (const T& a, const poly<T>& b) {
		return b * a;
	}
	
	template<typename T>
	poly<T> xk(int k) { // return x^k
		return poly<T>{1}.mul_xk(k);
	}

	template<typename T>
	T resultant(poly<T> a, poly<T> b) { // computes resultant of a and b
		if(b.is_zero()) {
			return 0;
		} else if(b.deg() == 0) {
			return bpow(b.lead(), a.deg());
		} else {
			int pw = a.deg();
			a %= b;
			pw -= a.deg();
			T mul = bpow(b.lead(), pw) * T((b.deg() & a.deg() & 1) ? -1 : 1);
			T ans = resultant(b, a);
			return ans * mul;
		}
	}
	template<typename iter>
	poly<typename iter::value_type> kmul(iter L, iter R) { // computes (x-a1)(x-a2)...(x-an) without building tree
		if(R - L == 1) {
			return vector<typename iter::value_type>{-*L, 1};
		} else {
			iter M = L + (R - L) / 2;
			return kmul(L, M) * kmul(M, R);
		}
	}
	template<typename T, typename iter>
	poly<T> build(vector<poly<T>> &res, int v, iter L, iter R) { // builds evaluation tree for (x-a1)(x-a2)...(x-an)
		if(R - L == 1) {
			return res[v] = vector<T>{-*L, 1};
		} else {
			iter M = L + (R - L) / 2;
			return res[v] = build(res, 2 * v, L, M) * build(res, 2 * v + 1, M, R);
		}
	}
	template<typename T>
	poly<T> inter(vector<T> x, vector<T> y) { // interpolates minimum polynomial from (xi, yi) pairs
		int n = x.size();
		vector<poly<T>> tree(4 * n);
		return build(tree, 1, begin(x), end(x)).deriv().inter(tree, 1, begin(x), end(x), begin(y), end(y));
	}
};

using namespace algebra;

typedef poly<base> polyn;

using namespace algebra;



// -------------------- Input Checker Start --------------------
 
long long readInt(long long l, long long r, char endd)
{
    long long x = 0;
    int cnt = 0, fi = -1;
    bool is_neg = false;
    while(true)
    {
        char g = getchar();
        if(g == '-')
        {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if('0' <= g && g <= '9')
        {
            x *= 10;
            x += g - '0';
            if(cnt == 0)
                fi = g - '0';
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);
            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if(g == endd)
        {
            if(is_neg)
                x = -x;
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(false);
            }
            return x;
        }
        else
        {
            assert(false);
        }
    }
}
 
string readString(int l, int r, char endd)
{
    string ret = "";
    int cnt = 0;
    while(true)
    {
        char g = getchar();
        assert(g != -1);
        if(g == endd)
            break;
        cnt++;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}
 
long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
void readEOF() { assert(getchar() == EOF); }
 
vector<int> readVectorInt(int n, long long l, long long r)
{
    vector<int> a(n);
    for(int i = 0; i < n - 1; i++)
        a[i] = readIntSp(l, r);
    a[n - 1] = readIntLn(l, r);
    return a;
}
 
// -------------------- Input Checker End --------------------

int sum_x = 0;
int sum_n = 0;


int solve(){
 		int x = readIntSp(1,3e5);
 		int n = readIntLn(1,x);
 		multiset<pair<int,polyn>> ms;
 		vector<int> st = readVectorInt(n,1,x);
 		int sum = 0;
 		sum_n += n;
 		sum_x += x;
 		for(int i = 0; i < n; i++){
 			int c = st[i];
 			sum += c;
 			polyn p;
 			p.a.resize(c + 1);
 			p.a[c] = 1;
 			p.a[0] = 1;
 			ms.insert(make_pair(c + 1,p));
 		}
 		assert(sum == x);
 		while(ms.size() > 1){
 			auto p1 = *ms.begin();
 			ms.erase(ms.find(p1));
 			auto p2 = *ms.begin();
 			ms.erase(ms.find(p2));
 			ms.insert({p1.x + p2.x,p1.y*p2.y});
 		}
 		auto p = (*ms.begin()).y;
 		mod_int ans = 0;
 		mod_int val = 1/mod_int(2).pow(n);
 		for(int i = 0; i < p.a.size(); i++){
 			if(2*i < x)ans += p.a[i]*(x - 2*i);
 			else ans += p.a[i]*(2*i - x);
 		}
 		cout << ans*val << endl;

 return 0;
}
signed main(){
    ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    //freopen("input.txt", "r", stdin);
    //freopen("output.txt", "w", stdout);
    #ifdef SIEVE
    sieve();
    #endif
    #ifdef NCR
    init();
    #endif
    int t = readIntLn(1,1e3);
    while(t--){
        solve();
    }
    assert(max(sum_x,sum_n) <= 3e5);
    return 0;
}
 

Can you please give detail proof of TC of algo?

  • I understand multiplying p(x) (degree p) and Q(x) (degree q) takes max(p, q)log(max(p, q)) time.
  • As per Divide and conquer, T(n) = 2 * T(n/2) + max(sum(a1+…an2), )sum(an/2+…an))log(…)
  • I am not able to formaly prove that its SlogSlogN.
  • It’s just how is max() part changing as we keep on going deeper in divide.
  • lets say di be max degree of polynomials being multiplied at division depth i
  • Final TC naively: T(n) = 2 * T(n/2) + d1logd1
  • T(n) = 2*( 2*T(n/2^2) + d2logd2) + d1logd1
  • T(n) = 2^h(T(n/2^h) + d1logd1+2*d2logd2+2^2d3…2^h * dhlogdh
  • lets amortised log(di) ~ log(X) also 2^h ~ n or h ~ logn
  • T(n) ~ (d1 + 2*d2…2^hdh) * logX
  • how is d1 + 2*d2…2^hdh ~ n * X? That’s I am not able to formally proof
  • Note di is max degree of two polynomial at depth i of divide conquer. Sum(ai) = X but not di. di ar different from ai.
    KIndly help!

As I understand it, FFT gives you the S log S part. You just ignore coefficients with exponent greater than S, they won’t help. The log N comes from doing log N products of two polynomials. Every time you do a product, again you get rid of coefficients with exponent greater than S, since they are useless.
This way you have log N FFT’s, and you get O(S log S log N).

@lbm364dl We are not doing logn products, treat it as kind of merge sort we do total O(n) merges count. its just that size of merge depends on depth of “divide” approach. so we get nlogn TC instead of naively multiplying merge count * n which will give n*n.

  • Merge sort T(n) = 2*T(n/2) + O(n) → T(n) = nLogn
  • Not because Logn merges are done . Total Merge Count is O(n) but the size to be merged at depth is say (just for example) is n/2^i and it’s contribution in sum is 2^i.
  • Hence TC: sum 2^i * (n/2^i) till 2^i ~ n or i~logn → Sum (n) till i~logn
  • That’s how we get TC: Nlogn, Merge count is still O(n)
  • SImilary product count PC(n) = 2 * P(n/2) + 1
  • PC: 2*( 2*P(n/2^2) + 1 ) +1
  • PC: 2^h * P(n/2^h) + 1 + 2 + 2^2 …2^h
  • PC ~ 2^(h+1)-1 ~ O(n). , So product is not called logn times but like O(n) times (to be precise 2*n times), so your logic is not correct!
1 Like

Sorry you’re right, my logic was wrong. I’m also interested in knowing how they got that complexity then.