# XX - Editorial

Setter: Karanjeet Talwar
Tester & Editorialist: Taranpreet Singh

Medium-Hard

# PREREQUISITES

Expectation and Generating Functions

# PROBLEM

There are N magical stones available in a store. For each stone, there’s probability p that you’d put it in the bag independent of other stones. Since the store is magical, the cost of buying exactly r stones is (A*r+B)^K for some fixed integers A, B and K.

Given N, p, A and B, you want to compute the expected cost for each K such that 1 \leq K \leq K_{max}, K_{max} given in input.

# EXPLANATION

Let’s first understand the expected value here. Expected value is sum of probability weighted average of cost in all possible outcomes.

Focusing on all possible outcomes, we can see that for each stone, either it is picked, or not picked. Hence, there are total 2^N possible outcomes. Let’s try an approach trying all outcomes and computing expected cost.

### Brute Force

In this approach, we’ll try all subsets of N stones for each K, and find the sum of cost multiplied by probability of that outcome. Let’s assume C_K denotes the cost for fixed K. It is easy to see that

\displaystyle E[C_K] = (1-p)^N*\sum_{mask = 0}^{2^N-1} (A*bit(mask)+B)^K* \bigg (\frac{p}{1-p}\bigg )^ {bit(mask)}

We can iterate over all bitmasks for each K ans compute this sum in O(K_{max}*N*2^N). This won’t get any points, but is a good start.

We can notice in above approach that for each mask, all we care about is the number of 1 s in the mask. It is easy to notice that there are exactly \binom{N}{r} masks with exactly r bits set. Hence, we can rewrite summation as

\displaystyle E[C_K] = (1-p)^N* \sum_{r = 0}^N \binom{N}{r}*(A*r+B)^K*\bigg (\frac{p}{1-p}\bigg )^ r

Let’s replace \displaystyle \bigg (\frac{p}{1-p}\bigg ) with g. We have

\displaystyle E[C_K] = (1-p)^N* \sum_{r = 0}^N \binom{N}{r}*(A*r+B)^K*g^ r

We can iterate over all r for each K and compute the required expected value in O(N*K_{max}) which is sufficient for subtask 1.

### Core Idea

Now, it is time for some generating function. Let’s consider the sequence E[C_k] for each 0 \leq k \leq \infin and encode it using exponential generating function.

\displaystyle P(x) = \sum_{k = 0}^{\infin} E[C_k] * \frac{x^k}{k!}
\displaystyle P(x) = \sum_{k = 0}^{\infin} (1-p)^N * \sum_{r = 0}^N \binom{N}{r}*(A*r+B)^k*g^ r * \frac{x^k}{k!}

\displaystyle P(x) = (1-p)^N * \sum_{r = 0}^N \binom{N}{r}*g^r * \sum_{k = 0}^{\infin} (A*r+B)^k* \frac{x^k}{k!}
\displaystyle P(x) = (1-p)^N * \sum_{r = 0}^N \binom{N}{r}*g^r * \sum_{k = 0}^{\infin} \frac{((A*r+B)*x)^k}{k!}

The inner summation is just an expansion of e^{(A*r+B)*x}
\displaystyle P(x) = (1-p)^N * \sum_{r = 0}^N \binom{N}{r}*g^r * e^{(A*r+B)*x}
\displaystyle P(x) = (1-p)^N *e^{B*x}* \sum_{r = 0}^N \binom{N}{r}*g^r * e^{A*r*x}
\displaystyle P(x) = (1-p)^N *e^{B*x}* \sum_{r = 0}^N \binom{N}{r}*(g*e^{A*x})^r

Now, using the closed form (1+x)^N of binomial expansion, we can write
\displaystyle P(x) = (1-p)^N *e^{B*x}* (1+g*e^{A*x})^N

Hence the E[C_K] is the coefficient of x^k in expansion of above polynomial.

Hence, all we need to do is to figure out way to find expansion of e^{C*x} and finding first K+1 coefficients of N-th power of a polynomial.

Expansion of e^{C*x} is well known, it is \displaystyle \sum_{k = 0}^{\infin} C^k * \frac{x^k}{k!} which can be computed in O(K_{max}) time after precomputing factorials and their inverses.

For finding N-th power of a polynomial, we can try same idea as binary exponentiation.

o = {1}
a is the polynomial
while N > 0:
if N mod 2 == 1:
o = mul(o, a)
a = mul(a, a)
N /= 2


For this subtask, we can use naive polynomial multiplication to solve the problem in O(K^2*log(N)) time.

Instead of doing naive polynomial multiplication, we can use Number Theoretic Transform, as beautifully explained here to multiply polynomials in O(K*log(K)), solving the problem in O(K*log(K)*log(N))

The time to compute N-th power is dominating the time complexity, so we need to optimize that.

Since we can write P^N(x) = exp(N*ln(P(x))), so now we need to compute ways to exp(P(x)) and ln(P(x)) for some polynomial P(x). It can be done using Newton’s method, as explained here

Since computing both exp(P(x)) and ln(P(x)) can be computed in O(K*log(K)), this is sufficient to solve the whole problem.

# TIME COMPLEXITY

The time complexity is O(K_{max}*log(K_{max})) per test case.

# SOLUTIONS

Setter's Solution
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#include <vector>
#include <set>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <queue>
#include <ctime>
#include <cassert>
#include <complex>
#include <string>
#include <cstring>
#include <chrono>
#include <random>
#include <bitset>
using namespace std;

#ifdef LOCAL
#define eprintf(...) fprintf(stderr, __VA_ARGS__);fflush(stderr);
#else
#define eprintf(...) 42
#endif

using ll = long long;
using ld = long double;
using uint = unsigned int;
using ull = unsigned long long;
template<typename T>
using pair2 = pair<T, T>;
using pii = pair<int, int>;
using pli = pair<ll, int>;
using pll = pair<ll, ll>;
ll myRand(ll B) {
return (ull)rng() % B;
}

#define pb push_back
#define mp make_pair
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second

clock_t startTime;
double getCurrentTime() {
return (double)(clock() - startTime) / CLOCKS_PER_SEC;
}

//template
#define rep(i,a,b) for(int i=(a);i<(b);i++)
#define rrep(i,a,b) for(int i=(a);i>(b);i--)
#define ALL(v) (v).begin(),(v).end()
typedef long long int ll;
const int inf = 0x3fffffff; const ll INF = 0x1fffffffffffffff; const double eps=1e-12;
void tostr(ll x,string& res){while(x)res+=('0'+(x%10)),x/=10; reverse(ALL(res)); return;}
template<class T> inline bool chmax(T& a,T b){ if(a<b){a=b;return 1;}return 0; }
template<class T> inline bool chmin(T& a,T b){ if(a>b){a=b;return 1;}return 0; }
//end

template<unsigned mod=998244353>struct mint {
unsigned val;
static unsigned get_mod(){return mod;}
unsigned inv() const{
int tmp,a=val,b=mod,x=1,y=0;
while(b)tmp=a/b,a-=tmp*b,swap(a,b),x-=tmp*y,swap(x,y);
if(x<0)x+=mod; return x;
}
mint():val(0){}
mint(ll x):val(x>=0?x%mod:mod+(x%mod)){}
mint pow(ll t){mint res=1,b=*this; while(t){if(t&1)res*=b;b*=b;t>>=1;}return res;}
mint& operator+=(const mint& x){if((val+=x.val)>=mod)val-=mod;return *this;}
mint& operator-=(const mint& x){if((val+=mod-x.val)>=mod)val-=mod; return *this;}
mint& operator*=(const mint& x){val=ll(val)*x.val%mod; return *this;}
mint& operator/=(const mint& x){val=ll(val)*x.inv()%mod; return *this;}
mint operator+(const mint& x)const{return mint(*this)+=x;}
mint operator-(const mint& x)const{return mint(*this)-=x;}
mint operator*(const mint& x)const{return mint(*this)*=x;}
mint operator/(const mint& x)const{return mint(*this)/=x;}
bool operator==(const mint& x)const{return val==x.val;}
bool operator!=(const mint& x)const{return val!=x.val;}
};
template<unsigned mod=998244353>struct factorial {
using Mint=mint<mod>;
vector<Mint> Fact, Finv;
public:
factorial(int maxx){
Fact.resize(maxx+1),Finv.resize(maxx+1); Fact[0]=Mint(1); rep(i,0,maxx)Fact[i+1]=Fact[i]*(i+1);
Finv[maxx]=Mint(1)/Fact[maxx]; rrep(i,maxx,0)Finv[i-1]=Finv[i]*i;
}
Mint fact(int n,bool inv=0){if(inv)return Finv[n];else return Fact[n];}
Mint nPr(int n,int r){if(n<0||n<r||r<0)return Mint(0);else return Fact[n]*Finv[n-r];}
Mint nCr(int n,int r){if(n<0||n<r||r<0)return Mint(0);else return Fact[n]*Finv[r]*Finv[n-r];}
};
using Mint=mint<>;

vector<int> rt,irt;
template<unsigned mod=998244353>void init(int lg=21){
using Mint=mint<mod>; Mint prt=2;
while(prt.pow(mod>>1).val==1)prt+=1;
rt.resize(1<<lg,1); irt.resize(1<<lg,1);
rep(w,0,lg){
Mint g=prt.pow(t),ig=prt.pow(mod-1-t);
}
}
}

template<unsigned mod=998244353>struct FPS{
using Mint=mint<mod>; vector<Mint> f;
FPS():f({1}){}
FPS(int _n):f(_n){}
FPS(vector<Mint> _f):f(_f){}
Mint& operator[](const int i){return f[i];}
Mint eval(Mint x){
Mint res,w=1;
for(Mint v:f)res+=w*v,w*=x; return res;
}
FPS inv()const{
assert(f[0]!=0); int n=f.size();
FPS res(n); res.f[0]=f[0].inv();
for(int k=1;k<n;k<<=1){
FPS g(k*2),h(k*2);
rep(i,0,min(k*2,n))g[i]=f[i]; rep(i,0,k)h[i]=res[i];
g.ntt(); h.ntt(); rep(i,0,k*2)g[i]*=h[i]; g.ntt(1);
rep(i,0,k)g[i]=0,g[i+k]*=-1;
g.ntt(); rep(i,0,k*2)g[i]*=h[i]; g.ntt(1);
rep(i,k,min(k*2,n))res[i]=g[i];
} return res;
}
void ntt(bool inv=0){
int n=f.size(); if(n==1)return;
if(inv){
for(int i=1;i<n;i<<=1){
for(int j=0;j<n;j+=i*2){
rep(k,0,i){
f[i+j+k]*=irt[i*2-1+k];
const Mint tmp=f[j+k]-f[i+j+k];
f[j+k]+=f[i+j+k]; f[i+j+k]=tmp;
}
}
}
Mint mul=Mint(n).inv(); rep(i,0,n)f[i]*=mul;
}else{
for(int i=n>>1;i;i>>=1){
for(int j=0;j<n;j+=i*2){
rep(k,0,i){
const Mint tmp=f[j+k]-f[i+j+k];
f[j+k]+=f[i+j+k]; f[i+j+k]=tmp*rt[i*2-1+k];
}
}
}
}
}
FPS operator+(const FPS& g)const{return FPS(*this)+=g;}
FPS operator-(const FPS& g)const{return FPS(*this)-=g;}
FPS operator*(const FPS& g)const{return FPS(*this)*=g;}
template<class T>FPS operator*(T t)const{return FPS(*this)*=t;}
FPS operator/(const FPS& g)const{return FPS(*this)/=g;}
template<class T>FPS operator/(T t)const{return FPS(*this)/=t;}
FPS operator%(const FPS& g)const{return FPS(*this)%=g;}
FPS& operator+=(FPS g){
if(g.f.size()>f.size())f.resize(g.f.size());
rep(i,0,g.f.size())f[i]+=g[i]; return *this;
}
FPS& operator-=(FPS g){
if(g.f.size()>f.size())f.resize(g.f.size());
rep(i,0,g.f.size())f[i]-=g[i]; return *this;
}
FPS& operator*=(FPS g){
int m=f.size()+g.f.size()-1,n=1; while(n<m)n<<=1;
f.resize(n); g.f.resize(n);
ntt(); g.ntt(); rep(i,0,n)f[i]*=g[i];
ntt(1); f.resize(m); return *this;
}
template<class T>FPS& operator*=(T t){for(Mint x:f)x*=t; return *this;}
FPS& operator/=(FPS g){
if(g.f.size()>f.size())return *this=FPS({0});
reverse(ALL(f)); reverse(ALL(g.f));
int n=f.size()-g.f.size()+1;
f.resize(n); g.f.resize(n); FPS mul=g.inv();
*this*=mul; f.resize(n); reverse(ALL(f)); return *this;
}
template<class T>FPS& operator/=(T t){for(Mint x:f)x/=t; return *this;}
FPS& operator%=(FPS g){
*this-=*this/g*g;
while(!f.empty()&&f.back()==0)f.pop_back();
return *this;
}
FPS sqrt(){
int n=f.size(); FPS res(1); res[0]=1;
for(int k=1;k<n;k<<=1){
FPS ff=*this; res.f.resize(k*2);
res+=ff/res; res/=2;
} res.f.resize(n); return res;
}
FPS diff(){
FPS res=*this; rep(i,0,res.f.size()-1)res[i]=res[i+1]*(i+1);
res.f.pop_back(); return res;
}
FPS inte(){
FPS res=*this; res.f.push_back(0);
rrep(i,res.f.size()-1,0)res[i]=res[i-1]/i;
res[0]=0; return res;
}
FPS log(){
assert(f[0]==1); FPS res=diff()*inv();
res.f.resize(f.size()-1); res=res.inte(); return res;
}
FPS exp(){
assert(f[0]==0); int m=f.size(),n=1; while(n<m)n<<=1;
f.resize(n); FPS d=diff(),res(n); vector<FPS> pre;
for(int k=n;k;k>>=1){
FPS g=d; g.f.resize(k);
g.ntt(); pre.push_back(g);
}
auto dfs=[&](auto dfs,int l,int r,int dep)->void{
if(r-l==1){if(l>0)res[l]/=l; return;}
int m=(l+r)>>1; dfs(dfs,l,m,dep+1);
FPS g(r-l); rep(i,0,m-l)g[i]=res[l+i];
g.ntt(); rep(i,0,r-l)g[i]*=pre[dep][i]; g.ntt(1);
rep(i,m,r)res[i]+=g[i-l-1]; dfs(dfs,m,r,dep+1);
}; res[0]=1; dfs(dfs,0,n,0); res.f.resize(m); return res;
}
};//need to initialize

int n,t; int a[1010000];
int cnt[501000]={};
factorial<> fact(501000);
Mint inv[501000];

FPS<> substituteXplus(FPS<> A, Mint w) {
int n = (int)A.f.size();
for (int i = 0; i < n; i++)
A[i] *= fact.fact(i, 0);
vector<Mint> B(n);
Mint pw = 1;
for (int i = 0; i < n; i++) {
B[i] = pw * fact.fact(i, 1);
pw *= w;
}
reverse(all(B));
FPS<> C = A * FPS<>(B);
B = C.f;
rotate(B.begin(), B.begin() + n - 1, B.end());
B.resize(n);
for (int i = 0; i < n; i++)
B[i] *= fact.fact(i, 1);
return FPS<>(B);
}
/*
CALL INIT !!!
*/

int main()
{
startTime = clock();
//	freopen("input.txt", "r", stdin);
//	freopen("output.txt", "w", stdout);

init();
vector<Mint> A;
int n, k, a, b, p;
scanf("%d%d%d%d%d", &n, &k, &a, &b, &p);
A.resize(k + 1);
A[0] = Mint(1);
Mint cur = p;
for (int i = 1; i <= k; i++) {
cur *= Mint(a);
A[i] = cur * fact.fact(i, 1);
}
auto B = FPS<>(A);
B = B.log();
for (int i = 0; i <= k; i++)
B[i] *= Mint(n);
B[1] += Mint(b);
B = B.exp();
for (int i = 1; i <= k; i++) {
Mint x = B[i] * fact.fact(i);
printf("%u ", x.val);
}
printf("\n");

return 0;
}

Tester's Solution
import java.util.*;
import java.io.*;
class XX{
//SOLUTION BEGIN
int MOD = 998244353, GEN = 3;
int MAX = (int)1e6;
long[][] fif;
void pre() throws Exception{fif = fif(MAX);}
void solve(int TC) throws Exception{
int N = ni(), K = ni();
long a = nl(), b = nl(), p = nl();
long[] ans = subtask4(N, K, a, b, p);
StringBuilder o = new StringBuilder();
for(int i = 1; i<= K; i++)o.append(ans[i]+" ");
pn(o.toString());
}

//O(K*N*2^N) brute force
long[] brute(int N, int K, long A, long B, long P){
long[] ans = new long[1+K];
for(int k = 1; k <= K; k++){
long num = 0;
int count = 0;
for(int i = 0; i< N; i++)count += (mask>>i)&1;
long cur = 1;
for(int j = 0; j<k; j++)cur = cur*(A*count+B)%MOD;
for(int i = 0; i< count; i++)cur = cur*P%MOD;
for(int j = 0; j< N-count; j++)cur = cur*(1+MOD-P)%MOD;
num += cur;
if(num >= MOD)num -= MOD;
}
ans[k] = num;
}
return ans;
}
//O(N*K*log(MOD))
long[] subtask1(int N, int K, long A, long B, long P){
long G = P*pow(1+MOD-P, MOD-2, MOD)%MOD;
long coeff = pow(1+MOD-P, N, MOD);
long[] ans = new long[1+K];
for(int k = 1; k <= K; k++){
long total = 0;
for(int r = 0; r <= N; r++){
total += C(fif, N, r)*pow((A*r+B)%MOD, k, MOD)%MOD * pow(G, r, MOD)%MOD;
if(total >= MOD)total -= MOD;
}
ans[k] = (total*coeff)%MOD;
}
return ans;
}
//O(K*log(K)*log(N))
long[] subtask3(long N, int K, long A, long B, long P){
long G = P*pow(1+MOD-P, MOD-2, MOD)%MOD;
long coeff = pow(1+MOD-P, N, MOD);

long[] C = e(fif, B, 1+K), D = e(fif, A, 1+K);
for(int i = 0; i< D.length; i++)D[i] = D[i]*G%MOD;
D[0] = (D[0]+1)%MOD;
long[] E = binaryPow(D, N, 1+K);

long[] F = mul(C, E);
for(int i = 0; i<= 1+K; i++)F[i] = coeff*F[i]%MOD*fif[0][i]%MOD;
return Arrays.copyOf(F, 1+K);
}
//O(K*log(K))
long[] subtask4(long N, int K, long A, long B, long P){
long G = P*pow(1+MOD-P, MOD-2, MOD)%MOD;
long coeff = pow(1+MOD-P, N, MOD);
long[] C = e(fif, B, 1+K), D = e(fif, A, 1+K);
for(int i = 0; i< D.length; i++)D[i] = D[i]*G%MOD;
D[0] = (D[0]+1)%MOD;
long[] E = newtonPow(D, N, 1+K);
long[] F = mul(C, E);
for(int i = 0; i<= 1+K; i++)F[i] = coeff*F[i]%MOD*fif[0][i]%MOD;
return Arrays.copyOf(F, 1+K);
}
//Polynomial Exponentiation using Newton's method
long[] newtonPow(long[] a, long p, int max){
long v = a[0];long vn = pow(v, p, MOD);
long vInv = pow(v, MOD-2, MOD);
long[] b = Arrays.copyOf(a, a.length);
for(int i = 0; i< a.length; i++)b[i] = a[i]*vInv%MOD;
long[] bLog = ln(b);
for(int i = 0; i< max; i++)bLog[i] = bLog[i]*p%MOD;
long[] bExp = exp(bLog);
for(int i = 0; i<max; i++)bExp[i] = bExp[i]*vn%MOD;
return bExp;
}
//Polynomial inverse
long[] inv(long[] a){
a = Arrays.copyOf(a, Integer.highestOneBit(a.length)<<1);
int n = a.length;
assert (n & (n - 1)) == 0;
long r[] = new long[]{pow(a[0], MOD-2, MOD)};
for(int len = 2; len <= n; len *= 2) {
long nr[] = new long[len];
System.arraycopy(a, 0, nr, 0, len);
nr = mul(nr, mul(r, r));
for(int i = 0; i < len; ++i) nr[i] = (MOD - nr[i]) % MOD;
for(int i = 0; i < len / 2; ++i) nr[i] = (nr[i] + 2 * r[i]) % MOD;
r = new long[len];
System.arraycopy(nr, 0, r, 0, len);
}
return r;//clean(r);
}
//Derivative
long[] dx(long[] a){
if(a.length == 1)return new long[]{0};
long[] b = new long[a.length-1];
for(int i = 0; i< b.length; i++)b[i] = a[i+1]*(i+1)%MOD;
return b;
}
//Integral
long[] inte(long[] a){
long[] b = new long[1+a.length];
for(int i = 1; i< b.length; i++)b[i] = a[i-1] * fif[1][i]%MOD * fif[0][i-1]%MOD;///(i);
return b;
}
//Logarithm
long[] ln(long[] a){
assert(a[0] == 1);
long[] b = mul(dx(a), inv(a));
b = Arrays.copyOf(b, a.length-1);
return inte(b);
}
//Anti-logarithm
long[] exp(long[] a){
assert(a.length == 0 || a[0] == 0);
long[] b = new long[]{1};
while(b.length < a.length){
long[] x = new long[2*b.length];
System.arraycopy(a, 0, x, 0, Math.min(a.length, 2*b.length));
x[0]++;
b = Arrays.copyOf(b, 2*b.length);
long[] p = ln(b);
for(int i = 0; i< x.length; i++)
if(i < p.length)
x[i] = (x[i]+MOD-p[i])%MOD;

x = mul(x, b);
for(int i = b.length/2; i< b.length; i++)b[i] = x[i];
}
return b;
}
//Multiply two polynomials
long[] mul(long[] a, long[] b){
return Convolution.convolution(a, b, MOD);
}
//For subtask 2 and 3
long[] binaryPow(long[] a, long n, int max){
long[] o = new long[]{1};
while(n > 0){
if((n&1)==1){
o = mul(o, a);
if(o.length > max)o = Arrays.copyOf(o, max);
}
a = mul(a, a);
if(a.length > max)a = Arrays.copyOf(a, max);
n>>=1;
}
return o;
}
//Returns the first K coefficients of expansion of e^{Cx)
long[] e(long[][] fif, long C, int K){
long[] A = new long[1+K];
long cur = 1;
for(int i = 0; i<= K; i++){
A[i] = cur*fif[1][i]%MOD;
cur = cur*C%MOD;
}
return A;
}
long pow(long a, long n, long mod) {
long ret = 1;
int x = 63 - Long.numberOfLeadingZeros(n);
for (; x >= 0; x--){
ret = ret * ret % mod;
if (n << 63 - x < 0)ret = ret * a % mod;
}
return ret;
}
long inv(long x){return pow(x, MOD-2, MOD);}
long[][] fif(int mx){
mx++;
long[] F = new long[mx], IF = new long[mx];
F[0] = 1;
for(int i = 1; i< mx; i++)F[i] = (F[i-1]*i)%MOD;
//GFG
long M = MOD;
long y = 0, x = 1;
long a = F[mx-1];
while(a> 1){
long q = a/M;
long t = M;
M = a%M;
a = t;
t = y;
y = x-q*y;
x = t;
}
if(x<0)x+=MOD;
IF[mx-1] = x;
for(int i = mx-2; i>= 0; i--)IF[i] = (IF[i+1]*(i+1))%MOD;
return new long[][]{F, IF};
}
long C(long[][] fif, int n, int r){
if(n< 0 || n<r || r<0)return 0;
return (fif[0][n]*((fif[1][r]*fif[1][n-r])%MOD))%MOD;
}
static void dbg(Object... o){System.err.println(Arrays.deepToString(o));}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = false;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new XX().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}

StringTokenizer st;
}

public FastReader(String s) throws Exception{
}

String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException  e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}

String nextLine() throws Exception{
String str = "";
try{
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
//https://github.com/NASU41/AtCoderLibraryForJava/blob/master/Convolution/Convolution.java
/**
* Convolution.
*
* @verified https://judge.yosupo.jp/problem/convolution_mod_1000000007
*/
class Convolution {
/**
* Find a primitive root.
*
* @param m A prime number.
* @return Primitive root.
*/
private static int primitiveRoot(int m) {
if (m == 2) return 1;
if (m == 167772161) return 3;
if (m == 469762049) return 3;
if (m == 754974721) return 11;
if (m == 998244353) return 3;

int[] divs = new int[20];
divs[0] = 2;
int cnt = 1;
int x = (m - 1) / 2;
while (x % 2 == 0) x /= 2;
for (int i = 3; (long) (i) * i <= x; i += 2) {
if (x % i == 0) {
divs[cnt++] = i;
while (x % i == 0) {
x /= i;
}
}
}
if (x > 1) {
divs[cnt++] = x;
}
for (int g = 2; ; g++) {
boolean ok = true;
for (int i = 0; i < cnt; i++) {
if (pow(g, (m - 1) / divs[i], m) == 1) {
ok = false;
break;
}
}
if (ok) return g;
}
}

/**
* Power.
*
* @param x Parameter x.
* @param n Parameter n.
* @param m Mod.
* @return n-th power of x mod m.
*/
private static long pow(long x, long n, int m) {
if (m == 1) return 0;
long r = 1;
long y = x % m;
while (n > 0) {
if ((n & 1) != 0) r = (r * y) % m;
y = (y * y) % m;
n >>= 1;
}
return r;
}

/**
* Ceil of power 2.
*
* @param n Value.
* @return Ceil of power 2.
*/
private static int ceilPow2(int n) {
int x = 0;
while ((1L << x) < n) x++;
return x;
}

/**
* Garner's algorithm.
*
* @param c    Mod convolution results.
* @param mods Mods.
* @return Result.
*/
private static long garner(long[] c, int[] mods) {
int n = c.length + 1;
long[] cnst = new long[n];
long[] coef = new long[n];
java.util.Arrays.fill(coef, 1);
for (int i = 0; i < n - 1; i++) {
int m1 = mods[i];
long v = (c[i] - cnst[i] + m1) % m1;
v = v * pow(coef[i], m1 - 2, m1) % m1;

for (int j = i + 1; j < n; j++) {
long m2 = mods[j];
cnst[j] = (cnst[j] + coef[j] * v) % m2;
coef[j] = (coef[j] * m1) % m2;
}
}
return cnst[n - 1];
}

/**
* Pre-calculation for NTT.
*
* @param mod NTT Prime.
* @param g   Primitive root of mod.
* @return Pre-calculation table.
*/
private static long[] sumE(int mod, int g) {
long[] sum_e = new long[30];
long[] es = new long[30];
long[] ies = new long[30];
int cnt2 = Integer.numberOfTrailingZeros(mod - 1);
long e = pow(g, (mod - 1) >> cnt2, mod);
long ie = pow(e, mod - 2, mod);
for (int i = cnt2; i >= 2; i--) {
es[i - 2] = e;
ies[i - 2] = ie;
e = e * e % mod;
ie = ie * ie % mod;
}
long now = 1;
for (int i = 0; i < cnt2 - 2; i++) {
sum_e[i] = es[i] * now % mod;
now = now * ies[i] % mod;
}
return sum_e;
}

/**
* Pre-calculation for inverse NTT.
*
* @param mod Mod.
* @param g   Primitive root of mod.
* @return Pre-calculation table.
*/
private static long[] sumIE(int mod, int g) {
long[] sum_ie = new long[30];
long[] es = new long[30];
long[] ies = new long[30];

int cnt2 = Integer.numberOfTrailingZeros(mod - 1);
long e = pow(g, (mod - 1) >> cnt2, mod);
long ie = pow(e, mod - 2, mod);
for (int i = cnt2; i >= 2; i--) {
es[i - 2] = e;
ies[i - 2] = ie;
e = e * e % mod;
ie = ie * ie % mod;
}
long now = 1;
for (int i = 0; i < cnt2 - 2; i++) {
sum_ie[i] = ies[i] * now % mod;
now = now * es[i] % mod;
}
return sum_ie;
}

/**
* Inverse NTT.
*
* @param a     Target array.
* @param sumIE Pre-calculation table.
* @param mod   NTT Prime.
*/
private static void butterflyInv(long[] a, long[] sumIE, int mod) {
int n = a.length;
int h = ceilPow2(n);

for (int ph = h; ph >= 1; ph--) {
int w = 1 << (ph - 1), p = 1 << (h - ph);
long inow = 1;
for (int s = 0; s < w; s++) {
int offset = s << (h - ph + 1);
for (int i = 0; i < p; i++) {
long l = a[i + offset];
long r = a[i + offset + p];
a[i + offset] = (l + r) % mod;
a[i + offset + p] = (mod + l - r) * inow % mod;
}
int x = Integer.numberOfTrailingZeros(~s);
inow = inow * sumIE[x] % mod;
}
}
}

/**
* Inverse NTT.
*
* @param a    Target array.
* @param sumE Pre-calculation table.
* @param mod  NTT Prime.
*/
private static void butterfly(long[] a, long[] sumE, int mod) {
int n = a.length;
int h = ceilPow2(n);

for (int ph = 1; ph <= h; ph++) {
int w = 1 << (ph - 1), p = 1 << (h - ph);
long now = 1;
for (int s = 0; s < w; s++) {
int offset = s << (h - ph + 1);
for (int i = 0; i < p; i++) {
long l = a[i + offset];
long r = a[i + offset + p] * now % mod;
a[i + offset] = (l + r) % mod;
a[i + offset + p] = (l - r + mod) % mod;
}
int x = Integer.numberOfTrailingZeros(~s);
now = now * sumE[x] % mod;
}
}
}

/**
* Convolution.
*
* @param a   Target array 1.
* @param b   Target array 2.
* @param mod NTT Prime.
*/
public static long[] convolution(long[] a, long[] b, int mod) {
int n = a.length;
int m = b.length;
if (n == 0 || m == 0) return new long[0];

int z = 1 << ceilPow2(n + m - 1);
{
long[] na = new long[z];
long[] nb = new long[z];
System.arraycopy(a, 0, na, 0, n);
System.arraycopy(b, 0, nb, 0, m);
a = na;
b = nb;
}

int g = primitiveRoot(mod);
long[] sume = sumE(mod, g);
long[] sumie = sumIE(mod, g);

butterfly(a, sume, mod);
butterfly(b, sume, mod);
for (int i = 0; i < z; i++) {
a[i] = a[i] * b[i] % mod;
}
butterflyInv(a, sumie, mod);
a = java.util.Arrays.copyOf(a, n + m - 1);

long iz = pow(z, mod - 2, mod);
for (int i = 0; i < n + m - 1; i++) a[i] = a[i] * iz % mod;
return a;
}

/**
* Convolution.
*
* @param a   Target array 1.
* @param b   Target array 2.
* @param mod Any mod.
*/
public static long[] convolutionLL(long[] a, long[] b, int mod) {
int n = a.length;
int m = b.length;
if (n == 0 || m == 0) return new long[0];

int mod1 = 754974721;
int mod2 = 167772161;
int mod3 = 469762049;

long[] c1 = convolution(a, b, mod1);
long[] c2 = convolution(a, b, mod2);
long[] c3 = convolution(a, b, mod3);

int retSize = c1.length;
long[] ret = new long[retSize];
int[] mods = {mod1, mod2, mod3, mod};
for (int i = 0; i < retSize; ++i) {
ret[i] = garner(new long[]{c1[i], c2[i], c3[i]}, mods);
}
return ret;
}

/**
* Naive convolution. (Complexity is O(N^2)!!)
*
* @param a   Target array 1.
* @param b   Target array 2.
* @param mod Mod.
*/
public static long[] convolutionNaive(long[] a, long[] b, int mod) {
int n = a.length;
int m = b.length;
int k = n + m - 1;
long[] ret = new long[k];
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
ret[i + j] += a[i] * b[j] % mod;
ret[i + j] %= mod;
}
}
return ret;
}
}


Feel free to share your approach. Suggestions are welcomed as always.

1 Like

Really nice problem, can you provide some more problems which can be solved by generating functions.

You can try this.

1 Like

Actually O(K*log(N)*log(K)) was also getting 100 points.

Yes.

Saw a couple of those to pass too. Atcoder template is quite fast. In editorial, I listed the approaches we intended for each subtask to pass.

2 Likes