MAT - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Practice

Setter: Vivek mishra
Tester: Aryan
Editorialist: Taranpreet Singh

DIFFICULTY

Medium

PREREQUISITES

Number Theoretic Transform, Probability, and Expected Value.

PROBLEM

There are K schools numbered from 1 to K. For some competitions, each school is invited with probability P_i independently from other schools. If invited, the said school must send exactly r students from A_i students studying in their school.

Find the expected number of ways the teams can be formed. Probabilities are given modulo M = 998244353, and find the answer mod M.

QUICK EXPLANATION

  • For each school, is selected with probability P_i has \displaystyle\binom{A_i}{r} ways to select students, and if not selected, only 1 way. So Expected number of ways is \displaystyle P_i* \binom{A_i}{r} + (1-P_i)*1
  • Since the events are independent, the expected number of ways is \displaystyle \prod_{i = 1}^K \bigg[ P_i * \binom{A_i}{r} + (1-P_i)\bigg]
  • For computing \displaystyle\binom{A_i}{r} for each A_i, we define polynomial \displaystyle Q(x) = \prod_{i = 1}^r (x+i), then \displaystyle\binom{A_i}{r} = \frac{Q(A_i)}{r!}
  • So we need to evaluate an r-degree polynomial at K points, which is a well-known problem.

EXPLANATION

Mathematics Section

Solving for K = 1

Thereā€™s only one school, which may be invited with probability P_1, and has A_1 students. There are two possibilities:

  • If the school is selected, there are \displaystyle\binom{A_1}{r} ways to select students, which happend with probability P_1.
  • If the school is not selected, no students are sent, which happens in exactly 1 way. This happens with probability 1-P_1

Hence, the expected number of ways is given by \displaystyle P_1 * \binom{A_1}{r} + (1-P_1)

Solving for K > 1

Letā€™s say there are E[X_1] ways to select students from first school, and E[X_2] ways to select students from second school. We can see from Fundamental Principle of Counting that we need to compute \displaystyle E \bigg[\prod_{i = 1}^K X_i \bigg]

Since the schools are chosen independently from each other, We know, for independent events X and Y, E[X*Y] = E[X]*E[Y], which implies that
S = \displaystyle E \bigg[\prod_{i = 1}^K X_i \bigg] = \prod_{i = 1}^K E[X_i] = \prod_{i = 1}^K \bigg[P_i*\binom{A_i}{r} + (1-P_i)\bigg]

Hence, if we can compute \displaystyle \binom{A_i}{r} \bmod M quickly, we have solved the problem.

Computing \displaystyle \binom{N}{r} \bmod M

It is worth to have read this blog and this answer, which I believe is one of the best content written on the topic.

Subtask 1, r \leq 100

Since the r is small, we can write \displaystyle \binom{N}{r} = \frac{\prod_{i = 1}^r (N-r+i)}{r!}. The benefit of writing \displaystyle \binom{N}{r} this way is that both numberator and denominator is product of r terms, hence we can compute \displaystyle\binom{N}{r} in O(r) time, making the overall solution O(K*r) which is sufficient for subtask 1, but not for subtask 2.

The final subtask

Now, r \leq 10^5, so we cannot use above solution. One fact we havenā€™t used is that we need to compute \displaystyle \binom{A_i}{r} for each A_i with a fixed r.

If we try the method in subtask 1, the denominator r! is fixed, calculated easily, and the numerator is the product of r consecutive integers ending at A_i

Letā€™s write a polynomial \displaystyle Q(x) = \prod_{i = 1}^r (x+i), the product of r consecutive integers greater than x.

It is easy to see that the numerator in \displaystyle \binom{N}{r} = \frac{\prod_{i = 1}^r (N-r+i)}{r!} is nothing, but Q(A_i - r), and denominator is r!, so we get \displaystyle \binom{A_i}{r} = \frac{Q(A_i - r)}{r!}

Hence, we have a polynomial Q(x) with degree r, and we need to evaluate it at points A_i - r for each 1 \leq i \leq K

Computing Q(x)

It is worth trying the second subtask of the problem LUCASTH, which uses the same type of polynomial.

In order to compute Q(x), if we multiply (x+i) one by one, it shall lead to complexity O(r^2) which is too much.

But we can use Divide and Conquer here. Initially, there are r polynomials of degree 1. Divide them into pairs and multiply each pair. Now we get \lceil r/2 \rceil polynomials with degree up to 2. Repeating this, we get \lceil r/4 \rceil polynomials with degree up to 4. This process shall be repeated log(r) times. We can prove that each layer takes O(r*log(r)) time, either using the masterā€™s theorem or as follows.

In ith step, \lceil r / 2^i \rceil pairs of polynomials of degree 2^{i-1} are being multiplied, each multiplication taking \lceil r / 2^i \rceil * i*2^i = r*i \sim O(r*log(r)).

Hence, we can compute Q(x) in O(r*log^2(r)) time.

Alternatively, One way commonly used is to maintain a heap of polynomials, where polynomials are ordered by degree, smallest degree polynomial at the top. Repeated multiplying the top two polynomials until there are at least 2 polynomials also yield similar time complexity.

Computing Q(x) at K points.

This is a well-known problem by name Multi-point Evaluation of polynomial, explained here, this answer does a good job summarizing the process concisely.

Letā€™s say B_i = A_i-r for each i, so we need to evaluate Q(x) at each value in B.
The core idea is that we build a segment tree styled structure, where node corresponding to range [L, R] represent product P_{l, r}(x) = \displaystyle\prod_{i = L}^R (x-B_i).

Starting with Q(x) at root node, we descend down the tree, passing down Q(x) \bmod P_{l, r}(x) if moving to node with range [l, r], since we have Q(x) = Q(B_i) \bmod (x-B_i). Polynomial operations like division and modulo are required, which are also described here

Additionally, you can test your implementation here, or refer to the fastest implementations as well.

Feel free to refer to implementations below, a warning from tester ā€œcode may be too long, you have been warned.ā€

TIME COMPLEXITY

The time complexity is O(r*log^2(r) + k*log^2(k))

SOLUTIONS

Setter's Solution
#include <cstdio>
#include <cctype>
#include <algorithm>
#include <tuple>
#include <cstring>
#include <cmath>
#include <vector>
#include<iostream>
#include <cstdint>
#include <assert.h> 
#include <bits/stdc++.h>
typedef unsigned int uint;
typedef long long unsigned int uint64;
 
constexpr uint Max_size = 1 << 22 | 32;
constexpr uint g = 3, Mod = 998244353;
static constexpr uint map[4] = {0, Mod, Mod * 2, Mod * 3};
inline uint norm_2(const uint x)
{
    return x - map[x >> 30];
}
inline uint norm(const uint x)
{
    return x < Mod ? x : x - Mod;
}
inline uint mult_2(const uint x1, const uint x2)
{
    uint64 x = static_cast<uint64>(x1) * x2;
    constexpr uint64 base = (-1ULL) / Mod;
    return static_cast<uint>(x) - Mod * static_cast<uint>((static_cast<__uint128_t>(x) * base) >> 64);
}
 
struct Z
{
    uint v;
    Z() { }
    Z(const uint _v) : v(_v) { }
};
 
inline Z operator+(const Z x1, const Z x2) { return norm(x1.v + x2.v); }
inline Z operator-(const Z x1, const Z x2) { return norm(x1.v + Mod - x2.v); }
inline Z operator-(const Z x) { return x.v ? Mod - x.v : 0; }
inline Z operator*(const Z x1, const Z x2) { return static_cast<uint64>(x1.v) * x2.v % Mod; }
inline Z &operator+=(Z &x1, const Z x2) { return x1 = x1 + x2; }
inline Z &operator-=(Z &x1, const Z x2) { return x1 = x1 - x2; }
inline Z &operator*=(Z &x1, const Z x2) { return x1 = x1 * x2; }
 
inline Z mult_2(const Z x1, const Z x2)
{
    return mult_2(x1.v, x2.v);
}
 
inline Z Power(Z Base, int Exp)
{
    Z res = 1;
    for (; Exp; Base *= Base, Exp >>= 1)
	    if (Exp & 1)
		    res *= Base;
    return res;
}
 
inline Z Rec(const Z x)
{
    return Power(x, Mod - 2);
}
 
Z _Rec[Max_size];
 
void init_Rec(const int n)
{
    _Rec[1] = 1;
    for (int i = 2; i != n; ++i)
	    _Rec[i] = _Rec[i - 1] * (i - 1);
    Z R = Rec(_Rec[n - 1] * (n - 1));
    for (int i = n - 1; i != 1; --i)
	    _Rec[i] *= R, R *= i;
}
 
struct Shoup
{
    uint v, q;
    Shoup() = default;
    explicit Shoup(const uint _v) : v(_v), q((static_cast<uint64>(_v) << 32) / Mod) { }
};
 
int size;
Shoup w[Max_size], rw[Max_size];
 
inline uint mult_Shoup_2(const uint x, const Shoup y)
{
    uint q = static_cast<uint64>(x) * y.q >> 32;
    return x * y.v - q * Mod;
}
 
inline uint mult_Shoup(const uint x, const Shoup y)
{
    return norm(mult_Shoup_2(x, y));
}
 
inline uint mult_Shoup_q(const uint x, const Shoup y)
{
    uint q = static_cast<uint64>(x) * y.q >> 32;
    return q + (x * y.v - q * Mod >= Mod);
}
 
void init_w(const int n)
{
    for (size = 2; size < n; size <<= 1)
	    ;
    Shoup pr = Shoup(Power(g, (Mod - 1) / size).v);
    Shoup rpr = Shoup(Rec(pr.v).v);
    size >>= 1;
    w[size] = Shoup(1);
    rw[size] = Shoup(1);
    for (int i = 1; i < size; ++i)
    {
	    w[size + i] = Shoup(mult_Shoup(w[size + i - 1].v, pr));
	    rw[size + i] = Shoup(mult_Shoup(rw[size + i - 1].v, rpr));
    }
    for (int i = size - 1; i; --i)
    {
	    w[i] = w[i * 2];
	    rw[i] = rw[i * 2];
    }
    size <<= 1;
}
void DFT_fr_2(Z _A[], const int L)
{
    if (L == 1)
	    return;
    uint *A = reinterpret_cast<uint *>(_A);
#define butterfly1(a, b)\
    do\
    {\
	    uint _a = a, _b = b;\
	    uint x = norm_2(_a + _b), y = norm_2(_a + Mod * 2 - _b);\
	    a = x, b = y;\
    } while (0)
    if (L == 2)
    {
	    butterfly1(A[0], A[1]);
	    return;
    }
#define butterfly(a, b, _w)\
    do\
    {\
	    uint _a = a, _b = b;\
	    uint x = norm_2(_a + _b), y = mult_Shoup_2(_a + Mod * 2 - _b, _w);\
	    a = x, b = y;\
    } while (0)
    if (L == 4)
    {
	    butterfly1(A[0], A[2]);
	    butterfly(A[1], A[3], w[3]);
	    butterfly1(A[0], A[1]);
	    butterfly1(A[2], A[3]);
	    return;
    }
    for (int d = L >> 1; d != 4; d >>= 1)
	    for (int i = 0; i != L; i += d << 1)
		    for (int j = 0; j != d; j += 4)
		    {
			    butterfly(A[i + j + 0], A[i + d + j + 0], w[d + j + 0]);
			    butterfly(A[i + j + 1], A[i + d + j + 1], w[d + j + 1]);
			    butterfly(A[i + j + 2], A[i + d + j + 2], w[d + j + 2]);
			    butterfly(A[i + j + 3], A[i + d + j + 3], w[d + j + 3]);
		    }
    for (int i = 0; i != L; i += 8)
    {
	    butterfly1(A[i + 0], A[i + 4]);
	    butterfly(A[i + 1], A[i + 5], w[5]);
	    butterfly(A[i + 2], A[i + 6], w[6]);
	    butterfly(A[i + 3], A[i + 7], w[7]);
 
	    butterfly1(A[i + 0], A[i + 2]);
	    butterfly(A[i + 1], A[i + 3], w[3]);
	    butterfly1(A[i + 4], A[i + 6]);
	    butterfly(A[i + 5], A[i + 7], w[3]);
 
	    butterfly1(A[i + 0], A[i + 1]);
	    butterfly1(A[i + 2], A[i + 3]);
	    butterfly1(A[i + 4], A[i + 5]);
	    butterfly1(A[i + 6], A[i + 7]);
    }
#undef butterfly1
#undef butterfly
}
 
void DFT_fr(Z _A[], const int L)
{
    DFT_fr_2(_A, L);
    for (int i = 0; i != L; ++i)
	    _A[i] = norm(_A[i].v);
}
 
void IDFT_fr(Z _A[], const int L)
{
    if (L == 1)
	    return;
    uint *A = reinterpret_cast<uint *>(_A);
 
#define butterfly1(a, b)\
    do\
    {\
	    uint _a = a, _b = b;\
	    uint x = norm_2(_a), t = norm_2(_b);\
	    a = x + t, b = x + Mod * 2 - t;\
    } while (0)
    if (L == 2)
    {
	    butterfly1(A[0], A[1]);
	    A[0] = norm(norm_2(A[0])), A[0] = A[0] & 1 ? A[0] + Mod : A[0], A[0] /= 2;
	    A[1] = norm(norm_2(A[1])), A[1] = A[1] & 1 ? A[1] + Mod : A[1], A[1] /= 2;
	    return;
    }
 
#define butterfly(a, b, _w)\
    do\
    {\
	    uint _a = a, _b = b;\
	    uint x = norm_2(_a), t = mult_Shoup_2(_b, _w);\
	    a = x + t, b = x + Mod * 2 - t;\
    } while (0)
    if (L == 4)
    {
	    butterfly1(A[0], A[1]);
	    butterfly1(A[2], A[3]);
	    butterfly1(A[0], A[2]);
	    butterfly(A[1], A[3], rw[3]);
	    for (int i = 0; i != L; ++i)
	    {
		    uint64 m = -A[i] & 3;
		    A[i] = norm((A[i] + m * Mod) >> 2);
	    }
	    return;
    }
    for (int i = 0; i != L; i += 8)
    {
	    butterfly1(A[i + 0], A[i + 1]);
	    butterfly1(A[i + 2], A[i + 3]);
	    butterfly1(A[i + 4], A[i + 5]);
	    butterfly1(A[i + 6], A[i + 7]);
 
	    butterfly1(A[i + 0], A[i + 2]);
	    butterfly(A[i + 1], A[i + 3], rw[3]);
	    butterfly1(A[i + 4], A[i + 6]);
	    butterfly(A[i + 5], A[i + 7], rw[3]);
 
	    butterfly1(A[i + 0], A[i + 4]);
	    butterfly(A[i + 1], A[i + 5], rw[5]);
	    butterfly(A[i + 2], A[i + 6], rw[6]);
	    butterfly(A[i + 3], A[i + 7], rw[7]);
    }
    for (int d = 8; d != L; d <<= 1)
	    for (int i = 0; i != L; i += d << 1)
		    for (int j = 0; j != d; j += 4)
		    {
			    butterfly(A[i + j + 0], A[i + d + j + 0], rw[d + j + 0]);
			    butterfly(A[i + j + 1], A[i + d + j + 1], rw[d + j + 1]);
			    butterfly(A[i + j + 2], A[i + d + j + 2], rw[d + j + 2]);
			    butterfly(A[i + j + 3], A[i + d + j + 3], rw[d + j + 3]);
		    }
#undef butterfly1
#undef butterfly
    int k = __builtin_ctz(L);
    for (int i = 0; i != L; ++i)
    {
	    uint64 m = -A[i] & (L - 1);
	    A[i] = norm((A[i] + m * Mod) >> k);
    }
}
 
void IDFT_fr_core_2(Z _A[], const int L)
{
    if (L == 1)
	    return;
    uint *A = reinterpret_cast<uint *>(_A);
#define butterfly1(a, b)\
    do\
    {\
	    uint _a = a, _b = b;\
	    uint x = norm_2(_a), t = norm_2(_b);\
	    a = x + t, b = x + Mod * 2 - t;\
    } while (0)
    if (L == 2)
    {
	    butterfly1(A[0], A[1]);
	    A[0] = norm_2(A[0]);
	    A[1] = norm_2(A[1]);
	    return;
    }
#define butterfly(a, b, _w)\
    do\
    {\
	    uint _a = a, _b = b;\
	    uint x = norm_2(_a), t = mult_Shoup_2(_b, _w);\
	    a = x + t, b = x + Mod * 2 - t;\
    } while (0)
    if (L == 4)
    {
	    butterfly1(A[0], A[1]);
	    butterfly1(A[2], A[3]);
	    butterfly1(A[0], A[2]);
	    butterfly(A[1], A[3], rw[3]);
	    A[0] = norm_2(A[0]);
	    A[1] = norm_2(A[1]);
	    A[2] = norm_2(A[2]);
	    A[3] = norm_2(A[3]);
	    return;
    }
    for (int i = 0; i != L; i += 8)
    {
	    butterfly1(A[i + 0], A[i + 1]);
	    butterfly1(A[i + 2], A[i + 3]);
	    butterfly1(A[i + 4], A[i + 5]);
	    butterfly1(A[i + 6], A[i + 7]);
 
	    butterfly1(A[i + 0], A[i + 2]);
	    butterfly(A[i + 1], A[i + 3], rw[3]);
	    butterfly1(A[i + 4], A[i + 6]);
	    butterfly(A[i + 5], A[i + 7], rw[3]);
 
	    butterfly1(A[i + 0], A[i + 4]);
	    butterfly(A[i + 1], A[i + 5], rw[5]);
	    butterfly(A[i + 2], A[i + 6], rw[6]);
	    butterfly(A[i + 3], A[i + 7], rw[7]);
    }
    for (int d = 8; d != L; d <<= 1)
	    for (int i = 0; i != L; i += d << 1)
		    for (int j = 0; j != d; j += 4)
		    {
			    butterfly(A[i + j + 0], A[i + d + j + 0], rw[d + j + 0]);
			    butterfly(A[i + j + 1], A[i + d + j + 1], rw[d + j + 1]);
			    butterfly(A[i + j + 2], A[i + d + j + 2], rw[d + j + 2]);
			    butterfly(A[i + j + 3], A[i + d + j + 3], rw[d + j + 3]);
		    }
#undef butterfly1
#undef butterfly
    for (int i = 0; i != L; ++i)
	    A[i] = norm_2(A[i]);
}
 
inline void copy_n_a(const Z A[], const int n, Z B[])
{
    std::memmove(B, A, n * sizeof(Z));
}
 
inline void reset(Z A[], const int n)
{
    std::memset(reinterpret_cast<uint *>(A), 0, n * sizeof(Z));
}
 
const int Brute_size_l = 4;
const int Brute_size = 1 << Brute_size_l;
 
inline void mult_2(Z A[], const Z B[], const Z C[], const int N)
{
    int s = N & ~3;
    for (int i = 0; i != s; i += 4)
    {
	    A[i + 0] = mult_2(B[i + 0], C[i + 0]);
	    A[i + 1] = mult_2(B[i + 1], C[i + 1]);
	    A[i + 2] = mult_2(B[i + 2], C[i + 2]);
	    A[i + 3] = mult_2(B[i + 3], C[i + 3]);
    }
    for (int i = s; i != N; ++i)
	    A[i] = mult_2(B[i], C[i]);
}
 
inline void mult(Z A[], const Z B[], const Z C[], const int N)
{
    int s = N & ~3;
    for (int i = 0; i != s; i += 4)
    {
	    A[i + 0] = B[i + 0] * C[i + 0];
	    A[i + 1] = B[i + 1] * C[i + 1];
	    A[i + 2] = B[i + 2] * C[i + 2];
	    A[i + 3] = B[i + 3] * C[i + 3];
    }
    for (int i = s; i != N; ++i)
	    A[i] = B[i] * C[i];
}
 
const uint ME_M_size = Max_size * 6;
Z ME_M[ME_M_size], *ME_M_pos = ME_M + ME_M_size;
 
int ME_Pre(const Z x[], const int N, Z _x[], int id[], Z tmp[], const int LM)
{
    Z *p = ME_M_pos -= (N + 3) & ~3;
    int cnt = 0, rcnt = N, l = __builtin_ctz(LM);
    copy_n_a(x, N, p);
    if (l)
    {
	    for (int j = 0; j != l - 1; ++j)
		    mult_2(p, p, p, N);
	    mult(p, p, p, N);
    }
    for (int i = 0; i != N; ++i)
    {
	    if (p[i].v != 1)
		    _x[cnt] = x[i], id[cnt] = i, tmp[cnt++] = p[i];
	    else
	    {
		    id[--rcnt] = i;
		    Z _p[20];
		    int j = 0;
		    for (_p[0] = x[i]; _p[j].v != 1; ++j)
			    _p[j + 1] = _p[j] * _p[j];
		    int k = 0;
		    while (j)
		    {
			    k >>= 1;
			    if (_p[--j].v != w[(LM >> 1) + k].v)
				    k += LM >> 1;
		    }
		    tmp[rcnt].v = k;
	    }
    }
    ME_M_pos += (N + 3) & ~3;
    return cnt;
}
 
inline void ST_DFT_fr_2(const Z A[], const int N, Z dfta[], const int L)
{
    dfta[0] = 1, copy_n_a(A + 1, N - 1, dfta + 1), reset(dfta + N, L - N);
    DFT_fr_2(dfta, L);
}
 
inline void ST_DFT_fr_2_extra(const Z A[], Z dfta[], const int L)
{
    dfta[0] = 1, copy_n_a(A + 1, L - 1, dfta + 1);
    dfta[0] += A[L];
    DFT_fr_2(dfta, L);
}
 
inline void ST_DFT_fr_2_tbd(const Z A[], const int N, Z dfta[], const int L)
{
    if (N != L + 1)
	    ST_DFT_fr_2(A, N, dfta, L);
    else
	    ST_DFT_fr_2_extra(A, dfta, L);
}
 
inline void ST_extend_DFT_fr_2(const Z A[], const int N, Z dfta[], const int L)
{
    dfta[L] = 1;
    for (int j = 1; j != N; ++j)
	    dfta[L + j] = mult_Shoup_2(A[j].v, w[L + j]);
    reset(dfta + L + N, L - N);
    DFT_fr_2(dfta + L, L);
}
 
inline void ST_extend_DFT_fr_2_extra(const Z A[], Z dfta[], const int L)
{
    dfta[L] = 1;
    for (int j = 1; j != L; ++j)
	    dfta[L + j] = mult_Shoup_2(A[j].v, w[L + j]);
    dfta[L] -= A[L];
    DFT_fr_2(dfta + L, L);
}
 
inline void ST_extend_DFT_fr_2_tbd(const Z A[], const int N, Z dfta[], const int L)
{
    if (N != L + 1)
	    ST_extend_DFT_fr_2(A, N, dfta, L);
    else
	    ST_extend_DFT_fr_2_extra(A, dfta, L);
}
 
inline void ST_extend_2_DFT_fr_2(const Z A[], const int N, Z dfta[], const int L)
{
    dfta[L] = 1;
    dfta[L + L] = 1;
    dfta[L + L + L] = 1;
    for (int j = 1; j != N; ++j)
    {
	    dfta[L + j] = mult_Shoup_2(A[j].v, w[L + j]);
	    dfta[L + L + j] = mult_Shoup_2(A[j].v, w[L + L + j]);
	    dfta[L + L + L + j] = mult_Shoup_2(dfta[L + L + j].v, w[L + j]);
    }
    reset(dfta + L + N, L - N);
    reset(dfta + L + L + N, L - N);
    reset(dfta + L + L + L + N, L - N);
    DFT_fr_2(dfta + L, L);
    DFT_fr_2(dfta + L + L, L);
    DFT_fr_2(dfta + L + L + L, L);
}
 
inline void ST_extend_2_DFT_fr_2_extra(const Z A[], Z dfta[], const int L)
{
    for (int j = 1; j != L; ++j)
    {
	    dfta[L + j] = mult_Shoup_2(A[j].v, w[L + j]);
	    dfta[L + L + j] = mult_Shoup_2(A[j].v, w[L + L + j]);
	    dfta[L + L + L + j] = mult_Shoup_2(dfta[L + L + j].v, w[L + j]);
    }
    int tmp = mult_Shoup(A[L].v, w[L + L + L]);
    dfta[L] = 1 + Mod - A[L].v;
    DFT_fr_2(dfta + L, L);
    dfta[L + L] = 1 + tmp;
    DFT_fr_2(dfta + L + L, L);
    dfta[L + L + L] = 1 + Mod - tmp;
    DFT_fr_2(dfta + L + L + L, L);
}
 
inline void ST_extend_2_DFT_fr_2_tbd(const Z A[], const int N, Z dfta[], const int L)
{
    if (N != L + 1)
	    ST_extend_2_DFT_fr_2(A, N, dfta, L);
    else
	    ST_extend_2_DFT_fr_2_extra(A, dfta, L);
}
 
template<const size_t _size0, const size_t _size1>
void Subproduct_Tree_Rev_(const Z x[], const int N, Z T0[], Z T[], Z dftt[][_size0], Z dfttx[][_size1], const int L, const int LM)
{
    if (std::min(L, LM) <= Brute_size)
	    return;
    int _l = __builtin_ctz(std::min(L, LM)) & 1;
    while (_l + 2 <= Brute_size_l)
	    _l += 2;
    for (int i = 0; i < N; i += 1 << _l)
    {
	    auto _T = T + i;
	    {
		    const Z v0 = Mod - x[i].v, v1 = Mod - x[i + 1].v, v2 = Mod - x[i + 2].v, v3 = Mod - x[i + 3].v;
		    const uint64 v01 = mult_2(v0.v, v1.v), v23 = mult_2(v2.v, v3.v);
		    _T[4] = v01 * v23 % Mod;
		    _T[3] = (v01 * (v2.v + v3.v) + v23 * (v0.v + v1.v)) % Mod;
		    _T[2] = (v01 + v23 + static_cast<uint64>(v0.v + v1.v) * (v2.v + v3.v)) % Mod;
		    _T[1] = (v0.v + v1.v + v2.v + v3.v) % Mod;
	    }
	    for (int j = 4; j != (1 << _l) && j < N - i; j += 4)
	    {
		    const Z v0 = Mod - x[i + j].v, v1 = Mod - x[i + j + 1].v, v2 = Mod - x[i + j + 2].v, v3 = Mod - x[i + j + 3].v;
		    const uint64 v01 = mult_2(v0.v, v1.v), v23 = mult_2(v2.v, v3.v);
		    const uint64 a4 = v01 * v23 % Mod;
		    const uint64 a3 = (v01 * (v2.v + v3.v) + v23 * (v0.v + v1.v)) % Mod;
		    const uint64 a2 = (v01 + v23 + static_cast<uint64>(v0.v + v1.v) * (v2.v + v3.v)) % Mod;
		    const uint64 a1 = (v0.v + v1.v + v2.v + v3.v) % Mod;
		    _T[j + 4] = (_T[j].v * a4) % Mod;
		    _T[j + 3] = (_T[j].v * a3 + _T[j - 1].v * a4) % Mod;
		    _T[j + 2] = (_T[j].v * a2 + _T[j - 1].v * a3 + _T[j - 2].v * a4) % Mod;
		    _T[j + 1] = (_T[j].v * a1 + _T[j - 1].v * a2 + _T[j - 2].v * a3 + _T[j - 3].v * a4) % Mod;
		    for (int k = j; k > 4; --k)
			    _T[k] = (_T[k].v + _T[k - 1].v * a1 + _T[k - 2].v * a2 + _T[k - 3].v * a3 + _T[k - 4].v * a4) % Mod;
		    _T[4] = (_T[4].v + _T[3].v * a1 + _T[2].v * a2 + _T[1].v * a3 + a4) % Mod;
		    _T[3] = (_T[3].v + _T[2].v * a1 + _T[1].v * a2 + a3) % Mod;
		    _T[2] = (_T[2].v + _T[1].v * a1 + a2) % Mod;
		    _T[1] = a1 + _T[1];
	    }
    }
    copy_n_a(T, N + 1, T0);
    for (int l = 0, d; d = 1 << (_l + l * 2), d != std::min(L, LM); ++l)
    {
	    const int s = N & ~(d * 4 - 1);
	    Z p = 1;
	    for (int i = 0; i != s; i += d * 4)
	    {
		    auto _T0 = T + i, _T1 = _T0 + d, _T2 = _T1 + d, _T3 = _T2 + d;
		    auto _dftt0 = dftt[l] + i * 4, _dftt1 = _dftt0 + d * 4, _dftt2 = _dftt1 + d * 4, _dftt3 = _dftt2 + d * 4;
		    auto _dftta = dfttx[l] + i * 2, _dfttb = _dftta + d * 4;
		    if (d <= Brute_size)
		    {
			    ST_DFT_fr_2_extra(_T0, _dftt0, d);
			    ST_DFT_fr_2_extra(_T1, _dftt1, d);
			    ST_DFT_fr_2_extra(_T2, _dftt2, d);
			    ST_DFT_fr_2_extra(_T3, _dftt3, d);
		    }
		    ST_extend_2_DFT_fr_2_extra(_T0, _dftt0, d);
		    ST_extend_2_DFT_fr_2_extra(_T1, _dftt1, d);
		    ST_extend_2_DFT_fr_2_extra(_T2, _dftt2, d);
		    ST_extend_2_DFT_fr_2_extra(_T3, _dftt3, d);
		    mult_2(_dftta, _dftt0, _dftt1, d * 4);
		    mult_2(_dfttb, _dftt2, _dftt3, d * 4);
		    mult_2(_T0, _dftta, _dfttb, d * 4);
		    copy_n_a(_T0, d * 4, dftt[l + 1] + i * 4);
		    IDFT_fr(_T0, d * 4);
		    std::swap(_T0[0] -= 1, p);
	    }
	    if (N != s)
	    {
		    auto _T0 = T + s, _T1 = _T0 + d, _T2 = _T1 + d, _T3 = _T2 + d;
		    auto _dftt0 = dftt[l] + s * 4, _dftt1 = _dftt0 + d * 4, _dftt2 = _dftt1 + d * 4, _dftt3 = _dftt2 + d * 4;
		    auto _dftta = dfttx[l] + s * 2, _dfttb = _dftta + d * 4;
		    if (N - s <= d)
		    {
			    const int t = N - s;
			    if (d <= Brute_size)
				    ST_DFT_fr_2_tbd(_T0, t + 1, _dftt0, d);
			    ST_extend_2_DFT_fr_2_tbd(_T0, t + 1, _dftt0, d);
			    copy_n_a(_dftt0, d * 4, dftt[l + 1] + s * 4);
		    }
		    else
		    {
			    if (N - s <= d * 2)
			    {
				    const int t = N - s - d;
				    if (d <= Brute_size)
				    {
					    ST_DFT_fr_2_extra(_T0, _dftt0, d);
					    ST_DFT_fr_2_tbd(_T1, t + 1, _dftt1, d);
				    }
				    ST_extend_DFT_fr_2_extra(_T0, _dftt0, d);
				    ST_extend_DFT_fr_2_tbd(_T1, t + 1, _dftt1, d);
				    mult_2(_T0, _dftt0, _dftt1, d * 2);
				    copy_n_a(_T0, d * 2, dftt[l + 1] + s * 4);
				    IDFT_fr(_T0, d * 2);
				    if (t == d)
					    _T0[d * 2] = _T0[0] - 1, _T0[0] = 1;
				    ST_extend_DFT_fr_2_tbd(_T0, d + t + 1, dftt[l + 1] + s * 4, d * 2);
			    }
			    else
			    {
				    if (d <= Brute_size)
				    {
					    ST_DFT_fr_2_extra(_T0, _dftt0, d);
					    ST_DFT_fr_2_extra(_T1, _dftt1, d);
				    }
				    ST_extend_2_DFT_fr_2_extra(_T0, _dftt0, d);
				    ST_extend_2_DFT_fr_2_extra(_T1, _dftt1, d);
				    mult_2(_dftta, _dftt0, _dftt1, d * 4);
				    if (N - s <= d * 3) // can be faster
				    {
					    const int t = N - s - d * 2;
					    if (d <= Brute_size)
						    ST_DFT_fr_2_tbd(_T2, t + 1, _dftt2, d);
					    ST_extend_2_DFT_fr_2_tbd(_T2, t + 1, _dftt2, d);
					    copy_n_a(_dftt2, d * 4, _dfttb);
				    }
				    else
				    {
					    const int t = N - s - d * 3;
					    if (d <= Brute_size)
					    {
						    ST_DFT_fr_2_extra(_T2, _dftt2, d);
						    ST_DFT_fr_2(_T3, t + 1, _dftt3, d);
					    }
					    ST_extend_2_DFT_fr_2_extra(_T2, _dftt2, d);
					    ST_extend_2_DFT_fr_2(_T3, t + 1, _dftt3, d);
					    mult_2(_dfttb, _dftt2, _dftt3, d * 4);
				    }
				    mult_2(_T0, _dftta, _dfttb, d * 4);
				    copy_n_a(_T0, d * 4, dftt[l + 1] + s * 4);
				    IDFT_fr(_T0, d * 4);
			    }
		    }
	    }
	    T[s] = p;
    }
}
 
void IDFT_fr_core_b4_2(Z _A[], const int L)
{
    uint *A = reinterpret_cast<uint *>(_A);
#define butterfly1(a, b)\
    do\
    {\
	    uint _a = a, _b = b;\
	    uint x = norm_2(_a), t = norm_2(_b);\
	    a = x + t, b = x + Mod * 2 - t;\
    } while (0)
#define semi_butterfly1(a, b)\
    do\
    {\
	    uint _a = a, _b = b;\
	    uint x = norm_2(_a), t = norm_2(_b);\
	    b = x + Mod * 2 - t;\
    } while (0)
//	auto butterfly = [](uint &a, uint &b, const Shoup _w)
//	{
//		uint x = norm_2(a), t = mult_Shoup_2(b, _w);
//		a = x + t, b = x + Mod * 2 - t;
//	};
#define butterfly(a, b, _w)\
    do\
    {\
	    uint _a = a, _b = b;\
	    uint x = norm_2(_a), t = mult_Shoup_2(_b, _w);\
	    a = x + t, b = x + Mod * 2 - t;\
    } while (0)
#define semi_butterfly(a, b, _w)\
    do\
    {\
	    uint _a = a, _b = b;\
	    uint x = norm_2(_a), t = mult_Shoup_2(_b, _w);\
	    b = x + Mod * 2 - t;\
    } while (0)
    if (L == 4)
    {
	    semi_butterfly1(A[0], A[1]);
	    semi_butterfly1(A[2], A[3]);
	    semi_butterfly(A[1], A[3], rw[3]);
	    A[3] = norm_2(A[3]);
	    return;
    }
    if (L == 8)
    {
	    butterfly1(A[0], A[1]);
	    butterfly1(A[2], A[3]);
	    butterfly1(A[4], A[5]);
	    butterfly1(A[6], A[7]);
	    semi_butterfly1(A[0], A[2]);
	    semi_butterfly(A[1], A[3], rw[3]);
	    semi_butterfly1(A[4], A[6]);
	    semi_butterfly(A[5], A[7], rw[3]);
	    semi_butterfly(A[2], A[6], rw[6]);
	    semi_butterfly(A[3], A[7], rw[7]);
	    A[6] = norm_2(A[6]);
	    A[7] = norm_2(A[7]);
	    return;
    }
    if (L == 16)
    {
	    butterfly1(A[0], A[1]);
	    butterfly1(A[2], A[3]);
	    butterfly1(A[4], A[5]);
	    butterfly1(A[6], A[7]);
	    butterfly1(A[8], A[9]);
	    butterfly1(A[10], A[11]);
	    butterfly1(A[12], A[13]);
	    butterfly1(A[14], A[15]);
	    butterfly1(A[0], A[2]);
	    butterfly(A[1], A[3], rw[3]);
	    butterfly1(A[4], A[6]);
	    butterfly(A[5], A[7], rw[3]);
	    butterfly1(A[8], A[10]);
	    butterfly(A[9], A[11], rw[3]);
	    butterfly1(A[12], A[14]);
	    butterfly(A[13], A[15], rw[3]);
	    semi_butterfly1(A[0], A[4]);
	    semi_butterfly(A[1], A[5], rw[5]);
	    semi_butterfly(A[2], A[6], rw[6]);
	    semi_butterfly(A[3], A[7], rw[7]);
	    semi_butterfly1(A[8], A[12]);
	    semi_butterfly(A[9], A[13], rw[5]);
	    semi_butterfly(A[10], A[14], rw[6]);
	    semi_butterfly(A[11], A[15], rw[7]);
	    semi_butterfly(A[4], A[12], rw[12]);
	    semi_butterfly(A[5], A[13], rw[13]);
	    semi_butterfly(A[6], A[14], rw[14]);
	    semi_butterfly(A[7], A[15], rw[15]);
	    A[12] = norm_2(A[12]);
	    A[13] = norm_2(A[13]);
	    A[14] = norm_2(A[14]);
	    A[15] = norm_2(A[15]);
	    return;
    }
    for (int i = 0; i != L; i += 8)
    {
	    butterfly1(A[i + 0], A[i + 1]);
	    butterfly1(A[i + 2], A[i + 3]);
	    butterfly1(A[i + 4], A[i + 5]);
	    butterfly1(A[i + 6], A[i + 7]);
 
	    butterfly1(A[i + 0], A[i + 2]);
	    butterfly(A[i + 1], A[i + 3], rw[3]);
	    butterfly1(A[i + 4], A[i + 6]);
	    butterfly(A[i + 5], A[i + 7], rw[3]);
 
	    butterfly1(A[i + 0], A[i + 4]);
	    butterfly(A[i + 1], A[i + 5], rw[5]);
	    butterfly(A[i + 2], A[i + 6], rw[6]);
	    butterfly(A[i + 3], A[i + 7], rw[7]);
    }
    for (int d = 8; d != L >> 2; d <<= 1)
	    for (int i = 0; i != L; i += d << 1)
		    for (int j = 0; j != d; j += 4)
		    {
			    butterfly(A[i + j + 0], A[i + d + j + 0], rw[d + j + 0]);
			    butterfly(A[i + j + 1], A[i + d + j + 1], rw[d + j + 1]);
			    butterfly(A[i + j + 2], A[i + d + j + 2], rw[d + j + 2]);
			    butterfly(A[i + j + 3], A[i + d + j + 3], rw[d + j + 3]);
		    }
    int d = L >> 2;
    for (int j = 0; j != d; j += 4)
    {
	    semi_butterfly(A[j + 0], A[d + j + 0], rw[d + j + 0]);
	    semi_butterfly(A[j + 1], A[d + j + 1], rw[d + j + 1]);
	    semi_butterfly(A[j + 2], A[d + j + 2], rw[d + j + 2]);
	    semi_butterfly(A[j + 3], A[d + j + 3], rw[d + j + 3]);
    }
    for (int j = 0; j != d; j += 4)
    {
	    semi_butterfly(A[(L >> 1) + j + 0], A[(L >> 1) + d + j + 0], rw[d + j + 0]);
	    semi_butterfly(A[(L >> 1) + j + 1], A[(L >> 1) + d + j + 1], rw[d + j + 1]);
	    semi_butterfly(A[(L >> 1) + j + 2], A[(L >> 1) + d + j + 2], rw[d + j + 2]);
	    semi_butterfly(A[(L >> 1) + j + 3], A[(L >> 1) + d + j + 3], rw[d + j + 3]);
    }
    d = L >> 1;
    for (int j = L >> 2; j != d; j += 4)
    {
	    semi_butterfly(A[j + 0], A[d + j + 0], rw[d + j + 0]);
	    semi_butterfly(A[j + 1], A[d + j + 1], rw[d + j + 1]);
	    semi_butterfly(A[j + 2], A[d + j + 2], rw[d + j + 2]);
	    semi_butterfly(A[j + 3], A[d + j + 3], rw[d + j + 3]);
    }
#undef butterfly1
#undef semi_butterfly1
#undef butterfly
#undef semi_butterfly
    for (int i = (L >> 4) * 3; i != L; ++i)
	    A[i] = norm_2(A[i]);
}
 
template<const size_t _size0, const size_t _size1>
void Multipoint_Evaluation_core(const Z x[], const Z P0[], const Z P[], const Z dftt[][_size0], const Z dfttx[][_size1], const Z p[], const int N, const Z F[], const Z dftf[], const int M, Z res[], int L, int LM)
{
    if (std::min(L, LM) <= Brute_size)
    {
	    for (int j = 0; j != N; ++j)
	    {
		    Z y = 0;
		    for (int k = M - 1; k >= 0; --k)
			    y = mult_2(y, x[j]).v + F[k].v;
		    res[j].v = norm(norm_2(y.v));
	    }
	    return;
    }
    int _l = __builtin_ctz(std::min(L, LM)) & 1;
    while (_l + 2 <= Brute_size_l)
	    _l += 2;
    int l = (__builtin_ctz(std::min(L, LM)) - _l) >> 1;
    Z *C = ME_M_pos -= std::max(L * 4, LM);
    Z *tmp = ME_M_pos -= std::max(L * 2, LM);
    for (int i = 0; i < N; i += LM)
    {
	    auto _P = P + i;
	    auto _dftt = dftt[l] + i * 4;
	    auto _C = C + i;
	    copy_n_a(_dftt, std::min(L, LM), _C);
	    for (int d = L; d < LM; d <<= 1)
		    ST_extend_DFT_fr_2_tbd(_P, N + 1, _C, d);
    }
    {
	    const int s = (N + LM - 1) & ~(LM - 1);
	    tmp[0] = 1;
	    for (int j = 1; j != s; ++j)
		    tmp[j] = mult_2(tmp[j - 1], C[j - 1]);
	    Z r = Rec(tmp[s - 1] * C[s - 1]);
	    for (int j = s - 1; j >= 0; --j)
		    tmp[j] = mult_2(tmp[j], r), r = mult_2(r, C[j]);
    }
    for (int i = 0; i < N; i += LM)
    {
	    auto _dftt = dftt[l] + i * 4;
	    auto _C = C + i * 4;
	    auto _tmp = tmp + i;
	    mult_2(_C, dftf, _tmp, LM);
	    IDFT_fr(_C, LM);
	    const int t = std::min(N - i, LM);
	    if (t <= M)
		    copy_n_a(_C + M - t, t, _C + std::min(L, LM) * 4 - t);
	    else
		    copy_n_a(_C + LM - (t - M), t - M, _C + std::min(L, LM) * 4 - t), copy_n_a(_C, M, _C + std::min(L, LM) * 4 - M);
	    reset(_C + std::min(L, LM) * 3, std::min(L, LM) - t);
    }
    --l;
    uint idft_c = 1;
    for (int d; d = 1 << (_l + l * 2), l >= 0; --l)
    {
	    const int s = N & ~(d * 4 - 1);
	    for (int i = 0; i != s; i += d * 4)
	    {
		    auto _C0 = C + i * 4, _C1 = _C0 + d * 4, _C2 = _C1 + d * 4, _C3 = _C2 + d * 4;
		    auto _dftt0 = dftt[l] + i * 4, _dftt1 = _dftt0 + d * 4, _dftt2 = _dftt1 + d * 4, _dftt3 = _dftt2 + d * 4;
		    auto _dftta = dfttx[l] + i * 2, _dfttb = _dftta + d * 4;
		    DFT_fr_2(_C3, d * 4);
		    mult_2(_C1, _C3, _dfttb, d * 4);
		    mult_2(_C0, _C1, _dftt1, d * 4);
		    mult_2(_C1, _C1, _dftt0, d * 4);
		    mult_2(_C3, _C3, _dftta, d * 4);
		    mult_2(_C2, _C3, _dftt3, d * 4);
		    mult_2(_C3, _C3, _dftt2, d * 4);
		    IDFT_fr_core_b4_2(_C0, d * 4), IDFT_fr_core_b4_2(_C1, d * 4), IDFT_fr_core_b4_2(_C2, d * 4), IDFT_fr_core_b4_2(_C3, d * 4);
	    }
	    if (N != s)
	    {
		    auto _C0 = C + s * 4, _C1 = _C0 + d * 4, _C2 = _C1 + d * 4, _C3 = _C2 + d * 4;
		    auto _dftt0 = dftt[l] + s * 4, _dftt1 = _dftt0 + d * 4, _dftt2 = _dftt1 + d * 4, _dftt3 = _dftt2 + d * 4;
		    auto _dftta = dfttx[l] + s * 2, _dfttb = _dftta + d * 4;
		    if (N - s <= d)
		    {
			    copy_n_a(_C3 + d * 3, d, _C0 + d * 3);
			    for (int i = 0; i != d; ++i)
				    _C0[d * 3 + i] *= d * 4;
		    }
		    else
		    {
			    if (N - s <= d * 2)
			    {
				    const int t = N - s - d;
				    DFT_fr_2(_C3 + d * 2, d * 2);
				    mult_2(_C0 + d * 2, _C3 + d * 2, _dftt1, d * 2);
				    mult_2(_C1 + d * 2, _C3 + d * 2, _dftt0, d * 2);
				    IDFT_fr_core_2(_C0 + d * 2, d * 2), IDFT_fr_core_2(_C1 + d * 2, d * 2);
				    for (int i = 0; i != d * 3; ++i)
					    _C0[d * 3 + i] = norm_2(_C0[d * 3 + i].v * 2);
				    for (int i = 0; i != t; ++i)
					    _C1[d * 4 - t + i] = norm_2(_C1[d * 4 - t + i].v * 2);
				    reset(_C1 + d * 3, d - t);
			    }
			    else
			    {
				    DFT_fr_2(_C3, d * 4);
				    mult_2(_C1, _C3, _dfttb, d * 4);
				    mult_2(_C0, _C1, _dftt1, d * 4);
				    mult_2(_C1, _C1, _dftt0, d * 4);
				    if (N - s <= d * 3) // can be faster
				    {
					    const int t = N - s - d * 2;
					    mult_2(_C2, _C3, _dftta, d * 4);
					    IDFT_fr_core_b4_2(_C0, d * 4), IDFT_fr_core_b4_2(_C1, d * 4), IDFT_fr_core_b4_2(_C2, d * 4);
					    reset(_C2 + d * 3, d - t);
				    }
				    else
				    {
					    const int t = N - s - d * 3;
					    mult_2(_C3, _C3, _dftta, d * 4);
					    mult_2(_C2, _C3, _dftt3, d * 4);
					    mult_2(_C3, _C3, _dftt2, d * 4);
					    IDFT_fr_core_b4_2(_C0, d * 4), IDFT_fr_core_b4_2(_C1, d * 4), IDFT_fr_core_b4_2(_C2, d * 4), IDFT_fr_core_b4_2(_C3, d * 4);
					    reset(_C3 + d * 3, d - t);
				    }
			    }
		    }
	    }
	    int k = __builtin_ctz(d * 4);
	    uint64 m = -idft_c & (d * 4 - 1);
	    idft_c = norm((idft_c + m * Mod) >> k);
    }
    if (_l)
    {
	    const int d = 1 << _l;
	    for (int i = 0; i < N; i += d)
	    {
		    const int t = std::min(d, N - i);
		    const int o = -t & 3;
		    auto _C = C + i * 4 + d * 4 - t;
		    auto _F = C + i * 4 + d * 2;
		    auto _P0 = P0 + i;
		    for (int j = 0; j != t; ++j)
			    _C[j].v = norm(_C[j].v);
		    for (int j = 0; j != o; ++j)
			    _F[j] = 0;
		    for (int j = 0; j != t; ++j) // Brute_size <= 16
		    {
			    uint64 c = _C[j].v;
//				for (int k = 1; k <= j; ++k)
//					c += static_cast<uint64>(_C[j - k].v) * _P0[k].v;
			    int s = j & ~3;
#define calc(k)\
			    (c += static_cast<uint64>(_C[j - (k)].v) * _P0[k].v + static_cast<uint64>(_C[j - (k) - 1].v) * _P0[(k) + 1].v + static_cast<uint64>(_C[j - (k) - 2].v) * _P0[(k) + 2].v + static_cast<uint64>(_C[j - (k) - 3].v) * _P0[(k) + 3].v)
			    if (0 < s)
				    if (calc(1), 4 < s)
					    if (calc(5), 8 < s)
						    if (calc(9), 12 < s)
							    calc(13);
#undef calc
			    for (int k = s + 1; k <= j; ++k)
				    c += static_cast<uint64>(_C[j - k].v) * _P0[k].v;
			    _F[o + j] = c % Mod;
		    }
		    for (int j = 0; j != t; ++j)
		    {
			    const uint _x = x[i + j].v, _x2 = mult_2(_x, _x), _x3 = mult_2(_x2, _x), _x4 = mult_2(_x3, _x);
			    Z y = 0;
//				for (int k = 0; k < t; k += 4)
//					y = (y.v * static_cast<uint64>(_x4) + _F[k].v * static_cast<uint64>(_x3) + _F[k + 1].v * static_cast<uint64>(_x2) + _F[k + 2].v * static_cast<uint64>(_x) + _F[k + 3].v) % Mod;
#define calc(k)\
			    (y = (y.v * static_cast<uint64>(_x4) + _F[k].v * static_cast<uint64>(_x3) + _F[(k) + 1].v * static_cast<uint64>(_x2) + _F[(k) + 2].v * static_cast<uint64>(_x) + _F[(k) + 3].v) % Mod)
			    if (calc(0), 4 < t)
				    if (calc(4), 8 < t)
					    if (calc(8), 12 < t)
						    calc(12);
#undef calc
			    res[i + j] = y * idft_c * (Mod + 1 - p[i + j].v);
		    }
	    }
    }
    else
	    for (int j = 0; j != N; ++j)
		    res[j] = C[j * 4 + 3] * idft_c * (Mod + 1 - p[j].v);
    ME_M_pos += std::max(L * 4, LM) + std::max(L * 2, LM);
}
 
int revbin(int x, const int LM)
{
    int y = 0;
    for (int i = __builtin_ctz(LM) - 1; i >= 0; --i)
	    y |= (x & 1) << i, x >>= 1;
    return y;
}
 
void Multipoint_Evaluation(const Z x[], const int N, const Z F[], const int M, Z res[], int LM)
{
    static Z _x[Max_size];
    static int id[Max_size];
    static Z tmp[Max_size];
    static Z dftf[Max_size];
    std::reverse_copy(F, F + M, dftf), reset(dftf + M, LM - M);
    DFT_fr(dftf, LM);
    int cnt = ME_Pre(x, N, _x, id, tmp, LM);
    for (int i = cnt; i != N; ++i)
    {
	    int j = (tmp[i].v * (M - 1)) & (LM - 1);
	    uint y = dftf[revbin((LM - tmp[i].v) & (LM - 1), LM)].v;
	    tmp[i].v = j < (LM >> 1) ? mult_Shoup(y, w[(LM >> 1) + j]) : mult_Shoup(Mod - y, w[j]);
    }
    int L;
    for (L = 1; L < cnt; L <<= 1)
	    ;
    static Z P0[Max_size];
    static Z P[Max_size];
    static Z dftt[10][Max_size * 4];
    static Z dfttx[10][Max_size * 2];
    _x[cnt] = _x[cnt + 1] = _x[cnt + 2] = 0;
    Subproduct_Tree_Rev_(_x, cnt, P0, P, dftt, dfttx, L, LM);
    Multipoint_Evaluation_core(_x, P0, P, dftt, dfttx, tmp, cnt, F, dftf, M, tmp, L, LM);
    for (int i = 0; i != N; ++i)
	    res[id[i]] = tmp[i];
}
typedef long long ll;
ll mod=998244353;
int N, Q, LM;
Z q0, qx, qy;
Z A[Max_size], q[Max_size];
ll fac[200000],infac[200000];
ll exp(ll a,ll b) 
{
    a%=mod;
    ll res=1;
    while(b>0) 
    {
        if(b&1)
        res=res*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return res;
}
ll mod_in(ll a)
{
    return exp(a,mod-2);
}
void pre()
{
    fac[0]=1;
    infac[0]=1;
    for(int i=1;i<=100000;i++)
    {
        fac[i]=(fac[i-1]*i)%mod;
        infac[i]=mod_in(fac[i]);
    }
}
const ll root=15311432;
const ll root_1=469870224;
const ll root_pw=1<<23;
void fft(std::vector<ll> & a, bool invert)
{
    ll n = a.size();
    for (ll i = 1, j = 0; i < n; i++) 
    {
        ll bit = n >> 1;
        for (; j & bit; bit >>= 1)
            j ^= bit;
        j ^= bit;
 
        if (i < j)
            std::swap(a[i], a[j]);
    }
 
    for (ll len = 2; len <= n; len <<= 1) 
    {
        ll wlen = invert ? root_1 : root;
        for (ll i = len; i < root_pw; i <<= 1)
            wlen=(1LL * wlen * wlen % mod);
 
        for (ll i = 0; i < n; i += len)
         {
            ll w = 1;
            for (ll j = 0; j < len / 2; j++) {
                ll u = a[i+j], v = (a[i+j+len/2] * w % mod);
                a[i+j] = u + v < mod ? u + v : u + v - mod;
                a[i+j+len/2] = u - v >= 0 ? u - v : u - v + mod;
                w = w * wlen % mod;
            }
        }
    }
 
    if (invert) {
        ll n_1 = mod_in(n);
        for (ll & x : a)
            x =  x * n_1 % mod;
    }
}
std::vector<ll> multi(std::vector<ll> const& a,std::vector<ll> const& b,ll degree) 
{
    std::vector<ll> fa(a.begin(), a.end()), fb(b.begin(), b.end());
    ll n = 1;
    while (n < a.size() + b.size())
    n <<= 1;
    fa.resize(n);
    fb.resize(n);
    fft(fa, false);
    fft(fb, false);
    for (ll i = 0; i < n; i++)
    {fa[i] *= fb[i];fa[i]%=mod;}
    fft(fa, true);
    fa.resize(degree+1);
    return fa;
}
std::vector<ll> bangbang(ll l,ll r)
{
    if(r<l)
    return {1};
    if(l==r)
    return {l,1};
    ll mid=(l+r)/2;
    return multi(bangbang(l,mid),bangbang(mid+1,r),r-l+1);
}
int main()
{
    pre();
    ll r;
    std::cin>>r>>Q;
    ll p[Q],ans[Q];
    for(int i=0;i<Q;i++)
    {
    std::cin>>q[i].v;
    q[i].v=(q[i].v-r+mod)%mod;
    }
    for(int j=0;j<Q;j++)
    std::cin>>p[j];
    for(int i=0;i<Q;i++)
    ans[i]=(infac[r]*p[i])%mod;
    std::vector<ll> beg=bangbang(1,r);
    N=r+1;
    for(int i=0;i<N;i++)
    A[i].v=beg[i];
    for (LM = 1; LM < N; LM <<= 1);
    init_w(LM);
    Multipoint_Evaluation(q,Q, A, N,q,LM);
    ll o=1;
    for(int i=0;i<Q;i++)
    {
        o=(o*((((q[i].v)*ans[i])%mod)+(1-p[i]+mod))%mod)%mod;
    }
    std::cout<<o<<std::endl;
}
Tester's Solution
/* in the name of Anton */
 
/*
  Compete against Yourself.
  Author - Aryan (@aryanc403)
  Atcoder library - https://atcoder.github.io/ac-library/production/document_en/
*/
 
#ifdef ARYANC403
    #include <header.h>
#else
    #pragma GCC optimize ("Ofast")
    #pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx")
    //#pragma GCC optimize ("-ffloat-store")
    #include<bits/stdc++.h>
    #define dbg(args...) 42;
#endif
 
using namespace std;
#define fo(i,n)   for(i=0;i<(n);++i)
#define repA(i,j,n)   for(i=(j);i<=(n);++i)
#define repD(i,j,n)   for(i=(j);i>=(n);--i)
#define all(x) begin(x), end(x)
#define sz(x) ((lli)(x).size())
#define pb push_back
#define mp make_pair
#define X first
#define Y second
#define endl "\n"
 
typedef long long int lli;
typedef long double mytype;
typedef pair<lli,lli> ii;
typedef vector<ii> vii;
typedef vector<lli> vi;
 
const auto start_time = std::chrono::high_resolution_clock::now();
void aryanc403()
{
#ifdef ARYANC403
auto end_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end_time-start_time;
    cerr<<"Time Taken : "<<diff.count()<<"\n";
#endif
}
 
long long readInt(long long l, long long r, char endd) {
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true) {
        char g=getchar();
        if(g=='-') {
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g&&g<='9') {
            x*=10;
            x+=g-'0';
            if(cnt==0) {
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd) {
            if(is_neg) {
                x=-x;
            }
            assert(l<=x&&x<=r);
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l, int r, char endd) {
    string ret="";
    int cnt=0;
    while(true) {
        char g=getchar();
        assert(g!=-1);
        if(g==endd) {
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt&&cnt<=r);
    return ret;
}
long long readIntSp(long long l, long long r) {
    return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
    return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
    return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
    return readString(l,r,' ');
}
 
void readEOF(){
    assert(getchar()==EOF);
}
 
vi readVectorInt(int n,lli l,lli r){
    vi a(n);
    for(int i=0;i<n-1;++i)
        a[i]=readIntSp(l,r);
    a[n-1]=readIntLn(l,r);
    return a;
}
 
const lli INF = 0xFFFFFFFFFFFFFFFL;
 
lli seed;
mt19937 rng(seed=chrono::steady_clock::now().time_since_epoch().count());
inline lli rnd(lli l=0,lli r=INF)
{return uniform_int_distribution<lli>(l,r)(rng);}
 
class CMP
{public:
bool operator()(ii a , ii b) //For min priority_queue .
{    return ! ( a.X < b.X || ( a.X==b.X && a.Y <= b.Y ));   }};
 
void add( map<lli,lli> &m, lli x,lli cnt=1)
{
    auto jt=m.find(x);
    if(jt==m.end())         m.insert({x,cnt});
    else                    jt->Y+=cnt;
}
 
void del( map<lli,lli> &m, lli x,lli cnt=1)
{
    auto jt=m.find(x);
    if(jt->Y<=cnt)            m.erase(jt);
    else                      jt->Y-=cnt;
}
 
bool cmp(const ii &a,const ii &b)
{
    return a.X<b.X||(a.X==b.X&&a.Y<b.Y);
}
 
 
 
#include <algorithm>
#include <array>
#include <cassert>
#include <type_traits>
#include <vector>
 
 
#ifdef _MSC_VER
#include <intrin.h>
#endif
 
namespace atcoder {
 
namespace internal {
 
int ceil_pow2(int n) {
    int x = 0;
    while ((1U << x) < (unsigned int)(n)) x++;
    return x;
}
 
int bsf(unsigned int n) {
#ifdef _MSC_VER
    unsigned long index;
    _BitScanForward(&index, n);
    return index;
#else
    return __builtin_ctz(n);
#endif
}
 
}  // namespace internal
 
}  // namespace atcoder
 
 
#include <cassert>
#include <numeric>
#include <type_traits>
 
#ifdef _MSC_VER
#include <intrin.h>
#endif
 
 
#include <utility>
 
#ifdef _MSC_VER
#include <intrin.h>
#endif
 
namespace atcoder {
 
namespace internal {
 
constexpr long long safe_mod(long long x, long long m) {
    x %= m;
    if (x < 0) x += m;
    return x;
}
 
struct barrett {
    unsigned int _m;
    unsigned long long im;
 
    barrett(unsigned int m) : _m(m), im((unsigned long long)(-1) / m + 1) {}
 
    unsigned int umod() const { return _m; }
 
    unsigned int mul(unsigned int a, unsigned int b) const {
 
        unsigned long long z = a;
        z *= b;
#ifdef _MSC_VER
        unsigned long long x;
        _umul128(z, im, &x);
#else
        unsigned long long x =
            (unsigned long long)(((unsigned __int128)(z)*im) >> 64);
#endif
        unsigned int v = (unsigned int)(z - x * _m);
        if (_m <= v) v += _m;
        return v;
    }
};
 
constexpr long long pow_mod_constexpr(long long x, long long n, int m) {
    if (m == 1) return 0;
    unsigned int _m = (unsigned int)(m);
    unsigned long long r = 1;
    unsigned long long y = safe_mod(x, m);
    while (n) {
        if (n & 1) r = (r * y) % _m;
        y = (y * y) % _m;
        n >>= 1;
    }
    return r;
}
 
constexpr bool is_prime_constexpr(int n) {
    if (n <= 1) return false;
    if (n == 2 || n == 7 || n == 61) return true;
    if (n % 2 == 0) return false;
    long long d = n - 1;
    while (d % 2 == 0) d /= 2;
    constexpr long long bases[3] = {2, 7, 61};
    for (long long a : bases) {
        long long t = d;
        long long y = pow_mod_constexpr(a, t, n);
        while (t != n - 1 && y != 1 && y != n - 1) {
            y = y * y % n;
            t <<= 1;
        }
        if (y != n - 1 && t % 2 == 0) {
            return false;
        }
    }
    return true;
}
template <int n> constexpr bool is_prime = is_prime_constexpr(n);
 
constexpr std::pair<long long, long long> inv_gcd(long long a, long long b) {
    a = safe_mod(a, b);
    if (a == 0) return {b, 0};
 
    long long s = b, t = a;
    long long m0 = 0, m1 = 1;
 
    while (t) {
        long long u = s / t;
        s -= t * u;
        m0 -= m1 * u;  // |m1 * u| <= |m1| * s <= b
 
 
        auto tmp = s;
        s = t;
        t = tmp;
        tmp = m0;
        m0 = m1;
        m1 = tmp;
    }
    if (m0 < 0) m0 += b / s;
    return {s, m0};
}
 
constexpr int primitive_root_constexpr(int m) {
    if (m == 2) return 1;
    if (m == 167772161) return 3;
    if (m == 469762049) return 3;
    if (m == 754974721) return 11;
    if (m == 998244353) return 3;
    int divs[20] = {};
    divs[0] = 2;
    int cnt = 1;
    int x = (m - 1) / 2;
    while (x % 2 == 0) x /= 2;
    for (int i = 3; (long long)(i)*i <= x; i += 2) {
        if (x % i == 0) {
            divs[cnt++] = i;
            while (x % i == 0) {
                x /= i;
            }
        }
    }
    if (x > 1) {
        divs[cnt++] = x;
    }
    for (int g = 2;; g++) {
        bool ok = true;
        for (int i = 0; i < cnt; i++) {
            if (pow_mod_constexpr(g, (m - 1) / divs[i], m) == 1) {
                ok = false;
                break;
            }
        }
        if (ok) return g;
    }
}
template <int m> constexpr int primitive_root = primitive_root_constexpr(m);
 
}  // namespace internal
 
}  // namespace atcoder
 
 
#include <cassert>
#include <numeric>
#include <type_traits>
 
namespace atcoder {
 
namespace internal {
 
#ifndef _MSC_VER
template <class T>
using is_signed_int128 =
    typename std::conditional<std::is_same<T, __int128_t>::value ||
                                  std::is_same<T, __int128>::value,
                              std::true_type,
                              std::false_type>::type;
 
template <class T>
using is_unsigned_int128 =
    typename std::conditional<std::is_same<T, __uint128_t>::value ||
                                  std::is_same<T, unsigned __int128>::value,
                              std::true_type,
                              std::false_type>::type;
 
template <class T>
using make_unsigned_int128 =
    typename std::conditional<std::is_same<T, __int128_t>::value,
                              __uint128_t,
                              unsigned __int128>;
 
template <class T>
using is_integral = typename std::conditional<std::is_integral<T>::value ||
                                                  is_signed_int128<T>::value ||
                                                  is_unsigned_int128<T>::value,
                                              std::true_type,
                                              std::false_type>::type;
 
template <class T>
using is_signed_int = typename std::conditional<(is_integral<T>::value &&
                                                 std::is_signed<T>::value) ||
                                                    is_signed_int128<T>::value,
                                                std::true_type,
                                                std::false_type>::type;
 
template <class T>
using is_unsigned_int =
    typename std::conditional<(is_integral<T>::value &&
                               std::is_unsigned<T>::value) ||
                                  is_unsigned_int128<T>::value,
                              std::true_type,
                              std::false_type>::type;
 
template <class T>
using to_unsigned = typename std::conditional<
    is_signed_int128<T>::value,
    make_unsigned_int128<T>,
    typename std::conditional<std::is_signed<T>::value,
                              std::make_unsigned<T>,
                              std::common_type<T>>::type>::type;
 
#else
 
template <class T> using is_integral = typename std::is_integral<T>;
 
template <class T>
using is_signed_int =
    typename std::conditional<is_integral<T>::value && std::is_signed<T>::value,
                              std::true_type,
                              std::false_type>::type;
 
template <class T>
using is_unsigned_int =
    typename std::conditional<is_integral<T>::value &&
                                  std::is_unsigned<T>::value,
                              std::true_type,
                              std::false_type>::type;
 
template <class T>
using to_unsigned = typename std::conditional<is_signed_int<T>::value,
                                              std::make_unsigned<T>,
                                              std::common_type<T>>::type;
 
#endif
 
template <class T>
using is_signed_int_t = std::enable_if_t<is_signed_int<T>::value>;
 
template <class T>
using is_unsigned_int_t = std::enable_if_t<is_unsigned_int<T>::value>;
 
template <class T> using to_unsigned_t = typename to_unsigned<T>::type;
 
}  // namespace internal
 
}  // namespace atcoder
 
 
namespace atcoder {
 
namespace internal {
 
struct modint_base {};
struct static_modint_base : modint_base {};
 
template <class T> using is_modint = std::is_base_of<modint_base, T>;
template <class T> using is_modint_t = std::enable_if_t<is_modint<T>::value>;
 
}  // namespace internal
 
template <int m, std::enable_if_t<(1 <= m)>* = nullptr>
struct static_modint : internal::static_modint_base {
    using mint = static_modint;
 
  public:
    static constexpr int mod() { return m; }
    static mint raw(int v) {
        mint x;
        x._v = v;
        return x;
    }
 
    static_modint() : _v(0) {}
    template <class T, internal::is_signed_int_t<T>* = nullptr>
    static_modint(T v) {
        long long x = (long long)(v % (long long)(umod()));
        if (x < 0) x += umod();
        _v = (unsigned int)(x);
    }
    template <class T, internal::is_unsigned_int_t<T>* = nullptr>
    static_modint(T v) {
        _v = (unsigned int)(v % umod());
    }
 
    unsigned int val() const { return _v; }
 
    mint& operator++() {
        _v++;
        if (_v == umod()) _v = 0;
        return *this;
    }
    mint& operator--() {
        if (_v == 0) _v = umod();
        _v--;
        return *this;
    }
    mint operator++(int) {
        mint result = *this;
        ++*this;
        return result;
    }
    mint operator--(int) {
        mint result = *this;
        --*this;
        return result;
    }
 
    mint& operator+=(const mint& rhs) {
        _v += rhs._v;
        if (_v >= umod()) _v -= umod();
        return *this;
    }
    mint& operator-=(const mint& rhs) {
        _v -= rhs._v;
        if (_v >= umod()) _v += umod();
        return *this;
    }
    mint& operator*=(const mint& rhs) {
        unsigned long long z = _v;
        z *= rhs._v;
        _v = (unsigned int)(z % umod());
        return *this;
    }
    mint& operator/=(const mint& rhs) { return *this = *this * rhs.inv(); }
 
    mint operator+() const { return *this; }
    mint operator-() const { return mint() - *this; }
 
    mint pow(long long n) const {
        assert(0 <= n);
        mint x = *this, r = 1;
        while (n) {
            if (n & 1) r *= x;
            x *= x;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        if (prime) {
            assert(_v);
            return pow(umod() - 2);
        } else {
            auto eg = internal::inv_gcd(_v, m);
            assert(eg.first == 1);
            return eg.second;
        }
    }
 
    friend mint operator+(const mint& lhs, const mint& rhs) {
        return mint(lhs) += rhs;
    }
    friend mint operator-(const mint& lhs, const mint& rhs) {
        return mint(lhs) -= rhs;
    }
    friend mint operator*(const mint& lhs, const mint& rhs) {
        return mint(lhs) *= rhs;
    }
    friend mint operator/(const mint& lhs, const mint& rhs) {
        return mint(lhs) /= rhs;
    }
    friend bool operator==(const mint& lhs, const mint& rhs) {
        return lhs._v == rhs._v;
    }
    friend bool operator!=(const mint& lhs, const mint& rhs) {
        return lhs._v != rhs._v;
    }
 
  private:
    unsigned int _v;
    static constexpr unsigned int umod() { return m; }
    static constexpr bool prime = internal::is_prime<m>;
};
 
template <int id> struct dynamic_modint : internal::modint_base {
    using mint = dynamic_modint;
 
  public:
    static int mod() { return (int)(bt.umod()); }
    static void set_mod(int m) {
        assert(1 <= m);
        bt = internal::barrett(m);
    }
    static mint raw(int v) {
        mint x;
        x._v = v;
        return x;
    }
 
    dynamic_modint() : _v(0) {}
    template <class T, internal::is_signed_int_t<T>* = nullptr>
    dynamic_modint(T v) {
        long long x = (long long)(v % (long long)(mod()));
        if (x < 0) x += mod();
        _v = (unsigned int)(x);
    }
    template <class T, internal::is_unsigned_int_t<T>* = nullptr>
    dynamic_modint(T v) {
        _v = (unsigned int)(v % mod());
    }
 
    unsigned int val() const { return _v; }
 
    mint& operator++() {
        _v++;
        if (_v == umod()) _v = 0;
        return *this;
    }
    mint& operator--() {
        if (_v == 0) _v = umod();
        _v--;
        return *this;
    }
    mint operator++(int) {
        mint result = *this;
        ++*this;
        return result;
    }
    mint operator--(int) {
        mint result = *this;
        --*this;
        return result;
    }
 
    mint& operator+=(const mint& rhs) {
        _v += rhs._v;
        if (_v >= umod()) _v -= umod();
        return *this;
    }
    mint& operator-=(const mint& rhs) {
        _v += mod() - rhs._v;
        if (_v >= umod()) _v -= umod();
        return *this;
    }
    mint& operator*=(const mint& rhs) {
        _v = bt.mul(_v, rhs._v);
        return *this;
    }
    mint& operator/=(const mint& rhs) { return *this = *this * rhs.inv(); }
 
    mint operator+() const { return *this; }
    mint operator-() const { return mint() - *this; }
 
    mint pow(long long n) const {
        assert(0 <= n);
        mint x = *this, r = 1;
        while (n) {
            if (n & 1) r *= x;
            x *= x;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        auto eg = internal::inv_gcd(_v, mod());
        assert(eg.first == 1);
        return eg.second;
    }
 
    friend mint operator+(const mint& lhs, const mint& rhs) {
        return mint(lhs) += rhs;
    }
    friend mint operator-(const mint& lhs, const mint& rhs) {
        return mint(lhs) -= rhs;
    }
    friend mint operator*(const mint& lhs, const mint& rhs) {
        return mint(lhs) *= rhs;
    }
    friend mint operator/(const mint& lhs, const mint& rhs) {
        return mint(lhs) /= rhs;
    }
    friend bool operator==(const mint& lhs, const mint& rhs) {
        return lhs._v == rhs._v;
    }
    friend bool operator!=(const mint& lhs, const mint& rhs) {
        return lhs._v != rhs._v;
    }
 
  private:
    unsigned int _v;
    static internal::barrett bt;
    static unsigned int umod() { return bt.umod(); }
};
template <int id> internal::barrett dynamic_modint<id>::bt = 998244353;
 
using modint998244353 = static_modint<998244353>;
using modint1000000007 = static_modint<1000000007>;
using modint = dynamic_modint<-1>;
 
namespace internal {
 
template <class T>
using is_static_modint = std::is_base_of<internal::static_modint_base, T>;
 
template <class T>
using is_static_modint_t = std::enable_if_t<is_static_modint<T>::value>;
 
template <class> struct is_dynamic_modint : public std::false_type {};
template <int id>
struct is_dynamic_modint<dynamic_modint<id>> : public std::true_type {};
 
template <class T>
using is_dynamic_modint_t = std::enable_if_t<is_dynamic_modint<T>::value>;
 
}  // namespace internal
 
}  // namespace atcoder
 
 
namespace atcoder {
 
namespace internal {
 
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
void butterfly(std::vector<mint>& a) {
    static constexpr int g = internal::primitive_root<mint::mod()>;
    int n = int(a.size());
    int h = internal::ceil_pow2(n);
 
    static bool first = true;
    static mint sum_e[30];  // sum_e[i] = ies[0] * ... * ies[i - 1] * es[i]
    if (first) {
        first = false;
        mint es[30], ies[30];  // es[i]^(2^(2+i)) == 1
        int cnt2 = bsf(mint::mod() - 1);
        mint e = mint(g).pow((mint::mod() - 1) >> cnt2), ie = e.inv();
        for (int i = cnt2; i >= 2; i--) {
            es[i - 2] = e;
            ies[i - 2] = ie;
            e *= e;
            ie *= ie;
        }
        mint now = 1;
        for (int i = 0; i <= cnt2 - 2; i++) {
            sum_e[i] = es[i] * now;
            now *= ies[i];
        }
    }
    for (int ph = 1; ph <= h; ph++) {
        int w = 1 << (ph - 1), p = 1 << (h - ph);
        mint now = 1;
        for (int s = 0; s < w; s++) {
            int offset = s << (h - ph + 1);
            for (int i = 0; i < p; i++) {
                auto l = a[i + offset];
                auto r = a[i + offset + p] * now;
                a[i + offset] = l + r;
                a[i + offset + p] = l - r;
            }
            now *= sum_e[bsf(~(unsigned int)(s))];
        }
    }
}
 
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
void butterfly_inv(std::vector<mint>& a) {
    static constexpr int g = internal::primitive_root<mint::mod()>;
    int n = int(a.size());
    int h = internal::ceil_pow2(n);
 
    static bool first = true;
    static mint sum_ie[30];  // sum_ie[i] = es[0] * ... * es[i - 1] * ies[i]
    if (first) {
        first = false;
        mint es[30], ies[30];  // es[i]^(2^(2+i)) == 1
        int cnt2 = bsf(mint::mod() - 1);
        mint e = mint(g).pow((mint::mod() - 1) >> cnt2), ie = e.inv();
        for (int i = cnt2; i >= 2; i--) {
            es[i - 2] = e;
            ies[i - 2] = ie;
            e *= e;
            ie *= ie;
        }
        mint now = 1;
        for (int i = 0; i <= cnt2 - 2; i++) {
            sum_ie[i] = ies[i] * now;
            now *= es[i];
        }
    }
 
    for (int ph = h; ph >= 1; ph--) {
        int w = 1 << (ph - 1), p = 1 << (h - ph);
        mint inow = 1;
        for (int s = 0; s < w; s++) {
            int offset = s << (h - ph + 1);
            for (int i = 0; i < p; i++) {
                auto l = a[i + offset];
                auto r = a[i + offset + p];
                a[i + offset] = l + r;
                a[i + offset + p] =
                    (unsigned long long)(mint::mod() + l.val() - r.val()) *
                    inow.val();
            }
            inow *= sum_ie[bsf(~(unsigned int)(s))];
        }
    }
}
 
}  // namespace internal
 
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
std::vector<mint> convolution(std::vector<mint> a, std::vector<mint> b) {
    int n = int(a.size()), m = int(b.size());
    if (!n || !m) return {};
    if (std::min(n, m) <= 60) {
        if (n < m) {
            std::swap(n, m);
            std::swap(a, b);
        }
        std::vector<mint> ans(n + m - 1);
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                ans[i + j] += a[i] * b[j];
            }
        }
        return ans;
    }
    int z = 1 << internal::ceil_pow2(n + m - 1);
    a.resize(z);
    internal::butterfly(a);
    b.resize(z);
    internal::butterfly(b);
    for (int i = 0; i < z; i++) {
        a[i] *= b[i];
    }
    internal::butterfly_inv(a);
    a.resize(n + m - 1);
    mint iz = mint(z).inv();
    for (int i = 0; i < n + m - 1; i++) a[i] *= iz;
    return a;
}
 
template <unsigned int mod = 998244353,
          class T,
          std::enable_if_t<internal::is_integral<T>::value>* = nullptr>
std::vector<T> convolution(const std::vector<T>& a, const std::vector<T>& b) {
    int n = int(a.size()), m = int(b.size());
    if (!n || !m) return {};
 
    using mint = static_modint<mod>;
    std::vector<mint> a2(n), b2(m);
    for (int i = 0; i < n; i++) {
        a2[i] = mint(a[i]);
    }
    for (int i = 0; i < m; i++) {
        b2[i] = mint(b[i]);
    }
    auto c2 = convolution(move(a2), move(b2));
    std::vector<T> c(n + m - 1);
    for (int i = 0; i < n + m - 1; i++) {
        c[i] = c2[i].val();
    }
    return c;
}
 
std::vector<long long> convolution_ll(const std::vector<long long>& a,
                                      const std::vector<long long>& b) {
    int n = int(a.size()), m = int(b.size());
    if (!n || !m) return {};
 
    static constexpr unsigned long long MOD1 = 754974721;  // 2^24
    static constexpr unsigned long long MOD2 = 167772161;  // 2^25
    static constexpr unsigned long long MOD3 = 469762049;  // 2^26
    static constexpr unsigned long long M2M3 = MOD2 * MOD3;
    static constexpr unsigned long long M1M3 = MOD1 * MOD3;
    static constexpr unsigned long long M1M2 = MOD1 * MOD2;
    static constexpr unsigned long long M1M2M3 = MOD1 * MOD2 * MOD3;
 
    static constexpr unsigned long long i1 =
        internal::inv_gcd(MOD2 * MOD3, MOD1).second;
    static constexpr unsigned long long i2 =
        internal::inv_gcd(MOD1 * MOD3, MOD2).second;
    static constexpr unsigned long long i3 =
        internal::inv_gcd(MOD1 * MOD2, MOD3).second;
 
    auto c1 = convolution<MOD1>(a, b);
    auto c2 = convolution<MOD2>(a, b);
    auto c3 = convolution<MOD3>(a, b);
 
    std::vector<long long> c(n + m - 1);
    for (int i = 0; i < n + m - 1; i++) {
        unsigned long long x = 0;
        x += (c1[i] * i1) % MOD1 * M2M3;
        x += (c2[i] * i2) % MOD2 * M1M3;
        x += (c3[i] * i3) % MOD3 * M1M2;
        long long diff =
            c1[i] - internal::safe_mod((long long)(x), (long long)(MOD1));
        if (diff < 0) diff += MOD1;
        static constexpr unsigned long long offset[5] = {
            0, 0, M1M2M3, 2 * M1M2M3, 3 * M1M2M3};
        x -= offset[diff % 5];
        c[i] = x;
    }
 
    return c;
}
 
}  // namespace atcoder
 
 
using namespace atcoder;
using mint = modint998244353;
std::ostream& operator << (std::ostream& out, const mint& rhs) {
        return out<<rhs.val();
    }
 
// https://atcoder.jp/contests/arc113/submissions/20423265
namespace algebra {
    template <typename T>
vector<T>& operator+=(vector<T>& a, const vector<T>& b) {
  if (a.size() < b.size()) {
    a.resize(b.size());
  }
  for (int i = 0; i < (int) b.size(); i++) {
    a[i] += b[i];
  }
  return a;
}
 
template <typename T>
vector<T> operator+(const vector<T>& a, const vector<T>& b) {
  vector<T> c = a;
  return c += b;
}
 
template <typename T>
vector<T>& operator-=(vector<T>& a, const vector<T>& b) {
  if (a.size() < b.size()) {
    a.resize(b.size());
  }
  for (int i = 0; i < (int) b.size(); i++) {
    a[i] -= b[i];
  }
  return a;
}
 
template <typename T>
vector<T> operator-(const vector<T>& a, const vector<T>& b) {
  vector<T> c = a;
  return c -= b;
}
 
template <typename T>
vector<T> operator-(const vector<T>& a) {
  vector<T> c = a;
  for (int i = 0; i < (int) c.size(); i++) {
    c[i] = -c[i];
  }
  return c;
}
 
template <typename T>
vector<T> operator*(const vector<T>& a, const vector<T>& b) {
  if (a.empty() || b.empty()) {
    return {};
  }
  return convolution<T>(a,b);
  // vector<T> c(a.size() + b.size() - 1, 0);
  // for (int i = 0; i < (int) a.size(); i++) {
  //   for (int j = 0; j < (int) b.size(); j++) {
  //     c[i + j] += a[i] * b[j];
  //   }
  // }
  // return c;
}
 
template <typename T>
vector<T>& operator*=(vector<T>& a, const vector<T>& b) {
  return a = a * b;
}
 
template <typename T>
vector<T> inverse(const vector<T>& a) {
  assert(!a.empty());
  int n = (int) a.size();
  vector<T> b = {1 / a[0]};
  while ((int) b.size() < n) {
    vector<T> a_cut(a.begin(), a.begin() + min(a.size(), b.size() << 1));
    vector<T> x = b * b * a_cut;
    b.resize(b.size() << 1);
    for (int i = (int) b.size() >> 1; i < (int) min(x.size(), b.size()); i++) {
      b[i] = -x[i];
    }
  }
  b.resize(n);
  return b;
}
 
template <typename T>
vector<T>& operator/=(vector<T>& a, const vector<T>& b) {
  int n = (int) a.size();
  int m = (int) b.size();
  if (n < m) {
    a.clear();
  } else {
    vector<T> d = b;
    reverse(a.begin(), a.end());
    reverse(d.begin(), d.end());
    d.resize(n - m + 1);
    a *= inverse(d);
    a.erase(a.begin() + n - m + 1, a.end());
    reverse(a.begin(), a.end());
  }
  return a;
}
 
template <typename T>
vector<T> operator/(const vector<T>& a, const vector<T>& b) {
  vector<T> c = a;
  return c /= b;
}
 
template <typename T>
vector<T> operator*(const vector<T>& a, T b) {
  vector<T> c = a;
  for(auto &x:c)
    x*=b;
  return c;
}
 
template <typename T>
vector<T>& operator%=(vector<T>& a, const vector<T>& b) {
  int n = (int) a.size();
  int m = (int) b.size();
  if (n >= m) {
    vector<T> c = (a / b) * b;
    a.resize(m - 1);
    for (int i = 0; i < m - 1; i++) {
      a[i] -= c[i];
    }
  }
  return a;
}
 
template <typename T>
vector<T> operator%(const vector<T>& a, const vector<T>& b) {
  vector<T> c = a;
  return c %= b;
}
 
template <typename T, typename U>
vector<T> power(const vector<T>& a, const U& b, const vector<T>& c) {
  assert(b >= 0);
  vector<U> binary;
  U bb = b;
  while (bb > 0) {
    binary.push_back(bb & 1);
    bb >>= 1;
  }
  vector<T> res = vector<T>{1} % c;
  for (int j = (int) binary.size() - 1; j >= 0; j--) {
    res = res * res % c;
    if (binary[j] == 1) {
      res = res * a % c;
    }
  }
  return res;
}
 
template <typename T>
vector<T> derivative(const vector<T>& a) {
  vector<T> c = a;
  for (int i = 0; i < (int) c.size(); i++) {
    c[i] *= i;
  }
  if (!c.empty()) {
    c.erase(c.begin());
  }
  return c;
}
 
template <typename T>
vector<T> integrate(const vector<T>& a) {
  vector<T> c = {0};
  for (int i = 0; i < (int) a.size(); i++) {
    c.push_back(a[i]/(i+1));
  }
  return c;
}
 
template <typename T>
vector<T> primitive(const vector<T>& a) {
  vector<T> c = a;
  c.insert(c.begin(), 0);
  for (int i = 1; i < (int) c.size(); i++) {
    c[i] /= i;
  }
  return c;
}
 
template <typename T>
vector<T> logarithm(const vector<T>& a) {
  assert(!a.empty() && a[0] == 1);
  vector<T> u = primitive(derivative(a) * inverse(a));
  u.resize(a.size());
  return u;
}
 
template <typename T>
vector<T> exponent(const vector<T>& a) {
  assert(!a.empty() && a[0] == 0);
  int n = (int) a.size();
  vector<T> b = {1};
  while ((int) b.size() < n) {
    vector<T> x(a.begin(), a.begin() + min(a.size(), b.size() << 1));
    x[0] += 1;
    vector<T> old_b = b;
    b.resize(b.size() << 1);
    x -= logarithm(b);
    x *= old_b;
    for (int i = (int) b.size() >> 1; i < (int) min(x.size(), b.size()); i++) {
      b[i] = x[i];
    }
  }
  b.resize(n);
  return b;
}
 
template <typename T>
vector<T> sqrt(const vector<T>& a) {
  assert(!a.empty() && a[0] == 1);
  int n = (int) a.size();
  vector<T> b = {1};
  while ((int) b.size() < n) {
    vector<T> x(a.begin(), a.begin() + min(a.size(), b.size() << 1));
    b.resize(b.size() << 1);
    x *= inverse(b);
    T inv2 = 1 / static_cast<T>(2);
    for (int i = (int) b.size() >> 1; i < (int) min(x.size(), b.size()); i++) {
      b[i] = x[i] * inv2;
    }
  }
  b.resize(n);
  return b;
}
 
template <typename T>
vector<T> multiply(const vector<vector<T>>& a) {
  if (a.empty()) {
    return {0};
  }
  function<vector<T>(int, int)> mult = [&](int l, int r) {
    if (l == r) {
      return a[l];
    }
    int y = (l + r) >> 1;
    return mult(l, y) * mult(y + 1, r);
  };
  return mult(0, (int) a.size() - 1);
}
 
template <typename T>
T evaluate(const vector<T>& a, const T& x) {
  T res = 0;
  for (int i = (int) a.size() - 1; i >= 0; i--) {
    res = res * x + a[i];
  }
  return res;
}
 
template <typename T>
vector<T> evaluate(const vector<T>& a, const vector<T>& x) {
  if (x.empty()) {
    return {};
  }
  if (a.empty()) {
    return vector<T>(x.size(), 0);
  }
  int n = (int) x.size();
  vector<vector<T>> st((n << 1) - 1);
  function<void(int, int, int)> build = [&](int v, int l, int r) {
    if (l == r) {
      st[v] = vector<T>{-x[l], 1};
    } else {
      int y = (l + r) >> 1;
      int z = v + ((y - l + 1) << 1);
      build(v + 1, l, y);
      build(z, y + 1, r);
      st[v] = st[v + 1] * st[z];
    }
  };
  build(0, 0, n - 1);
  vector<T> res(n);
  function<void(int, int, int, vector<T>)> eval = [&](int v, int l, int r, vector<T> f) {
    f %= st[v];
    if ((int) f.size() < 150) {
      for (int i = l; i <= r; i++) {
        res[i] = evaluate(f, x[i]);
      }
      return;
    }
    if (l == r) {
      res[l] = f[0];
    } else {
      int y = (l + r) >> 1;
      int z = v + ((y - l + 1) << 1);
      eval(v + 1, l, y, f);
      eval(z, y + 1, r, f);
    }
  };
  eval(0, 0, n - 1, a);
  return res;
}
 
template <typename T>
vector<T> interpolate(const vector<T>& x, const vector<T>& y) {
  if (x.empty()) {
    return {};
  }
  assert(x.size() == y.size());
  int n = (int) x.size();
  vector<vector<T>> st((n << 1) - 1);
  function<void(int, int, int)> build = [&](int v, int l, int r) {
    if (l == r) {
      st[v] = vector<T>{-x[l], 1};
    } else {
      int w = (l + r) >> 1;
      int z = v + ((w - l + 1) << 1);
      build(v + 1, l, w);
      build(z, w + 1, r);
      st[v] = st[v + 1] * st[z];
    }
  };
  build(0, 0, n - 1);
  vector<T> m = st[0];
  vector<T> dm = derivative(m);
  vector<T> val(n);
  function<void(int, int, int, vector<T>)> eval = [&](int v, int l, int r, vector<T> f) {
    f %= st[v];
    if ((int) f.size() < 150) {
      for (int i = l; i <= r; i++) {
        val[i] = evaluate(f, x[i]);
      }
      return;
    }
    if (l == r) {
      val[l] = f[0];
    } else {
      int w = (l + r) >> 1;
      int z = v + ((w - l + 1) << 1);
      eval(v + 1, l, w, f);
      eval(z, w + 1, r, f);
    }
  };
  eval(0, 0, n - 1, dm);
  for (int i = 0; i < n; i++) {
    val[i] = y[i] / val[i];
  }
  function<vector<T>(int, int, int)> calc = [&](int v, int l, int r) {
    if (l == r) {
      return vector<T>{val[l]};
    }
    int w = (l + r) >> 1;
    int z = v + ((w - l + 1) << 1);
    return calc(v + 1, l, w) * st[z] + calc(z, w + 1, r) * st[v + 1];
  };
  return calc(0, 0, n - 1);
}
 
// f[i] = 1^i + 2^i + ... + up^i
template <typename T>
vector<T> faulhaber(const T& up, int n) {
  vector<T> ex(n + 1);
  T e = 1;
  for (int i = 0; i <= n; i++) {
    ex[i] = e;
    e /= i + 1;
  }
  vector<T> den = ex;
  den.erase(den.begin());
  for (auto& d : den) {
    d = -d;
  }
  vector<T> num(n);
  T p = 1;
  for (int i = 0; i < n; i++) {
    p *= up + 1;
    num[i] = ex[i + 1] * (1 - p);
  }
  vector<T> res = num * inverse(den);
  res.resize(n);
  T f = 1;
  for (int i = 0; i < n; i++) {
    res[i] *= f;
    f *= i + 1;
  }
  return res;
}
 
// (x + 1) * (x + 2) * ... * (x + n)
// (can be optimized with precomputed inverses)
template <typename T>
vector<T> sequence(int n) {
  if (n == 0) {
    return {1};
  }
  if (n % 2 == 1) {
    return sequence<T>(n - 1) * vector<T>{n, 1};
  }
  vector<T> c = sequence<T>(n / 2);
  vector<T> a = c;
  reverse(a.begin(), a.end());
  T f = 1;
  for (int i = n / 2 - 1; i >= 0; i--) {
    f *= n / 2 - i;
    a[i] *= f;
  }
  vector<T> b(n / 2 + 1);
  b[0] = 1;
  for (int i = 1; i <= n / 2; i++) {
    b[i] = b[i - 1] * (n / 2) / i;
  }
  vector<T> h = a * b;
  h.resize(n / 2 + 1);
  reverse(h.begin(), h.end());
  f = 1;
  for (int i = 1; i <= n / 2; i++) {
    f /= i;
    h[i] *= f;
  }
  vector<T> res = c * h;
  return res;
}
 
template <typename T>
class OnlineProduct {
 public:
  const vector<T> a;
  vector<T> b;
  vector<T> c;
 
  OnlineProduct(const vector<T>& a_) : a(a_) {}
 
  T add(const T& val) {
    int i = (int) b.size();
    b.push_back(val);
    if ((int) c.size() <= i) {
      c.resize(i + 1);
    }
    c[i] += a[0] * b[i];
    int z = 1;
    while ((i & (z - 1)) == z - 1 && (int) a.size() > z) {
      vector<T> a_mul(a.begin() + z, a.begin() + min(z << 1, (int) a.size()));
      vector<T> b_mul(b.end() - z, b.end());
      vector<T> c_mul = a_mul * b_mul;
      if ((int) c.size() <= i + (int) c_mul.size()) {
        c.resize(i + c_mul.size() + 1);
      }
      for (int j = 0; j < (int) c_mul.size(); j++) {
        c[i + 1 + j] += c_mul[j];
      }
      z <<= 1;
    }
    return c[i];
  }
};
};
 
using namespace algebra;
 
typedef vector<mint> polyn;
 
// Ref - https://www.codechef.com/viewsolution/41909444 Line 827 - 850.
const int N=2e6+5;
array<mint,N+1> fac,inv;
 
mint nCr(lli n,lli r)
{
    if(n<0||r<0||r>n)
        return 0;
    return fac[n]*inv[r]*inv[n-r];
}
 
void pre(lli n)
{
    fac[0]=1;
    for(int i=1;i<=n;++i)
        fac[i]=i*fac[i-1];
    inv[n]=fac[n].pow(mint(-2).val());
    for(int i=n;i>0;--i)
        inv[i-1]=i*inv[i];
    assert(inv[0]==mint(1));
}
 
polyn nCr(lli r){
  polyn x(r+1),y(r+1);
  y[r]=1;
  for(int i=0;i<=r;++i)
    x[i]=i;
  return interpolate(x,y);
}
 
int main(void) {
    ios_base::sync_with_stdio(false);cin.tie(NULL);
    // freopen("txt.in", "r", stdin);
    // freopen("txt.out", "w", stdout);
// cout<<std::fixed<<std::setprecision(35);
// T=readIntLn(1,1);
// while(T--)
{
    pre(N);
     lli r,k;
     cin>>r>>k;
     vi a(k),p(k);
     for(auto &x:a)  cin>>x;
     for(auto &x:p)  cin>>x;
    //const lli r=readIntLn(1,100000);
    //const lli k=readIntLn(1,200000);
    //auto a=readVectorInt(k,1,1e9);
    //auto p=readVectorInt(k,1,998244352);
 
    vector<mint> xcord;
    auto polnCr=sequence<mint>(r);
    for(auto &x:polnCr)
      x*=inv[r];
    for(auto x:a)
        xcord.pb(x-r);
 
    auto ycord=evaluate(polnCr,xcord);
 
    mint ans=1;
    for(int i=0;i<k;++i){
        ans*=p[i]*ycord[i]+1-p[i];
    }
    cout<<ans<<endl;
}   aryanc403();
    readEOF();
    return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class MAT{
    //SOLUTION BEGIN
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        int R = ni(), K = ni();
        long[] A = new long[K];
        long[] P = new long[K];
        for(int i = 0; i< K; i++)A[i] = ni()-R;
        for(int i = 0; i< K; i++)P[i] = nl();
        PriorityQueue<long[]> pq = new PriorityQueue<>((long[] l1, long[] l2) -> Long.compare(l1.length, l2.length));
        for(int i = 1; i<= R; i++)pq.add(new long[]{i, 1});
        while(pq.size() > 1)pq.add(mul(pq.poll(), pq.poll()));
        long[] poly = pq.poll();
        
        long[] eval = substitute(poly, A);
        long ans = 1;
        long factR = 1;
        for(int i = 1; i<= R; i++)factR = factR*i%mod;
        long invR = pow(factR, mod-2);
        for(int i = 0; i< K; i++){
            long w = (P[i]*eval[i]%mod*invR%mod + 1-P[i]+mod)%mod;
            ans = ans*w%mod;
        }
        pn(ans);
    }
    long pow(long a, long p){
        long o = 1;
        while(p > 0){
            if(p%2 == 1)o = o*a%mod;
            a = a*a%mod;
            p>>=1;
        }
        return o;
    }
    //uwi
    public static int mod = 998244353;
    public static int G = 3;
    
    public static long[] substitute(long[] p, long[] xs)
    {
	    return descendProductTree(p, buildProductTree(xs));
    }
    
    public static long[][] buildProductTree(long[] xs)
    {
	    int m = Integer.highestOneBit(xs.length)*4;
	    long[][] ms = new long[m][];
	    for(int i = 0;i < xs.length;i++){
		    ms[m/2+i] = new long[]{mod-xs[i], 1};
	    }
	    for(int i = m/2-1;i >= 1;i--){
		    if(ms[2*i] == null){
			    ms[i] = null;
		    }else if(ms[2*i+1] == null){
			    ms[i] = ms[2*i];
		    }else{
			    ms[i] = mul(ms[2*i], ms[2*i+1]);
		    }
	    }
	    return ms;
    }
    
    public static long[] mul(long[] a, long[] b)
    {
	    return Arrays.copyOf(convoluteSimply(a, b, mod, G), a.length+b.length-1);
    }
    
    public static long[] mul(long[] a, long[] b, int lim)
    {
	    return Arrays.copyOf(convoluteSimply(a, b, mod, G), lim);
    }
    
    public static long[] descendProductTree(long[] p, long[][] pt)
    {
	    long[] rets = new long[pt[1].length-1];
	    dfs(p, pt, 1, rets);
	    return rets;
    }
    
    private static void dfs(long[] p, long[][] pt, int cur, long[] rets)
    {
	    if(pt[cur] == null)return;
	    if(cur >= pt.length/2){
		    rets[cur-pt.length/2] = p[0];
	    }else{
		    // F = q1X+r1
		    // F = q2Y+r2
		    
		    if(p.length >= 800){
			    if(pt[2*cur+1] != null){
				    long[][] qr0 = div(p, pt[2*cur]);
				    dfs(qr0[1], pt, cur*2, rets);
				    long[][] qr1 = div(p, pt[2*cur+1]);
				    dfs(qr1[1], pt, cur*2+1, rets);
			    }else if(pt[2*cur] != null){
				    long[] nex = cur == 1 ? div(p, pt[2*cur])[1] : p;
				    dfs(nex, pt, cur*2, rets);
			    }
		    }else{
			    if(pt[2*cur+1] != null){
				    dfs(modnaive(p, pt[2*cur]), pt, cur*2, rets);
				    dfs(modnaive(p, pt[2*cur+1]), pt, cur*2+1, rets);
			    }else if(pt[2*cur] != null){
				    long[] nex = cur == 1 ? modnaive(p, pt[2*cur]) : p;
				    dfs(nex, pt, cur*2, rets);
			    }
		    }
	    }
    }
    
    public static long[][] div(long[] f, long[] g)
    {
	    int n = f.length, m = g.length;
	    if(n < m)return new long[][]{new long[0], Arrays.copyOf(f, n)};
	    long[] rf = reverse(f, n-m+1);
	    long[] rg = reverse(g, n-m+1);
	    long[] rq = mul(rf, inv(rg), n-m+1);
	    long[] q = reverse(rq, n-m+1);
	    long[] r = sub(f, mul(q, g, m-1), m-1);
	    return new long[][]{q, r};
    }
    
    public static long[] reverse(long[] p)
    {
	    long[] ret = new long[p.length];
	    for(int i = 0;i < p.length;i++){
		    ret[i] = p[p.length-1-i];
	    }
	    return ret;
    }
    
    public static long[] reverse(long[] p, int lim)
    {
	    long[] ret = new long[lim];
	    for(int i = 0;i < lim && i < p.length;i++){
		    ret[i] = p[p.length-1-i];
	    }
	    return ret;
    }
    
    public static long[] sub(long[] a, long[] b)
    {
	    long[] c = new long[Math.max(a.length, b.length)];
	    for(int i = 0;i < a.length;i++)c[i] += a[i];
	    for(int i = 0;i < b.length;i++)c[i] -= b[i];
	    for(int i = 0;i < c.length;i++)if(c[i] < 0)c[i] += mod;
	    return c;
    }
    
    public static long[] sub(long[] a, long[] b, int lim)
    {
	    long[] c = new long[lim];
	    for(int i = 0;i < a.length && i < lim;i++)c[i] += a[i];
	    for(int i = 0;i < b.length && i < lim;i++)c[i] -= b[i];
	    for(int i = 0;i < c.length;i++)if(c[i] < 0)c[i] += mod;
	    return c;
    }
    
    // F_{t+1}(x) = -F_t(x)^2*P(x) + 2F_t(x)
    // if want p-destructive, comment out flipping p just before returning.
    public static long[] inv(long[] p)
    {
	    int n = p.length;
	    long[] f = {invl(p[0], mod)};
	    for(int i = 0;i < p.length;i++){
		    if(p[i] == 0)continue;
		    p[i] = mod-p[i];
	    }
	    for(int i = 1;i < 2*n;i*=2){
		    long[] f2 = mul(f, f, Math.min(n, 2*i));
		    long[] f2p = mul(f2, Arrays.copyOf(p, i), Math.min(n, 2*i));
		    for(int j = 0;j < f.length;j++){
			    f2p[j] += 2L*f[j];
			    if(f2p[j] >= mod)f2p[j] -= mod;
			    if(f2p[j] >= mod)f2p[j] -= mod;
		    }
		    f = f2p;
	    }
	    for(int i = 0;i < p.length;i++){
		    if(p[i] == 0)continue;
		    p[i] = mod-p[i];
	    }
	    return f;
    }



    
    public static long[] modnaive(long[] a, long[] b)
    {
	    int n = a.length, m = b.length;
	    if(n-m+1 <= 0)return a;
	    long[] r = Arrays.copyOf(a, n);
	    long ib = invl(b[m-1], mod);
	    for(int i = n-1;i >= m-1;i--){
		    long x = ib * r[i] % mod;
		    for(int j = m-1;j >= 0;j--){
			    r[i+j-(m-1)] -= b[j]*x;
			    r[i+j-(m-1)] %= mod;
			    if(r[i+j-(m-1)] < 0)r[i+j-(m-1)] += mod;
//				r[i+j-(m-1)] = modh(r[i+j-(m-1)]+(long)mod*mod - b[j]*x, MM, HH, mod);
		    }
	    }
	    return Arrays.copyOf(r, m-1);
    }
    
    public static final int[] NTTPrimes = {1053818881, 1051721729, 1045430273, 1012924417, 1007681537, 1004535809, 998244353, 985661441, 976224257, 975175681};
    public static final int[] NTTPrimitiveRoots = {7, 6, 3, 5, 3, 3, 3, 3, 3, 17};
//	public static final int[] NTTPrimes = {1012924417, 1004535809, 998244353, 985661441, 975175681, 962592769, 950009857, 943718401, 935329793, 924844033};
//	public static final int[] NTTPrimitiveRoots = {5, 3, 3, 3, 17, 7, 7, 7, 3, 5};
    
    public static long[] convoluteSimply(long[] a, long[] b, int P, int g)
    {
	    int m = Math.max(2, Integer.highestOneBit(Math.max(a.length, b.length)-1)<<2);
	    long[] fa = nttmb(a, m, false, P, g);
	    long[] fb = a == b ? fa : nttmb(b, m, false, P, g);
	    for(int i = 0;i < m;i++){
		    fa[i] = fa[i]*fb[i]%P;
	    }
	    return nttmb(fa, m, true, P, g);
    }
    
    public static long[] convolute(long[] a, long[] b)
    {
	    int USE = 2;
	    int m = Math.max(2, Integer.highestOneBit(Math.max(a.length, b.length)-1)<<2);
	    long[][] fs = new long[USE][];
	    for(int k = 0;k < USE;k++){
		    int P = NTTPrimes[k], g = NTTPrimitiveRoots[k];
		    long[] fa = nttmb(a, m, false, P, g);
		    long[] fb = a == b ? fa : nttmb(b, m, false, P, g);
		    for(int i = 0;i < m;i++){
			    fa[i] = fa[i]*fb[i]%P;
		    }
		    fs[k] = nttmb(fa, m, true, P, g);
	    }
	    
	    int[] mods = Arrays.copyOf(NTTPrimes, USE);
	    long[] gammas = garnerPrepare(mods);
	    int[] buf = new int[USE];
	    for(int i = 0;i < fs[0].length;i++){
		    for(int j = 0;j < USE;j++)buf[j] = (int)fs[j][i];
		    long[] res = garnerBatch(buf, mods, gammas);
		    long ret = 0;
		    for(int j = res.length-1;j >= 0;j--)ret = ret * mods[j] + res[j];
		    fs[0][i] = ret;
	    }
	    return fs[0];
    }
    
    public static long[] convolute(long[] a, long[] b, int USE, int mod)
    {
	    int m = Math.max(2, Integer.highestOneBit(Math.max(a.length, b.length)-1)<<2);
	    long[][] fs = new long[USE][];
	    for(int k = 0;k < USE;k++){
		    int P = NTTPrimes[k], g = NTTPrimitiveRoots[k];
		    long[] fa = nttmb(a, m, false, P, g);
		    long[] fb = a == b ? fa : nttmb(b, m, false, P, g);
		    for(int i = 0;i < m;i++){
			    fa[i] = fa[i]*fb[i]%P;
		    }
		    fs[k] = nttmb(fa, m, true, P, g);
	    }
	    
	    int[] mods = Arrays.copyOf(NTTPrimes, USE);
	    long[] gammas = garnerPrepare(mods);
	    int[] buf = new int[USE];
	    for(int i = 0;i < fs[0].length;i++){
		    for(int j = 0;j < USE;j++)buf[j] = (int)fs[j][i];
		    long[] res = garnerBatch(buf, mods, gammas);
		    long ret = 0;
		    for(int j = res.length-1;j >= 0;j--)ret = (ret * mods[j] + res[j]) % mod;
		    fs[0][i] = ret;
	    }
	    return fs[0];
    }
    
    // static int[] wws = new int[270000]; // outer faster
    
    // Modifed Montgomery + Barrett
    private static long[] nttmb(long[] src, int n, boolean inverse, int P, int g)
    {
	    long[] dst = Arrays.copyOf(src, n);
	    
	    int h = Integer.numberOfTrailingZeros(n);
	    long K = Integer.highestOneBit(P)<<1;
	    int H = Long.numberOfTrailingZeros(K)*2;
	    long M = K*K/P;
	    
	    int[] wws = new int[1<<h-1];
	    long dw = inverse ? pow(g, P-1-(P-1)/n, P) : pow(g, (P-1)/n, P);
	    long w = (1L<<32)%P;
	    for(int k = 0;k < 1<<h-1;k++){
		    wws[k] = (int)w;
		    w = modh(w*dw, M, H, P);
	    }
	    long J = invl(P, 1L<<32);
	    for(int i = 0;i < h;i++){
		    for(int j = 0;j < 1<<i;j++){
			    for(int k = 0, s = j<<h-i, t = s|1<<h-i-1;k < 1<<h-i-1;k++,s++,t++){
				    long u = (dst[s] - dst[t] + 2*P)*wws[k];
				    dst[s] += dst[t];
				    if(dst[s] >= 2*P)dst[s] -= 2*P;
//					long Q = (u&(1L<<32)-1)*J&(1L<<32)-1;
				    long Q = (u<<32)*J>>>32;
				    dst[t] = (u>>>32)-(Q*P>>>32)+P;
			    }
		    }
		    if(i < h-1){
			    for(int k = 0;k < 1<<h-i-2;k++)wws[k] = wws[k*2];
		    }
	    }
	    for(int i = 0;i < n;i++){
		    if(dst[i] >= P)dst[i] -= P;
	    }
	    for(int i = 0;i < n;i++){
		    int rev = Integer.reverse(i)>>>-h;
		    if(i < rev){
			    long d = dst[i]; dst[i] = dst[rev]; dst[rev] = d;
		    }
	    }
	    
	    if(inverse){
		    long in = invl(n, P);
		    for(int i = 0;i < n;i++)dst[i] = modh(dst[i]*in, M, H, P);
	    }
	    
	    return dst;
    }
    
    // Modified Shoup + Barrett
    private static long[] nttsb(long[] src, int n, boolean inverse, int P, int g)
    {
	    long[] dst = Arrays.copyOf(src, n);
	    
	    int h = Integer.numberOfTrailingZeros(n);
	    long K = Integer.highestOneBit(P)<<1;
	    int H = Long.numberOfTrailingZeros(K)*2;
	    long M = K*K/P;
	    
	    long dw = inverse ? pow(g, P-1-(P-1)/n, P) : pow(g, (P-1)/n, P);
	    long[] wws = new long[1<<h-1];
	    long[] ws = new long[1<<h-1];
	    long w = 1;
	    for(int k = 0;k < 1<<h-1;k++){
		    wws[k] = (w<<32)/P;
		    ws[k] = w;
		    w = modh(w*dw, M, H, P);
	    }
	    for(int i = 0;i < h;i++){
		    for(int j = 0;j < 1<<i;j++){
			    for(int k = 0, s = j<<h-i, t = s|1<<h-i-1;k < 1<<h-i-1;k++,s++,t++){
				    long ndsts = dst[s] + dst[t];
				    if(ndsts >= 2*P)ndsts -= 2*P;
				    long T = dst[s] - dst[t] + 2*P;
				    long Q = wws[k]*T>>>32;
				    dst[s] = ndsts;
				    dst[t] = ws[k]*T-Q*P&(1L<<32)-1;
			    }
		    }
//			dw = dw * dw % P;
		    if(i < h-1){
			    for(int k = 0;k < 1<<h-i-2;k++){
				    wws[k] = wws[k*2];
				    ws[k] = ws[k*2];
			    }
		    }
	    }
	    for(int i = 0;i < n;i++){
		    if(dst[i] >= P)dst[i] -= P;
	    }
	    for(int i = 0;i < n;i++){
		    int rev = Integer.reverse(i)>>>-h;
		    if(i < rev){
			    long d = dst[i]; dst[i] = dst[rev]; dst[rev] = d;
		    }
	    }
	    
	    if(inverse){
		    long in = invl(n, P);
		    for(int i = 0;i < n;i++){
			    dst[i] = modh(dst[i] * in, M, H, P);
		    }
	    }
	    
	    return dst;
    }
    
    static final long mask = (1L<<31)-1;
    
    public static long modh(long a, long M, int h, int mod)
    {
	    long r = a-((M*(a&mask)>>>31)+M*(a>>>31)>>>h-31)*mod;
	    return r < mod ? r : r-mod;
    }
    
    private static long[] garnerPrepare(int[] m)
    {
	    int n = m.length;
	    assert n == m.length;
	    if(n == 0)return new long[0];
	    long[] gamma = new long[n];
	    for(int k = 1;k < n;k++){
		    long prod = 1;
		    for(int i = 0;i < k;i++){
			    prod = prod * m[i] % m[k];
		    }
		    gamma[k] = invl(prod, m[k]);
	    }
	    return gamma;
    }
    
    private static long[] garnerBatch(int[] u, int[] m, long[] gamma)
    {
	    int n = u.length;
	    assert n == m.length;
	    long[] v = new long[n];
	    v[0] = u[0];
	    for(int k = 1;k < n;k++){
		    long temp = v[k-1];
		    for(int j = k-2;j >= 0;j--){
			    temp = (temp * m[j] + v[j]) % m[k];
		    }
		    v[k] = (u[k] - temp) * gamma[k] % m[k];
		    if(v[k] < 0)v[k] += m[k];
	    }
	    return v;
    }
    
    private static long pow(long a, long n, long mod) {
	    //		a %= mod;
	    long ret = 1;
	    int x = 63 - Long.numberOfLeadingZeros(n);
	    for (; x >= 0; x--) {
		    ret = ret * ret % mod;
		    if (n << 63 - x < 0)
			    ret = ret * a % mod;
	    }
	    return ret;
    }
    
    private static long invl(long a, long mod) {
	    long b = mod;
	    long p = 1, q = 0;
	    while (b > 0) {
		    long c = a / b;
		    long d;
		    d = a;
		    a = b;
		    b = d % b;
		    d = p;
		    p = q;
		    q = d - c * q;
	    }
	    return p < 0 ? p + mod : p;
    }
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = false;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
        new MAT().run();
    }
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
                }
            }
            return st.nextToken();
        }

        String nextLine() throws Exception{
            String str = "";
            try{   
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            }  
            return str;
        }
    }
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

4 Likes

Why is my Pypy3 solution giving WA? CodeChef: Practical coding for everyone
The time complexity is O(10^9) and I was expecting a TLE. But it received a WA verdict.

I simply stored the integers in a list whose factorials Iā€™ll be needing while calculating nCr. Then, performed a O(10^9) loop and filled up the required factorials in a map.

The time limit was probably too lenient. My O(10^9) solution passed in 3.6 seconds.

Hi yat_00,
Could you please explain your solution? I also tried solving this problem this way (noticing the lenient Time Limit), but couldnā€™t succeed - it was giving me Wrong Answer.
Thanks!

I maintained sorted arrays of values A[i] and A[i]-r separately and iterated from 0 to 998244352 to and stored the factorial only if it was required (using pointers on the two above arrays).

I precomputed the values of n! till 5*10^8 and found all ^{A_i}C_r in O(1) each, though I had a hunch that it wasnā€™t the expected solution. It didnā€™t even give MLE.

This is a good editorial, but the language in the original problem is terribly imprecise. Probabilities have to be numbers between 0 and 1, but this problem has probabilities as integers between 1 and M-1, where M is the modulus. The author meant to say that the probability for school i is P_i/M.

Furthermore, the expected value will be a fraction, but the output is supposed to be an int. What the author meant was that the expected value can be written as a fraction N/M for some integer N, and you should output N\pmod M. At least, this is the only way I can see to sensibly interpret things, I am happy to be proved wrong.

3 Likes

You need to learn some modular arithmetic

Thereā€™s no need to be condescending. I know modular arithmetic plenty well, but I still found something about the wording confusing.

Can you help me alleviate some of my confusion; specifically, how can a probability be a number which is greater than one?

Sorry if I sounded harsh, but it is one of the most basic modular arithmetic operations, thatā€™s why it gave me an impression that youā€™re not familiar with modular arithmetic.

Okay, I found the answer in another comment. The probability is a fraction a/b for integers a,b which are coprime to M, and the integer given is a\cdot b^{-1} \pmod M. I understand this now, but to me, this is not clear from the problem statement. It may be obvious from someone who has done lots of coding competitions, but I think problem statements should be precisely worded so that even a newcomer can understand them precisely. The only reason I expect this standard from CodeChef is that all of their previous problem statements I have seen uphold that standard, regularly.

2 Likes


Itā€™s clearly mentioned that probablities are given P_i \%M

Right, but the notation x\pmod m is only standard when x is an integer. I suppose I am in the wrong here for not knowing that x\pmod m means a\cdot b^{-1}\pmod m where x=a/b.

1 Like

I see, and what makes you say that it is standard for integers as I donā€™t find myself learning standard and non standard modulo operation?

What makes me say it is standard for integers? The fact that that is the only way that I learned it in all of the math education I had (up to a masters in combinatorics). Also, the only context in which x\pmod m appears on the wikipedia article for modular arithmetic is when x is an integer.

If you can produce a single non-programming site which refers to x\pmod m where x is a fraction, then I will concede I am wrong. Otherwise, I stipulate that it is purely a programming contest idiom that needs to be explained when questions are written.

For example, the problem TREDEG doesnā€™t simply say ā€œfind the expected value of A \pmod{998,244,353}ā€, they take the care to say ā€œthe expected value of A can be expressed as a fraction P/Q, where P and Q are coprime positive integers and Q is coprime to 998,244,353. You should compute the value of Pā‹…Qāˆ’1 modulo 998,244,353 where Qāˆ’1 denotes the multiplicative inverse of Q modulo 998,244,353.ā€

Sorry, but to me it still rather seem like matter of common sense to infer that we have to take AB^{-1}\%M as obviously P_i is a fraction (from basic math we know probability is less than equal 1 ) and what other meaning could you make out of an operation asking you to find a fraction mod M.