PROBLEM LINK:
Contest Division 1
Contest Division 2
Contest Division 3
Practice
Setter: Srikkanth R and Daanish Mahajan
Tester: Aryan
Editorialist: Taranpreet Singh
DIFFICULTY
Medium-Hard
PREREQUISITES
Berlekamp Massey Algorithm, Linear Recurrences, Dynamic Programming
PROBLEM
Let f(n) be the number of ways to partition the array [1, 2, 3, \ldots, n] into contiguous subarrays such that every pair of adjacent subarrays in the partition have sums of different parity.
Let S_0(n) = f(n) and S_{k+1}(n) = S_k(1) + S_k(2) + S_k(3) + \cdots + S_k(n) for k \geq 0.
Given n and k, find S_k(n) \bmod 998\ 244\ 353.
QUICK EXPLANATION
We’d compute the first M terms of the function f to compute the first M terms of S_k, computing characteristic polynomial of recurrence relation, and then using characteristic polynomial to compute S_k(n).
EXPLANATION
As mentioned in the quick explanation, we’d first compute M terms of f, then compute first M terms of S_k and then apply Berlekamp-Massey Algorithm to compute its N-th term.
Computing f(i)
Let’s figure out a way to compute f(i) for all 1 \leq i \leq M where M is the order of recurrence relation.
In order to compute f(i), we can write f(i) = f_0(i)+f_1(i), where f_b(x) denotes the number of partitions of first x numbers into contiguous arrays such that every pair of adjacent subarrays in the partition have sums of different parity and last subarray has parity b.
We can write \displaystyle f_b(i) = \sum_{j \in S_{b,i}} f_{1-b}(j-1), where set I_{b,i} contains all j such that 1 \leq j \leq i and \displaystyle\sum_{k = j}^i k has parity same as b.
Above recurrence allows us to compute f(i) in O(M^2) time by iterating over all j for each i.
We can notice that if I_{b, i} contains x, then I_{b,i} contains x-4 as well if x-4 \gt 0
We can do better. Let’s define function S_b(i) = S_b(i-4) + f_b(i) and S_b(i) = 0 for i < 0.
Then, we can write \displaystyle f_b(i) = \sum_{j = i-3, j\in I_{b,i}}^i S_{1-b}(j-1).
Hence, this way, we have to try only 4 values of j for each i, thus computing recurrence in O(M)
Moreover, this proves that the recurrence relation for f can be written as the sum of up to 8 terms. This fact shall be useful later.
Additionally, we can work out its recurrence relation as f(i) = f(i-2) + 2*f(i-4)+3*f(i-6)-f(i-8), with first 8 terms of f(i) being 1,2,2,3,6,10,12,21 for 1 \leq i \leq 8.
We can write \displaystyle P(x) = x + 2*x^3 + 3*x^5 - x^7. as the polynomial representing recurrence relation
Let’s define A_i = f(i-1), so \displaystyle A(x) = \sum_{n \geq 0} A_i * x^i
Hence, the relation \displaystyle A(x) = \sum_{n = 0}^7 A_i * x^i + x*A(x)*P(x)
Computing S_k(i)
This part is easier. Just compute S_k(i) = S_K(i-1) + S_{k-1}(i), allowing to compute first M terms of S_k in O(K*M) time.
We can see that recurrence relation for S_k(x) shall be represented by polynomial (1-x)^k * P(x), a polynomial with degree k+7
Computing the characteristic Polynomial
There are two ways to do this.
Method 1
One is to manually figure out the characteristic polynomial specific to this problem by some maths which setter did and we did above.
We found the recurrence relation for P_k(x) as (1-x)^k * P(x) Let’s denote Q(x) = (1-x)^k * P(x)
Then characteristic polynomial of recurrence relation is C(x) = x^M*(1 - Q(1/x)) where M is the order of recurrence relation, M = K+8
Method 2
Apply Berlekamp-Massey Algorithm, which using first 2*M terms, can derive the shortest linear recurrence, which gives us characteristic polynomial.
This way, we can claim that S_K can be written in form of a recurrence relation consisting of K+8 terms.
Let us apply the Berlekamp-Massey algorithm in order to compute the shortest linear recurrence for S_k. Since we know that there shall be a recurrence relation of up to K+8 terms, we only need to feed Berlekamp-Massey Algorithm with 2*(K+8) terms.
The details on Berlekamp Massey Algorithm can be found here.
Computing S_k(N)
Even the method to compute N-th term is explained in the above-mentioned blog, which I’ll explain in brief.
Now we have the characteristic polynomial of recurrence relation C(x) is a polynomial of degree K+8, initial K+7 values of S_k(i), and we want to compute A_{N-1} (since we adjusted indexing from 1-based to 0-based)
We want to compute A_{N-1}, which shall be the coefficient of x^{N-1}. We can claim that A_{N-1} can be written as a linear combination of A_0, A_1 \ldots A_{M-1}.
Specifically, we can obtain constants r_0, r_1 \ldots g_{M-1} such that \displaystyle A_{N-1} = \sum_{i = 0}^{M-1} A_i * r_i.
Let’s denote \displaystyle G(B(x)) = \sum_{n} B_i * A_i where A_i is the coefficient of x^i in A(x).
We aim to compute G(x^{N-1}). We can also prove that G(C(x)) = 0 and G(U(x) \pm V(x)) = G((U(x)) \pm G(V(x)). This allows us to write G(x^{N-1}) = G(x^{N-1} - D(x)*C(x)). By choosing D(x) being the polynomial on division of x^{M-1} by C(x), we get G(x^{N-1}) = G( x^{N-1} \bmod C(x))
R(x) = x^{N-1} mod C(x) can be computed by binary exponentiation, performing remainder at each step such that the degree doesn’t exceed 2*M.
Polynomial division with remainder can be performed as explained here, thus solving the problem.
Bonus
Cayley Hamilton’s theorem can also be used to solve this problem if someone tried it. Here’s an excellent tutorial on the theorem.
The bonus is to solve the above problem with the Cayley Hamilton theorem.
TIME COMPLEXITY
Computing first M terms of S_k can be done in O(K*M). Berlekamp Massey takes O(M^2*log(M)) or O(M*log(M)*log(N)), though we need to use FFT here in order to pass the last subtask.
SOLUTIONS
Setter's Solution
// codechef RNG (Random Number Generator)
// BOJ 13725
#include <bits/stdc++.h>
#define x first
#define y second
#define all(v) v.begin(), v.end()
#define compress(v) sort(all(v)), v.erase(unique(all(v)), v.end())
#define IDX(v, x) (lower_bound(all(v), x) - v.begin())
using namespace std;
using uint = unsigned;
using ll = long long;
using ull = unsigned long long;
template<int M>
struct MINT{
int v;
MINT() : v(0) {}
MINT(ll val){
v = (-M <= val && val < M) ? val : val % M;
if(v < 0) v += M;
}
friend istream& operator >> (istream &is, MINT &a) { ll t; is >> t; a = MINT(t); return is; }
friend ostream& operator << (ostream &os, const MINT &a) { return os << a.v; }
friend bool operator == (const MINT &a, const MINT &b) { return a.v == b.v; }
friend bool operator != (const MINT &a, const MINT &b) { return a.v != b.v; }
friend MINT pw(MINT a, ll b){
MINT ret= 1;
while(b){
if(b & 1) ret *= a;
b >>= 1; a *= a;
}
return ret;
}
friend MINT inv(const MINT a) { return pw(a, M-2); }
MINT operator - () const { return MINT(-v); }
MINT& operator += (const MINT m) { if((v += m.v) >= M) v -= M; return *this; }
MINT& operator -= (const MINT m) { if((v -= m.v) < 0) v += M; return *this; }
MINT& operator *= (const MINT m) { v = (ll)v*m.v%M; return *this; }
MINT& operator /= (const MINT m) { *this *= inv(m); return *this; }
friend MINT operator + (MINT a, MINT b) { a += b; return a; }
friend MINT operator - (MINT a, MINT b) { a -= b; return a; }
friend MINT operator * (MINT a, MINT b) { a *= b; return a; }
friend MINT operator / (MINT a, MINT b) { a /= b; return a; }
operator int32_t() const { return v; }
operator int64_t() const { return v; }
};
namespace fft{
template<int W, int M>
static void NTT(vector<MINT<M>> &f, bool inv_fft = false){
using T = MINT<M>;
int N = f.size();
vector<T> root(N >> 1);
for(int i=1, j=0; i<N; i++){
int bit = N >> 1;
while(j >= bit) j -= bit, bit >>= 1;
j += bit;
if(i < j) swap(f[i], f[j]);
}
T ang = pw(T(W), (M-1)/N); if(inv_fft) ang = inv(ang);
root[0] = 1; for(int i=1; i<N>>1; i++) root[i] = root[i-1] * ang;
for(int i=2; i<=N; i<<=1){
int step = N / i;
for(int j=0; j<N; j+=i){
for(int k=0; k<i/2; k++){
T u = f[j+k], v = f[j+k+(i>>1)] * root[k*step];
f[j+k] = u + v;
f[j+k+(i>>1)] = u - v;
}
}
}
if(inv_fft){
T rev = inv(T(N));
for(int i=0; i<N; i++) f[i] *= rev;
}
}
template<int W, int M>
vector<MINT<M>> multiply_ntt(vector<MINT<M>> a, vector<MINT<M>> b){
int N = 2; while(N < (int)a.size() + (int)b.size()) N <<= 1;
a.resize(N); b.resize(N);
NTT<W, M>(a); NTT<W, M>(b);
for(int i=0; i<N; i++) a[i] *= b[i];
NTT<W, M>(a, true);
return a;
}
}
template<int W, int M>
struct PolyMod{
using T = MINT<M>;
vector<T> a;
// constructor
PolyMod(){}
PolyMod(T a0) : a(1, a0) { normalize(); }
PolyMod(const vector<T> a) : a(a) { normalize(); }
// method from vector<T>
int size() const { return a.size(); }
int deg() const { return a.size() - 1; }
void normalize(){ while(a.size() && a.back() == T(0)) a.pop_back(); }
T operator [] (int idx) const { return a[idx]; }
typename vector<T>::const_iterator begin() const { return a.begin(); }
typename vector<T>::const_iterator end() const { return a.end(); }
void push_back(const T val) { a.push_back(val); }
void pop_back() { a.pop_back(); }
friend ostream& operator << (ostream &os, const PolyMod &a) { for (auto x : a.a) os << x << " "; return os;}
// basic manipulation
PolyMod reversed() const {
vector<T> b = a;
reverse(b.begin(), b.end());
return b;
}
PolyMod trim(int n) const {
return vector<T>(a.begin(), a.begin() + min(n, size()));
}
PolyMod inv(int n){
PolyMod q(T(1) / a[0]);
for(int i=1; i<n; i<<=1){
PolyMod p = PolyMod(2) - q * trim(i * 2);
q = (p * q).trim(i * 2);
}
return q.trim(n);
}
// operation with scala value
PolyMod operator *= (const T x){
for(auto &i : a) i *= x;
normalize();
return *this;
}
PolyMod operator /= (const T x){
return *this *= (T(1) / T(x));
}
// operation with poly
PolyMod operator += (const PolyMod &b){
a.resize(max(size(), b.size()));
for(int i=0; i<b.size(); i++) a[i] += b.a[i];
normalize();
return *this;
}
PolyMod operator -= (const PolyMod &b){
a.resize(max(size(), b.size()));
for(int i=0; i<b.size(); i++) a[i] -= b.a[i];
normalize();
return *this;
}
PolyMod operator *= (const PolyMod &b){
*this = fft::multiply_ntt<W, M>(a, b.a);
normalize();
return *this;
}
PolyMod operator /= (const PolyMod &b){
if(deg() < b.deg()) return *this = PolyMod();
int sz = deg() - b.deg() + 1;
PolyMod ra = reversed().trim(sz), rb = b.reversed().trim(sz).inv(sz);
*this = (ra * rb).trim(sz);
for(int i=sz-size(); i; i--) push_back(T(0));
reverse(all(a));
normalize();
return *this;
}
PolyMod operator %= (const PolyMod &b){
if(deg() < b.deg()) return *this;
PolyMod tmp = *this; tmp /= b; tmp *= b;
*this -= tmp;
normalize();
return *this;
}
// operator
PolyMod operator * (const T x) const { return PolyMod(*this) *= x; }
PolyMod operator / (const T x) const { return PolyMod(*this) /= x; }
PolyMod operator + (const PolyMod &b) const { return PolyMod(*this) += b; }
PolyMod operator - (const PolyMod &b) const { return PolyMod(*this) -= b; }
PolyMod operator * (const PolyMod &b) const { return PolyMod(*this) *= b; }
PolyMod operator / (const PolyMod &b) const { return PolyMod(*this) /= b; }
PolyMod operator % (const PolyMod &b) const { return PolyMod(*this) %= b; }
};
constexpr int W = 3, MOD = 998244353;
using mint = MINT<MOD>;
using poly = PolyMod<W, MOD>;
mint kitamasa(vector<mint> c, vector<mint> a, ll n){
poly d = vector<mint>{1};
poly xn = vector<mint>{0, 1};
poly f(c);
// cout << "f : " << f << '\n';
while(n){
if(n & 1) d = d * xn % f;
n >>= 1; xn = xn * xn % f;
}
// cout << "d : " << d << '\n';
// cout << "a : " << poly(a) << '\n';
mint ret = 0;
for(int i=0; i<a.size()&&i<=d.deg(); i++) ret += a[i] * d[i];
return ret;
}
const int K_MAX = 50005;
mint fac[K_MAX], invfac[K_MAX];
void pre() {
fac[0] = 1;
for (int i=1;i<K_MAX;++i) fac[i] = fac[i-1] * mint(i);
invfac[K_MAX-1] = mint(1) / fac[K_MAX-1];
for (int i=K_MAX-2;i>=0;--i) invfac[i] = invfac[i+1] * mint(i+1);
}
mint ncr(int n, int r) {
if (r < 0 || r > n) return mint(0);
return fac[n] / (fac[r] * fac[n-r]);
}
mint solve(ll N, int K) {
// cout << N << " " << K << " solving\n";
poly P(vector<mint>({0, 1, 2, 1, 1, 2, 3, -1, -1}));
poly Q(vector<mint>{1, 0, -1, 0, -2, 0, -3, 0, 1});
// cin >> N >> K;
vector<mint> R(K+1);
mint kcr = 1;
for (int r=0;r<=K;++r) {
R[r] = ((r & 1) ? -kcr : kcr);
kcr = kcr * mint(K - r) / mint(r + 1);
}
Q *= poly(R);
poly Qinv = Q.inv(K + 8 + 1);
poly ret = P * Qinv;
// cout << "first k coefficients : ";
// for (auto x : ret.a) cout << x << " "; cout << '\n';
// cout << "Qinv : ";
// for (auto x : Qinv.a) cout << x << " "; cout << '\n';
// cout << "Q : " << Q << "\n";
vector<mint> v_dp(K + 8), v_rec(K + 9);
// cout << "sequence : ";
for(int i=0; i<K+8; i++){
v_dp[i] = ret.a[i+1];
// cout << v_dp[i] << " ";
}
// cout << '\n';
// cout << "recurrence : ";
for(int i=0; i<K+8; i++){
v_rec[K+8-i-1] = Q[i+1];
// cout << v_rec[i] << " ";
}
v_rec[K+8] = mint(1);
// cout << '\n';
return kitamasa(v_rec, v_dp, N-1);
}
int main() {
clock_t start = clock();
ios_base::sync_with_stdio(false); cin.tie(nullptr);
int T;
cin >> T;
while (T--) {
ll n; int k;
cin >> n >> k;
cout << solve(n, k) << '\n';
}
cerr << fixed << setprecision(10);
cerr << "Time taken = " << (clock() - start) / ((double)CLOCKS_PER_SEC) << " s\n";
return 0;
}
Tester's Solution
/* in the name of Anton */
/*
Compete against Yourself.
Author - Aryan (@aryanc403)
Atcoder library - https://atcoder.github.io/ac-library/production/document_en/
*/
#ifdef ARYANC403
#include <header.h>
#else
#pragma GCC optimize ("Ofast")
#pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx")
//#pragma GCC optimize ("-ffloat-store")
#include<bits/stdc++.h>
#define dbg(args...) 42;
#endif
using namespace std;
#define fo(i,n) for(i=0;i<(n);++i)
#define repA(i,j,n) for(i=(j);i<=(n);++i)
#define repD(i,j,n) for(i=(j);i>=(n);--i)
#define all(x) begin(x), end(x)
#define sz(x) ((lli)(x).size())
#define pb push_back
#define mp make_pair
#define X first
#define Y second
#define endl "\n"
typedef long long int lli;
typedef long double mytype;
typedef pair<lli,lli> ii;
typedef vector<ii> vii;
typedef vector<lli> vi;
const auto start_time = std::chrono::high_resolution_clock::now();
void aryanc403()
{
#ifdef ARYANC403
auto end_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end_time-start_time;
cerr<<"Time Taken : "<<diff.count()<<"\n";
#endif
}
#include <cassert>
#include <numeric>
#include <type_traits>
#ifdef _MSC_VER
#include <intrin.h>
#endif
#include <utility>
#ifdef _MSC_VER
#include <intrin.h>
#endif
namespace atcoder {
namespace internal {
constexpr long long safe_mod(long long x, long long m) {
x %= m;
if (x < 0) x += m;
return x;
}
struct barrett {
unsigned int _m;
unsigned long long im;
explicit barrett(unsigned int m) : _m(m), im((unsigned long long)(-1) / m + 1) {}
unsigned int umod() const { return _m; }
unsigned int mul(unsigned int a, unsigned int b) const {
unsigned long long z = a;
z *= b;
#ifdef _MSC_VER
unsigned long long x;
_umul128(z, im, &x);
#else
unsigned long long x =
(unsigned long long)(((unsigned __int128)(z)*im) >> 64);
#endif
unsigned int v = (unsigned int)(z - x * _m);
if (_m <= v) v += _m;
return v;
}
};
constexpr long long pow_mod_constexpr(long long x, long long n, int m) {
if (m == 1) return 0;
unsigned int _m = (unsigned int)(m);
unsigned long long r = 1;
unsigned long long y = safe_mod(x, m);
while (n) {
if (n & 1) r = (r * y) % _m;
y = (y * y) % _m;
n >>= 1;
}
return r;
}
constexpr bool is_prime_constexpr(int n) {
if (n <= 1) return false;
if (n == 2 || n == 7 || n == 61) return true;
if (n % 2 == 0) return false;
long long d = n - 1;
while (d % 2 == 0) d /= 2;
constexpr long long bases[3] = {2, 7, 61};
for (long long a : bases) {
long long t = d;
long long y = pow_mod_constexpr(a, t, n);
while (t != n - 1 && y != 1 && y != n - 1) {
y = y * y % n;
t <<= 1;
}
if (y != n - 1 && t % 2 == 0) {
return false;
}
}
return true;
}
template <int n> constexpr bool is_prime = is_prime_constexpr(n);
constexpr std::pair<long long, long long> inv_gcd(long long a, long long b) {
a = safe_mod(a, b);
if (a == 0) return {b, 0};
long long s = b, t = a;
long long m0 = 0, m1 = 1;
while (t) {
long long u = s / t;
s -= t * u;
m0 -= m1 * u; // |m1 * u| <= |m1| * s <= b
auto tmp = s;
s = t;
t = tmp;
tmp = m0;
m0 = m1;
m1 = tmp;
}
if (m0 < 0) m0 += b / s;
return {s, m0};
}
constexpr int primitive_root_constexpr(int m) {
if (m == 2) return 1;
if (m == 167772161) return 3;
if (m == 469762049) return 3;
if (m == 754974721) return 11;
if (m == 998244353) return 3;
int divs[20] = {};
divs[0] = 2;
int cnt = 1;
int x = (m - 1) / 2;
while (x % 2 == 0) x /= 2;
for (int i = 3; (long long)(i)*i <= x; i += 2) {
if (x % i == 0) {
divs[cnt++] = i;
while (x % i == 0) {
x /= i;
}
}
}
if (x > 1) {
divs[cnt++] = x;
}
for (int g = 2;; g++) {
bool ok = true;
for (int i = 0; i < cnt; i++) {
if (pow_mod_constexpr(g, (m - 1) / divs[i], m) == 1) {
ok = false;
break;
}
}
if (ok) return g;
}
}
template <int m> constexpr int primitive_root = primitive_root_constexpr(m);
unsigned long long floor_sum_unsigned(unsigned long long n,
unsigned long long m,
unsigned long long a,
unsigned long long b) {
unsigned long long ans = 0;
while (true) {
if (a >= m) {
ans += n * (n - 1) / 2 * (a / m);
a %= m;
}
if (b >= m) {
ans += n * (b / m);
b %= m;
}
unsigned long long y_max = a * n + b;
if (y_max < m) break;
n = (unsigned long long)(y_max / m);
b = (unsigned long long)(y_max % m);
std::swap(m, a);
}
return ans;
}
} // namespace internal
} // namespace atcoder
#include <cassert>
#include <numeric>
#include <type_traits>
namespace atcoder {
namespace internal {
#ifndef _MSC_VER
template <class T>
using is_signed_int128 =
typename std::conditional<std::is_same<T, __int128_t>::value ||
std::is_same<T, __int128>::value,
std::true_type,
std::false_type>::type;
template <class T>
using is_unsigned_int128 =
typename std::conditional<std::is_same<T, __uint128_t>::value ||
std::is_same<T, unsigned __int128>::value,
std::true_type,
std::false_type>::type;
template <class T>
using make_unsigned_int128 =
typename std::conditional<std::is_same<T, __int128_t>::value,
__uint128_t,
unsigned __int128>;
template <class T>
using is_integral = typename std::conditional<std::is_integral<T>::value ||
is_signed_int128<T>::value ||
is_unsigned_int128<T>::value,
std::true_type,
std::false_type>::type;
template <class T>
using is_signed_int = typename std::conditional<(is_integral<T>::value &&
std::is_signed<T>::value) ||
is_signed_int128<T>::value,
std::true_type,
std::false_type>::type;
template <class T>
using is_unsigned_int =
typename std::conditional<(is_integral<T>::value &&
std::is_unsigned<T>::value) ||
is_unsigned_int128<T>::value,
std::true_type,
std::false_type>::type;
template <class T>
using to_unsigned = typename std::conditional<
is_signed_int128<T>::value,
make_unsigned_int128<T>,
typename std::conditional<std::is_signed<T>::value,
std::make_unsigned<T>,
std::common_type<T>>::type>::type;
#else
template <class T> using is_integral = typename std::is_integral<T>;
template <class T>
using is_signed_int =
typename std::conditional<is_integral<T>::value && std::is_signed<T>::value,
std::true_type,
std::false_type>::type;
template <class T>
using is_unsigned_int =
typename std::conditional<is_integral<T>::value &&
std::is_unsigned<T>::value,
std::true_type,
std::false_type>::type;
template <class T>
using to_unsigned = typename std::conditional<is_signed_int<T>::value,
std::make_unsigned<T>,
std::common_type<T>>::type;
#endif
template <class T>
using is_signed_int_t = std::enable_if_t<is_signed_int<T>::value>;
template <class T>
using is_unsigned_int_t = std::enable_if_t<is_unsigned_int<T>::value>;
template <class T> using to_unsigned_t = typename to_unsigned<T>::type;
} // namespace internal
} // namespace atcoder
namespace atcoder {
namespace internal {
struct modint_base {};
struct static_modint_base : modint_base {};
template <class T> using is_modint = std::is_base_of<modint_base, T>;
template <class T> using is_modint_t = std::enable_if_t<is_modint<T>::value>;
} // namespace internal
template <int m, std::enable_if_t<(1 <= m)>* = nullptr>
struct static_modint : internal::static_modint_base {
using mint = static_modint;
public:
static constexpr int mod() { return m; }
static mint raw(int v) {
mint x;
x._v = v;
return x;
}
static_modint() : _v(0) {}
template <class T, internal::is_signed_int_t<T>* = nullptr>
static_modint(T v) {
long long x = (long long)(v % (long long)(umod()));
if (x < 0) x += umod();
_v = (unsigned int)(x);
}
template <class T, internal::is_unsigned_int_t<T>* = nullptr>
static_modint(T v) {
_v = (unsigned int)(v % umod());
}
unsigned int val() const { return _v; }
mint& operator++() {
_v++;
if (_v == umod()) _v = 0;
return *this;
}
mint& operator--() {
if (_v == 0) _v = umod();
_v--;
return *this;
}
mint operator++(int) {
mint result = *this;
++*this;
return result;
}
mint operator--(int) {
mint result = *this;
--*this;
return result;
}
mint& operator+=(const mint& rhs) {
_v += rhs._v;
if (_v >= umod()) _v -= umod();
return *this;
}
mint& operator-=(const mint& rhs) {
_v -= rhs._v;
if (_v >= umod()) _v += umod();
return *this;
}
mint& operator*=(const mint& rhs) {
unsigned long long z = _v;
z *= rhs._v;
_v = (unsigned int)(z % umod());
return *this;
}
mint& operator/=(const mint& rhs) { return *this = *this * rhs.inv(); }
mint operator+() const { return *this; }
mint operator-() const { return mint() - *this; }
mint pow(long long n) const {
assert(0 <= n);
mint x = *this, r = 1;
while (n) {
if (n & 1) r *= x;
x *= x;
n >>= 1;
}
return r;
}
mint inv() const {
if (prime) {
assert(_v);
return pow(umod() - 2);
} else {
auto eg = internal::inv_gcd(_v, m);
assert(eg.first == 1);
return eg.second;
}
}
friend mint operator+(const mint& lhs, const mint& rhs) {
return mint(lhs) += rhs;
}
friend mint operator-(const mint& lhs, const mint& rhs) {
return mint(lhs) -= rhs;
}
friend mint operator*(const mint& lhs, const mint& rhs) {
return mint(lhs) *= rhs;
}
friend mint operator/(const mint& lhs, const mint& rhs) {
return mint(lhs) /= rhs;
}
friend bool operator==(const mint& lhs, const mint& rhs) {
return lhs._v == rhs._v;
}
friend bool operator!=(const mint& lhs, const mint& rhs) {
return lhs._v != rhs._v;
}
private:
unsigned int _v;
static constexpr unsigned int umod() { return m; }
static constexpr bool prime = internal::is_prime<m>;
};
template <int id> struct dynamic_modint : internal::modint_base {
using mint = dynamic_modint;
public:
static int mod() { return (int)(bt.umod()); }
static void set_mod(int m) {
assert(1 <= m);
bt = internal::barrett(m);
}
static mint raw(int v) {
mint x;
x._v = v;
return x;
}
dynamic_modint() : _v(0) {}
template <class T, internal::is_signed_int_t<T>* = nullptr>
dynamic_modint(T v) {
long long x = (long long)(v % (long long)(mod()));
if (x < 0) x += mod();
_v = (unsigned int)(x);
}
template <class T, internal::is_unsigned_int_t<T>* = nullptr>
dynamic_modint(T v) {
_v = (unsigned int)(v % mod());
}
unsigned int val() const { return _v; }
mint& operator++() {
_v++;
if (_v == umod()) _v = 0;
return *this;
}
mint& operator--() {
if (_v == 0) _v = umod();
_v--;
return *this;
}
mint operator++(int) {
mint result = *this;
++*this;
return result;
}
mint operator--(int) {
mint result = *this;
--*this;
return result;
}
mint& operator+=(const mint& rhs) {
_v += rhs._v;
if (_v >= umod()) _v -= umod();
return *this;
}
mint& operator-=(const mint& rhs) {
_v += mod() - rhs._v;
if (_v >= umod()) _v -= umod();
return *this;
}
mint& operator*=(const mint& rhs) {
_v = bt.mul(_v, rhs._v);
return *this;
}
mint& operator/=(const mint& rhs) { return *this = *this * rhs.inv(); }
mint operator+() const { return *this; }
mint operator-() const { return mint() - *this; }
mint pow(long long n) const {
assert(0 <= n);
mint x = *this, r = 1;
while (n) {
if (n & 1) r *= x;
x *= x;
n >>= 1;
}
return r;
}
mint inv() const {
auto eg = internal::inv_gcd(_v, mod());
assert(eg.first == 1);
return eg.second;
}
friend mint operator+(const mint& lhs, const mint& rhs) {
return mint(lhs) += rhs;
}
friend mint operator-(const mint& lhs, const mint& rhs) {
return mint(lhs) -= rhs;
}
friend mint operator*(const mint& lhs, const mint& rhs) {
return mint(lhs) *= rhs;
}
friend mint operator/(const mint& lhs, const mint& rhs) {
return mint(lhs) /= rhs;
}
friend bool operator==(const mint& lhs, const mint& rhs) {
return lhs._v == rhs._v;
}
friend bool operator!=(const mint& lhs, const mint& rhs) {
return lhs._v != rhs._v;
}
private:
unsigned int _v;
static internal::barrett bt;
static unsigned int umod() { return bt.umod(); }
};
template <int id> internal::barrett dynamic_modint<id>::bt(998244353);
using modint998244353 = static_modint<998244353>;
using modint1000000007 = static_modint<1000000007>;
using modint = dynamic_modint<-1>;
namespace internal {
template <class T>
using is_static_modint = std::is_base_of<internal::static_modint_base, T>;
template <class T>
using is_static_modint_t = std::enable_if_t<is_static_modint<T>::value>;
template <class> struct is_dynamic_modint : public std::false_type {};
template <int id>
struct is_dynamic_modint<dynamic_modint<id>> : public std::true_type {};
template <class T>
using is_dynamic_modint_t = std::enable_if_t<is_dynamic_modint<T>::value>;
} // namespace internal
} // namespace atcoder
#include <algorithm>
#include <array>
#include <cassert>
#include <type_traits>
#include <vector>
#ifdef _MSC_VER
#include <intrin.h>
#endif
namespace atcoder {
namespace internal {
int ceil_pow2(int n) {
int x = 0;
while ((1U << x) < (unsigned int)(n)) x++;
return x;
}
int bsf(unsigned int n) {
#ifdef _MSC_VER
unsigned long index;
_BitScanForward(&index, n);
return index;
#else
return __builtin_ctz(n);
#endif
}
} // namespace internal
} // namespace atcoder
namespace atcoder {
namespace internal {
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
void butterfly(std::vector<mint>& a) {
static constexpr int g = internal::primitive_root<mint::mod()>;
int n = int(a.size());
int h = internal::ceil_pow2(n);
static bool first = true;
static mint sum_e[30]; // sum_e[i] = ies[0] * ... * ies[i - 1] * es[i]
if (first) {
first = false;
mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1
int cnt2 = bsf(mint::mod() - 1);
mint e = mint(g).pow((mint::mod() - 1) >> cnt2), ie = e.inv();
for (int i = cnt2; i >= 2; i--) {
es[i - 2] = e;
ies[i - 2] = ie;
e *= e;
ie *= ie;
}
mint now = 1;
for (int i = 0; i <= cnt2 - 2; i++) {
sum_e[i] = es[i] * now;
now *= ies[i];
}
}
for (int ph = 1; ph <= h; ph++) {
int w = 1 << (ph - 1), p = 1 << (h - ph);
mint now = 1;
for (int s = 0; s < w; s++) {
int offset = s << (h - ph + 1);
for (int i = 0; i < p; i++) {
auto l = a[i + offset];
auto r = a[i + offset + p] * now;
a[i + offset] = l + r;
a[i + offset + p] = l - r;
}
now *= sum_e[bsf(~(unsigned int)(s))];
}
}
}
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
void butterfly_inv(std::vector<mint>& a) {
static constexpr int g = internal::primitive_root<mint::mod()>;
int n = int(a.size());
int h = internal::ceil_pow2(n);
static bool first = true;
static mint sum_ie[30]; // sum_ie[i] = es[0] * ... * es[i - 1] * ies[i]
if (first) {
first = false;
mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1
int cnt2 = bsf(mint::mod() - 1);
mint e = mint(g).pow((mint::mod() - 1) >> cnt2), ie = e.inv();
for (int i = cnt2; i >= 2; i--) {
es[i - 2] = e;
ies[i - 2] = ie;
e *= e;
ie *= ie;
}
mint now = 1;
for (int i = 0; i <= cnt2 - 2; i++) {
sum_ie[i] = ies[i] * now;
now *= es[i];
}
}
for (int ph = h; ph >= 1; ph--) {
int w = 1 << (ph - 1), p = 1 << (h - ph);
mint inow = 1;
for (int s = 0; s < w; s++) {
int offset = s << (h - ph + 1);
for (int i = 0; i < p; i++) {
auto l = a[i + offset];
auto r = a[i + offset + p];
a[i + offset] = l + r;
a[i + offset + p] =
(unsigned long long)(mint::mod() + l.val() - r.val()) *
inow.val();
}
inow *= sum_ie[bsf(~(unsigned int)(s))];
}
}
}
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
std::vector<mint> convolution_naive(const std::vector<mint>& a, const std::vector<mint>& b) {
int n = int(a.size()), m = int(b.size());
std::vector<mint> ans(n + m - 1);
if (n < m) {
for (int j = 0; j < m; j++) {
for (int i = 0; i < n; i++) {
ans[i + j] += a[i] * b[j];
}
}
} else {
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
ans[i + j] += a[i] * b[j];
}
}
}
return ans;
}
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
std::vector<mint> convolution_fft(std::vector<mint> a, std::vector<mint> b) {
int n = int(a.size()), m = int(b.size());
int z = 1 << internal::ceil_pow2(n + m - 1);
a.resize(z);
internal::butterfly(a);
b.resize(z);
internal::butterfly(b);
for (int i = 0; i < z; i++) {
a[i] *= b[i];
}
internal::butterfly_inv(a);
a.resize(n + m - 1);
mint iz = mint(z).inv();
for (int i = 0; i < n + m - 1; i++) a[i] *= iz;
return a;
}
} // namespace internal
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
std::vector<mint> convolution(std::vector<mint>&& a, std::vector<mint>&& b) {
int n = int(a.size()), m = int(b.size());
if (!n || !m) return {};
if (std::min(n, m) <= 60) return convolution_naive(a, b);
return internal::convolution_fft(a, b);
}
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
std::vector<mint> convolution(const std::vector<mint>& a, const std::vector<mint>& b) {
int n = int(a.size()), m = int(b.size());
if (!n || !m) return {};
if (std::min(n, m) <= 60) return convolution_naive(a, b);
return internal::convolution_fft(a, b);
}
template <unsigned int mod = 998244353,
class T,
std::enable_if_t<internal::is_integral<T>::value>* = nullptr>
std::vector<T> convolution(const std::vector<T>& a, const std::vector<T>& b) {
int n = int(a.size()), m = int(b.size());
if (!n || !m) return {};
using mint = static_modint<mod>;
std::vector<mint> a2(n), b2(m);
for (int i = 0; i < n; i++) {
a2[i] = mint(a[i]);
}
for (int i = 0; i < m; i++) {
b2[i] = mint(b[i]);
}
auto c2 = convolution(move(a2), move(b2));
std::vector<T> c(n + m - 1);
for (int i = 0; i < n + m - 1; i++) {
c[i] = c2[i].val();
}
return c;
}
std::vector<long long> convolution_ll(const std::vector<long long>& a,
const std::vector<long long>& b) {
int n = int(a.size()), m = int(b.size());
if (!n || !m) return {};
static constexpr unsigned long long MOD1 = 754974721; // 2^24
static constexpr unsigned long long MOD2 = 167772161; // 2^25
static constexpr unsigned long long MOD3 = 469762049; // 2^26
static constexpr unsigned long long M2M3 = MOD2 * MOD3;
static constexpr unsigned long long M1M3 = MOD1 * MOD3;
static constexpr unsigned long long M1M2 = MOD1 * MOD2;
static constexpr unsigned long long M1M2M3 = MOD1 * MOD2 * MOD3;
static constexpr unsigned long long i1 =
internal::inv_gcd(MOD2 * MOD3, MOD1).second;
static constexpr unsigned long long i2 =
internal::inv_gcd(MOD1 * MOD3, MOD2).second;
static constexpr unsigned long long i3 =
internal::inv_gcd(MOD1 * MOD2, MOD3).second;
auto c1 = convolution<MOD1>(a, b);
auto c2 = convolution<MOD2>(a, b);
auto c3 = convolution<MOD3>(a, b);
std::vector<long long> c(n + m - 1);
for (int i = 0; i < n + m - 1; i++) {
unsigned long long x = 0;
x += (c1[i] * i1) % MOD1 * M2M3;
x += (c2[i] * i2) % MOD2 * M1M3;
x += (c3[i] * i3) % MOD3 * M1M2;
long long diff =
c1[i] - internal::safe_mod((long long)(x), (long long)(MOD1));
if (diff < 0) diff += MOD1;
static constexpr unsigned long long offset[5] = {
0, 0, M1M2M3, 2 * M1M2M3, 3 * M1M2M3};
x -= offset[diff % 5];
c[i] = x;
}
return c;
}
} // namespace atcoder
// https://atcoder.jp/contests/arc113/submissions/20423265
// Convolution is O(n^2)
// Credits - tourist
namespace algebra {
template <typename T>
vector<T>& operator+=(vector<T>& a, const vector<T>& b) {
if (a.size() < b.size()) {
a.resize(b.size());
}
for (int i = 0; i < (int) b.size(); i++) {
a[i] += b[i];
}
return a;
}
template <typename T>
vector<T> operator+(const vector<T>& a, const vector<T>& b) {
vector<T> c = a;
return c += b;
}
template <typename T>
vector<T>& operator-=(vector<T>& a, const vector<T>& b) {
if (a.size() < b.size()) {
a.resize(b.size());
}
for (int i = 0; i < (int) b.size(); i++) {
a[i] -= b[i];
}
return a;
}
template <typename T>
vector<T> operator-(const vector<T>& a, const vector<T>& b) {
vector<T> c = a;
return c -= b;
}
template <typename T>
vector<T> operator-(const vector<T>& a) {
vector<T> c = a;
for (int i = 0; i < (int) c.size(); i++) {
c[i] = -c[i];
}
return c;
}
template <typename T>
vector<T> operator*(const vector<T>& a, const vector<T>& b) {
if (a.empty() || b.empty()) {
return {};
}
return convolution(a,b);
// vector<T> c(a.size() + b.size() - 1, 0);
// for (int i = 0; i < (int) a.size(); i++) {
// for (int j = 0; j < (int) b.size(); j++) {
// c[i + j] += a[i] * b[j];
// }
// }
// return c;
}
template <typename T>
vector<T>& operator*=(vector<T>& a, const vector<T>& b) {
return a = a * b;
}
template <typename T>
vector<T> inverse(const vector<T>& a) {
assert(!a.empty());
int n = (int) a.size();
vector<T> b = {1 / a[0]};
while ((int) b.size() < n) {
vector<T> a_cut(a.begin(), a.begin() + min(a.size(), b.size() << 1));
vector<T> x = b * b * a_cut;
b.resize(b.size() << 1);
for (int i = (int) b.size() >> 1; i < (int) min(x.size(), b.size()); i++) {
b[i] = -x[i];
}
}
b.resize(n);
return b;
}
template <typename T>
vector<T>& operator/=(vector<T>& a, const vector<T>& b) {
int n = (int) a.size();
int m = (int) b.size();
if (n < m) {
a.clear();
} else {
vector<T> d = b;
reverse(a.begin(), a.end());
reverse(d.begin(), d.end());
d.resize(n - m + 1);
a *= inverse(d);
a.erase(a.begin() + n - m + 1, a.end());
reverse(a.begin(), a.end());
}
return a;
}
template <typename T>
vector<T> operator/(const vector<T>& a, const vector<T>& b) {
vector<T> c = a;
return c /= b;
}
template <typename T>
vector<T> operator*(const vector<T>& a, T b) {
vector<T> c = a;
for(auto &x:c)
x*=b;
return c;
}
template <typename T>
vector<T>& operator%=(vector<T>& a, const vector<T>& b) {
int n = (int) a.size();
int m = (int) b.size();
if (n >= m) {
vector<T> c = (a / b) * b;
a.resize(m - 1);
for (int i = 0; i < m - 1; i++) {
a[i] -= c[i];
}
}
return a;
}
template <typename T>
vector<T> operator%(const vector<T>& a, const vector<T>& b) {
vector<T> c = a;
return c %= b;
}
template <typename T, typename U>
vector<T> power(const vector<T>& a, const U& b, const vector<T>& c) {
assert(b >= 0);
vector<U> binary;
U bb = b;
while (bb > 0) {
binary.push_back(bb & 1);
bb >>= 1;
}
vector<T> res = vector<T>{1} % c;
for (int j = (int) binary.size() - 1; j >= 0; j--) {
res = res * res % c;
if (binary[j] == 1) {
res = res * a % c;
}
}
return res;
}
template <typename T>
vector<T> derivative(const vector<T>& a) {
vector<T> c = a;
for (int i = 0; i < (int) c.size(); i++) {
c[i] *= i;
}
if (!c.empty()) {
c.erase(c.begin());
}
return c;
}
template <typename T>
vector<T> integrate(const vector<T>& a) {
vector<T> c = {0};
for (int i = 0; i < (int) a.size(); i++) {
c.push_back(a[i]/(i+1));
}
return c;
}
template <typename T>
vector<T> primitive(const vector<T>& a) {
vector<T> c = a;
c.insert(c.begin(), 0);
for (int i = 1; i < (int) c.size(); i++) {
c[i] /= i;
}
return c;
}
template <typename T>
vector<T> logarithm(const vector<T>& a) {
assert(!a.empty() && a[0] == 1);
vector<T> u = primitive(derivative(a) * inverse(a));
u.resize(a.size());
return u;
}
template <typename T>
vector<T> exponent(const vector<T>& a) {
assert(!a.empty() && a[0] == 0);
int n = (int) a.size();
vector<T> b = {1};
while ((int) b.size() < n) {
vector<T> x(a.begin(), a.begin() + min(a.size(), b.size() << 1));
x[0] += 1;
vector<T> old_b = b;
b.resize(b.size() << 1);
x -= logarithm(b);
x *= old_b;
for (int i = (int) b.size() >> 1; i < (int) min(x.size(), b.size()); i++) {
b[i] = x[i];
}
}
b.resize(n);
return b;
}
template <typename T>
vector<T> sqrt(const vector<T>& a) {
assert(!a.empty() && a[0] == 1);
int n = (int) a.size();
vector<T> b = {1};
while ((int) b.size() < n) {
vector<T> x(a.begin(), a.begin() + min(a.size(), b.size() << 1));
b.resize(b.size() << 1);
x *= inverse(b);
T inv2 = 1 / static_cast<T>(2);
for (int i = (int) b.size() >> 1; i < (int) min(x.size(), b.size()); i++) {
b[i] = x[i] * inv2;
}
}
b.resize(n);
return b;
}
template <typename T>
vector<T> multiply(const vector<vector<T>>& a) {
if (a.empty()) {
return {0};
}
function<vector<T>(int, int)> mult = [&](int l, int r) {
if (l == r) {
return a[l];
}
int y = (l + r) >> 1;
return mult(l, y) * mult(y + 1, r);
};
return mult(0, (int) a.size() - 1);
}
template <typename T>
T evaluate(const vector<T>& a, const T& x) {
T res = 0;
for (int i = (int) a.size() - 1; i >= 0; i--) {
res = res * x + a[i];
}
return res;
}
template <typename T>
vector<T> evaluate(const vector<T>& a, const vector<T>& x) {
if (x.empty()) {
return {};
}
if (a.empty()) {
return vector<T>(x.size(), 0);
}
int n = (int) x.size();
vector<vector<T>> st((n << 1) - 1);
function<void(int, int, int)> build = [&](int v, int l, int r) {
if (l == r) {
st[v] = vector<T>{-x[l], 1};
} else {
int y = (l + r) >> 1;
int z = v + ((y - l + 1) << 1);
build(v + 1, l, y);
build(z, y + 1, r);
st[v] = st[v + 1] * st[z];
}
};
build(0, 0, n - 1);
vector<T> res(n);
function<void(int, int, int, vector<T>)> eval = [&](int v, int l, int r, vector<T> f) {
f %= st[v];
if ((int) f.size() < 150) {
for (int i = l; i <= r; i++) {
res[i] = evaluate(f, x[i]);
}
return;
}
if (l == r) {
res[l] = f[0];
} else {
int y = (l + r) >> 1;
int z = v + ((y - l + 1) << 1);
eval(v + 1, l, y, f);
eval(z, y + 1, r, f);
}
};
eval(0, 0, n - 1, a);
return res;
}
template <typename T>
vector<T> interpolate(const vector<T>& x, const vector<T>& y) {
if (x.empty()) {
return {};
}
assert(x.size() == y.size());
int n = (int) x.size();
vector<vector<T>> st((n << 1) - 1);
function<void(int, int, int)> build = [&](int v, int l, int r) {
if (l == r) {
st[v] = vector<T>{-x[l], 1};
} else {
int w = (l + r) >> 1;
int z = v + ((w - l + 1) << 1);
build(v + 1, l, w);
build(z, w + 1, r);
st[v] = st[v + 1] * st[z];
}
};
build(0, 0, n - 1);
vector<T> m = st[0];
vector<T> dm = derivative(m);
vector<T> val(n);
function<void(int, int, int, vector<T>)> eval = [&](int v, int l, int r, vector<T> f) {
f %= st[v];
if ((int) f.size() < 150) {
for (int i = l; i <= r; i++) {
val[i] = evaluate(f, x[i]);
}
return;
}
if (l == r) {
val[l] = f[0];
} else {
int w = (l + r) >> 1;
int z = v + ((w - l + 1) << 1);
eval(v + 1, l, w, f);
eval(z, w + 1, r, f);
}
};
eval(0, 0, n - 1, dm);
for (int i = 0; i < n; i++) {
val[i] = y[i] / val[i];
}
function<vector<T>(int, int, int)> calc = [&](int v, int l, int r) {
if (l == r) {
return vector<T>{val[l]};
}
int w = (l + r) >> 1;
int z = v + ((w - l + 1) << 1);
return calc(v + 1, l, w) * st[z] + calc(z, w + 1, r) * st[v + 1];
};
return calc(0, 0, n - 1);
}
// f[i] = 1^i + 2^i + ... + up^i
template <typename T>
vector<T> faulhaber(const T& up, int n) {
vector<T> ex(n + 1);
T e = 1;
for (int i = 0; i <= n; i++) {
ex[i] = e;
e /= i + 1;
}
vector<T> den = ex;
den.erase(den.begin());
for (auto& d : den) {
d = -d;
}
vector<T> num(n);
T p = 1;
for (int i = 0; i < n; i++) {
p *= up + 1;
num[i] = ex[i + 1] * (1 - p);
}
vector<T> res = num * inverse(den);
res.resize(n);
T f = 1;
for (int i = 0; i < n; i++) {
res[i] *= f;
f *= i + 1;
}
return res;
}
// (x + 1) * (x + 2) * ... * (x + n)
// (can be optimized with precomputed inverses)
template <typename T>
vector<T> sequence(int n) {
if (n == 0) {
return {1};
}
if (n % 2 == 1) {
return sequence<T>(n - 1) * vector<T>{n, 1};
}
vector<T> c = sequence<T>(n / 2);
vector<T> a = c;
reverse(a.begin(), a.end());
T f = 1;
for (int i = n / 2 - 1; i >= 0; i--) {
f *= n / 2 - i;
a[i] *= f;
}
vector<T> b(n / 2 + 1);
b[0] = 1;
for (int i = 1; i <= n / 2; i++) {
b[i] = b[i - 1] * (n / 2) / i;
}
vector<T> h = a * b;
h.resize(n / 2 + 1);
reverse(h.begin(), h.end());
f = 1;
for (int i = 1; i <= n / 2; i++) {
f /= i;
h[i] *= f;
}
vector<T> res = c * h;
return res;
}
template <typename T>
class OnlineProduct {
public:
const vector<T> a;
vector<T> b;
vector<T> c;
OnlineProduct(const vector<T>& a_) : a(a_) {}
T add(const T& val) {
int i = (int) b.size();
b.push_back(val);
if ((int) c.size() <= i) {
c.resize(i + 1);
}
c[i] += a[0] * b[i];
int z = 1;
while ((i & (z - 1)) == z - 1 && (int) a.size() > z) {
vector<T> a_mul(a.begin() + z, a.begin() + min(z << 1, (int) a.size()));
vector<T> b_mul(b.end() - z, b.end());
vector<T> c_mul = a_mul * b_mul;
if ((int) c.size() <= i + (int) c_mul.size()) {
c.resize(i + c_mul.size() + 1);
}
for (int j = 0; j < (int) c_mul.size(); j++) {
c[i + 1 + j] += c_mul[j];
}
z <<= 1;
}
return c[i];
}
};
};
using namespace algebra;
using namespace std;
using namespace atcoder;
using mint = modint998244353;
// using mint = lli;
using vm = vector<mint>;
std::ostream& operator << (std::ostream& out, const mint& rhs) {
return out<<rhs.val();
}
typedef vector<mint> polyn;
long long readInt(long long l, long long r, char endd) {
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true) {
char g=getchar();
if(g=='-') {
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g&&g<='9') {
x*=10;
x+=g-'0';
if(cnt==0) {
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd) {
if(is_neg) {
x=-x;
}
assert(l<=x&&x<=r);
return x;
} else {
assert(false);
}
}
}
string readString(int l, int r, char endd) {
string ret="";
int cnt=0;
while(true) {
char g=getchar();
assert(g!=-1);
if(g==endd) {
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt&&cnt<=r);
return ret;
}
long long readIntSp(long long l, long long r) {
return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
return readString(l,r,' ');
}
void readEOF(){
assert(getchar()==EOF);
}
vi readVectorInt(int n,lli l,lli r){
vi a(n);
for(int i=0;i<n-1;++i)
a[i]=readIntSp(l,r);
a[n-1]=readIntLn(l,r);
return a;
}
const lli INF = 0xFFFFFFFFFFFFFFFL;
lli seed;
mt19937 rng(seed=chrono::steady_clock::now().time_since_epoch().count());
inline lli rnd(lli l=0,lli r=INF)
{return uniform_int_distribution<lli>(l,r)(rng);}
class CMP
{public:
bool operator()(ii a , ii b) //For min priority_queue .
{ return ! ( a.X < b.X || ( a.X==b.X && a.Y <= b.Y )); }};
void add( map<lli,lli> &m, lli x,lli cnt=1)
{
auto jt=m.find(x);
if(jt==m.end()) m.insert({x,cnt});
else jt->Y+=cnt;
}
void del( map<lli,lli> &m, lli x,lli cnt=1)
{
auto jt=m.find(x);
if(jt->Y<=cnt) m.erase(jt);
else jt->Y-=cnt;
}
bool cmp(const ii &a,const ii &b)
{
return a.X<b.X||(a.X==b.X&&a.Y<b.Y);
}
const lli mod = 998244353LL;
// const lli maxN = 1000000007L;
// #include <atcoder/modint>
// using namespace atcoder;
// using mint = modint998244353;
//using mint = modint1000000007;
// std::ostream& operator << (std::ostream& out, const mint& rhs) {
// return out<<rhs.val();
// }
// using mint = lli;
using vm = vector<mint>;
struct LinearRecurrence {
using ll = lli;
using vec = vector<ll>;
static void extand(vec &a, ll d, ll value = 0) {
if (d <= a.size()) return;
a.resize(d, value);
}
static vec BerlekampMassey(const vec &s, ll mod) {
std::function<ll(ll)> inverse = [&](ll a) {
return a == 1 ? 1 : (ll)(mod - mod / a) * inverse(mod % a) % mod;
};
vec A = {1}, B = {1};
ll b = s[0];
for (size_t i = 1, m = 1; i < s.size(); ++i, m++) {
ll d = 0;
for (size_t j = 0; j < A.size(); ++j) {
d += A[j] * s[i - j] % mod;
}
if (!(d %= mod)) continue;
if (2 * (A.size() - 1) <= i) {
auto temp = A;
extand(A, B.size() + m);
ll coef = d * inverse(b) % mod;
for (size_t j = 0; j < B.size(); ++j) {
A[j + m] -= coef * B[j] % mod;
if (A[j + m] < 0) A[j + m] += mod;
}
B = temp, b = d, m = 0;
} else {
assert(false);
extand(A, B.size() + m);
ll coef = d * inverse(b) % mod;
for (size_t j = 0; j < B.size(); ++j) {
A[j + m] -= coef * B[j] % mod;
if (A[j + m] < 0) A[j + m] += mod;
}
}
}
return A;
}
static void exgcd(ll a, ll b, ll &g, ll &x, ll &y) {
if (!b) x = 1, y = 0, g = a;
else {
exgcd(b, a % b, g, y, x);
y -= x * (a / b);
}
}
static ll crt(const vec &c, const vec &m) {
ll n = c.size();
ll M = 1, ans = 0;
for (ll i = 0; i < n; ++i) M *= m[i];
for (ll i = 0; i < n; ++i) {
ll x, y, g, tm = M / m[i];
exgcd(tm, m[i], g, x, y);
ans = (ans + tm * x * c[i] % M) % M;
}
return (ans + M) % M;
}
static vec ReedsSloane(const vec &s, ll mod) {
auto inverse = [] (ll a, ll m) {
ll d, x, y;
exgcd(a, m, d, x, y);
return d == 1 ? (x % m + m) % m : -1;
};
auto L = [] (const vec &a, const vec &b) {
ll da = (a.size() > 1 || (a.size() == 1 && a[0])) ? (ll)a.size() - 1 : -1000;
ll db = (b.size() > 1 || (b.size() == 1 && b[0])) ? (ll)b.size() - 1 : -1000;
return max(da, db + 1);
};
auto prime_power = [&] (const vec &s, ll mod, ll p, ll e) {
// linear feedback shift register mod p^e, p is prime
vector<vec> a(e), b(e), an(e), bn(e), ao(e), bo(e);
vec t(e), u(e), r(e), to(e, 1), uo(e), pw(e + 1);;
pw[0] = 1;
for (ll i = pw[0] = 1; i <= e; ++i) pw[i] = pw[i - 1] * p;
for (ll i = 0; i < e; ++i) {
a[i] = {pw[i]}, an[i] = {pw[i]};
b[i] = {0}, bn[i] = {s[0] * pw[i] % mod};
t[i] = s[0] * pw[i] % mod;
if (t[i] == 0) {
t[i] = 1, u[i] = e;
} else {
for (u[i] = 0; t[i] % p == 0; t[i] /= p, ++u[i]);
}
}
for (ll k = 1; k < s.size(); ++k) {
for (ll g = 0; g < e; ++g) {
if (L(an[g], bn[g]) > L(a[g], b[g])) {
ao[g] = a[e - 1 - u[g]];
bo[g] = b[e - 1 - u[g]];
to[g] = t[e - 1 - u[g]];
uo[g] = u[e - 1 - u[g]];
r[g] = k - 1;
}
}
a = an, b = bn;
for (ll o = 0; o < e; ++o) {
ll d = 0;
for (ll i = 0; i < a[o].size() && i <= k; ++i) {
d = (d + a[o][i] * s[k - i]) % mod;
}
if (d == 0) {
t[o] = 1, u[o] = e;
} else {
for (u[o] = 0, t[o] = d; t[o] % p == 0; t[o] /= p, ++u[o]);
ll g = e - 1 - u[o];
if (L(a[g], b[g]) == 0) {
extand(bn[o], k + 1);
bn[o][k] = (bn[o][k] + d) % mod;
} else {
ll coef = t[o] * inverse(to[g], mod) % mod * pw[u[o] - uo[g]] % mod;
ll m = k - r[g];
extand(an[o], ao[g].size() + m);
extand(bn[o], bo[g].size() + m);
for (ll i = 0; i < ao[g].size(); ++i) {
an[o][i + m] -= coef * ao[g][i] % mod;
if (an[o][i + m] < 0) an[o][i + m] += mod;
}
while (an[o].size() && an[o].back() == 0) an[o].pop_back();
for (ll i = 0; i < bo[g].size(); ++i) {
bn[o][i + m] -= coef * bo[g][i] % mod;
if (bn[o][i + m] < 0) bn[o][i + m] -= mod;
}
while (bn[o].size() && bn[o].back() == 0) bn[o].pop_back();
}
}
}
}
return make_pair(an[0], bn[0]);
};
vector<tuple<ll, ll, ll>> fac;
for (ll i = 2; i * i <= mod; ++i) if (mod % i == 0) {
ll cnt = 0, pw = 1;
while (mod % i == 0) mod /= i, ++cnt, pw *= i;
fac.emplace_back(pw, i, cnt);
}
if (mod > 1) fac.emplace_back(mod, mod, 1);
vector<vec> as;
ll n = 0;
for (auto &&x: fac) {
ll mod, p, e;
vec a, b;
tie(mod, p, e) = x;
auto ss = s;
for (auto &&x: ss) x %= mod;
tie(a, b) = prime_power(ss, mod, p, e);
as.emplace_back(a);
n = max(n, (ll) a.size());
}
vec a(n), c(as.size()), m(as.size());
for (ll i = 0; i < n; ++i) {
for (ll j = 0; j < as.size(); ++j) {
m[j] = get<0>(fac[j]);
c[j] = i < as[j].size() ? as[j][i] : 0;
}
a[i] = crt(c, m);
}
return a;
}
LinearRecurrence(const vec &s, const vec &c, ll mod):
init(s), trans(c), mod(mod), m(s.size()) {}
LinearRecurrence(const vec &s, ll mod, bool is_prime = true): mod(mod) {
vec A;
if(is_prime) A = BerlekampMassey(s,mod);
else A = ReedsSloane(s, mod);
if (A.empty()) A = {0};
m = A.size() - 1;
trans.resize(m);
for (ll i = 0; i < m; ++i) {
trans[i] = (mod - A[i + 1]) % mod;
}
reverse(trans.begin(), trans.end());
init = {s.begin(), s.begin() + m};
}
ll calcOG(ll n) {
if (mod == 1) return 0;
if (n < m) return init[n];
vec v(m), u(m << 1);
ll msk = !!n;
for (ll m = n; m > 1; m >>= 1LL) msk <<= 1LL;
v[0] = 1 % mod;
for (ll x = 0; msk; msk >>= 1LL, x <<= 1LL) {
fill_n(u.begin(), m * 2, 0);
x |= !!(n & msk);
if (x < m) u[x] = 1 % mod;
else { // can be optimized by fft/ntt
for (ll i = 0; i < m; ++i) {
for (ll j = 0, t = i + (x & 1); j < m; ++j, ++t) {
u[t] = (u[t] + v[i] * v[j]) % mod;
}
}
for (ll i = m * 2 - 1; i >= m; --i) {
for (ll j = 0, t = i - m; j < m; ++j, ++t) {
u[t] = (u[t] + trans[j] * u[i]) % mod;
}
}
}
v = {u.begin(), u.begin() + m};
}
ll ret = 0;
for (ll i = 0; i < m; ++i) {
ret = (ret + v[i] * init[i]) % mod;
}
return ret;
}
ll calc(ll n) {
// dbg(trans);
// return calcOG(n);
if (mod == 1) return 0;
if (n < m) return init[n];
vm pmod;
for(auto x:trans)
pmod.pb(-x);
pmod.pb(1);
polyn a={0,1},ans={1};
while(n){
if(n&1)
ans=(ans*a)%pmod;
n/=2;
a=(a*a)%pmod;
}
mint val=0;
for(lli i=0;i<sz(ans);++i){
val+=ans[i]*init[i];
}
return val.val();
}
vec init, trans;
ll mod;
ll m;
};
lli T,n,i,j,k,in,cnt,l,r,u,v,x,y;
lli m;
string s;
vector<vm> dp;
//priority_queue < ii , vector < ii > , CMP > pq;// min priority_queue .
vm pre(lli n,lli k){
dp.clear();
dp.resize(k+1,vm(n+1));
vector<pair<mint,mint>> sum(2);
lli c=0;
for(lli i=1;i<=n;++i){
c^=i&1;
mint odd=0,even=0;
if(c){
odd=sum[0].X+1;
even=sum[1].Y;
} else {
odd=sum[1].X;
even=sum[0].Y+1;
}
dp[0][i]=(odd+even);
sum[c].X+=even;sum[c].Y+=odd;
// dp[0][i]%=mod;
// sum[c].X%=mod;sum[c].Y%=mod;
}
for(lli j=1;j<=k;++j)
{
for(lli i=1;i<=n;++i){
dp[j][i]=(dp[j-1][i]+dp[j][i-1]);
// dp[j][i]=(dp[j-1][i]+dp[j][i-1])%mod;
}
}
return dp[k];
}
int main(void) {
ios_base::sync_with_stdio(false);cin.tie(NULL);
// freopen("txt.in", "r", stdin);
// freopen("txt.out", "w", stdout);
// cout<<std::fixed<<std::setprecision(35);
const lli N=6e3+16;
T=readIntLn(1,3e3);
lli sumK = 0;
while(T--)
{
const lli n=readIntSp(1,1e18);
const lli k=readIntLn(0,3e3);
sumK+=k+1;
auto curm=pre(N,k);
vi init(N);
for(lli i=1;i<=N;++i)
init[i-1]=curm[i].val();
// dbg(init);
LinearRecurrence lr(init, mod, true);
dbg(k,sz(lr.trans),lr.trans);
cout<<lr.calc(n-1)<<endl;
} aryanc403();
assert(sumK<=3e3);
readEOF();
return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class PARTN01{
//SOLUTION BEGIN
final long MOD = 998244353;
final long BIG = 8*MOD*(long)MOD;
int UPTO = 6050;
long[] k0Coeff;
void pre() throws Exception{
long[][] F = new long[2][1+UPTO];
long[][] S = new long[2][1+UPTO];
F[0][0] = F[1][0] = 1;
S[0][0] = S[1][0] = 1;
for(int i = 1; i<= UPTO; i++){
int sum = 0;
for(int j = i; j > Math.max(0, i-4); j--){
sum ^= (j&1);
F[sum][i] += S[sum^1][j-1];
if(F[sum][i] >= MOD)F[sum][i] -= MOD;
}
for(int c = 0; c< 2; c++){
if(i >= 4)S[c][i] = S[c][i-4];
S[c][i] += F[c][i];
if(S[c][i] >= MOD)S[c][i] -= MOD;
}
}
k0Coeff = new long[1+UPTO];
for(int i = 1; i<= UPTO; i++){
k0Coeff[i] = F[0][i]+F[1][i];
if(k0Coeff[i] >= MOD)k0Coeff[i] -= MOD;
}
}
void solve(int TC) throws Exception{
long N = nl();
int K = ni();
int MAX = 2*K+16;
long[] coeff = Arrays.copyOf(k0Coeff, 1+MAX);
for(int i = 1; i<= K; i++){
for(int j = 1; j<= MAX; j++){
coeff[j] += coeff[j-1];
if(coeff[j] >= MOD)coeff[j] -= MOD;
}
}
if(N <= MAX){
pn(coeff[(int)N]);
return;
}
long[] tmp = new long[MAX];
for(int i = 0; i< MAX; i++)tmp[i] = coeff[i+1];
long[] rec = bm(tmp);
long[] mod = new long[1+rec.length];
for(int i = 0; i< rec.length; i++)mod[i] = rec[rec.length-1-i] == 0?0:(MOD-rec[rec.length-1-i]);
mod[rec.length] = 1;
long[] Px = pow(mod, N-1);
long ans = 0;
for(int i = 0; i< rec.length; i++){
ans += tmp[i] * Px[i]%MOD;
if(ans >= BIG)ans -= BIG;
}
pn(ans%MOD);
}
//Computes x^p mod R
long[] pow(long[] R, long p){
long[] ans = new long[]{1};
long[] mul = new long[]{0, 1};
for(; p>0; p>>=1){
if((p&1)==1){
ans = mul(ans, mul);
ans = remainder(ans, R);
}
mul = mul(mul, mul);
mul = remainder(mul, R);
}
return ans;
}
void dbg(Object... o){System.err.println(Arrays.deepToString(o));}
long[] remainder(long[] A, long[] B){
int N = A.length-1, M = B.length-1;
if(N < M)return A;
long[] Ar = new long[1+N], Br = new long[1+M];
for(int i = 0; i< A.length; i++)Ar[N-i] = A[i];
for(int i = 0; i< B.length; i++)Br[M-i] = B[i];
long[] invBr = inv(Br);
long[] D = Arrays.copyOf(mul(Ar, invBr), N-M+1);
for(int i = 0, j = D.length-1; i< j; i++, j--){
long tmp = D[i];
D[i] = D[j];
D[j] = tmp;
}
long[] BD = mul(B, D);
long[] R = new long[A.length];
for(int i = 0; i< A.length; i++)R[i] = (A[i] + MOD - BD[i]%MOD)%MOD;
for(int i = M; i< A.length; i++)assert (R[i] == 0);
return Arrays.copyOf(R, M);
}
int MAGIC = 50;
long[] mul(long[] A, long[] B){
if(Math.max(A.length, B.length) < MAGIC)return mulNaive(A, B);
return Convolution.convolution(A, B, (int)MOD);
}
long[] mulNaive(long[] A, long[] B){
long[] C = new long[A.length+B.length-1];
for(int i = 0; i< A.length; i++)
for(int j = 0; j< B.length; j++){
C[i+j] += A[i]*B[j]%MOD;
if(C[i+j] >= BIG)C[i+j] -= BIG;
}
for(int i = 0; i< C.length; i++)C[i] %= MOD;
return C;
}
long[] inv(long[] t){
long[] a = Arrays.copyOf(t, Integer.highestOneBit(t.length)<<1);
int n = a.length;
assert (n & (n - 1)) == 0;
long r[] = new long[]{inv(a[0])};
for(int len = 2; len <= n; len *= 2) {
long nr[] = new long[len];
System.arraycopy(a, 0, nr, 0, len);
nr = mul(nr, mul(r, r));
for(int i = 0; i < len; ++i) nr[i] = (MOD - nr[i]) % MOD;
for(int i = 0; i < len / 2; ++i) nr[i] = (nr[i] + 2 * r[i]) % MOD;
r = new long[len];
System.arraycopy(nr, 0, r, 0, len);
}
return r;
}
long[] bm(long[] S) throws Exception{
long[] ls = new long[0], cur = new long[0];
int lf = 0;
long ld = 0;
for(int i = 0; i< S.length; ++i){
long t = 0;
for(int j = 0; j< cur.length; ++j){
t += S[i-j-1]*(long)cur[j]%MOD;
if(t >= BIG)t -= BIG;
}
t %= MOD;
if(t == S[i])continue;//Recurrence works
if(cur.length == 0){
//First non-zero
cur = new long[i+1];
lf = i;
ld = (t+MOD-S[i])%MOD;
continue;
}
int k = (int)((MOD-S[i]+t)*inv(ld)%MOD);
long[] c = new long[Math.max(cur.length, i-lf-1+1+ls.length)];
int pos = i-lf-1;
c[pos++] = k;
c[i-lf-1] = k;
for(int j = 0; j< ls.length; j++, pos++){
c[pos] += MOD-ls[j]*(long)k%MOD;
if(c[pos] >= MOD)c[pos] -= MOD;
}
for(int j = 0; j< cur.length; j++){
c[j] += cur[j];
if(c[j] >= MOD)c[j] -= MOD;
}
if(i-lf+ls.length >= cur.length){
ls = cur;
lf = i;
ld = (t+MOD-S[i])%MOD;
}
cur = c;
}
return cur;
}
long inv(long a){return pow(a, MOD-2);}
long pow(long a, long p){
long o = 1;
for(; p>0; p>>=1){
if((p&1)==1)o = o*a%MOD;
a = a*a%MOD;
}
return o;
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new PARTN01().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
/**
* Convolution.
*
* @verified https://atcoder.jp/contests/practice2/tasks/practice2_f
* @verified https://judge.yosupo.jp/problem/convolution_mod_1000000007
*/
class Convolution {
/**
* Find a primitive root.
*
* @param m A prime number.
* @return Primitive root.
*/
private static int primitiveRoot(int m) {
if (m == 2) return 1;
if (m == 167772161) return 3;
if (m == 469762049) return 3;
if (m == 754974721) return 11;
if (m == 998244353) return 3;
int[] divs = new int[20];
divs[0] = 2;
int cnt = 1;
int x = (m - 1) / 2;
while (x % 2 == 0) x /= 2;
for (int i = 3; (long) (i) * i <= x; i += 2) {
if (x % i == 0) {
divs[cnt++] = i;
while (x % i == 0) {
x /= i;
}
}
}
if (x > 1) {
divs[cnt++] = x;
}
for (int g = 2; ; g++) {
boolean ok = true;
for (int i = 0; i < cnt; i++) {
if (pow(g, (m - 1) / divs[i], m) == 1) {
ok = false;
break;
}
}
if (ok) return g;
}
}
/**
* Power.
*
* @param x Parameter x.
* @param n Parameter n.
* @param m Mod.
* @return n-th power of x mod m.
*/
private static long pow(long x, long n, int m) {
if (m == 1) return 0;
long r = 1;
long y = x % m;
while (n > 0) {
if ((n & 1) != 0) r = (r * y) % m;
y = (y * y) % m;
n >>= 1;
}
return r;
}
/**
* Ceil of power 2.
*
* @param n Value.
* @return Ceil of power 2.
*/
private static int ceilPow2(int n) {
int x = 0;
while ((1L << x) < n) x++;
return x;
}
/**
* Garner's algorithm.
*
* @param c Mod convolution results.
* @param mods Mods.
* @return Result.
*/
private static long garner(long[] c, int[] mods) {
int n = c.length + 1;
long[] cnst = new long[n];
long[] coef = new long[n];
java.util.Arrays.fill(coef, 1);
for (int i = 0; i < n - 1; i++) {
int m1 = mods[i];
long v = (c[i] - cnst[i] + m1) % m1;
v = v * pow(coef[i], m1 - 2, m1) % m1;
for (int j = i + 1; j < n; j++) {
long m2 = mods[j];
cnst[j] = (cnst[j] + coef[j] * v) % m2;
coef[j] = (coef[j] * m1) % m2;
}
}
return cnst[n - 1];
}
/**
* Pre-calculation for NTT.
*
* @param mod NTT Prime.
* @param g Primitive root of mod.
* @return Pre-calculation table.
*/
private static long[] sumE(int mod, int g) {
long[] sum_e = new long[30];
long[] es = new long[30];
long[] ies = new long[30];
int cnt2 = Integer.numberOfTrailingZeros(mod - 1);
long e = pow(g, (mod - 1) >> cnt2, mod);
long ie = pow(e, mod - 2, mod);
for (int i = cnt2; i >= 2; i--) {
es[i - 2] = e;
ies[i - 2] = ie;
e = e * e % mod;
ie = ie * ie % mod;
}
long now = 1;
for (int i = 0; i < cnt2 - 2; i++) {
sum_e[i] = es[i] * now % mod;
now = now * ies[i] % mod;
}
return sum_e;
}
/**
* Pre-calculation for inverse NTT.
*
* @param mod Mod.
* @param g Primitive root of mod.
* @return Pre-calculation table.
*/
private static long[] sumIE(int mod, int g) {
long[] sum_ie = new long[30];
long[] es = new long[30];
long[] ies = new long[30];
int cnt2 = Integer.numberOfTrailingZeros(mod - 1);
long e = pow(g, (mod - 1) >> cnt2, mod);
long ie = pow(e, mod - 2, mod);
for (int i = cnt2; i >= 2; i--) {
es[i - 2] = e;
ies[i - 2] = ie;
e = e * e % mod;
ie = ie * ie % mod;
}
long now = 1;
for (int i = 0; i < cnt2 - 2; i++) {
sum_ie[i] = ies[i] * now % mod;
now = now * es[i] % mod;
}
return sum_ie;
}
/**
* Inverse NTT.
*
* @param a Target array.
* @param sumIE Pre-calculation table.
* @param mod NTT Prime.
*/
private static void butterflyInv(long[] a, long[] sumIE, int mod) {
int n = a.length;
int h = ceilPow2(n);
for (int ph = h; ph >= 1; ph--) {
int w = 1 << (ph - 1), p = 1 << (h - ph);
long inow = 1;
for (int s = 0; s < w; s++) {
int offset = s << (h - ph + 1);
for (int i = 0; i < p; i++) {
long l = a[i + offset];
long r = a[i + offset + p];
a[i + offset] = (l + r) % mod;
a[i + offset + p] = (mod + l - r) * inow % mod;
}
int x = Integer.numberOfTrailingZeros(~s);
inow = inow * sumIE[x] % mod;
}
}
}
/**
* Inverse NTT.
*
* @param a Target array.
* @param sumE Pre-calculation table.
* @param mod NTT Prime.
*/
private static void butterfly(long[] a, long[] sumE, int mod) {
int n = a.length;
int h = ceilPow2(n);
for (int ph = 1; ph <= h; ph++) {
int w = 1 << (ph - 1), p = 1 << (h - ph);
long now = 1;
for (int s = 0; s < w; s++) {
int offset = s << (h - ph + 1);
for (int i = 0; i < p; i++) {
long l = a[i + offset];
long r = a[i + offset + p] * now % mod;
a[i + offset] = (l + r) % mod;
a[i + offset + p] = (l - r + mod) % mod;
}
int x = Integer.numberOfTrailingZeros(~s);
now = now * sumE[x] % mod;
}
}
}
/**
* Convolution.
*
* @param a Target array 1.
* @param b Target array 2.
* @param mod NTT Prime.
* @return Answer.
*/
public static long[] convolution(long[] a, long[] b, int mod) {
int n = a.length;
int m = b.length;
if (n == 0 || m == 0) return new long[0];
int z = 1 << ceilPow2(n + m - 1);
{
long[] na = new long[z];
long[] nb = new long[z];
System.arraycopy(a, 0, na, 0, n);
System.arraycopy(b, 0, nb, 0, m);
a = na;
b = nb;
}
int g = primitiveRoot(mod);
long[] sume = sumE(mod, g);
long[] sumie = sumIE(mod, g);
butterfly(a, sume, mod);
butterfly(b, sume, mod);
for (int i = 0; i < z; i++) {
a[i] = a[i] * b[i] % mod;
}
butterflyInv(a, sumie, mod);
a = java.util.Arrays.copyOf(a, n + m - 1);
long iz = pow(z, mod - 2, mod);
for (int i = 0; i < n + m - 1; i++) a[i] = a[i] * iz % mod;
return a;
}
/**
* Convolution.
*
* @param a Target array 1.
* @param b Target array 2.
* @param mod Any mod.
* @return Answer.
*/
public static long[] convolutionLL(long[] a, long[] b, int mod) {
int n = a.length;
int m = b.length;
if (n == 0 || m == 0) return new long[0];
int mod1 = 754974721;
int mod2 = 167772161;
int mod3 = 469762049;
long[] c1 = convolution(a, b, mod1);
long[] c2 = convolution(a, b, mod2);
long[] c3 = convolution(a, b, mod3);
int retSize = c1.length;
long[] ret = new long[retSize];
int[] mods = {mod1, mod2, mod3, mod};
for (int i = 0; i < retSize; ++i) {
ret[i] = garner(new long[]{c1[i], c2[i], c3[i]}, mods);
}
return ret;
}
/**
* Naive convolution. (Complexity is O(N^2)!!)
*
* @param a Target array 1.
* @param b Target array 2.
* @param mod Mod.
* @return Answer.
*/
public static long[] convolutionNaive(long[] a, long[] b, int mod) {
int n = a.length;
int m = b.length;
int k = n + m - 1;
long[] ret = new long[k];
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
ret[i + j] += a[i] * b[j] % mod;
ret[i + j] %= mod;
}
}
return ret;
}
}
Feel free to share your approach. Suggestions are welcomed as always.