XX - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Practice

Setter: Karanjeet Talwar
Tester & Editorialist: Taranpreet Singh

DIFFICULTY

Medium-Hard

PREREQUISITES

Expectation and Generating Functions

PROBLEM

There are N magical stones available in a store. For each stone, there’s probability p that you’d put it in the bag independent of other stones. Since the store is magical, the cost of buying exactly r stones is (A*r+B)^K for some fixed integers A, B and K.

Given N, p, A and B, you want to compute the expected cost for each K such that 1 \leq K \leq K_{max}, K_{max} given in input.

EXPLANATION

Let’s first understand the expected value here. Expected value is sum of probability weighted average of cost in all possible outcomes.

Focusing on all possible outcomes, we can see that for each stone, either it is picked, or not picked. Hence, there are total 2^N possible outcomes. Let’s try an approach trying all outcomes and computing expected cost.

Brute Force

In this approach, we’ll try all subsets of N stones for each K, and find the sum of cost multiplied by probability of that outcome. Let’s assume C_K denotes the cost for fixed K. It is easy to see that

\displaystyle E[C_K] = \sum_{mask = 0}^{2^N-1} (A*bit(mask)+B)^K*p^{bit(mask)}*(1-p)^{N-bit(mask)}, where bit(mask) denoting the number of ones in the mask.

\displaystyle E[C_K] = (1-p)^N*\sum_{mask = 0}^{2^N-1} (A*bit(mask)+B)^K* \bigg (\frac{p}{1-p}\bigg )^ {bit(mask)}

We can iterate over all bitmasks for each K ans compute this sum in O(K_{max}*N*2^N). This won’t get any points, but is a good start.

Subtask 1

We can notice in above approach that for each mask, all we care about is the number of 1 s in the mask. It is easy to notice that there are exactly \binom{N}{r} masks with exactly r bits set. Hence, we can rewrite summation as

\displaystyle E[C_K] = (1-p)^N* \sum_{r = 0}^N \binom{N}{r}*(A*r+B)^K*\bigg (\frac{p}{1-p}\bigg )^ r

Let’s replace \displaystyle \bigg (\frac{p}{1-p}\bigg ) with g. We have

\displaystyle E[C_K] = (1-p)^N* \sum_{r = 0}^N \binom{N}{r}*(A*r+B)^K*g^ r

We can iterate over all r for each K and compute the required expected value in O(N*K_{max}) which is sufficient for subtask 1.

Core Idea

Now, it is time for some generating function. Let’s consider the sequence E[C_k] for each 0 \leq k \leq \infin and encode it using exponential generating function.

\displaystyle P(x) = \sum_{k = 0}^{\infin} E[C_k] * \frac{x^k}{k!}
\displaystyle P(x) = \sum_{k = 0}^{\infin} (1-p)^N * \sum_{r = 0}^N \binom{N}{r}*(A*r+B)^k*g^ r * \frac{x^k}{k!}

\displaystyle P(x) = (1-p)^N * \sum_{r = 0}^N \binom{N}{r}*g^r * \sum_{k = 0}^{\infin} (A*r+B)^k* \frac{x^k}{k!}
\displaystyle P(x) = (1-p)^N * \sum_{r = 0}^N \binom{N}{r}*g^r * \sum_{k = 0}^{\infin} \frac{((A*r+B)*x)^k}{k!}

The inner summation is just an expansion of e^{(A*r+B)*x}
\displaystyle P(x) = (1-p)^N * \sum_{r = 0}^N \binom{N}{r}*g^r * e^{(A*r+B)*x}
\displaystyle P(x) = (1-p)^N *e^{B*x}* \sum_{r = 0}^N \binom{N}{r}*g^r * e^{A*r*x}
\displaystyle P(x) = (1-p)^N *e^{B*x}* \sum_{r = 0}^N \binom{N}{r}*(g*e^{A*x})^r

Now, using the closed form (1+x)^N of binomial expansion, we can write
\displaystyle P(x) = (1-p)^N *e^{B*x}* (1+g*e^{A*x})^N

Hence the E[C_K] is the coefficient of x^k in expansion of above polynomial.

Hence, all we need to do is to figure out way to find expansion of e^{C*x} and finding first K+1 coefficients of N-th power of a polynomial.

Expansion of e^{C*x} is well known, it is \displaystyle \sum_{k = 0}^{\infin} C^k * \frac{x^k}{k!} which can be computed in O(K_{max}) time after precomputing factorials and their inverses.

Subtask 2

For finding N-th power of a polynomial, we can try same idea as binary exponentiation.

o = {1}
a is the polynomial
while N > 0:
     if N mod 2 == 1:
            o = mul(o, a)
     a = mul(a, a)
     N /= 2

For this subtask, we can use naive polynomial multiplication to solve the problem in O(K^2*log(N)) time.

Subtask 3

Instead of doing naive polynomial multiplication, we can use Number Theoretic Transform, as beautifully explained here to multiply polynomials in O(K*log(K)), solving the problem in O(K*log(K)*log(N))

Final subtask

The time to compute N-th power is dominating the time complexity, so we need to optimize that.

Since we can write P^N(x) = exp(N*ln(P(x))), so now we need to compute ways to exp(P(x)) and ln(P(x)) for some polynomial P(x). It can be done using Newton’s method, as explained here

Since computing both exp(P(x)) and ln(P(x)) can be computed in O(K*log(K)), this is sufficient to solve the whole problem.

TIME COMPLEXITY

The time complexity is O(K_{max}*log(K_{max})) per test case.

SOLUTIONS

Setter's Solution
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#include <vector>
#include <set>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <queue>
#include <ctime>
#include <cassert>
#include <complex>
#include <string>
#include <cstring>
#include <chrono>
#include <random>
#include <bitset>
using namespace std;
 
#ifdef LOCAL
    #define eprintf(...) fprintf(stderr, __VA_ARGS__);fflush(stderr);
#else
    #define eprintf(...) 42
#endif
 
using ll = long long;
using ld = long double;
using uint = unsigned int;
using ull = unsigned long long;
template<typename T>
using pair2 = pair<T, T>;
using pii = pair<int, int>;
using pli = pair<ll, int>;
using pll = pair<ll, ll>;
mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
ll myRand(ll B) {
    return (ull)rng() % B;
}
 
#define pb push_back
#define mp make_pair
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second
 
clock_t startTime;
double getCurrentTime() {
    return (double)(clock() - startTime) / CLOCKS_PER_SEC;
}
 
 
//template
#define rep(i,a,b) for(int i=(a);i<(b);i++)
#define rrep(i,a,b) for(int i=(a);i>(b);i--)
#define ALL(v) (v).begin(),(v).end()
typedef long long int ll;
const int inf = 0x3fffffff; const ll INF = 0x1fffffffffffffff; const double eps=1e-12;
void tostr(ll x,string& res){while(x)res+=('0'+(x%10)),x/=10; reverse(ALL(res)); return;}
template<class T> inline bool chmax(T& a,T b){ if(a<b){a=b;return 1;}return 0; }
template<class T> inline bool chmin(T& a,T b){ if(a>b){a=b;return 1;}return 0; }
//end
 
template<unsigned mod=998244353>struct mint {
   unsigned val;
   static unsigned get_mod(){return mod;}
   unsigned inv() const{
      int tmp,a=val,b=mod,x=1,y=0;
      while(b)tmp=a/b,a-=tmp*b,swap(a,b),x-=tmp*y,swap(x,y);
      if(x<0)x+=mod; return x;
   }
   mint():val(0){}
   mint(ll x):val(x>=0?x%mod:mod+(x%mod)){}
   mint pow(ll t){mint res=1,b=*this; while(t){if(t&1)res*=b;b*=b;t>>=1;}return res;}
   mint& operator+=(const mint& x){if((val+=x.val)>=mod)val-=mod;return *this;}
   mint& operator-=(const mint& x){if((val+=mod-x.val)>=mod)val-=mod; return *this;}
   mint& operator*=(const mint& x){val=ll(val)*x.val%mod; return *this;}
   mint& operator/=(const mint& x){val=ll(val)*x.inv()%mod; return *this;}
   mint operator+(const mint& x)const{return mint(*this)+=x;}
   mint operator-(const mint& x)const{return mint(*this)-=x;}
   mint operator*(const mint& x)const{return mint(*this)*=x;}
   mint operator/(const mint& x)const{return mint(*this)/=x;}
   bool operator==(const mint& x)const{return val==x.val;}
   bool operator!=(const mint& x)const{return val!=x.val;}
};
template<unsigned mod=998244353>struct factorial {
   using Mint=mint<mod>;
   vector<Mint> Fact, Finv;
public:
   factorial(int maxx){
      Fact.resize(maxx+1),Finv.resize(maxx+1); Fact[0]=Mint(1); rep(i,0,maxx)Fact[i+1]=Fact[i]*(i+1);
      Finv[maxx]=Mint(1)/Fact[maxx]; rrep(i,maxx,0)Finv[i-1]=Finv[i]*i;
   }
   Mint fact(int n,bool inv=0){if(inv)return Finv[n];else return Fact[n];}
   Mint nPr(int n,int r){if(n<0||n<r||r<0)return Mint(0);else return Fact[n]*Finv[n-r];}
   Mint nCr(int n,int r){if(n<0||n<r||r<0)return Mint(0);else return Fact[n]*Finv[r]*Finv[n-r];}
};
using Mint=mint<>;
 
vector<int> rt,irt;
template<unsigned mod=998244353>void init(int lg=21){
   using Mint=mint<mod>; Mint prt=2;
   while(prt.pow(mod>>1).val==1)prt+=1;
   rt.resize(1<<lg,1); irt.resize(1<<lg,1);
   rep(w,0,lg){
      int mask=(1<<w)-1,t=Mint(-1).val>>w;
      Mint g=prt.pow(t),ig=prt.pow(mod-1-t);
      rep(i,0,mask){
         rt[mask+i+1]=(g*rt[mask+i]).val;
         irt[mask+i+1]=(ig*irt[mask+i]).val;
      }
   }
}
 
template<unsigned mod=998244353>struct FPS{
   using Mint=mint<mod>; vector<Mint> f;
   FPS():f({1}){}
   FPS(int _n):f(_n){}
   FPS(vector<Mint> _f):f(_f){}
   Mint& operator[](const int i){return f[i];}
   Mint eval(Mint x){
      Mint res,w=1;
      for(Mint v:f)res+=w*v,w*=x; return res;
   }
   FPS inv()const{
      assert(f[0]!=0); int n=f.size();
      FPS res(n); res.f[0]=f[0].inv();
      for(int k=1;k<n;k<<=1){
         FPS g(k*2),h(k*2);
         rep(i,0,min(k*2,n))g[i]=f[i]; rep(i,0,k)h[i]=res[i];
         g.ntt(); h.ntt(); rep(i,0,k*2)g[i]*=h[i]; g.ntt(1);
         rep(i,0,k)g[i]=0,g[i+k]*=-1;
         g.ntt(); rep(i,0,k*2)g[i]*=h[i]; g.ntt(1);
         rep(i,k,min(k*2,n))res[i]=g[i];
      } return res;
   }
   void ntt(bool inv=0){
        int n=f.size(); if(n==1)return;
        if(inv){
            for(int i=1;i<n;i<<=1){
                for(int j=0;j<n;j+=i*2){
                    rep(k,0,i){
                        f[i+j+k]*=irt[i*2-1+k];
                        const Mint tmp=f[j+k]-f[i+j+k];
                        f[j+k]+=f[i+j+k]; f[i+j+k]=tmp;
                    }
                }
            }
            Mint mul=Mint(n).inv(); rep(i,0,n)f[i]*=mul;
        }else{
            for(int i=n>>1;i;i>>=1){
                for(int j=0;j<n;j+=i*2){
                    rep(k,0,i){
                        const Mint tmp=f[j+k]-f[i+j+k];
                        f[j+k]+=f[i+j+k]; f[i+j+k]=tmp*rt[i*2-1+k];
                    }
                }
            }
        }
   }
   FPS operator+(const FPS& g)const{return FPS(*this)+=g;}
   FPS operator-(const FPS& g)const{return FPS(*this)-=g;}
   FPS operator*(const FPS& g)const{return FPS(*this)*=g;}
   template<class T>FPS operator*(T t)const{return FPS(*this)*=t;}
   FPS operator/(const FPS& g)const{return FPS(*this)/=g;}
   template<class T>FPS operator/(T t)const{return FPS(*this)/=t;}
   FPS operator%(const FPS& g)const{return FPS(*this)%=g;}
   FPS& operator+=(FPS g){
      if(g.f.size()>f.size())f.resize(g.f.size());
      rep(i,0,g.f.size())f[i]+=g[i]; return *this;
   }
   FPS& operator-=(FPS g){
      if(g.f.size()>f.size())f.resize(g.f.size());
      rep(i,0,g.f.size())f[i]-=g[i]; return *this;
   }
   FPS& operator*=(FPS g){
      int m=f.size()+g.f.size()-1,n=1; while(n<m)n<<=1;
      f.resize(n); g.f.resize(n);
      ntt(); g.ntt(); rep(i,0,n)f[i]*=g[i]; 
      ntt(1); f.resize(m); return *this;
   }
   template<class T>FPS& operator*=(T t){for(Mint x:f)x*=t; return *this;}
   FPS& operator/=(FPS g){
      if(g.f.size()>f.size())return *this=FPS({0});
      reverse(ALL(f)); reverse(ALL(g.f));
      int n=f.size()-g.f.size()+1;
      f.resize(n); g.f.resize(n); FPS mul=g.inv();
      *this*=mul; f.resize(n); reverse(ALL(f)); return *this;
   }
   template<class T>FPS& operator/=(T t){for(Mint x:f)x/=t; return *this;}
   FPS& operator%=(FPS g){
      *this-=*this/g*g;
      while(!f.empty()&&f.back()==0)f.pop_back();
      return *this;
   }
   FPS sqrt(){
      int n=f.size(); FPS res(1); res[0]=1;
      for(int k=1;k<n;k<<=1){
         FPS ff=*this; res.f.resize(k*2);
         res+=ff/res; res/=2;
      } res.f.resize(n); return res;
   }
   FPS diff(){
      FPS res=*this; rep(i,0,res.f.size()-1)res[i]=res[i+1]*(i+1);
      res.f.pop_back(); return res;
   }
   FPS inte(){
      FPS res=*this; res.f.push_back(0);
      rrep(i,res.f.size()-1,0)res[i]=res[i-1]/i;
      res[0]=0; return res;
   }
   FPS log(){
      assert(f[0]==1); FPS res=diff()*inv(); 
      res.f.resize(f.size()-1); res=res.inte(); return res;
   }
   FPS exp(){
      assert(f[0]==0); int m=f.size(),n=1; while(n<m)n<<=1;
      f.resize(n); FPS d=diff(),res(n); vector<FPS> pre;
      for(int k=n;k;k>>=1){
         FPS g=d; g.f.resize(k);
         g.ntt(); pre.push_back(g);
      }
      auto dfs=[&](auto dfs,int l,int r,int dep)->void{
         if(r-l==1){if(l>0)res[l]/=l; return;}
         int m=(l+r)>>1; dfs(dfs,l,m,dep+1);
         FPS g(r-l); rep(i,0,m-l)g[i]=res[l+i];
         g.ntt(); rep(i,0,r-l)g[i]*=pre[dep][i]; g.ntt(1);
         rep(i,m,r)res[i]+=g[i-l-1]; dfs(dfs,m,r,dep+1);
      }; res[0]=1; dfs(dfs,0,n,0); res.f.resize(m); return res;
   }
};//need to initialize
 
int n,t; int a[1010000];
int cnt[501000]={};
factorial<> fact(501000);
Mint inv[501000];
 
FPS<> substituteXplus(FPS<> A, Mint w) {
   int n = (int)A.f.size();
   for (int i = 0; i < n; i++)
      A[i] *= fact.fact(i, 0);
   vector<Mint> B(n);
   Mint pw = 1;
   for (int i = 0; i < n; i++) {
      B[i] = pw * fact.fact(i, 1);
      pw *= w;
   }
   reverse(all(B));
   FPS<> C = A * FPS<>(B);
   B = C.f;
   rotate(B.begin(), B.begin() + n - 1, B.end());
   B.resize(n);
   for (int i = 0; i < n; i++)
      B[i] *= fact.fact(i, 1);
   return FPS<>(B);
}
/*
   CALL INIT !!!
*/
 
 
int main()
{
    startTime = clock();
//	freopen("input.txt", "r", stdin);
//	freopen("output.txt", "w", stdout);
 
    init();
    vector<Mint> A;
    int n, k, a, b, p;
    scanf("%d%d%d%d%d", &n, &k, &a, &b, &p);
    A.resize(k + 1);
    A[0] = Mint(1);
    Mint cur = p;
    for (int i = 1; i <= k; i++) {
	    cur *= Mint(a);
	    A[i] = cur * fact.fact(i, 1);
    }
    auto B = FPS<>(A);
    B = B.log();
    for (int i = 0; i <= k; i++)
	    B[i] *= Mint(n);
    B[1] += Mint(b);
    B = B.exp();
    for (int i = 1; i <= k; i++) {
	    Mint x = B[i] * fact.fact(i);
	    printf("%u ", x.val);
    }
    printf("\n");
 
    return 0;
}
Tester's Solution
import java.util.*;
import java.io.*;
class XX{
    //SOLUTION BEGIN
    int MOD = 998244353, GEN = 3;
    int MAX = (int)1e6;
    long[][] fif;
    void pre() throws Exception{fif = fif(MAX);}
    void solve(int TC) throws Exception{
        int N = ni(), K = ni();
        long a = nl(), b = nl(), p = nl();
        long[] ans = subtask4(N, K, a, b, p);
        StringBuilder o = new StringBuilder();
        for(int i = 1; i<= K; i++)o.append(ans[i]+" ");
        pn(o.toString());
    }
    
    //O(K*N*2^N) brute force
    long[] brute(int N, int K, long A, long B, long P){
        long[] ans = new long[1+K];
        for(int k = 1; k <= K; k++){
            long num = 0;
            for(int mask = 0; mask< 1<<N; mask++){
                int count = 0;
                for(int i = 0; i< N; i++)count += (mask>>i)&1;
                long cur = 1;
                for(int j = 0; j<k; j++)cur = cur*(A*count+B)%MOD;
                for(int i = 0; i< count; i++)cur = cur*P%MOD;
                for(int j = 0; j< N-count; j++)cur = cur*(1+MOD-P)%MOD;
                num += cur;
                if(num >= MOD)num -= MOD;
            }
            ans[k] = num;
        }
        return ans;
    }
    //O(N*K*log(MOD))
    long[] subtask1(int N, int K, long A, long B, long P){
        long G = P*pow(1+MOD-P, MOD-2, MOD)%MOD;
        long coeff = pow(1+MOD-P, N, MOD);
        long[] ans = new long[1+K];
        for(int k = 1; k <= K; k++){
            long total = 0;
            for(int r = 0; r <= N; r++){
                total += C(fif, N, r)*pow((A*r+B)%MOD, k, MOD)%MOD * pow(G, r, MOD)%MOD;
                if(total >= MOD)total -= MOD;
            }
            ans[k] = (total*coeff)%MOD;
        }
        return ans;
    }
    //O(K*log(K)*log(N))
    long[] subtask3(long N, int K, long A, long B, long P){
        long G = P*pow(1+MOD-P, MOD-2, MOD)%MOD;
        long coeff = pow(1+MOD-P, N, MOD);
        
        long[] C = e(fif, B, 1+K), D = e(fif, A, 1+K);
        for(int i = 0; i< D.length; i++)D[i] = D[i]*G%MOD;
        D[0] = (D[0]+1)%MOD;
        long[] E = binaryPow(D, N, 1+K);
        
        long[] F = mul(C, E);
        for(int i = 0; i<= 1+K; i++)F[i] = coeff*F[i]%MOD*fif[0][i]%MOD;
        return Arrays.copyOf(F, 1+K);
    }
    //O(K*log(K))
    long[] subtask4(long N, int K, long A, long B, long P){
        long G = P*pow(1+MOD-P, MOD-2, MOD)%MOD;
        long coeff = pow(1+MOD-P, N, MOD);
        long[] C = e(fif, B, 1+K), D = e(fif, A, 1+K);
        for(int i = 0; i< D.length; i++)D[i] = D[i]*G%MOD;
        D[0] = (D[0]+1)%MOD;
        long[] E = newtonPow(D, N, 1+K);
        long[] F = mul(C, E);
        for(int i = 0; i<= 1+K; i++)F[i] = coeff*F[i]%MOD*fif[0][i]%MOD;
        return Arrays.copyOf(F, 1+K);
    }
    //Polynomial Exponentiation using Newton's method
    long[] newtonPow(long[] a, long p, int max){
        long v = a[0];long vn = pow(v, p, MOD);
        long vInv = pow(v, MOD-2, MOD);
        long[] b = Arrays.copyOf(a, a.length);
        for(int i = 0; i< a.length; i++)b[i] = a[i]*vInv%MOD;
        long[] bLog = ln(b);
        for(int i = 0; i< max; i++)bLog[i] = bLog[i]*p%MOD;
        long[] bExp = exp(bLog);
        for(int i = 0; i<max; i++)bExp[i] = bExp[i]*vn%MOD;
        return bExp;
    }
    //Polynomial inverse
    long[] inv(long[] a){
        a = Arrays.copyOf(a, Integer.highestOneBit(a.length)<<1);
        int n = a.length;
        assert (n & (n - 1)) == 0;
        long r[] = new long[]{pow(a[0], MOD-2, MOD)};
        for(int len = 2; len <= n; len *= 2) {
            long nr[] = new long[len];
            System.arraycopy(a, 0, nr, 0, len);
            nr = mul(nr, mul(r, r));
            for(int i = 0; i < len; ++i) nr[i] = (MOD - nr[i]) % MOD;
            for(int i = 0; i < len / 2; ++i) nr[i] = (nr[i] + 2 * r[i]) % MOD;
            r = new long[len];
            System.arraycopy(nr, 0, r, 0, len);
        }
        return r;//clean(r);
    }
    //Derivative
    long[] dx(long[] a){
        if(a.length == 1)return new long[]{0};
        long[] b = new long[a.length-1];
        for(int i = 0; i< b.length; i++)b[i] = a[i+1]*(i+1)%MOD;
        return b;
    }
    //Integral
    long[] inte(long[] a){
        long[] b = new long[1+a.length];
        for(int i = 1; i< b.length; i++)b[i] = a[i-1] * fif[1][i]%MOD * fif[0][i-1]%MOD;///(i);
        return b;
    }
    //Logarithm
    long[] ln(long[] a){
        assert(a[0] == 1);
        long[] b = mul(dx(a), inv(a));
        b = Arrays.copyOf(b, a.length-1);
        return inte(b);
    }
    //Anti-logarithm
    long[] exp(long[] a){
        assert(a.length == 0 || a[0] == 0);
        long[] b = new long[]{1};
        while(b.length < a.length){
            long[] x = new long[2*b.length];
            System.arraycopy(a, 0, x, 0, Math.min(a.length, 2*b.length));
            x[0]++;
            b = Arrays.copyOf(b, 2*b.length);
            long[] p = ln(b);
            for(int i = 0; i< x.length; i++)
                if(i < p.length)
                    x[i] = (x[i]+MOD-p[i])%MOD;
            
            x = mul(x, b);
            for(int i = b.length/2; i< b.length; i++)b[i] = x[i];
        }
        return b;
    }
    //Multiply two polynomials
    long[] mul(long[] a, long[] b){
        return Convolution.convolution(a, b, MOD);
    }
    //For subtask 2 and 3
    long[] binaryPow(long[] a, long n, int max){
        long[] o = new long[]{1};
        while(n > 0){
            if((n&1)==1){
                o = mul(o, a);
                if(o.length > max)o = Arrays.copyOf(o, max);
            }
            a = mul(a, a);
            if(a.length > max)a = Arrays.copyOf(a, max);
            n>>=1;
        }
        return o;
    }
    //Returns the first K coefficients of expansion of e^{Cx)
    long[] e(long[][] fif, long C, int K){
        long[] A = new long[1+K];
        long cur = 1;
        for(int i = 0; i<= K; i++){
            A[i] = cur*fif[1][i]%MOD;
            cur = cur*C%MOD;
        }
        return A;
    }
    long pow(long a, long n, long mod) {
        long ret = 1;
        int x = 63 - Long.numberOfLeadingZeros(n);
        for (; x >= 0; x--){
            ret = ret * ret % mod;
            if (n << 63 - x < 0)ret = ret * a % mod;
        }
        return ret;
    }
    long inv(long x){return pow(x, MOD-2, MOD);}
    long[][] fif(int mx){
        mx++;
        long[] F = new long[mx], IF = new long[mx];
        F[0] = 1;
        for(int i = 1; i< mx; i++)F[i] = (F[i-1]*i)%MOD;
        //GFG
        long M = MOD; 
        long y = 0, x = 1; 
        long a = F[mx-1];
        while(a> 1){ 
            long q = a/M;
            long t = M; 
            M = a%M;
            a = t;
            t = y;
            y = x-q*y;
            x = t;
        } 
        if(x<0)x+=MOD;
        IF[mx-1] = x;
        for(int i = mx-2; i>= 0; i--)IF[i] = (IF[i+1]*(i+1))%MOD;
        return new long[][]{F, IF};
    }
    long C(long[][] fif, int n, int r){
        if(n< 0 || n<r || r<0)return 0;
        return (fif[0][n]*((fif[1][r]*fif[1][n-r])%MOD))%MOD;
    }
    static void dbg(Object... o){System.err.println(Arrays.deepToString(o));}
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = false;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
        new XX().run();
    }
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
                }
            }
            return st.nextToken();
        }

        String nextLine() throws Exception{
            String str = "";
            try{   
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            }  
            return str;
        }
    }
}
//https://github.com/NASU41/AtCoderLibraryForJava/blob/master/Convolution/Convolution.java
/**
 * Convolution.
 *
 * @verified https://atcoder.jp/contests/practice2/tasks/practice2_f
 * @verified https://judge.yosupo.jp/problem/convolution_mod_1000000007
 */
class Convolution {
    /**
     * Find a primitive root.
     *
     * @param m A prime number.
     * @return Primitive root.
     */
    private static int primitiveRoot(int m) {
        if (m == 2) return 1;
        if (m == 167772161) return 3;
        if (m == 469762049) return 3;
        if (m == 754974721) return 11;
        if (m == 998244353) return 3;

        int[] divs = new int[20];
        divs[0] = 2;
        int cnt = 1;
        int x = (m - 1) / 2;
        while (x % 2 == 0) x /= 2;
        for (int i = 3; (long) (i) * i <= x; i += 2) {
            if (x % i == 0) {
                divs[cnt++] = i;
                while (x % i == 0) {
                    x /= i;
                }
            }
        }
        if (x > 1) {
            divs[cnt++] = x;
        }
        for (int g = 2; ; g++) {
            boolean ok = true;
            for (int i = 0; i < cnt; i++) {
                if (pow(g, (m - 1) / divs[i], m) == 1) {
                    ok = false;
                    break;
                }
            }
            if (ok) return g;
        }
    }

    /**
     * Power.
     *
     * @param x Parameter x.
     * @param n Parameter n.
     * @param m Mod.
     * @return n-th power of x mod m.
     */
    private static long pow(long x, long n, int m) {
        if (m == 1) return 0;
        long r = 1;
        long y = x % m;
        while (n > 0) {
            if ((n & 1) != 0) r = (r * y) % m;
            y = (y * y) % m;
            n >>= 1;
        }
        return r;
    }

    /**
     * Ceil of power 2.
     *
     * @param n Value.
     * @return Ceil of power 2.
     */
    private static int ceilPow2(int n) {
        int x = 0;
        while ((1L << x) < n) x++;
        return x;
    }

    /**
     * Garner's algorithm.
     *
     * @param c    Mod convolution results.
     * @param mods Mods.
     * @return Result.
     */
    private static long garner(long[] c, int[] mods) {
        int n = c.length + 1;
        long[] cnst = new long[n];
        long[] coef = new long[n];
        java.util.Arrays.fill(coef, 1);
        for (int i = 0; i < n - 1; i++) {
            int m1 = mods[i];
            long v = (c[i] - cnst[i] + m1) % m1;
            v = v * pow(coef[i], m1 - 2, m1) % m1;

            for (int j = i + 1; j < n; j++) {
                long m2 = mods[j];
                cnst[j] = (cnst[j] + coef[j] * v) % m2;
                coef[j] = (coef[j] * m1) % m2;
            }
        }
        return cnst[n - 1];
    }

    /**
     * Pre-calculation for NTT.
     *
     * @param mod NTT Prime.
     * @param g   Primitive root of mod.
     * @return Pre-calculation table.
     */
    private static long[] sumE(int mod, int g) {
        long[] sum_e = new long[30];
        long[] es = new long[30];
        long[] ies = new long[30];
        int cnt2 = Integer.numberOfTrailingZeros(mod - 1);
        long e = pow(g, (mod - 1) >> cnt2, mod);
        long ie = pow(e, mod - 2, mod);
        for (int i = cnt2; i >= 2; i--) {
            es[i - 2] = e;
            ies[i - 2] = ie;
            e = e * e % mod;
            ie = ie * ie % mod;
        }
        long now = 1;
        for (int i = 0; i < cnt2 - 2; i++) {
            sum_e[i] = es[i] * now % mod;
            now = now * ies[i] % mod;
        }
        return sum_e;
    }

    /**
     * Pre-calculation for inverse NTT.
     *
     * @param mod Mod.
     * @param g   Primitive root of mod.
     * @return Pre-calculation table.
     */
    private static long[] sumIE(int mod, int g) {
        long[] sum_ie = new long[30];
        long[] es = new long[30];
        long[] ies = new long[30];

        int cnt2 = Integer.numberOfTrailingZeros(mod - 1);
        long e = pow(g, (mod - 1) >> cnt2, mod);
        long ie = pow(e, mod - 2, mod);
        for (int i = cnt2; i >= 2; i--) {
            es[i - 2] = e;
            ies[i - 2] = ie;
            e = e * e % mod;
            ie = ie * ie % mod;
        }
        long now = 1;
        for (int i = 0; i < cnt2 - 2; i++) {
            sum_ie[i] = ies[i] * now % mod;
            now = now * es[i] % mod;
        }
        return sum_ie;
    }

    /**
     * Inverse NTT.
     *
     * @param a     Target array.
     * @param sumIE Pre-calculation table.
     * @param mod   NTT Prime.
     */
    private static void butterflyInv(long[] a, long[] sumIE, int mod) {
        int n = a.length;
        int h = ceilPow2(n);

        for (int ph = h; ph >= 1; ph--) {
            int w = 1 << (ph - 1), p = 1 << (h - ph);
            long inow = 1;
            for (int s = 0; s < w; s++) {
                int offset = s << (h - ph + 1);
                for (int i = 0; i < p; i++) {
                    long l = a[i + offset];
                    long r = a[i + offset + p];
                    a[i + offset] = (l + r) % mod;
                    a[i + offset + p] = (mod + l - r) * inow % mod;
                }
                int x = Integer.numberOfTrailingZeros(~s);
                inow = inow * sumIE[x] % mod;
            }
        }
    }

    /**
     * Inverse NTT.
     *
     * @param a    Target array.
     * @param sumE Pre-calculation table.
     * @param mod  NTT Prime.
     */
    private static void butterfly(long[] a, long[] sumE, int mod) {
        int n = a.length;
        int h = ceilPow2(n);

        for (int ph = 1; ph <= h; ph++) {
            int w = 1 << (ph - 1), p = 1 << (h - ph);
            long now = 1;
            for (int s = 0; s < w; s++) {
                int offset = s << (h - ph + 1);
                for (int i = 0; i < p; i++) {
                    long l = a[i + offset];
                    long r = a[i + offset + p] * now % mod;
                    a[i + offset] = (l + r) % mod;
                    a[i + offset + p] = (l - r + mod) % mod;
                }
                int x = Integer.numberOfTrailingZeros(~s);
                now = now * sumE[x] % mod;
            }
        }
    }

    /**
     * Convolution.
     *
     * @param a   Target array 1.
     * @param b   Target array 2.
     * @param mod NTT Prime.
     * @return Answer.
     */
    public static long[] convolution(long[] a, long[] b, int mod) {
        int n = a.length;
        int m = b.length;
        if (n == 0 || m == 0) return new long[0];

        int z = 1 << ceilPow2(n + m - 1);
        {
            long[] na = new long[z];
            long[] nb = new long[z];
            System.arraycopy(a, 0, na, 0, n);
            System.arraycopy(b, 0, nb, 0, m);
            a = na;
            b = nb;
        }

        int g = primitiveRoot(mod);
        long[] sume = sumE(mod, g);
        long[] sumie = sumIE(mod, g);

        butterfly(a, sume, mod);
        butterfly(b, sume, mod);
        for (int i = 0; i < z; i++) {
            a[i] = a[i] * b[i] % mod;
        }
        butterflyInv(a, sumie, mod);
        a = java.util.Arrays.copyOf(a, n + m - 1);

        long iz = pow(z, mod - 2, mod);
        for (int i = 0; i < n + m - 1; i++) a[i] = a[i] * iz % mod;
        return a;
    }

    /**
     * Convolution.
     *
     * @param a   Target array 1.
     * @param b   Target array 2.
     * @param mod Any mod.
     * @return Answer.
     */
    public static long[] convolutionLL(long[] a, long[] b, int mod) {
        int n = a.length;
        int m = b.length;
        if (n == 0 || m == 0) return new long[0];

        int mod1 = 754974721;
        int mod2 = 167772161;
        int mod3 = 469762049;

        long[] c1 = convolution(a, b, mod1);
        long[] c2 = convolution(a, b, mod2);
        long[] c3 = convolution(a, b, mod3);

        int retSize = c1.length;
        long[] ret = new long[retSize];
        int[] mods = {mod1, mod2, mod3, mod};
        for (int i = 0; i < retSize; ++i) {
            ret[i] = garner(new long[]{c1[i], c2[i], c3[i]}, mods);
        }
        return ret;
    }
    
    /**
     * Naive convolution. (Complexity is O(N^2)!!)
     *
     * @param a   Target array 1.
     * @param b   Target array 2.
     * @param mod Mod.
     * @return Answer.
     */
    public static long[] convolutionNaive(long[] a, long[] b, int mod) {
        int n = a.length;
        int m = b.length;
        int k = n + m - 1;
        long[] ret = new long[k];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                ret[i + j] += a[i] * b[j] % mod;
                ret[i + j] %= mod;
            }
        }
        return ret;
    }
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

1 Like

Really nice problem, can you provide some more problems which can be solved by generating functions.

You can try this.

1 Like

Actually O(K*log(N)*log(K)) was also getting 100 points.

Yes. :frowning:

Saw a couple of those to pass too. Atcoder template is quite fast. In editorial, I listed the approaches we intended for each subtask to pass.

Try from here
https://discuss.codechef.com/tag/generating-functions

1 Like