PROBLEM LINK:
Contest Division 1
Contest Division 2
Contest Division 3
Practice
Setter: Vivek mishra
Tester: Aryan
Editorialist: Taranpreet Singh
DIFFICULTY
Medium
PREREQUISITES
Number Theoretic Transform, Probability, and Expected Value.
PROBLEM
There are K schools numbered from 1 to K. For some competitions, each school is invited with probability P_i independently from other schools. If invited, the said school must send exactly r students from A_i students studying in their school.
Find the expected number of ways the teams can be formed. Probabilities are given modulo M = 998244353, and find the answer mod M.
QUICK EXPLANATION
- For each school, is selected with probability P_i has \displaystyle\binom{A_i}{r} ways to select students, and if not selected, only 1 way. So Expected number of ways is \displaystyle P_i* \binom{A_i}{r} + (1-P_i)*1
- Since the events are independent, the expected number of ways is \displaystyle \prod_{i = 1}^K \bigg[ P_i * \binom{A_i}{r} + (1-P_i)\bigg]
- For computing \displaystyle\binom{A_i}{r} for each A_i, we define polynomial \displaystyle Q(x) = \prod_{i = 1}^r (x+i), then \displaystyle\binom{A_i}{r} = \frac{Q(A_i)}{r!}
- So we need to evaluate an r-degree polynomial at K points, which is a well-known problem.
EXPLANATION
Mathematics Section
Solving for K = 1
Thereās only one school, which may be invited with probability P_1, and has A_1 students. There are two possibilities:
- If the school is selected, there are \displaystyle\binom{A_1}{r} ways to select students, which happend with probability P_1.
- If the school is not selected, no students are sent, which happens in exactly 1 way. This happens with probability 1-P_1
Hence, the expected number of ways is given by \displaystyle P_1 * \binom{A_1}{r} + (1-P_1)
Solving for K > 1
Letās say there are E[X_1] ways to select students from first school, and E[X_2] ways to select students from second school. We can see from Fundamental Principle of Counting that we need to compute \displaystyle E \bigg[\prod_{i = 1}^K X_i \bigg]
Since the schools are chosen independently from each other, We know, for independent events X and Y, E[X*Y] = E[X]*E[Y], which implies that
S = \displaystyle E \bigg[\prod_{i = 1}^K X_i \bigg] = \prod_{i = 1}^K E[X_i] = \prod_{i = 1}^K \bigg[P_i*\binom{A_i}{r} + (1-P_i)\bigg]
Hence, if we can compute \displaystyle \binom{A_i}{r} \bmod M quickly, we have solved the problem.
Computing \displaystyle \binom{N}{r} \bmod M
It is worth to have read this blog and this answer, which I believe is one of the best content written on the topic.
Subtask 1, r \leq 100
Since the r is small, we can write \displaystyle \binom{N}{r} = \frac{\prod_{i = 1}^r (N-r+i)}{r!}. The benefit of writing \displaystyle \binom{N}{r} this way is that both numberator and denominator is product of r terms, hence we can compute \displaystyle\binom{N}{r} in O(r) time, making the overall solution O(K*r) which is sufficient for subtask 1, but not for subtask 2.
The final subtask
Now, r \leq 10^5, so we cannot use above solution. One fact we havenāt used is that we need to compute \displaystyle \binom{A_i}{r} for each A_i with a fixed r.
If we try the method in subtask 1, the denominator r! is fixed, calculated easily, and the numerator is the product of r consecutive integers ending at A_i
Letās write a polynomial \displaystyle Q(x) = \prod_{i = 1}^r (x+i), the product of r consecutive integers greater than x.
It is easy to see that the numerator in \displaystyle \binom{N}{r} = \frac{\prod_{i = 1}^r (N-r+i)}{r!} is nothing, but Q(A_i - r), and denominator is r!, so we get \displaystyle \binom{A_i}{r} = \frac{Q(A_i - r)}{r!}
Hence, we have a polynomial Q(x) with degree r, and we need to evaluate it at points A_i - r for each 1 \leq i \leq K
Computing Q(x)
It is worth trying the second subtask of the problem LUCASTH, which uses the same type of polynomial.
In order to compute Q(x), if we multiply (x+i) one by one, it shall lead to complexity O(r^2) which is too much.
But we can use Divide and Conquer here. Initially, there are r polynomials of degree 1. Divide them into pairs and multiply each pair. Now we get \lceil r/2 \rceil polynomials with degree up to 2. Repeating this, we get \lceil r/4 \rceil polynomials with degree up to 4. This process shall be repeated log(r) times. We can prove that each layer takes O(r*log(r)) time, either using the masterās theorem or as follows.
In ith step, \lceil r / 2^i \rceil pairs of polynomials of degree 2^{i-1} are being multiplied, each multiplication taking \lceil r / 2^i \rceil * i*2^i = r*i \sim O(r*log(r)).
Hence, we can compute Q(x) in O(r*log^2(r)) time.
Alternatively, One way commonly used is to maintain a heap of polynomials, where polynomials are ordered by degree, smallest degree polynomial at the top. Repeated multiplying the top two polynomials until there are at least 2 polynomials also yield similar time complexity.
Computing Q(x) at K points.
This is a well-known problem by name Multi-point Evaluation of polynomial, explained here, this answer does a good job summarizing the process concisely.
Letās say B_i = A_i-r for each i, so we need to evaluate Q(x) at each value in B.
The core idea is that we build a segment tree styled structure, where node corresponding to range [L, R] represent product P_{l, r}(x) = \displaystyle\prod_{i = L}^R (x-B_i).
Starting with Q(x) at root node, we descend down the tree, passing down Q(x) \bmod P_{l, r}(x) if moving to node with range [l, r], since we have Q(x) = Q(B_i) \bmod (x-B_i). Polynomial operations like division and modulo are required, which are also described here
Additionally, you can test your implementation here, or refer to the fastest implementations as well.
Feel free to refer to implementations below, a warning from tester ācode may be too long, you have been warned.ā
TIME COMPLEXITY
The time complexity is O(r*log^2(r) + k*log^2(k))
SOLUTIONS
Setter's Solution
#include <cstdio>
#include <cctype>
#include <algorithm>
#include <tuple>
#include <cstring>
#include <cmath>
#include <vector>
#include<iostream>
#include <cstdint>
#include <assert.h>
#include <bits/stdc++.h>
typedef unsigned int uint;
typedef long long unsigned int uint64;
constexpr uint Max_size = 1 << 22 | 32;
constexpr uint g = 3, Mod = 998244353;
static constexpr uint map[4] = {0, Mod, Mod * 2, Mod * 3};
inline uint norm_2(const uint x)
{
return x - map[x >> 30];
}
inline uint norm(const uint x)
{
return x < Mod ? x : x - Mod;
}
inline uint mult_2(const uint x1, const uint x2)
{
uint64 x = static_cast<uint64>(x1) * x2;
constexpr uint64 base = (-1ULL) / Mod;
return static_cast<uint>(x) - Mod * static_cast<uint>((static_cast<__uint128_t>(x) * base) >> 64);
}
struct Z
{
uint v;
Z() { }
Z(const uint _v) : v(_v) { }
};
inline Z operator+(const Z x1, const Z x2) { return norm(x1.v + x2.v); }
inline Z operator-(const Z x1, const Z x2) { return norm(x1.v + Mod - x2.v); }
inline Z operator-(const Z x) { return x.v ? Mod - x.v : 0; }
inline Z operator*(const Z x1, const Z x2) { return static_cast<uint64>(x1.v) * x2.v % Mod; }
inline Z &operator+=(Z &x1, const Z x2) { return x1 = x1 + x2; }
inline Z &operator-=(Z &x1, const Z x2) { return x1 = x1 - x2; }
inline Z &operator*=(Z &x1, const Z x2) { return x1 = x1 * x2; }
inline Z mult_2(const Z x1, const Z x2)
{
return mult_2(x1.v, x2.v);
}
inline Z Power(Z Base, int Exp)
{
Z res = 1;
for (; Exp; Base *= Base, Exp >>= 1)
if (Exp & 1)
res *= Base;
return res;
}
inline Z Rec(const Z x)
{
return Power(x, Mod - 2);
}
Z _Rec[Max_size];
void init_Rec(const int n)
{
_Rec[1] = 1;
for (int i = 2; i != n; ++i)
_Rec[i] = _Rec[i - 1] * (i - 1);
Z R = Rec(_Rec[n - 1] * (n - 1));
for (int i = n - 1; i != 1; --i)
_Rec[i] *= R, R *= i;
}
struct Shoup
{
uint v, q;
Shoup() = default;
explicit Shoup(const uint _v) : v(_v), q((static_cast<uint64>(_v) << 32) / Mod) { }
};
int size;
Shoup w[Max_size], rw[Max_size];
inline uint mult_Shoup_2(const uint x, const Shoup y)
{
uint q = static_cast<uint64>(x) * y.q >> 32;
return x * y.v - q * Mod;
}
inline uint mult_Shoup(const uint x, const Shoup y)
{
return norm(mult_Shoup_2(x, y));
}
inline uint mult_Shoup_q(const uint x, const Shoup y)
{
uint q = static_cast<uint64>(x) * y.q >> 32;
return q + (x * y.v - q * Mod >= Mod);
}
void init_w(const int n)
{
for (size = 2; size < n; size <<= 1)
;
Shoup pr = Shoup(Power(g, (Mod - 1) / size).v);
Shoup rpr = Shoup(Rec(pr.v).v);
size >>= 1;
w[size] = Shoup(1);
rw[size] = Shoup(1);
for (int i = 1; i < size; ++i)
{
w[size + i] = Shoup(mult_Shoup(w[size + i - 1].v, pr));
rw[size + i] = Shoup(mult_Shoup(rw[size + i - 1].v, rpr));
}
for (int i = size - 1; i; --i)
{
w[i] = w[i * 2];
rw[i] = rw[i * 2];
}
size <<= 1;
}
void DFT_fr_2(Z _A[], const int L)
{
if (L == 1)
return;
uint *A = reinterpret_cast<uint *>(_A);
#define butterfly1(a, b)\
do\
{\
uint _a = a, _b = b;\
uint x = norm_2(_a + _b), y = norm_2(_a + Mod * 2 - _b);\
a = x, b = y;\
} while (0)
if (L == 2)
{
butterfly1(A[0], A[1]);
return;
}
#define butterfly(a, b, _w)\
do\
{\
uint _a = a, _b = b;\
uint x = norm_2(_a + _b), y = mult_Shoup_2(_a + Mod * 2 - _b, _w);\
a = x, b = y;\
} while (0)
if (L == 4)
{
butterfly1(A[0], A[2]);
butterfly(A[1], A[3], w[3]);
butterfly1(A[0], A[1]);
butterfly1(A[2], A[3]);
return;
}
for (int d = L >> 1; d != 4; d >>= 1)
for (int i = 0; i != L; i += d << 1)
for (int j = 0; j != d; j += 4)
{
butterfly(A[i + j + 0], A[i + d + j + 0], w[d + j + 0]);
butterfly(A[i + j + 1], A[i + d + j + 1], w[d + j + 1]);
butterfly(A[i + j + 2], A[i + d + j + 2], w[d + j + 2]);
butterfly(A[i + j + 3], A[i + d + j + 3], w[d + j + 3]);
}
for (int i = 0; i != L; i += 8)
{
butterfly1(A[i + 0], A[i + 4]);
butterfly(A[i + 1], A[i + 5], w[5]);
butterfly(A[i + 2], A[i + 6], w[6]);
butterfly(A[i + 3], A[i + 7], w[7]);
butterfly1(A[i + 0], A[i + 2]);
butterfly(A[i + 1], A[i + 3], w[3]);
butterfly1(A[i + 4], A[i + 6]);
butterfly(A[i + 5], A[i + 7], w[3]);
butterfly1(A[i + 0], A[i + 1]);
butterfly1(A[i + 2], A[i + 3]);
butterfly1(A[i + 4], A[i + 5]);
butterfly1(A[i + 6], A[i + 7]);
}
#undef butterfly1
#undef butterfly
}
void DFT_fr(Z _A[], const int L)
{
DFT_fr_2(_A, L);
for (int i = 0; i != L; ++i)
_A[i] = norm(_A[i].v);
}
void IDFT_fr(Z _A[], const int L)
{
if (L == 1)
return;
uint *A = reinterpret_cast<uint *>(_A);
#define butterfly1(a, b)\
do\
{\
uint _a = a, _b = b;\
uint x = norm_2(_a), t = norm_2(_b);\
a = x + t, b = x + Mod * 2 - t;\
} while (0)
if (L == 2)
{
butterfly1(A[0], A[1]);
A[0] = norm(norm_2(A[0])), A[0] = A[0] & 1 ? A[0] + Mod : A[0], A[0] /= 2;
A[1] = norm(norm_2(A[1])), A[1] = A[1] & 1 ? A[1] + Mod : A[1], A[1] /= 2;
return;
}
#define butterfly(a, b, _w)\
do\
{\
uint _a = a, _b = b;\
uint x = norm_2(_a), t = mult_Shoup_2(_b, _w);\
a = x + t, b = x + Mod * 2 - t;\
} while (0)
if (L == 4)
{
butterfly1(A[0], A[1]);
butterfly1(A[2], A[3]);
butterfly1(A[0], A[2]);
butterfly(A[1], A[3], rw[3]);
for (int i = 0; i != L; ++i)
{
uint64 m = -A[i] & 3;
A[i] = norm((A[i] + m * Mod) >> 2);
}
return;
}
for (int i = 0; i != L; i += 8)
{
butterfly1(A[i + 0], A[i + 1]);
butterfly1(A[i + 2], A[i + 3]);
butterfly1(A[i + 4], A[i + 5]);
butterfly1(A[i + 6], A[i + 7]);
butterfly1(A[i + 0], A[i + 2]);
butterfly(A[i + 1], A[i + 3], rw[3]);
butterfly1(A[i + 4], A[i + 6]);
butterfly(A[i + 5], A[i + 7], rw[3]);
butterfly1(A[i + 0], A[i + 4]);
butterfly(A[i + 1], A[i + 5], rw[5]);
butterfly(A[i + 2], A[i + 6], rw[6]);
butterfly(A[i + 3], A[i + 7], rw[7]);
}
for (int d = 8; d != L; d <<= 1)
for (int i = 0; i != L; i += d << 1)
for (int j = 0; j != d; j += 4)
{
butterfly(A[i + j + 0], A[i + d + j + 0], rw[d + j + 0]);
butterfly(A[i + j + 1], A[i + d + j + 1], rw[d + j + 1]);
butterfly(A[i + j + 2], A[i + d + j + 2], rw[d + j + 2]);
butterfly(A[i + j + 3], A[i + d + j + 3], rw[d + j + 3]);
}
#undef butterfly1
#undef butterfly
int k = __builtin_ctz(L);
for (int i = 0; i != L; ++i)
{
uint64 m = -A[i] & (L - 1);
A[i] = norm((A[i] + m * Mod) >> k);
}
}
void IDFT_fr_core_2(Z _A[], const int L)
{
if (L == 1)
return;
uint *A = reinterpret_cast<uint *>(_A);
#define butterfly1(a, b)\
do\
{\
uint _a = a, _b = b;\
uint x = norm_2(_a), t = norm_2(_b);\
a = x + t, b = x + Mod * 2 - t;\
} while (0)
if (L == 2)
{
butterfly1(A[0], A[1]);
A[0] = norm_2(A[0]);
A[1] = norm_2(A[1]);
return;
}
#define butterfly(a, b, _w)\
do\
{\
uint _a = a, _b = b;\
uint x = norm_2(_a), t = mult_Shoup_2(_b, _w);\
a = x + t, b = x + Mod * 2 - t;\
} while (0)
if (L == 4)
{
butterfly1(A[0], A[1]);
butterfly1(A[2], A[3]);
butterfly1(A[0], A[2]);
butterfly(A[1], A[3], rw[3]);
A[0] = norm_2(A[0]);
A[1] = norm_2(A[1]);
A[2] = norm_2(A[2]);
A[3] = norm_2(A[3]);
return;
}
for (int i = 0; i != L; i += 8)
{
butterfly1(A[i + 0], A[i + 1]);
butterfly1(A[i + 2], A[i + 3]);
butterfly1(A[i + 4], A[i + 5]);
butterfly1(A[i + 6], A[i + 7]);
butterfly1(A[i + 0], A[i + 2]);
butterfly(A[i + 1], A[i + 3], rw[3]);
butterfly1(A[i + 4], A[i + 6]);
butterfly(A[i + 5], A[i + 7], rw[3]);
butterfly1(A[i + 0], A[i + 4]);
butterfly(A[i + 1], A[i + 5], rw[5]);
butterfly(A[i + 2], A[i + 6], rw[6]);
butterfly(A[i + 3], A[i + 7], rw[7]);
}
for (int d = 8; d != L; d <<= 1)
for (int i = 0; i != L; i += d << 1)
for (int j = 0; j != d; j += 4)
{
butterfly(A[i + j + 0], A[i + d + j + 0], rw[d + j + 0]);
butterfly(A[i + j + 1], A[i + d + j + 1], rw[d + j + 1]);
butterfly(A[i + j + 2], A[i + d + j + 2], rw[d + j + 2]);
butterfly(A[i + j + 3], A[i + d + j + 3], rw[d + j + 3]);
}
#undef butterfly1
#undef butterfly
for (int i = 0; i != L; ++i)
A[i] = norm_2(A[i]);
}
inline void copy_n_a(const Z A[], const int n, Z B[])
{
std::memmove(B, A, n * sizeof(Z));
}
inline void reset(Z A[], const int n)
{
std::memset(reinterpret_cast<uint *>(A), 0, n * sizeof(Z));
}
const int Brute_size_l = 4;
const int Brute_size = 1 << Brute_size_l;
inline void mult_2(Z A[], const Z B[], const Z C[], const int N)
{
int s = N & ~3;
for (int i = 0; i != s; i += 4)
{
A[i + 0] = mult_2(B[i + 0], C[i + 0]);
A[i + 1] = mult_2(B[i + 1], C[i + 1]);
A[i + 2] = mult_2(B[i + 2], C[i + 2]);
A[i + 3] = mult_2(B[i + 3], C[i + 3]);
}
for (int i = s; i != N; ++i)
A[i] = mult_2(B[i], C[i]);
}
inline void mult(Z A[], const Z B[], const Z C[], const int N)
{
int s = N & ~3;
for (int i = 0; i != s; i += 4)
{
A[i + 0] = B[i + 0] * C[i + 0];
A[i + 1] = B[i + 1] * C[i + 1];
A[i + 2] = B[i + 2] * C[i + 2];
A[i + 3] = B[i + 3] * C[i + 3];
}
for (int i = s; i != N; ++i)
A[i] = B[i] * C[i];
}
const uint ME_M_size = Max_size * 6;
Z ME_M[ME_M_size], *ME_M_pos = ME_M + ME_M_size;
int ME_Pre(const Z x[], const int N, Z _x[], int id[], Z tmp[], const int LM)
{
Z *p = ME_M_pos -= (N + 3) & ~3;
int cnt = 0, rcnt = N, l = __builtin_ctz(LM);
copy_n_a(x, N, p);
if (l)
{
for (int j = 0; j != l - 1; ++j)
mult_2(p, p, p, N);
mult(p, p, p, N);
}
for (int i = 0; i != N; ++i)
{
if (p[i].v != 1)
_x[cnt] = x[i], id[cnt] = i, tmp[cnt++] = p[i];
else
{
id[--rcnt] = i;
Z _p[20];
int j = 0;
for (_p[0] = x[i]; _p[j].v != 1; ++j)
_p[j + 1] = _p[j] * _p[j];
int k = 0;
while (j)
{
k >>= 1;
if (_p[--j].v != w[(LM >> 1) + k].v)
k += LM >> 1;
}
tmp[rcnt].v = k;
}
}
ME_M_pos += (N + 3) & ~3;
return cnt;
}
inline void ST_DFT_fr_2(const Z A[], const int N, Z dfta[], const int L)
{
dfta[0] = 1, copy_n_a(A + 1, N - 1, dfta + 1), reset(dfta + N, L - N);
DFT_fr_2(dfta, L);
}
inline void ST_DFT_fr_2_extra(const Z A[], Z dfta[], const int L)
{
dfta[0] = 1, copy_n_a(A + 1, L - 1, dfta + 1);
dfta[0] += A[L];
DFT_fr_2(dfta, L);
}
inline void ST_DFT_fr_2_tbd(const Z A[], const int N, Z dfta[], const int L)
{
if (N != L + 1)
ST_DFT_fr_2(A, N, dfta, L);
else
ST_DFT_fr_2_extra(A, dfta, L);
}
inline void ST_extend_DFT_fr_2(const Z A[], const int N, Z dfta[], const int L)
{
dfta[L] = 1;
for (int j = 1; j != N; ++j)
dfta[L + j] = mult_Shoup_2(A[j].v, w[L + j]);
reset(dfta + L + N, L - N);
DFT_fr_2(dfta + L, L);
}
inline void ST_extend_DFT_fr_2_extra(const Z A[], Z dfta[], const int L)
{
dfta[L] = 1;
for (int j = 1; j != L; ++j)
dfta[L + j] = mult_Shoup_2(A[j].v, w[L + j]);
dfta[L] -= A[L];
DFT_fr_2(dfta + L, L);
}
inline void ST_extend_DFT_fr_2_tbd(const Z A[], const int N, Z dfta[], const int L)
{
if (N != L + 1)
ST_extend_DFT_fr_2(A, N, dfta, L);
else
ST_extend_DFT_fr_2_extra(A, dfta, L);
}
inline void ST_extend_2_DFT_fr_2(const Z A[], const int N, Z dfta[], const int L)
{
dfta[L] = 1;
dfta[L + L] = 1;
dfta[L + L + L] = 1;
for (int j = 1; j != N; ++j)
{
dfta[L + j] = mult_Shoup_2(A[j].v, w[L + j]);
dfta[L + L + j] = mult_Shoup_2(A[j].v, w[L + L + j]);
dfta[L + L + L + j] = mult_Shoup_2(dfta[L + L + j].v, w[L + j]);
}
reset(dfta + L + N, L - N);
reset(dfta + L + L + N, L - N);
reset(dfta + L + L + L + N, L - N);
DFT_fr_2(dfta + L, L);
DFT_fr_2(dfta + L + L, L);
DFT_fr_2(dfta + L + L + L, L);
}
inline void ST_extend_2_DFT_fr_2_extra(const Z A[], Z dfta[], const int L)
{
for (int j = 1; j != L; ++j)
{
dfta[L + j] = mult_Shoup_2(A[j].v, w[L + j]);
dfta[L + L + j] = mult_Shoup_2(A[j].v, w[L + L + j]);
dfta[L + L + L + j] = mult_Shoup_2(dfta[L + L + j].v, w[L + j]);
}
int tmp = mult_Shoup(A[L].v, w[L + L + L]);
dfta[L] = 1 + Mod - A[L].v;
DFT_fr_2(dfta + L, L);
dfta[L + L] = 1 + tmp;
DFT_fr_2(dfta + L + L, L);
dfta[L + L + L] = 1 + Mod - tmp;
DFT_fr_2(dfta + L + L + L, L);
}
inline void ST_extend_2_DFT_fr_2_tbd(const Z A[], const int N, Z dfta[], const int L)
{
if (N != L + 1)
ST_extend_2_DFT_fr_2(A, N, dfta, L);
else
ST_extend_2_DFT_fr_2_extra(A, dfta, L);
}
template<const size_t _size0, const size_t _size1>
void Subproduct_Tree_Rev_(const Z x[], const int N, Z T0[], Z T[], Z dftt[][_size0], Z dfttx[][_size1], const int L, const int LM)
{
if (std::min(L, LM) <= Brute_size)
return;
int _l = __builtin_ctz(std::min(L, LM)) & 1;
while (_l + 2 <= Brute_size_l)
_l += 2;
for (int i = 0; i < N; i += 1 << _l)
{
auto _T = T + i;
{
const Z v0 = Mod - x[i].v, v1 = Mod - x[i + 1].v, v2 = Mod - x[i + 2].v, v3 = Mod - x[i + 3].v;
const uint64 v01 = mult_2(v0.v, v1.v), v23 = mult_2(v2.v, v3.v);
_T[4] = v01 * v23 % Mod;
_T[3] = (v01 * (v2.v + v3.v) + v23 * (v0.v + v1.v)) % Mod;
_T[2] = (v01 + v23 + static_cast<uint64>(v0.v + v1.v) * (v2.v + v3.v)) % Mod;
_T[1] = (v0.v + v1.v + v2.v + v3.v) % Mod;
}
for (int j = 4; j != (1 << _l) && j < N - i; j += 4)
{
const Z v0 = Mod - x[i + j].v, v1 = Mod - x[i + j + 1].v, v2 = Mod - x[i + j + 2].v, v3 = Mod - x[i + j + 3].v;
const uint64 v01 = mult_2(v0.v, v1.v), v23 = mult_2(v2.v, v3.v);
const uint64 a4 = v01 * v23 % Mod;
const uint64 a3 = (v01 * (v2.v + v3.v) + v23 * (v0.v + v1.v)) % Mod;
const uint64 a2 = (v01 + v23 + static_cast<uint64>(v0.v + v1.v) * (v2.v + v3.v)) % Mod;
const uint64 a1 = (v0.v + v1.v + v2.v + v3.v) % Mod;
_T[j + 4] = (_T[j].v * a4) % Mod;
_T[j + 3] = (_T[j].v * a3 + _T[j - 1].v * a4) % Mod;
_T[j + 2] = (_T[j].v * a2 + _T[j - 1].v * a3 + _T[j - 2].v * a4) % Mod;
_T[j + 1] = (_T[j].v * a1 + _T[j - 1].v * a2 + _T[j - 2].v * a3 + _T[j - 3].v * a4) % Mod;
for (int k = j; k > 4; --k)
_T[k] = (_T[k].v + _T[k - 1].v * a1 + _T[k - 2].v * a2 + _T[k - 3].v * a3 + _T[k - 4].v * a4) % Mod;
_T[4] = (_T[4].v + _T[3].v * a1 + _T[2].v * a2 + _T[1].v * a3 + a4) % Mod;
_T[3] = (_T[3].v + _T[2].v * a1 + _T[1].v * a2 + a3) % Mod;
_T[2] = (_T[2].v + _T[1].v * a1 + a2) % Mod;
_T[1] = a1 + _T[1];
}
}
copy_n_a(T, N + 1, T0);
for (int l = 0, d; d = 1 << (_l + l * 2), d != std::min(L, LM); ++l)
{
const int s = N & ~(d * 4 - 1);
Z p = 1;
for (int i = 0; i != s; i += d * 4)
{
auto _T0 = T + i, _T1 = _T0 + d, _T2 = _T1 + d, _T3 = _T2 + d;
auto _dftt0 = dftt[l] + i * 4, _dftt1 = _dftt0 + d * 4, _dftt2 = _dftt1 + d * 4, _dftt3 = _dftt2 + d * 4;
auto _dftta = dfttx[l] + i * 2, _dfttb = _dftta + d * 4;
if (d <= Brute_size)
{
ST_DFT_fr_2_extra(_T0, _dftt0, d);
ST_DFT_fr_2_extra(_T1, _dftt1, d);
ST_DFT_fr_2_extra(_T2, _dftt2, d);
ST_DFT_fr_2_extra(_T3, _dftt3, d);
}
ST_extend_2_DFT_fr_2_extra(_T0, _dftt0, d);
ST_extend_2_DFT_fr_2_extra(_T1, _dftt1, d);
ST_extend_2_DFT_fr_2_extra(_T2, _dftt2, d);
ST_extend_2_DFT_fr_2_extra(_T3, _dftt3, d);
mult_2(_dftta, _dftt0, _dftt1, d * 4);
mult_2(_dfttb, _dftt2, _dftt3, d * 4);
mult_2(_T0, _dftta, _dfttb, d * 4);
copy_n_a(_T0, d * 4, dftt[l + 1] + i * 4);
IDFT_fr(_T0, d * 4);
std::swap(_T0[0] -= 1, p);
}
if (N != s)
{
auto _T0 = T + s, _T1 = _T0 + d, _T2 = _T1 + d, _T3 = _T2 + d;
auto _dftt0 = dftt[l] + s * 4, _dftt1 = _dftt0 + d * 4, _dftt2 = _dftt1 + d * 4, _dftt3 = _dftt2 + d * 4;
auto _dftta = dfttx[l] + s * 2, _dfttb = _dftta + d * 4;
if (N - s <= d)
{
const int t = N - s;
if (d <= Brute_size)
ST_DFT_fr_2_tbd(_T0, t + 1, _dftt0, d);
ST_extend_2_DFT_fr_2_tbd(_T0, t + 1, _dftt0, d);
copy_n_a(_dftt0, d * 4, dftt[l + 1] + s * 4);
}
else
{
if (N - s <= d * 2)
{
const int t = N - s - d;
if (d <= Brute_size)
{
ST_DFT_fr_2_extra(_T0, _dftt0, d);
ST_DFT_fr_2_tbd(_T1, t + 1, _dftt1, d);
}
ST_extend_DFT_fr_2_extra(_T0, _dftt0, d);
ST_extend_DFT_fr_2_tbd(_T1, t + 1, _dftt1, d);
mult_2(_T0, _dftt0, _dftt1, d * 2);
copy_n_a(_T0, d * 2, dftt[l + 1] + s * 4);
IDFT_fr(_T0, d * 2);
if (t == d)
_T0[d * 2] = _T0[0] - 1, _T0[0] = 1;
ST_extend_DFT_fr_2_tbd(_T0, d + t + 1, dftt[l + 1] + s * 4, d * 2);
}
else
{
if (d <= Brute_size)
{
ST_DFT_fr_2_extra(_T0, _dftt0, d);
ST_DFT_fr_2_extra(_T1, _dftt1, d);
}
ST_extend_2_DFT_fr_2_extra(_T0, _dftt0, d);
ST_extend_2_DFT_fr_2_extra(_T1, _dftt1, d);
mult_2(_dftta, _dftt0, _dftt1, d * 4);
if (N - s <= d * 3) // can be faster
{
const int t = N - s - d * 2;
if (d <= Brute_size)
ST_DFT_fr_2_tbd(_T2, t + 1, _dftt2, d);
ST_extend_2_DFT_fr_2_tbd(_T2, t + 1, _dftt2, d);
copy_n_a(_dftt2, d * 4, _dfttb);
}
else
{
const int t = N - s - d * 3;
if (d <= Brute_size)
{
ST_DFT_fr_2_extra(_T2, _dftt2, d);
ST_DFT_fr_2(_T3, t + 1, _dftt3, d);
}
ST_extend_2_DFT_fr_2_extra(_T2, _dftt2, d);
ST_extend_2_DFT_fr_2(_T3, t + 1, _dftt3, d);
mult_2(_dfttb, _dftt2, _dftt3, d * 4);
}
mult_2(_T0, _dftta, _dfttb, d * 4);
copy_n_a(_T0, d * 4, dftt[l + 1] + s * 4);
IDFT_fr(_T0, d * 4);
}
}
}
T[s] = p;
}
}
void IDFT_fr_core_b4_2(Z _A[], const int L)
{
uint *A = reinterpret_cast<uint *>(_A);
#define butterfly1(a, b)\
do\
{\
uint _a = a, _b = b;\
uint x = norm_2(_a), t = norm_2(_b);\
a = x + t, b = x + Mod * 2 - t;\
} while (0)
#define semi_butterfly1(a, b)\
do\
{\
uint _a = a, _b = b;\
uint x = norm_2(_a), t = norm_2(_b);\
b = x + Mod * 2 - t;\
} while (0)
// auto butterfly = [](uint &a, uint &b, const Shoup _w)
// {
// uint x = norm_2(a), t = mult_Shoup_2(b, _w);
// a = x + t, b = x + Mod * 2 - t;
// };
#define butterfly(a, b, _w)\
do\
{\
uint _a = a, _b = b;\
uint x = norm_2(_a), t = mult_Shoup_2(_b, _w);\
a = x + t, b = x + Mod * 2 - t;\
} while (0)
#define semi_butterfly(a, b, _w)\
do\
{\
uint _a = a, _b = b;\
uint x = norm_2(_a), t = mult_Shoup_2(_b, _w);\
b = x + Mod * 2 - t;\
} while (0)
if (L == 4)
{
semi_butterfly1(A[0], A[1]);
semi_butterfly1(A[2], A[3]);
semi_butterfly(A[1], A[3], rw[3]);
A[3] = norm_2(A[3]);
return;
}
if (L == 8)
{
butterfly1(A[0], A[1]);
butterfly1(A[2], A[3]);
butterfly1(A[4], A[5]);
butterfly1(A[6], A[7]);
semi_butterfly1(A[0], A[2]);
semi_butterfly(A[1], A[3], rw[3]);
semi_butterfly1(A[4], A[6]);
semi_butterfly(A[5], A[7], rw[3]);
semi_butterfly(A[2], A[6], rw[6]);
semi_butterfly(A[3], A[7], rw[7]);
A[6] = norm_2(A[6]);
A[7] = norm_2(A[7]);
return;
}
if (L == 16)
{
butterfly1(A[0], A[1]);
butterfly1(A[2], A[3]);
butterfly1(A[4], A[5]);
butterfly1(A[6], A[7]);
butterfly1(A[8], A[9]);
butterfly1(A[10], A[11]);
butterfly1(A[12], A[13]);
butterfly1(A[14], A[15]);
butterfly1(A[0], A[2]);
butterfly(A[1], A[3], rw[3]);
butterfly1(A[4], A[6]);
butterfly(A[5], A[7], rw[3]);
butterfly1(A[8], A[10]);
butterfly(A[9], A[11], rw[3]);
butterfly1(A[12], A[14]);
butterfly(A[13], A[15], rw[3]);
semi_butterfly1(A[0], A[4]);
semi_butterfly(A[1], A[5], rw[5]);
semi_butterfly(A[2], A[6], rw[6]);
semi_butterfly(A[3], A[7], rw[7]);
semi_butterfly1(A[8], A[12]);
semi_butterfly(A[9], A[13], rw[5]);
semi_butterfly(A[10], A[14], rw[6]);
semi_butterfly(A[11], A[15], rw[7]);
semi_butterfly(A[4], A[12], rw[12]);
semi_butterfly(A[5], A[13], rw[13]);
semi_butterfly(A[6], A[14], rw[14]);
semi_butterfly(A[7], A[15], rw[15]);
A[12] = norm_2(A[12]);
A[13] = norm_2(A[13]);
A[14] = norm_2(A[14]);
A[15] = norm_2(A[15]);
return;
}
for (int i = 0; i != L; i += 8)
{
butterfly1(A[i + 0], A[i + 1]);
butterfly1(A[i + 2], A[i + 3]);
butterfly1(A[i + 4], A[i + 5]);
butterfly1(A[i + 6], A[i + 7]);
butterfly1(A[i + 0], A[i + 2]);
butterfly(A[i + 1], A[i + 3], rw[3]);
butterfly1(A[i + 4], A[i + 6]);
butterfly(A[i + 5], A[i + 7], rw[3]);
butterfly1(A[i + 0], A[i + 4]);
butterfly(A[i + 1], A[i + 5], rw[5]);
butterfly(A[i + 2], A[i + 6], rw[6]);
butterfly(A[i + 3], A[i + 7], rw[7]);
}
for (int d = 8; d != L >> 2; d <<= 1)
for (int i = 0; i != L; i += d << 1)
for (int j = 0; j != d; j += 4)
{
butterfly(A[i + j + 0], A[i + d + j + 0], rw[d + j + 0]);
butterfly(A[i + j + 1], A[i + d + j + 1], rw[d + j + 1]);
butterfly(A[i + j + 2], A[i + d + j + 2], rw[d + j + 2]);
butterfly(A[i + j + 3], A[i + d + j + 3], rw[d + j + 3]);
}
int d = L >> 2;
for (int j = 0; j != d; j += 4)
{
semi_butterfly(A[j + 0], A[d + j + 0], rw[d + j + 0]);
semi_butterfly(A[j + 1], A[d + j + 1], rw[d + j + 1]);
semi_butterfly(A[j + 2], A[d + j + 2], rw[d + j + 2]);
semi_butterfly(A[j + 3], A[d + j + 3], rw[d + j + 3]);
}
for (int j = 0; j != d; j += 4)
{
semi_butterfly(A[(L >> 1) + j + 0], A[(L >> 1) + d + j + 0], rw[d + j + 0]);
semi_butterfly(A[(L >> 1) + j + 1], A[(L >> 1) + d + j + 1], rw[d + j + 1]);
semi_butterfly(A[(L >> 1) + j + 2], A[(L >> 1) + d + j + 2], rw[d + j + 2]);
semi_butterfly(A[(L >> 1) + j + 3], A[(L >> 1) + d + j + 3], rw[d + j + 3]);
}
d = L >> 1;
for (int j = L >> 2; j != d; j += 4)
{
semi_butterfly(A[j + 0], A[d + j + 0], rw[d + j + 0]);
semi_butterfly(A[j + 1], A[d + j + 1], rw[d + j + 1]);
semi_butterfly(A[j + 2], A[d + j + 2], rw[d + j + 2]);
semi_butterfly(A[j + 3], A[d + j + 3], rw[d + j + 3]);
}
#undef butterfly1
#undef semi_butterfly1
#undef butterfly
#undef semi_butterfly
for (int i = (L >> 4) * 3; i != L; ++i)
A[i] = norm_2(A[i]);
}
template<const size_t _size0, const size_t _size1>
void Multipoint_Evaluation_core(const Z x[], const Z P0[], const Z P[], const Z dftt[][_size0], const Z dfttx[][_size1], const Z p[], const int N, const Z F[], const Z dftf[], const int M, Z res[], int L, int LM)
{
if (std::min(L, LM) <= Brute_size)
{
for (int j = 0; j != N; ++j)
{
Z y = 0;
for (int k = M - 1; k >= 0; --k)
y = mult_2(y, x[j]).v + F[k].v;
res[j].v = norm(norm_2(y.v));
}
return;
}
int _l = __builtin_ctz(std::min(L, LM)) & 1;
while (_l + 2 <= Brute_size_l)
_l += 2;
int l = (__builtin_ctz(std::min(L, LM)) - _l) >> 1;
Z *C = ME_M_pos -= std::max(L * 4, LM);
Z *tmp = ME_M_pos -= std::max(L * 2, LM);
for (int i = 0; i < N; i += LM)
{
auto _P = P + i;
auto _dftt = dftt[l] + i * 4;
auto _C = C + i;
copy_n_a(_dftt, std::min(L, LM), _C);
for (int d = L; d < LM; d <<= 1)
ST_extend_DFT_fr_2_tbd(_P, N + 1, _C, d);
}
{
const int s = (N + LM - 1) & ~(LM - 1);
tmp[0] = 1;
for (int j = 1; j != s; ++j)
tmp[j] = mult_2(tmp[j - 1], C[j - 1]);
Z r = Rec(tmp[s - 1] * C[s - 1]);
for (int j = s - 1; j >= 0; --j)
tmp[j] = mult_2(tmp[j], r), r = mult_2(r, C[j]);
}
for (int i = 0; i < N; i += LM)
{
auto _dftt = dftt[l] + i * 4;
auto _C = C + i * 4;
auto _tmp = tmp + i;
mult_2(_C, dftf, _tmp, LM);
IDFT_fr(_C, LM);
const int t = std::min(N - i, LM);
if (t <= M)
copy_n_a(_C + M - t, t, _C + std::min(L, LM) * 4 - t);
else
copy_n_a(_C + LM - (t - M), t - M, _C + std::min(L, LM) * 4 - t), copy_n_a(_C, M, _C + std::min(L, LM) * 4 - M);
reset(_C + std::min(L, LM) * 3, std::min(L, LM) - t);
}
--l;
uint idft_c = 1;
for (int d; d = 1 << (_l + l * 2), l >= 0; --l)
{
const int s = N & ~(d * 4 - 1);
for (int i = 0; i != s; i += d * 4)
{
auto _C0 = C + i * 4, _C1 = _C0 + d * 4, _C2 = _C1 + d * 4, _C3 = _C2 + d * 4;
auto _dftt0 = dftt[l] + i * 4, _dftt1 = _dftt0 + d * 4, _dftt2 = _dftt1 + d * 4, _dftt3 = _dftt2 + d * 4;
auto _dftta = dfttx[l] + i * 2, _dfttb = _dftta + d * 4;
DFT_fr_2(_C3, d * 4);
mult_2(_C1, _C3, _dfttb, d * 4);
mult_2(_C0, _C1, _dftt1, d * 4);
mult_2(_C1, _C1, _dftt0, d * 4);
mult_2(_C3, _C3, _dftta, d * 4);
mult_2(_C2, _C3, _dftt3, d * 4);
mult_2(_C3, _C3, _dftt2, d * 4);
IDFT_fr_core_b4_2(_C0, d * 4), IDFT_fr_core_b4_2(_C1, d * 4), IDFT_fr_core_b4_2(_C2, d * 4), IDFT_fr_core_b4_2(_C3, d * 4);
}
if (N != s)
{
auto _C0 = C + s * 4, _C1 = _C0 + d * 4, _C2 = _C1 + d * 4, _C3 = _C2 + d * 4;
auto _dftt0 = dftt[l] + s * 4, _dftt1 = _dftt0 + d * 4, _dftt2 = _dftt1 + d * 4, _dftt3 = _dftt2 + d * 4;
auto _dftta = dfttx[l] + s * 2, _dfttb = _dftta + d * 4;
if (N - s <= d)
{
copy_n_a(_C3 + d * 3, d, _C0 + d * 3);
for (int i = 0; i != d; ++i)
_C0[d * 3 + i] *= d * 4;
}
else
{
if (N - s <= d * 2)
{
const int t = N - s - d;
DFT_fr_2(_C3 + d * 2, d * 2);
mult_2(_C0 + d * 2, _C3 + d * 2, _dftt1, d * 2);
mult_2(_C1 + d * 2, _C3 + d * 2, _dftt0, d * 2);
IDFT_fr_core_2(_C0 + d * 2, d * 2), IDFT_fr_core_2(_C1 + d * 2, d * 2);
for (int i = 0; i != d * 3; ++i)
_C0[d * 3 + i] = norm_2(_C0[d * 3 + i].v * 2);
for (int i = 0; i != t; ++i)
_C1[d * 4 - t + i] = norm_2(_C1[d * 4 - t + i].v * 2);
reset(_C1 + d * 3, d - t);
}
else
{
DFT_fr_2(_C3, d * 4);
mult_2(_C1, _C3, _dfttb, d * 4);
mult_2(_C0, _C1, _dftt1, d * 4);
mult_2(_C1, _C1, _dftt0, d * 4);
if (N - s <= d * 3) // can be faster
{
const int t = N - s - d * 2;
mult_2(_C2, _C3, _dftta, d * 4);
IDFT_fr_core_b4_2(_C0, d * 4), IDFT_fr_core_b4_2(_C1, d * 4), IDFT_fr_core_b4_2(_C2, d * 4);
reset(_C2 + d * 3, d - t);
}
else
{
const int t = N - s - d * 3;
mult_2(_C3, _C3, _dftta, d * 4);
mult_2(_C2, _C3, _dftt3, d * 4);
mult_2(_C3, _C3, _dftt2, d * 4);
IDFT_fr_core_b4_2(_C0, d * 4), IDFT_fr_core_b4_2(_C1, d * 4), IDFT_fr_core_b4_2(_C2, d * 4), IDFT_fr_core_b4_2(_C3, d * 4);
reset(_C3 + d * 3, d - t);
}
}
}
}
int k = __builtin_ctz(d * 4);
uint64 m = -idft_c & (d * 4 - 1);
idft_c = norm((idft_c + m * Mod) >> k);
}
if (_l)
{
const int d = 1 << _l;
for (int i = 0; i < N; i += d)
{
const int t = std::min(d, N - i);
const int o = -t & 3;
auto _C = C + i * 4 + d * 4 - t;
auto _F = C + i * 4 + d * 2;
auto _P0 = P0 + i;
for (int j = 0; j != t; ++j)
_C[j].v = norm(_C[j].v);
for (int j = 0; j != o; ++j)
_F[j] = 0;
for (int j = 0; j != t; ++j) // Brute_size <= 16
{
uint64 c = _C[j].v;
// for (int k = 1; k <= j; ++k)
// c += static_cast<uint64>(_C[j - k].v) * _P0[k].v;
int s = j & ~3;
#define calc(k)\
(c += static_cast<uint64>(_C[j - (k)].v) * _P0[k].v + static_cast<uint64>(_C[j - (k) - 1].v) * _P0[(k) + 1].v + static_cast<uint64>(_C[j - (k) - 2].v) * _P0[(k) + 2].v + static_cast<uint64>(_C[j - (k) - 3].v) * _P0[(k) + 3].v)
if (0 < s)
if (calc(1), 4 < s)
if (calc(5), 8 < s)
if (calc(9), 12 < s)
calc(13);
#undef calc
for (int k = s + 1; k <= j; ++k)
c += static_cast<uint64>(_C[j - k].v) * _P0[k].v;
_F[o + j] = c % Mod;
}
for (int j = 0; j != t; ++j)
{
const uint _x = x[i + j].v, _x2 = mult_2(_x, _x), _x3 = mult_2(_x2, _x), _x4 = mult_2(_x3, _x);
Z y = 0;
// for (int k = 0; k < t; k += 4)
// y = (y.v * static_cast<uint64>(_x4) + _F[k].v * static_cast<uint64>(_x3) + _F[k + 1].v * static_cast<uint64>(_x2) + _F[k + 2].v * static_cast<uint64>(_x) + _F[k + 3].v) % Mod;
#define calc(k)\
(y = (y.v * static_cast<uint64>(_x4) + _F[k].v * static_cast<uint64>(_x3) + _F[(k) + 1].v * static_cast<uint64>(_x2) + _F[(k) + 2].v * static_cast<uint64>(_x) + _F[(k) + 3].v) % Mod)
if (calc(0), 4 < t)
if (calc(4), 8 < t)
if (calc(8), 12 < t)
calc(12);
#undef calc
res[i + j] = y * idft_c * (Mod + 1 - p[i + j].v);
}
}
}
else
for (int j = 0; j != N; ++j)
res[j] = C[j * 4 + 3] * idft_c * (Mod + 1 - p[j].v);
ME_M_pos += std::max(L * 4, LM) + std::max(L * 2, LM);
}
int revbin(int x, const int LM)
{
int y = 0;
for (int i = __builtin_ctz(LM) - 1; i >= 0; --i)
y |= (x & 1) << i, x >>= 1;
return y;
}
void Multipoint_Evaluation(const Z x[], const int N, const Z F[], const int M, Z res[], int LM)
{
static Z _x[Max_size];
static int id[Max_size];
static Z tmp[Max_size];
static Z dftf[Max_size];
std::reverse_copy(F, F + M, dftf), reset(dftf + M, LM - M);
DFT_fr(dftf, LM);
int cnt = ME_Pre(x, N, _x, id, tmp, LM);
for (int i = cnt; i != N; ++i)
{
int j = (tmp[i].v * (M - 1)) & (LM - 1);
uint y = dftf[revbin((LM - tmp[i].v) & (LM - 1), LM)].v;
tmp[i].v = j < (LM >> 1) ? mult_Shoup(y, w[(LM >> 1) + j]) : mult_Shoup(Mod - y, w[j]);
}
int L;
for (L = 1; L < cnt; L <<= 1)
;
static Z P0[Max_size];
static Z P[Max_size];
static Z dftt[10][Max_size * 4];
static Z dfttx[10][Max_size * 2];
_x[cnt] = _x[cnt + 1] = _x[cnt + 2] = 0;
Subproduct_Tree_Rev_(_x, cnt, P0, P, dftt, dfttx, L, LM);
Multipoint_Evaluation_core(_x, P0, P, dftt, dfttx, tmp, cnt, F, dftf, M, tmp, L, LM);
for (int i = 0; i != N; ++i)
res[id[i]] = tmp[i];
}
typedef long long ll;
ll mod=998244353;
int N, Q, LM;
Z q0, qx, qy;
Z A[Max_size], q[Max_size];
ll fac[200000],infac[200000];
ll exp(ll a,ll b)
{
a%=mod;
ll res=1;
while(b>0)
{
if(b&1)
res=res*a%mod;
a=a*a%mod;
b>>=1;
}
return res;
}
ll mod_in(ll a)
{
return exp(a,mod-2);
}
void pre()
{
fac[0]=1;
infac[0]=1;
for(int i=1;i<=100000;i++)
{
fac[i]=(fac[i-1]*i)%mod;
infac[i]=mod_in(fac[i]);
}
}
const ll root=15311432;
const ll root_1=469870224;
const ll root_pw=1<<23;
void fft(std::vector<ll> & a, bool invert)
{
ll n = a.size();
for (ll i = 1, j = 0; i < n; i++)
{
ll bit = n >> 1;
for (; j & bit; bit >>= 1)
j ^= bit;
j ^= bit;
if (i < j)
std::swap(a[i], a[j]);
}
for (ll len = 2; len <= n; len <<= 1)
{
ll wlen = invert ? root_1 : root;
for (ll i = len; i < root_pw; i <<= 1)
wlen=(1LL * wlen * wlen % mod);
for (ll i = 0; i < n; i += len)
{
ll w = 1;
for (ll j = 0; j < len / 2; j++) {
ll u = a[i+j], v = (a[i+j+len/2] * w % mod);
a[i+j] = u + v < mod ? u + v : u + v - mod;
a[i+j+len/2] = u - v >= 0 ? u - v : u - v + mod;
w = w * wlen % mod;
}
}
}
if (invert) {
ll n_1 = mod_in(n);
for (ll & x : a)
x = x * n_1 % mod;
}
}
std::vector<ll> multi(std::vector<ll> const& a,std::vector<ll> const& b,ll degree)
{
std::vector<ll> fa(a.begin(), a.end()), fb(b.begin(), b.end());
ll n = 1;
while (n < a.size() + b.size())
n <<= 1;
fa.resize(n);
fb.resize(n);
fft(fa, false);
fft(fb, false);
for (ll i = 0; i < n; i++)
{fa[i] *= fb[i];fa[i]%=mod;}
fft(fa, true);
fa.resize(degree+1);
return fa;
}
std::vector<ll> bangbang(ll l,ll r)
{
if(r<l)
return {1};
if(l==r)
return {l,1};
ll mid=(l+r)/2;
return multi(bangbang(l,mid),bangbang(mid+1,r),r-l+1);
}
int main()
{
pre();
ll r;
std::cin>>r>>Q;
ll p[Q],ans[Q];
for(int i=0;i<Q;i++)
{
std::cin>>q[i].v;
q[i].v=(q[i].v-r+mod)%mod;
}
for(int j=0;j<Q;j++)
std::cin>>p[j];
for(int i=0;i<Q;i++)
ans[i]=(infac[r]*p[i])%mod;
std::vector<ll> beg=bangbang(1,r);
N=r+1;
for(int i=0;i<N;i++)
A[i].v=beg[i];
for (LM = 1; LM < N; LM <<= 1);
init_w(LM);
Multipoint_Evaluation(q,Q, A, N,q,LM);
ll o=1;
for(int i=0;i<Q;i++)
{
o=(o*((((q[i].v)*ans[i])%mod)+(1-p[i]+mod))%mod)%mod;
}
std::cout<<o<<std::endl;
}
Tester's Solution
/* in the name of Anton */
/*
Compete against Yourself.
Author - Aryan (@aryanc403)
Atcoder library - https://atcoder.github.io/ac-library/production/document_en/
*/
#ifdef ARYANC403
#include <header.h>
#else
#pragma GCC optimize ("Ofast")
#pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx")
//#pragma GCC optimize ("-ffloat-store")
#include<bits/stdc++.h>
#define dbg(args...) 42;
#endif
using namespace std;
#define fo(i,n) for(i=0;i<(n);++i)
#define repA(i,j,n) for(i=(j);i<=(n);++i)
#define repD(i,j,n) for(i=(j);i>=(n);--i)
#define all(x) begin(x), end(x)
#define sz(x) ((lli)(x).size())
#define pb push_back
#define mp make_pair
#define X first
#define Y second
#define endl "\n"
typedef long long int lli;
typedef long double mytype;
typedef pair<lli,lli> ii;
typedef vector<ii> vii;
typedef vector<lli> vi;
const auto start_time = std::chrono::high_resolution_clock::now();
void aryanc403()
{
#ifdef ARYANC403
auto end_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end_time-start_time;
cerr<<"Time Taken : "<<diff.count()<<"\n";
#endif
}
long long readInt(long long l, long long r, char endd) {
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true) {
char g=getchar();
if(g=='-') {
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g&&g<='9') {
x*=10;
x+=g-'0';
if(cnt==0) {
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd) {
if(is_neg) {
x=-x;
}
assert(l<=x&&x<=r);
return x;
} else {
assert(false);
}
}
}
string readString(int l, int r, char endd) {
string ret="";
int cnt=0;
while(true) {
char g=getchar();
assert(g!=-1);
if(g==endd) {
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt&&cnt<=r);
return ret;
}
long long readIntSp(long long l, long long r) {
return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
return readString(l,r,' ');
}
void readEOF(){
assert(getchar()==EOF);
}
vi readVectorInt(int n,lli l,lli r){
vi a(n);
for(int i=0;i<n-1;++i)
a[i]=readIntSp(l,r);
a[n-1]=readIntLn(l,r);
return a;
}
const lli INF = 0xFFFFFFFFFFFFFFFL;
lli seed;
mt19937 rng(seed=chrono::steady_clock::now().time_since_epoch().count());
inline lli rnd(lli l=0,lli r=INF)
{return uniform_int_distribution<lli>(l,r)(rng);}
class CMP
{public:
bool operator()(ii a , ii b) //For min priority_queue .
{ return ! ( a.X < b.X || ( a.X==b.X && a.Y <= b.Y )); }};
void add( map<lli,lli> &m, lli x,lli cnt=1)
{
auto jt=m.find(x);
if(jt==m.end()) m.insert({x,cnt});
else jt->Y+=cnt;
}
void del( map<lli,lli> &m, lli x,lli cnt=1)
{
auto jt=m.find(x);
if(jt->Y<=cnt) m.erase(jt);
else jt->Y-=cnt;
}
bool cmp(const ii &a,const ii &b)
{
return a.X<b.X||(a.X==b.X&&a.Y<b.Y);
}
#include <algorithm>
#include <array>
#include <cassert>
#include <type_traits>
#include <vector>
#ifdef _MSC_VER
#include <intrin.h>
#endif
namespace atcoder {
namespace internal {
int ceil_pow2(int n) {
int x = 0;
while ((1U << x) < (unsigned int)(n)) x++;
return x;
}
int bsf(unsigned int n) {
#ifdef _MSC_VER
unsigned long index;
_BitScanForward(&index, n);
return index;
#else
return __builtin_ctz(n);
#endif
}
} // namespace internal
} // namespace atcoder
#include <cassert>
#include <numeric>
#include <type_traits>
#ifdef _MSC_VER
#include <intrin.h>
#endif
#include <utility>
#ifdef _MSC_VER
#include <intrin.h>
#endif
namespace atcoder {
namespace internal {
constexpr long long safe_mod(long long x, long long m) {
x %= m;
if (x < 0) x += m;
return x;
}
struct barrett {
unsigned int _m;
unsigned long long im;
barrett(unsigned int m) : _m(m), im((unsigned long long)(-1) / m + 1) {}
unsigned int umod() const { return _m; }
unsigned int mul(unsigned int a, unsigned int b) const {
unsigned long long z = a;
z *= b;
#ifdef _MSC_VER
unsigned long long x;
_umul128(z, im, &x);
#else
unsigned long long x =
(unsigned long long)(((unsigned __int128)(z)*im) >> 64);
#endif
unsigned int v = (unsigned int)(z - x * _m);
if (_m <= v) v += _m;
return v;
}
};
constexpr long long pow_mod_constexpr(long long x, long long n, int m) {
if (m == 1) return 0;
unsigned int _m = (unsigned int)(m);
unsigned long long r = 1;
unsigned long long y = safe_mod(x, m);
while (n) {
if (n & 1) r = (r * y) % _m;
y = (y * y) % _m;
n >>= 1;
}
return r;
}
constexpr bool is_prime_constexpr(int n) {
if (n <= 1) return false;
if (n == 2 || n == 7 || n == 61) return true;
if (n % 2 == 0) return false;
long long d = n - 1;
while (d % 2 == 0) d /= 2;
constexpr long long bases[3] = {2, 7, 61};
for (long long a : bases) {
long long t = d;
long long y = pow_mod_constexpr(a, t, n);
while (t != n - 1 && y != 1 && y != n - 1) {
y = y * y % n;
t <<= 1;
}
if (y != n - 1 && t % 2 == 0) {
return false;
}
}
return true;
}
template <int n> constexpr bool is_prime = is_prime_constexpr(n);
constexpr std::pair<long long, long long> inv_gcd(long long a, long long b) {
a = safe_mod(a, b);
if (a == 0) return {b, 0};
long long s = b, t = a;
long long m0 = 0, m1 = 1;
while (t) {
long long u = s / t;
s -= t * u;
m0 -= m1 * u; // |m1 * u| <= |m1| * s <= b
auto tmp = s;
s = t;
t = tmp;
tmp = m0;
m0 = m1;
m1 = tmp;
}
if (m0 < 0) m0 += b / s;
return {s, m0};
}
constexpr int primitive_root_constexpr(int m) {
if (m == 2) return 1;
if (m == 167772161) return 3;
if (m == 469762049) return 3;
if (m == 754974721) return 11;
if (m == 998244353) return 3;
int divs[20] = {};
divs[0] = 2;
int cnt = 1;
int x = (m - 1) / 2;
while (x % 2 == 0) x /= 2;
for (int i = 3; (long long)(i)*i <= x; i += 2) {
if (x % i == 0) {
divs[cnt++] = i;
while (x % i == 0) {
x /= i;
}
}
}
if (x > 1) {
divs[cnt++] = x;
}
for (int g = 2;; g++) {
bool ok = true;
for (int i = 0; i < cnt; i++) {
if (pow_mod_constexpr(g, (m - 1) / divs[i], m) == 1) {
ok = false;
break;
}
}
if (ok) return g;
}
}
template <int m> constexpr int primitive_root = primitive_root_constexpr(m);
} // namespace internal
} // namespace atcoder
#include <cassert>
#include <numeric>
#include <type_traits>
namespace atcoder {
namespace internal {
#ifndef _MSC_VER
template <class T>
using is_signed_int128 =
typename std::conditional<std::is_same<T, __int128_t>::value ||
std::is_same<T, __int128>::value,
std::true_type,
std::false_type>::type;
template <class T>
using is_unsigned_int128 =
typename std::conditional<std::is_same<T, __uint128_t>::value ||
std::is_same<T, unsigned __int128>::value,
std::true_type,
std::false_type>::type;
template <class T>
using make_unsigned_int128 =
typename std::conditional<std::is_same<T, __int128_t>::value,
__uint128_t,
unsigned __int128>;
template <class T>
using is_integral = typename std::conditional<std::is_integral<T>::value ||
is_signed_int128<T>::value ||
is_unsigned_int128<T>::value,
std::true_type,
std::false_type>::type;
template <class T>
using is_signed_int = typename std::conditional<(is_integral<T>::value &&
std::is_signed<T>::value) ||
is_signed_int128<T>::value,
std::true_type,
std::false_type>::type;
template <class T>
using is_unsigned_int =
typename std::conditional<(is_integral<T>::value &&
std::is_unsigned<T>::value) ||
is_unsigned_int128<T>::value,
std::true_type,
std::false_type>::type;
template <class T>
using to_unsigned = typename std::conditional<
is_signed_int128<T>::value,
make_unsigned_int128<T>,
typename std::conditional<std::is_signed<T>::value,
std::make_unsigned<T>,
std::common_type<T>>::type>::type;
#else
template <class T> using is_integral = typename std::is_integral<T>;
template <class T>
using is_signed_int =
typename std::conditional<is_integral<T>::value && std::is_signed<T>::value,
std::true_type,
std::false_type>::type;
template <class T>
using is_unsigned_int =
typename std::conditional<is_integral<T>::value &&
std::is_unsigned<T>::value,
std::true_type,
std::false_type>::type;
template <class T>
using to_unsigned = typename std::conditional<is_signed_int<T>::value,
std::make_unsigned<T>,
std::common_type<T>>::type;
#endif
template <class T>
using is_signed_int_t = std::enable_if_t<is_signed_int<T>::value>;
template <class T>
using is_unsigned_int_t = std::enable_if_t<is_unsigned_int<T>::value>;
template <class T> using to_unsigned_t = typename to_unsigned<T>::type;
} // namespace internal
} // namespace atcoder
namespace atcoder {
namespace internal {
struct modint_base {};
struct static_modint_base : modint_base {};
template <class T> using is_modint = std::is_base_of<modint_base, T>;
template <class T> using is_modint_t = std::enable_if_t<is_modint<T>::value>;
} // namespace internal
template <int m, std::enable_if_t<(1 <= m)>* = nullptr>
struct static_modint : internal::static_modint_base {
using mint = static_modint;
public:
static constexpr int mod() { return m; }
static mint raw(int v) {
mint x;
x._v = v;
return x;
}
static_modint() : _v(0) {}
template <class T, internal::is_signed_int_t<T>* = nullptr>
static_modint(T v) {
long long x = (long long)(v % (long long)(umod()));
if (x < 0) x += umod();
_v = (unsigned int)(x);
}
template <class T, internal::is_unsigned_int_t<T>* = nullptr>
static_modint(T v) {
_v = (unsigned int)(v % umod());
}
unsigned int val() const { return _v; }
mint& operator++() {
_v++;
if (_v == umod()) _v = 0;
return *this;
}
mint& operator--() {
if (_v == 0) _v = umod();
_v--;
return *this;
}
mint operator++(int) {
mint result = *this;
++*this;
return result;
}
mint operator--(int) {
mint result = *this;
--*this;
return result;
}
mint& operator+=(const mint& rhs) {
_v += rhs._v;
if (_v >= umod()) _v -= umod();
return *this;
}
mint& operator-=(const mint& rhs) {
_v -= rhs._v;
if (_v >= umod()) _v += umod();
return *this;
}
mint& operator*=(const mint& rhs) {
unsigned long long z = _v;
z *= rhs._v;
_v = (unsigned int)(z % umod());
return *this;
}
mint& operator/=(const mint& rhs) { return *this = *this * rhs.inv(); }
mint operator+() const { return *this; }
mint operator-() const { return mint() - *this; }
mint pow(long long n) const {
assert(0 <= n);
mint x = *this, r = 1;
while (n) {
if (n & 1) r *= x;
x *= x;
n >>= 1;
}
return r;
}
mint inv() const {
if (prime) {
assert(_v);
return pow(umod() - 2);
} else {
auto eg = internal::inv_gcd(_v, m);
assert(eg.first == 1);
return eg.second;
}
}
friend mint operator+(const mint& lhs, const mint& rhs) {
return mint(lhs) += rhs;
}
friend mint operator-(const mint& lhs, const mint& rhs) {
return mint(lhs) -= rhs;
}
friend mint operator*(const mint& lhs, const mint& rhs) {
return mint(lhs) *= rhs;
}
friend mint operator/(const mint& lhs, const mint& rhs) {
return mint(lhs) /= rhs;
}
friend bool operator==(const mint& lhs, const mint& rhs) {
return lhs._v == rhs._v;
}
friend bool operator!=(const mint& lhs, const mint& rhs) {
return lhs._v != rhs._v;
}
private:
unsigned int _v;
static constexpr unsigned int umod() { return m; }
static constexpr bool prime = internal::is_prime<m>;
};
template <int id> struct dynamic_modint : internal::modint_base {
using mint = dynamic_modint;
public:
static int mod() { return (int)(bt.umod()); }
static void set_mod(int m) {
assert(1 <= m);
bt = internal::barrett(m);
}
static mint raw(int v) {
mint x;
x._v = v;
return x;
}
dynamic_modint() : _v(0) {}
template <class T, internal::is_signed_int_t<T>* = nullptr>
dynamic_modint(T v) {
long long x = (long long)(v % (long long)(mod()));
if (x < 0) x += mod();
_v = (unsigned int)(x);
}
template <class T, internal::is_unsigned_int_t<T>* = nullptr>
dynamic_modint(T v) {
_v = (unsigned int)(v % mod());
}
unsigned int val() const { return _v; }
mint& operator++() {
_v++;
if (_v == umod()) _v = 0;
return *this;
}
mint& operator--() {
if (_v == 0) _v = umod();
_v--;
return *this;
}
mint operator++(int) {
mint result = *this;
++*this;
return result;
}
mint operator--(int) {
mint result = *this;
--*this;
return result;
}
mint& operator+=(const mint& rhs) {
_v += rhs._v;
if (_v >= umod()) _v -= umod();
return *this;
}
mint& operator-=(const mint& rhs) {
_v += mod() - rhs._v;
if (_v >= umod()) _v -= umod();
return *this;
}
mint& operator*=(const mint& rhs) {
_v = bt.mul(_v, rhs._v);
return *this;
}
mint& operator/=(const mint& rhs) { return *this = *this * rhs.inv(); }
mint operator+() const { return *this; }
mint operator-() const { return mint() - *this; }
mint pow(long long n) const {
assert(0 <= n);
mint x = *this, r = 1;
while (n) {
if (n & 1) r *= x;
x *= x;
n >>= 1;
}
return r;
}
mint inv() const {
auto eg = internal::inv_gcd(_v, mod());
assert(eg.first == 1);
return eg.second;
}
friend mint operator+(const mint& lhs, const mint& rhs) {
return mint(lhs) += rhs;
}
friend mint operator-(const mint& lhs, const mint& rhs) {
return mint(lhs) -= rhs;
}
friend mint operator*(const mint& lhs, const mint& rhs) {
return mint(lhs) *= rhs;
}
friend mint operator/(const mint& lhs, const mint& rhs) {
return mint(lhs) /= rhs;
}
friend bool operator==(const mint& lhs, const mint& rhs) {
return lhs._v == rhs._v;
}
friend bool operator!=(const mint& lhs, const mint& rhs) {
return lhs._v != rhs._v;
}
private:
unsigned int _v;
static internal::barrett bt;
static unsigned int umod() { return bt.umod(); }
};
template <int id> internal::barrett dynamic_modint<id>::bt = 998244353;
using modint998244353 = static_modint<998244353>;
using modint1000000007 = static_modint<1000000007>;
using modint = dynamic_modint<-1>;
namespace internal {
template <class T>
using is_static_modint = std::is_base_of<internal::static_modint_base, T>;
template <class T>
using is_static_modint_t = std::enable_if_t<is_static_modint<T>::value>;
template <class> struct is_dynamic_modint : public std::false_type {};
template <int id>
struct is_dynamic_modint<dynamic_modint<id>> : public std::true_type {};
template <class T>
using is_dynamic_modint_t = std::enable_if_t<is_dynamic_modint<T>::value>;
} // namespace internal
} // namespace atcoder
namespace atcoder {
namespace internal {
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
void butterfly(std::vector<mint>& a) {
static constexpr int g = internal::primitive_root<mint::mod()>;
int n = int(a.size());
int h = internal::ceil_pow2(n);
static bool first = true;
static mint sum_e[30]; // sum_e[i] = ies[0] * ... * ies[i - 1] * es[i]
if (first) {
first = false;
mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1
int cnt2 = bsf(mint::mod() - 1);
mint e = mint(g).pow((mint::mod() - 1) >> cnt2), ie = e.inv();
for (int i = cnt2; i >= 2; i--) {
es[i - 2] = e;
ies[i - 2] = ie;
e *= e;
ie *= ie;
}
mint now = 1;
for (int i = 0; i <= cnt2 - 2; i++) {
sum_e[i] = es[i] * now;
now *= ies[i];
}
}
for (int ph = 1; ph <= h; ph++) {
int w = 1 << (ph - 1), p = 1 << (h - ph);
mint now = 1;
for (int s = 0; s < w; s++) {
int offset = s << (h - ph + 1);
for (int i = 0; i < p; i++) {
auto l = a[i + offset];
auto r = a[i + offset + p] * now;
a[i + offset] = l + r;
a[i + offset + p] = l - r;
}
now *= sum_e[bsf(~(unsigned int)(s))];
}
}
}
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
void butterfly_inv(std::vector<mint>& a) {
static constexpr int g = internal::primitive_root<mint::mod()>;
int n = int(a.size());
int h = internal::ceil_pow2(n);
static bool first = true;
static mint sum_ie[30]; // sum_ie[i] = es[0] * ... * es[i - 1] * ies[i]
if (first) {
first = false;
mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1
int cnt2 = bsf(mint::mod() - 1);
mint e = mint(g).pow((mint::mod() - 1) >> cnt2), ie = e.inv();
for (int i = cnt2; i >= 2; i--) {
es[i - 2] = e;
ies[i - 2] = ie;
e *= e;
ie *= ie;
}
mint now = 1;
for (int i = 0; i <= cnt2 - 2; i++) {
sum_ie[i] = ies[i] * now;
now *= es[i];
}
}
for (int ph = h; ph >= 1; ph--) {
int w = 1 << (ph - 1), p = 1 << (h - ph);
mint inow = 1;
for (int s = 0; s < w; s++) {
int offset = s << (h - ph + 1);
for (int i = 0; i < p; i++) {
auto l = a[i + offset];
auto r = a[i + offset + p];
a[i + offset] = l + r;
a[i + offset + p] =
(unsigned long long)(mint::mod() + l.val() - r.val()) *
inow.val();
}
inow *= sum_ie[bsf(~(unsigned int)(s))];
}
}
}
} // namespace internal
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
std::vector<mint> convolution(std::vector<mint> a, std::vector<mint> b) {
int n = int(a.size()), m = int(b.size());
if (!n || !m) return {};
if (std::min(n, m) <= 60) {
if (n < m) {
std::swap(n, m);
std::swap(a, b);
}
std::vector<mint> ans(n + m - 1);
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
ans[i + j] += a[i] * b[j];
}
}
return ans;
}
int z = 1 << internal::ceil_pow2(n + m - 1);
a.resize(z);
internal::butterfly(a);
b.resize(z);
internal::butterfly(b);
for (int i = 0; i < z; i++) {
a[i] *= b[i];
}
internal::butterfly_inv(a);
a.resize(n + m - 1);
mint iz = mint(z).inv();
for (int i = 0; i < n + m - 1; i++) a[i] *= iz;
return a;
}
template <unsigned int mod = 998244353,
class T,
std::enable_if_t<internal::is_integral<T>::value>* = nullptr>
std::vector<T> convolution(const std::vector<T>& a, const std::vector<T>& b) {
int n = int(a.size()), m = int(b.size());
if (!n || !m) return {};
using mint = static_modint<mod>;
std::vector<mint> a2(n), b2(m);
for (int i = 0; i < n; i++) {
a2[i] = mint(a[i]);
}
for (int i = 0; i < m; i++) {
b2[i] = mint(b[i]);
}
auto c2 = convolution(move(a2), move(b2));
std::vector<T> c(n + m - 1);
for (int i = 0; i < n + m - 1; i++) {
c[i] = c2[i].val();
}
return c;
}
std::vector<long long> convolution_ll(const std::vector<long long>& a,
const std::vector<long long>& b) {
int n = int(a.size()), m = int(b.size());
if (!n || !m) return {};
static constexpr unsigned long long MOD1 = 754974721; // 2^24
static constexpr unsigned long long MOD2 = 167772161; // 2^25
static constexpr unsigned long long MOD3 = 469762049; // 2^26
static constexpr unsigned long long M2M3 = MOD2 * MOD3;
static constexpr unsigned long long M1M3 = MOD1 * MOD3;
static constexpr unsigned long long M1M2 = MOD1 * MOD2;
static constexpr unsigned long long M1M2M3 = MOD1 * MOD2 * MOD3;
static constexpr unsigned long long i1 =
internal::inv_gcd(MOD2 * MOD3, MOD1).second;
static constexpr unsigned long long i2 =
internal::inv_gcd(MOD1 * MOD3, MOD2).second;
static constexpr unsigned long long i3 =
internal::inv_gcd(MOD1 * MOD2, MOD3).second;
auto c1 = convolution<MOD1>(a, b);
auto c2 = convolution<MOD2>(a, b);
auto c3 = convolution<MOD3>(a, b);
std::vector<long long> c(n + m - 1);
for (int i = 0; i < n + m - 1; i++) {
unsigned long long x = 0;
x += (c1[i] * i1) % MOD1 * M2M3;
x += (c2[i] * i2) % MOD2 * M1M3;
x += (c3[i] * i3) % MOD3 * M1M2;
long long diff =
c1[i] - internal::safe_mod((long long)(x), (long long)(MOD1));
if (diff < 0) diff += MOD1;
static constexpr unsigned long long offset[5] = {
0, 0, M1M2M3, 2 * M1M2M3, 3 * M1M2M3};
x -= offset[diff % 5];
c[i] = x;
}
return c;
}
} // namespace atcoder
using namespace atcoder;
using mint = modint998244353;
std::ostream& operator << (std::ostream& out, const mint& rhs) {
return out<<rhs.val();
}
// https://atcoder.jp/contests/arc113/submissions/20423265
namespace algebra {
template <typename T>
vector<T>& operator+=(vector<T>& a, const vector<T>& b) {
if (a.size() < b.size()) {
a.resize(b.size());
}
for (int i = 0; i < (int) b.size(); i++) {
a[i] += b[i];
}
return a;
}
template <typename T>
vector<T> operator+(const vector<T>& a, const vector<T>& b) {
vector<T> c = a;
return c += b;
}
template <typename T>
vector<T>& operator-=(vector<T>& a, const vector<T>& b) {
if (a.size() < b.size()) {
a.resize(b.size());
}
for (int i = 0; i < (int) b.size(); i++) {
a[i] -= b[i];
}
return a;
}
template <typename T>
vector<T> operator-(const vector<T>& a, const vector<T>& b) {
vector<T> c = a;
return c -= b;
}
template <typename T>
vector<T> operator-(const vector<T>& a) {
vector<T> c = a;
for (int i = 0; i < (int) c.size(); i++) {
c[i] = -c[i];
}
return c;
}
template <typename T>
vector<T> operator*(const vector<T>& a, const vector<T>& b) {
if (a.empty() || b.empty()) {
return {};
}
return convolution<T>(a,b);
// vector<T> c(a.size() + b.size() - 1, 0);
// for (int i = 0; i < (int) a.size(); i++) {
// for (int j = 0; j < (int) b.size(); j++) {
// c[i + j] += a[i] * b[j];
// }
// }
// return c;
}
template <typename T>
vector<T>& operator*=(vector<T>& a, const vector<T>& b) {
return a = a * b;
}
template <typename T>
vector<T> inverse(const vector<T>& a) {
assert(!a.empty());
int n = (int) a.size();
vector<T> b = {1 / a[0]};
while ((int) b.size() < n) {
vector<T> a_cut(a.begin(), a.begin() + min(a.size(), b.size() << 1));
vector<T> x = b * b * a_cut;
b.resize(b.size() << 1);
for (int i = (int) b.size() >> 1; i < (int) min(x.size(), b.size()); i++) {
b[i] = -x[i];
}
}
b.resize(n);
return b;
}
template <typename T>
vector<T>& operator/=(vector<T>& a, const vector<T>& b) {
int n = (int) a.size();
int m = (int) b.size();
if (n < m) {
a.clear();
} else {
vector<T> d = b;
reverse(a.begin(), a.end());
reverse(d.begin(), d.end());
d.resize(n - m + 1);
a *= inverse(d);
a.erase(a.begin() + n - m + 1, a.end());
reverse(a.begin(), a.end());
}
return a;
}
template <typename T>
vector<T> operator/(const vector<T>& a, const vector<T>& b) {
vector<T> c = a;
return c /= b;
}
template <typename T>
vector<T> operator*(const vector<T>& a, T b) {
vector<T> c = a;
for(auto &x:c)
x*=b;
return c;
}
template <typename T>
vector<T>& operator%=(vector<T>& a, const vector<T>& b) {
int n = (int) a.size();
int m = (int) b.size();
if (n >= m) {
vector<T> c = (a / b) * b;
a.resize(m - 1);
for (int i = 0; i < m - 1; i++) {
a[i] -= c[i];
}
}
return a;
}
template <typename T>
vector<T> operator%(const vector<T>& a, const vector<T>& b) {
vector<T> c = a;
return c %= b;
}
template <typename T, typename U>
vector<T> power(const vector<T>& a, const U& b, const vector<T>& c) {
assert(b >= 0);
vector<U> binary;
U bb = b;
while (bb > 0) {
binary.push_back(bb & 1);
bb >>= 1;
}
vector<T> res = vector<T>{1} % c;
for (int j = (int) binary.size() - 1; j >= 0; j--) {
res = res * res % c;
if (binary[j] == 1) {
res = res * a % c;
}
}
return res;
}
template <typename T>
vector<T> derivative(const vector<T>& a) {
vector<T> c = a;
for (int i = 0; i < (int) c.size(); i++) {
c[i] *= i;
}
if (!c.empty()) {
c.erase(c.begin());
}
return c;
}
template <typename T>
vector<T> integrate(const vector<T>& a) {
vector<T> c = {0};
for (int i = 0; i < (int) a.size(); i++) {
c.push_back(a[i]/(i+1));
}
return c;
}
template <typename T>
vector<T> primitive(const vector<T>& a) {
vector<T> c = a;
c.insert(c.begin(), 0);
for (int i = 1; i < (int) c.size(); i++) {
c[i] /= i;
}
return c;
}
template <typename T>
vector<T> logarithm(const vector<T>& a) {
assert(!a.empty() && a[0] == 1);
vector<T> u = primitive(derivative(a) * inverse(a));
u.resize(a.size());
return u;
}
template <typename T>
vector<T> exponent(const vector<T>& a) {
assert(!a.empty() && a[0] == 0);
int n = (int) a.size();
vector<T> b = {1};
while ((int) b.size() < n) {
vector<T> x(a.begin(), a.begin() + min(a.size(), b.size() << 1));
x[0] += 1;
vector<T> old_b = b;
b.resize(b.size() << 1);
x -= logarithm(b);
x *= old_b;
for (int i = (int) b.size() >> 1; i < (int) min(x.size(), b.size()); i++) {
b[i] = x[i];
}
}
b.resize(n);
return b;
}
template <typename T>
vector<T> sqrt(const vector<T>& a) {
assert(!a.empty() && a[0] == 1);
int n = (int) a.size();
vector<T> b = {1};
while ((int) b.size() < n) {
vector<T> x(a.begin(), a.begin() + min(a.size(), b.size() << 1));
b.resize(b.size() << 1);
x *= inverse(b);
T inv2 = 1 / static_cast<T>(2);
for (int i = (int) b.size() >> 1; i < (int) min(x.size(), b.size()); i++) {
b[i] = x[i] * inv2;
}
}
b.resize(n);
return b;
}
template <typename T>
vector<T> multiply(const vector<vector<T>>& a) {
if (a.empty()) {
return {0};
}
function<vector<T>(int, int)> mult = [&](int l, int r) {
if (l == r) {
return a[l];
}
int y = (l + r) >> 1;
return mult(l, y) * mult(y + 1, r);
};
return mult(0, (int) a.size() - 1);
}
template <typename T>
T evaluate(const vector<T>& a, const T& x) {
T res = 0;
for (int i = (int) a.size() - 1; i >= 0; i--) {
res = res * x + a[i];
}
return res;
}
template <typename T>
vector<T> evaluate(const vector<T>& a, const vector<T>& x) {
if (x.empty()) {
return {};
}
if (a.empty()) {
return vector<T>(x.size(), 0);
}
int n = (int) x.size();
vector<vector<T>> st((n << 1) - 1);
function<void(int, int, int)> build = [&](int v, int l, int r) {
if (l == r) {
st[v] = vector<T>{-x[l], 1};
} else {
int y = (l + r) >> 1;
int z = v + ((y - l + 1) << 1);
build(v + 1, l, y);
build(z, y + 1, r);
st[v] = st[v + 1] * st[z];
}
};
build(0, 0, n - 1);
vector<T> res(n);
function<void(int, int, int, vector<T>)> eval = [&](int v, int l, int r, vector<T> f) {
f %= st[v];
if ((int) f.size() < 150) {
for (int i = l; i <= r; i++) {
res[i] = evaluate(f, x[i]);
}
return;
}
if (l == r) {
res[l] = f[0];
} else {
int y = (l + r) >> 1;
int z = v + ((y - l + 1) << 1);
eval(v + 1, l, y, f);
eval(z, y + 1, r, f);
}
};
eval(0, 0, n - 1, a);
return res;
}
template <typename T>
vector<T> interpolate(const vector<T>& x, const vector<T>& y) {
if (x.empty()) {
return {};
}
assert(x.size() == y.size());
int n = (int) x.size();
vector<vector<T>> st((n << 1) - 1);
function<void(int, int, int)> build = [&](int v, int l, int r) {
if (l == r) {
st[v] = vector<T>{-x[l], 1};
} else {
int w = (l + r) >> 1;
int z = v + ((w - l + 1) << 1);
build(v + 1, l, w);
build(z, w + 1, r);
st[v] = st[v + 1] * st[z];
}
};
build(0, 0, n - 1);
vector<T> m = st[0];
vector<T> dm = derivative(m);
vector<T> val(n);
function<void(int, int, int, vector<T>)> eval = [&](int v, int l, int r, vector<T> f) {
f %= st[v];
if ((int) f.size() < 150) {
for (int i = l; i <= r; i++) {
val[i] = evaluate(f, x[i]);
}
return;
}
if (l == r) {
val[l] = f[0];
} else {
int w = (l + r) >> 1;
int z = v + ((w - l + 1) << 1);
eval(v + 1, l, w, f);
eval(z, w + 1, r, f);
}
};
eval(0, 0, n - 1, dm);
for (int i = 0; i < n; i++) {
val[i] = y[i] / val[i];
}
function<vector<T>(int, int, int)> calc = [&](int v, int l, int r) {
if (l == r) {
return vector<T>{val[l]};
}
int w = (l + r) >> 1;
int z = v + ((w - l + 1) << 1);
return calc(v + 1, l, w) * st[z] + calc(z, w + 1, r) * st[v + 1];
};
return calc(0, 0, n - 1);
}
// f[i] = 1^i + 2^i + ... + up^i
template <typename T>
vector<T> faulhaber(const T& up, int n) {
vector<T> ex(n + 1);
T e = 1;
for (int i = 0; i <= n; i++) {
ex[i] = e;
e /= i + 1;
}
vector<T> den = ex;
den.erase(den.begin());
for (auto& d : den) {
d = -d;
}
vector<T> num(n);
T p = 1;
for (int i = 0; i < n; i++) {
p *= up + 1;
num[i] = ex[i + 1] * (1 - p);
}
vector<T> res = num * inverse(den);
res.resize(n);
T f = 1;
for (int i = 0; i < n; i++) {
res[i] *= f;
f *= i + 1;
}
return res;
}
// (x + 1) * (x + 2) * ... * (x + n)
// (can be optimized with precomputed inverses)
template <typename T>
vector<T> sequence(int n) {
if (n == 0) {
return {1};
}
if (n % 2 == 1) {
return sequence<T>(n - 1) * vector<T>{n, 1};
}
vector<T> c = sequence<T>(n / 2);
vector<T> a = c;
reverse(a.begin(), a.end());
T f = 1;
for (int i = n / 2 - 1; i >= 0; i--) {
f *= n / 2 - i;
a[i] *= f;
}
vector<T> b(n / 2 + 1);
b[0] = 1;
for (int i = 1; i <= n / 2; i++) {
b[i] = b[i - 1] * (n / 2) / i;
}
vector<T> h = a * b;
h.resize(n / 2 + 1);
reverse(h.begin(), h.end());
f = 1;
for (int i = 1; i <= n / 2; i++) {
f /= i;
h[i] *= f;
}
vector<T> res = c * h;
return res;
}
template <typename T>
class OnlineProduct {
public:
const vector<T> a;
vector<T> b;
vector<T> c;
OnlineProduct(const vector<T>& a_) : a(a_) {}
T add(const T& val) {
int i = (int) b.size();
b.push_back(val);
if ((int) c.size() <= i) {
c.resize(i + 1);
}
c[i] += a[0] * b[i];
int z = 1;
while ((i & (z - 1)) == z - 1 && (int) a.size() > z) {
vector<T> a_mul(a.begin() + z, a.begin() + min(z << 1, (int) a.size()));
vector<T> b_mul(b.end() - z, b.end());
vector<T> c_mul = a_mul * b_mul;
if ((int) c.size() <= i + (int) c_mul.size()) {
c.resize(i + c_mul.size() + 1);
}
for (int j = 0; j < (int) c_mul.size(); j++) {
c[i + 1 + j] += c_mul[j];
}
z <<= 1;
}
return c[i];
}
};
};
using namespace algebra;
typedef vector<mint> polyn;
// Ref - https://www.codechef.com/viewsolution/41909444 Line 827 - 850.
const int N=2e6+5;
array<mint,N+1> fac,inv;
mint nCr(lli n,lli r)
{
if(n<0||r<0||r>n)
return 0;
return fac[n]*inv[r]*inv[n-r];
}
void pre(lli n)
{
fac[0]=1;
for(int i=1;i<=n;++i)
fac[i]=i*fac[i-1];
inv[n]=fac[n].pow(mint(-2).val());
for(int i=n;i>0;--i)
inv[i-1]=i*inv[i];
assert(inv[0]==mint(1));
}
polyn nCr(lli r){
polyn x(r+1),y(r+1);
y[r]=1;
for(int i=0;i<=r;++i)
x[i]=i;
return interpolate(x,y);
}
int main(void) {
ios_base::sync_with_stdio(false);cin.tie(NULL);
// freopen("txt.in", "r", stdin);
// freopen("txt.out", "w", stdout);
// cout<<std::fixed<<std::setprecision(35);
// T=readIntLn(1,1);
// while(T--)
{
pre(N);
lli r,k;
cin>>r>>k;
vi a(k),p(k);
for(auto &x:a) cin>>x;
for(auto &x:p) cin>>x;
//const lli r=readIntLn(1,100000);
//const lli k=readIntLn(1,200000);
//auto a=readVectorInt(k,1,1e9);
//auto p=readVectorInt(k,1,998244352);
vector<mint> xcord;
auto polnCr=sequence<mint>(r);
for(auto &x:polnCr)
x*=inv[r];
for(auto x:a)
xcord.pb(x-r);
auto ycord=evaluate(polnCr,xcord);
mint ans=1;
for(int i=0;i<k;++i){
ans*=p[i]*ycord[i]+1-p[i];
}
cout<<ans<<endl;
} aryanc403();
readEOF();
return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class MAT{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
int R = ni(), K = ni();
long[] A = new long[K];
long[] P = new long[K];
for(int i = 0; i< K; i++)A[i] = ni()-R;
for(int i = 0; i< K; i++)P[i] = nl();
PriorityQueue<long[]> pq = new PriorityQueue<>((long[] l1, long[] l2) -> Long.compare(l1.length, l2.length));
for(int i = 1; i<= R; i++)pq.add(new long[]{i, 1});
while(pq.size() > 1)pq.add(mul(pq.poll(), pq.poll()));
long[] poly = pq.poll();
long[] eval = substitute(poly, A);
long ans = 1;
long factR = 1;
for(int i = 1; i<= R; i++)factR = factR*i%mod;
long invR = pow(factR, mod-2);
for(int i = 0; i< K; i++){
long w = (P[i]*eval[i]%mod*invR%mod + 1-P[i]+mod)%mod;
ans = ans*w%mod;
}
pn(ans);
}
long pow(long a, long p){
long o = 1;
while(p > 0){
if(p%2 == 1)o = o*a%mod;
a = a*a%mod;
p>>=1;
}
return o;
}
//uwi
public static int mod = 998244353;
public static int G = 3;
public static long[] substitute(long[] p, long[] xs)
{
return descendProductTree(p, buildProductTree(xs));
}
public static long[][] buildProductTree(long[] xs)
{
int m = Integer.highestOneBit(xs.length)*4;
long[][] ms = new long[m][];
for(int i = 0;i < xs.length;i++){
ms[m/2+i] = new long[]{mod-xs[i], 1};
}
for(int i = m/2-1;i >= 1;i--){
if(ms[2*i] == null){
ms[i] = null;
}else if(ms[2*i+1] == null){
ms[i] = ms[2*i];
}else{
ms[i] = mul(ms[2*i], ms[2*i+1]);
}
}
return ms;
}
public static long[] mul(long[] a, long[] b)
{
return Arrays.copyOf(convoluteSimply(a, b, mod, G), a.length+b.length-1);
}
public static long[] mul(long[] a, long[] b, int lim)
{
return Arrays.copyOf(convoluteSimply(a, b, mod, G), lim);
}
public static long[] descendProductTree(long[] p, long[][] pt)
{
long[] rets = new long[pt[1].length-1];
dfs(p, pt, 1, rets);
return rets;
}
private static void dfs(long[] p, long[][] pt, int cur, long[] rets)
{
if(pt[cur] == null)return;
if(cur >= pt.length/2){
rets[cur-pt.length/2] = p[0];
}else{
// F = q1X+r1
// F = q2Y+r2
if(p.length >= 800){
if(pt[2*cur+1] != null){
long[][] qr0 = div(p, pt[2*cur]);
dfs(qr0[1], pt, cur*2, rets);
long[][] qr1 = div(p, pt[2*cur+1]);
dfs(qr1[1], pt, cur*2+1, rets);
}else if(pt[2*cur] != null){
long[] nex = cur == 1 ? div(p, pt[2*cur])[1] : p;
dfs(nex, pt, cur*2, rets);
}
}else{
if(pt[2*cur+1] != null){
dfs(modnaive(p, pt[2*cur]), pt, cur*2, rets);
dfs(modnaive(p, pt[2*cur+1]), pt, cur*2+1, rets);
}else if(pt[2*cur] != null){
long[] nex = cur == 1 ? modnaive(p, pt[2*cur]) : p;
dfs(nex, pt, cur*2, rets);
}
}
}
}
public static long[][] div(long[] f, long[] g)
{
int n = f.length, m = g.length;
if(n < m)return new long[][]{new long[0], Arrays.copyOf(f, n)};
long[] rf = reverse(f, n-m+1);
long[] rg = reverse(g, n-m+1);
long[] rq = mul(rf, inv(rg), n-m+1);
long[] q = reverse(rq, n-m+1);
long[] r = sub(f, mul(q, g, m-1), m-1);
return new long[][]{q, r};
}
public static long[] reverse(long[] p)
{
long[] ret = new long[p.length];
for(int i = 0;i < p.length;i++){
ret[i] = p[p.length-1-i];
}
return ret;
}
public static long[] reverse(long[] p, int lim)
{
long[] ret = new long[lim];
for(int i = 0;i < lim && i < p.length;i++){
ret[i] = p[p.length-1-i];
}
return ret;
}
public static long[] sub(long[] a, long[] b)
{
long[] c = new long[Math.max(a.length, b.length)];
for(int i = 0;i < a.length;i++)c[i] += a[i];
for(int i = 0;i < b.length;i++)c[i] -= b[i];
for(int i = 0;i < c.length;i++)if(c[i] < 0)c[i] += mod;
return c;
}
public static long[] sub(long[] a, long[] b, int lim)
{
long[] c = new long[lim];
for(int i = 0;i < a.length && i < lim;i++)c[i] += a[i];
for(int i = 0;i < b.length && i < lim;i++)c[i] -= b[i];
for(int i = 0;i < c.length;i++)if(c[i] < 0)c[i] += mod;
return c;
}
// F_{t+1}(x) = -F_t(x)^2*P(x) + 2F_t(x)
// if want p-destructive, comment out flipping p just before returning.
public static long[] inv(long[] p)
{
int n = p.length;
long[] f = {invl(p[0], mod)};
for(int i = 0;i < p.length;i++){
if(p[i] == 0)continue;
p[i] = mod-p[i];
}
for(int i = 1;i < 2*n;i*=2){
long[] f2 = mul(f, f, Math.min(n, 2*i));
long[] f2p = mul(f2, Arrays.copyOf(p, i), Math.min(n, 2*i));
for(int j = 0;j < f.length;j++){
f2p[j] += 2L*f[j];
if(f2p[j] >= mod)f2p[j] -= mod;
if(f2p[j] >= mod)f2p[j] -= mod;
}
f = f2p;
}
for(int i = 0;i < p.length;i++){
if(p[i] == 0)continue;
p[i] = mod-p[i];
}
return f;
}
public static long[] modnaive(long[] a, long[] b)
{
int n = a.length, m = b.length;
if(n-m+1 <= 0)return a;
long[] r = Arrays.copyOf(a, n);
long ib = invl(b[m-1], mod);
for(int i = n-1;i >= m-1;i--){
long x = ib * r[i] % mod;
for(int j = m-1;j >= 0;j--){
r[i+j-(m-1)] -= b[j]*x;
r[i+j-(m-1)] %= mod;
if(r[i+j-(m-1)] < 0)r[i+j-(m-1)] += mod;
// r[i+j-(m-1)] = modh(r[i+j-(m-1)]+(long)mod*mod - b[j]*x, MM, HH, mod);
}
}
return Arrays.copyOf(r, m-1);
}
public static final int[] NTTPrimes = {1053818881, 1051721729, 1045430273, 1012924417, 1007681537, 1004535809, 998244353, 985661441, 976224257, 975175681};
public static final int[] NTTPrimitiveRoots = {7, 6, 3, 5, 3, 3, 3, 3, 3, 17};
// public static final int[] NTTPrimes = {1012924417, 1004535809, 998244353, 985661441, 975175681, 962592769, 950009857, 943718401, 935329793, 924844033};
// public static final int[] NTTPrimitiveRoots = {5, 3, 3, 3, 17, 7, 7, 7, 3, 5};
public static long[] convoluteSimply(long[] a, long[] b, int P, int g)
{
int m = Math.max(2, Integer.highestOneBit(Math.max(a.length, b.length)-1)<<2);
long[] fa = nttmb(a, m, false, P, g);
long[] fb = a == b ? fa : nttmb(b, m, false, P, g);
for(int i = 0;i < m;i++){
fa[i] = fa[i]*fb[i]%P;
}
return nttmb(fa, m, true, P, g);
}
public static long[] convolute(long[] a, long[] b)
{
int USE = 2;
int m = Math.max(2, Integer.highestOneBit(Math.max(a.length, b.length)-1)<<2);
long[][] fs = new long[USE][];
for(int k = 0;k < USE;k++){
int P = NTTPrimes[k], g = NTTPrimitiveRoots[k];
long[] fa = nttmb(a, m, false, P, g);
long[] fb = a == b ? fa : nttmb(b, m, false, P, g);
for(int i = 0;i < m;i++){
fa[i] = fa[i]*fb[i]%P;
}
fs[k] = nttmb(fa, m, true, P, g);
}
int[] mods = Arrays.copyOf(NTTPrimes, USE);
long[] gammas = garnerPrepare(mods);
int[] buf = new int[USE];
for(int i = 0;i < fs[0].length;i++){
for(int j = 0;j < USE;j++)buf[j] = (int)fs[j][i];
long[] res = garnerBatch(buf, mods, gammas);
long ret = 0;
for(int j = res.length-1;j >= 0;j--)ret = ret * mods[j] + res[j];
fs[0][i] = ret;
}
return fs[0];
}
public static long[] convolute(long[] a, long[] b, int USE, int mod)
{
int m = Math.max(2, Integer.highestOneBit(Math.max(a.length, b.length)-1)<<2);
long[][] fs = new long[USE][];
for(int k = 0;k < USE;k++){
int P = NTTPrimes[k], g = NTTPrimitiveRoots[k];
long[] fa = nttmb(a, m, false, P, g);
long[] fb = a == b ? fa : nttmb(b, m, false, P, g);
for(int i = 0;i < m;i++){
fa[i] = fa[i]*fb[i]%P;
}
fs[k] = nttmb(fa, m, true, P, g);
}
int[] mods = Arrays.copyOf(NTTPrimes, USE);
long[] gammas = garnerPrepare(mods);
int[] buf = new int[USE];
for(int i = 0;i < fs[0].length;i++){
for(int j = 0;j < USE;j++)buf[j] = (int)fs[j][i];
long[] res = garnerBatch(buf, mods, gammas);
long ret = 0;
for(int j = res.length-1;j >= 0;j--)ret = (ret * mods[j] + res[j]) % mod;
fs[0][i] = ret;
}
return fs[0];
}
// static int[] wws = new int[270000]; // outer faster
// Modifed Montgomery + Barrett
private static long[] nttmb(long[] src, int n, boolean inverse, int P, int g)
{
long[] dst = Arrays.copyOf(src, n);
int h = Integer.numberOfTrailingZeros(n);
long K = Integer.highestOneBit(P)<<1;
int H = Long.numberOfTrailingZeros(K)*2;
long M = K*K/P;
int[] wws = new int[1<<h-1];
long dw = inverse ? pow(g, P-1-(P-1)/n, P) : pow(g, (P-1)/n, P);
long w = (1L<<32)%P;
for(int k = 0;k < 1<<h-1;k++){
wws[k] = (int)w;
w = modh(w*dw, M, H, P);
}
long J = invl(P, 1L<<32);
for(int i = 0;i < h;i++){
for(int j = 0;j < 1<<i;j++){
for(int k = 0, s = j<<h-i, t = s|1<<h-i-1;k < 1<<h-i-1;k++,s++,t++){
long u = (dst[s] - dst[t] + 2*P)*wws[k];
dst[s] += dst[t];
if(dst[s] >= 2*P)dst[s] -= 2*P;
// long Q = (u&(1L<<32)-1)*J&(1L<<32)-1;
long Q = (u<<32)*J>>>32;
dst[t] = (u>>>32)-(Q*P>>>32)+P;
}
}
if(i < h-1){
for(int k = 0;k < 1<<h-i-2;k++)wws[k] = wws[k*2];
}
}
for(int i = 0;i < n;i++){
if(dst[i] >= P)dst[i] -= P;
}
for(int i = 0;i < n;i++){
int rev = Integer.reverse(i)>>>-h;
if(i < rev){
long d = dst[i]; dst[i] = dst[rev]; dst[rev] = d;
}
}
if(inverse){
long in = invl(n, P);
for(int i = 0;i < n;i++)dst[i] = modh(dst[i]*in, M, H, P);
}
return dst;
}
// Modified Shoup + Barrett
private static long[] nttsb(long[] src, int n, boolean inverse, int P, int g)
{
long[] dst = Arrays.copyOf(src, n);
int h = Integer.numberOfTrailingZeros(n);
long K = Integer.highestOneBit(P)<<1;
int H = Long.numberOfTrailingZeros(K)*2;
long M = K*K/P;
long dw = inverse ? pow(g, P-1-(P-1)/n, P) : pow(g, (P-1)/n, P);
long[] wws = new long[1<<h-1];
long[] ws = new long[1<<h-1];
long w = 1;
for(int k = 0;k < 1<<h-1;k++){
wws[k] = (w<<32)/P;
ws[k] = w;
w = modh(w*dw, M, H, P);
}
for(int i = 0;i < h;i++){
for(int j = 0;j < 1<<i;j++){
for(int k = 0, s = j<<h-i, t = s|1<<h-i-1;k < 1<<h-i-1;k++,s++,t++){
long ndsts = dst[s] + dst[t];
if(ndsts >= 2*P)ndsts -= 2*P;
long T = dst[s] - dst[t] + 2*P;
long Q = wws[k]*T>>>32;
dst[s] = ndsts;
dst[t] = ws[k]*T-Q*P&(1L<<32)-1;
}
}
// dw = dw * dw % P;
if(i < h-1){
for(int k = 0;k < 1<<h-i-2;k++){
wws[k] = wws[k*2];
ws[k] = ws[k*2];
}
}
}
for(int i = 0;i < n;i++){
if(dst[i] >= P)dst[i] -= P;
}
for(int i = 0;i < n;i++){
int rev = Integer.reverse(i)>>>-h;
if(i < rev){
long d = dst[i]; dst[i] = dst[rev]; dst[rev] = d;
}
}
if(inverse){
long in = invl(n, P);
for(int i = 0;i < n;i++){
dst[i] = modh(dst[i] * in, M, H, P);
}
}
return dst;
}
static final long mask = (1L<<31)-1;
public static long modh(long a, long M, int h, int mod)
{
long r = a-((M*(a&mask)>>>31)+M*(a>>>31)>>>h-31)*mod;
return r < mod ? r : r-mod;
}
private static long[] garnerPrepare(int[] m)
{
int n = m.length;
assert n == m.length;
if(n == 0)return new long[0];
long[] gamma = new long[n];
for(int k = 1;k < n;k++){
long prod = 1;
for(int i = 0;i < k;i++){
prod = prod * m[i] % m[k];
}
gamma[k] = invl(prod, m[k]);
}
return gamma;
}
private static long[] garnerBatch(int[] u, int[] m, long[] gamma)
{
int n = u.length;
assert n == m.length;
long[] v = new long[n];
v[0] = u[0];
for(int k = 1;k < n;k++){
long temp = v[k-1];
for(int j = k-2;j >= 0;j--){
temp = (temp * m[j] + v[j]) % m[k];
}
v[k] = (u[k] - temp) * gamma[k] % m[k];
if(v[k] < 0)v[k] += m[k];
}
return v;
}
private static long pow(long a, long n, long mod) {
// a %= mod;
long ret = 1;
int x = 63 - Long.numberOfLeadingZeros(n);
for (; x >= 0; x--) {
ret = ret * ret % mod;
if (n << 63 - x < 0)
ret = ret * a % mod;
}
return ret;
}
private static long invl(long a, long mod) {
long b = mod;
long p = 1, q = 0;
while (b > 0) {
long c = a / b;
long d;
d = a;
a = b;
b = d % b;
d = p;
p = q;
q = d - c * q;
}
return p < 0 ? p + mod : p;
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = false;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new MAT().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Feel free to share your approach. Suggestions are welcomed as always.