PROBLEM LINK:
Author: Kritagya Agarwal
Tester: Rahul Dugar
Editorialist: Aman Dwivedi
DIFFICULTY:
Medium-Hard
PREREQUISITES:
Dynamic Programming, NTT
PROBLEM:
A function P(A) for a sequence A is defined as the number of ways to divide A into contiguous subsequences such that each element of A belongs to exactly one of these subsequences and the sum of elements of each subsequence is between L and R inclusive.
You are given an integer N. For every integer n between 1 and N inclusive, you need to solve as:
Consider a random sequence A with length n where each element is an integer between 1 and K chosen uniformly randomly and independently. Find the expected value of P(A).
EXPLANATION:
We have a random sequence A, and we want expected value of P(A). First we will find the sum of number of valid partitions over all arrays of size n having element in range [1, K].
Since the number of ways for length of N, it will depend on the number for length smaller than length than N,
Let DP[n], denotes the corresponding answer for the sequence of length n. Then,
DP[n]=\displaystyle\sum_{j=0}^{n-1} DP[j] * f[n-j]where f[x], denote the number of ways of filling x spaces with elements in the range [1,K], such that sum is in range [L,R].
Now, lets calculate f[x]:
The first observation that we can make is that K doesn’t matter much as L \le R \le K. Since if we pick a number greater than R, then sum will be always greater than R. Let’s re-frame as, we need to pick f[x], denote the number of ways of filling x spaces with elements in the range [1,inf], such that sum is in range [L,R].
Consider a subsequence of length m, such that:
a_1+a_2+a_3+........+a_m=Xwhere, L \le X \le R
The number of ways to get this subsequence will be \binom{X+m-1}{m-1}
Since all our numbers are greater than so we can do basic math by subtracting 1 from our element. Hence:
(a_1-1)+(a_2-1)+(a_3-1)+........+(a_m-1)=(X-m)Hence, ways= \binom{X-1}{m-1}
Now, you can solve further and optimise the formula it more.
Now, the answer for length, i is given by DP[i]*K[i]. This can be calculated in O(N^2) complexity. We can solve this in O(N*log^2(N)) time using online FFT or NTT.
Let us try to optimize it further. For doing so in O(N*log(N)), time, we need to find the value of:
DP[i], where i denotes the sum of number of valid partitions over all arrays of size n.
DP[n]=\displaystyle\sum_{j=0}^{n-1} DP[j] * f[n-j]
where f[x], denote the number of ways of filling x spaces with elements in the range [1,K], such that sum is in range [L,R].
Consider a polynomial C(z) = f(1)z + f(2)z^2 + .... + f(k)z^k.
Value of dp[i]*K^i is the coefficient of z^i in expansion of : 1 / (1 - C(z)).
So the remaining question is: how to calculate the inverse of a power series 1 / (1 - C(z)).
To do so we can use NTT.
TIME COMPLEXITY:
O(N*log(N)), per testcase.
SOLUTIONS:
Setter
#include<bits/stdc++.h>
#define int long long
using namespace std;
int get() {
int x = 0, f = 1; char c = getchar();
while(!isdigit(c)) { if(c == '-') f = -1; c = getchar(); }
while(isdigit(c)) { x = x * 10 + c - '0'; c = getchar(); }
return x * f;
}
const int N = 2e6 + 10, P = 998244353, G = 3, iG = 332748118;
int n, k, l, rr;
int a[N], b[N], r[N];
int fact[N], invfact[N];
int qpow(int x, int y) {
int res = 1;
while(y) res = res * ((y & 1)? x : 1) % P, x = x * x % P, y >>= 1;
return res;
}
void NTT(int *A, int lim, int type) {
for(int i = 0; i < lim; i++) if(i < r[i]) swap(A[i], A[r[i]]);
for(int mid = 1; mid < lim; mid <<= 1) {
int Wn = qpow(type == 1? G : iG, (P - 1) / (mid << 1));
for(int i = 0; i < lim; i += mid << 1) {
int w = 1;
for(int j = 0; j < mid; j++, w = w * Wn % P) {
int x = A[i + j], y = w * A[i + mid + j] % P;
A[i + j] = (x + y) % P, A[i + mid + j] = (x - y + P) % P;
}
}
}
if(type == -1) {
int inv = qpow(lim, P - 2);
for(int i = 0; i < lim; i++) A[i] = A[i] * inv % P;
}
}
int tmp[N];
void GetInv(int deg, int *F, int *G) {
if(deg == 1) { G[0] = qpow(F[0], P - 2); return; }
GetInv((deg + 1) >> 1, F, G);
int lim = 1, l = 0;
while(lim < 2 * deg) lim <<= 1, l++;
for(int i = 0; i < lim; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
for(int i = 0; i < deg; i++) tmp[i] = F[i];
for(int i = deg; i < lim; i++) tmp[i] = 0;
NTT(tmp, lim, 1), NTT(G, lim, 1);
for(int i = 0; i < lim; i++) G[i] = (G[i] * 2 % P - G[i] * G[i] % P * tmp[i] % P + P) % P;
NTT(G, lim, -1);
for(int i = deg; i < lim; i++) G[i] = 0;
}
int ncr(int n, int r)
{
if(r > n or r < 0) return 0;
int ans = fact[n];
ans *= invfact[r];
ans %= P;
ans *= invfact[n-r];
ans %= P;
return ans;
}
signed main() {
fact[0] = invfact[0] = 1;
for(int i = 1 ; i < N ; i++)
{
fact[i] = fact[i-1]*i;
fact[i] %= P;
}
invfact[N-1] = qpow(fact[N-1], P-2);
for(int i = N - 2 ; i >= 1 ; i--)
{
invfact[i] = invfact[i+1]*(i+1);
invfact[i] %= P;
}
int t;
cin >> t;
while(t--){
cin >> n;
cin >> k;
cin >> l;
cin >> rr;
int d = n + 1;
for(int i = 1 ; i <= n ; i++)
{
b[i] = (ncr(rr,i) - ncr(l-1,i) + P) % P;
b[i] = (P - b[i]) % P;
}
memset(a, 0, sizeof(a));
b[0] = 1;
GetInv(d, b, a);
for(int i = 1; i <= n; i++){
a[i] *= qpow(qpow(k,i),P-2);
a[i] %= P;
printf("%lld ", a[i]);
}
printf("\n");
for(int i = 0 ; i < N ; i++){
tmp[i] = r[i] = 0;
}
}
return 0;
}
Tester
#include <bits/stdc++.h>
using namespace std;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/rope>
using namespace __gnu_pbds;
using namespace __gnu_cxx;
#ifndef rd
#define trace(...)
#define endl '\n'
#endif
#define pb push_back
#define fi first
#define se second
#define int long long
typedef long long ll;
typedef long double f80;
#define double long double
#define pii pair<int,int>
#define pll pair<ll,ll>
#define sz(x) ((long long)x.size())
#define fr(a,b,c) for(int a=b; a<=c; a++)
#define rep(a,b,c) for(int a=b; a<c; a++)
#define trav(a,x) for(auto &a:x)
#define all(con) con.begin(),con.end()
const ll infl=0x3f3f3f3f3f3f3f3fLL;
const int infi=0x3f3f3f3f;
const int mod=998244353;
//const int mod=1000000007;
typedef vector<int> vi;
typedef vector<ll> vl;
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> oset;
auto clk=clock();
mt19937_64 rang(chrono::high_resolution_clock::now().time_since_epoch().count());
int rng(int lim) {
uniform_int_distribution<int> uid(0,lim-1);
return uid(rang);
}
int powm(ll a, int b) {
ll res=1;
while(b) {
if(b&1)
res=(res*a)%mod;
a=(a*a)%mod;
b>>=1;
}
return res;
}
long long readInt(long long l, long long r, char endd) {
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true) {
char g=getchar();
if(g=='-') {
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g&&g<='9') {
x*=10;
x+=g-'0';
if(cnt==0) {
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd) {
if(is_neg) {
x=-x;
}
assert(l<=x&&x<=r);
return x;
} else {
assert(false);
}
}
}
string readString(int l, int r, char endd) {
string ret="";
int cnt=0;
while(true) {
char g=getchar();
assert(g!=-1);
if(g==endd) {
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt&&cnt<=r);
return ret;
}
long long readIntSp(long long l, long long r) {
return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
return readString(l,r,' ');
}
const int MAXN = 20;
const int maxn = 1 << MAXN;
const int root = 3;
int A[maxn], B[maxn];
int W[maxn], iW[maxn], I[maxn];
int nn;
const int threshold = 100;
namespace modulo{
const int MOD = 998244353;
int add(const int &a,const int &b){
int val = a + b;
if(val >= MOD) val -= MOD;
return val;
}
int sub(const int &a,const int &b){
int val = a - b;
if(val < 0) val += MOD;
return val;
}
int mul(const int &a, const int &b){ return 1ll * a * b % MOD; }
}
using namespace modulo;
void ensureINV(int n) {
if(n <= nn) return;
if(!nn){
I[1] = 1;
nn = 1;
}
fr(i, nn + 1, n)
I[i] = (mod - mul((mod / i), I[mod % i]));
nn = n;
}
int pwr(int a,int b){
int ans = 1;
while(b){
if(b & 1)
ans = mul(ans, a);
a = mul(a, a);
b >>= 1;
}
return ans;
}
void precompute(){
W[0] = iW[0] = 1;
int g = pwr(root,(mod - 1) / maxn), ig = pwr(g, mod - 2);
fr(i, 1, maxn / 2 - 1){
W[i] = mul(W[i - 1], g);
iW[i] = mul(iW[i - 1], ig);
}
}
int rev(int i, int n){
int irev = 0;
n >>= 1;
while(n){
n >>= 1;
irev = (irev << 1) | (i & 1);
i >>= 1;
}
return irev;
}
void go(int a[], int n){
fr(i, 0, n - 1){
int r = rev(i, n);
if(i < r)
swap(a[i], a[r]);
}
}
void fft(int a[], int n, bool inv = 0){
go(a, n);
int len, i, j, *p, *q, u, v, ind, add;
for(len = 2; len <= n; len <<= 1){
for(i = 0; i < n; i += len){
ind = 0, add = maxn / len;
p = &a[i], q = &a[i + len / 2];
fr(j, 0, len / 2 - 1){
v = mul((*q), (inv ? iW[ind] : W[ind]));
(*q) = sub((*p), v);
(*p) = ::add((*p), v);
ind += add;
p++, q++;
}
}
}
if(inv) {
int p = pwr(n, mod - 2);
fr(i, 0, n - 1)
a[i] = mul(a[i], p);
}
}
vi brute(const vi &a, const vi &b){ // brute multiplication
vi c(a.size() + b.size() - 1, 0);
for(int i = 0; i < a.size(); i++){
for(int j = 0; j < b.size(); j++){
c[i + j] = (c[i+j]+a[i]*b[j])%mod;
}
}
return c;
}
vi mul(vi a, vi b){ // n = total size (power of 2)
if(min(a.size(),b.size()) <= threshold)
return brute(a, b);
int n=1;
while(n<sz(a)+sz(b)-1)
n<<=1;
a.resize(n, 0);
b.resize(n, 0);
copy(all(a), A);
fft(A, n);
if(a == b)
copy(A, A + n, B);
else{
copy(all(b), B);
fft(B, n);
}
fr(i, 0, n - 1)
A[i] = mul(A[i], B[i]);
fft(A, n, 1);
vi c(A, A + n);
return c;
}
vector<int> v1,v2,a,b;
void go(int l1, int r1, int l2, int r2) {
v1.assign(a.begin()+l1,a.begin()+r1+1);
v2.assign(b.begin()+l2,b.begin()+r2+1);
v1=mul(v1,v2);
for(int i=0; i<v1.size()&&l1+l2+i<a.size(); i++) {
a[l1+l2+i]+=v1[i];
if(a[l1+l2+i]>=mod)
a[l1+l2+i]-=mod;
}
}
vi inv(vi a, int m){ // get m terms
assert(a[0] != 0);
int tot = 1;
while(tot < m)
tot <<= 1;
swap(tot, m);
a.resize(m, 0);
vi ia(m, 0);
ia[0] = pwr(a[0], mod - 2);
for(int sz = 2; sz <= m; sz <<= 1){
copy(ia.begin(), ia.begin() + sz / 2, A);
copy(a.begin(), a.begin() + sz, B);
fill(A + sz / 2, A + (sz << 1), 0);
fill(B + sz, B + (sz << 1), 0);
fft(A, sz << 1);
fft(B, sz << 1);
fr(j, 0, (sz << 1) - 1)
A[j] = add(A[j], sub(A[j], mul(mul(A[j], A[j]), B[j])));
fft(A, sz << 1, 1);
copy(A, A + sz, ia.begin());
}
ia.resize(tot);
return ia;
}
void online_fft(vector<int> &a, vector<int> &b) { // a and b are 1-indexed
int n=1;
while(n<b.size())
n<<=1;
a.resize(n+2,0),b.resize(n+1,0);
for(int i=1; i<n; i++) {
a[i]=(a[i]+b[i])%mod;;
a[i+1]=(a[i+1]+((ll)a[i])*b[1])%mod,a[i+2]=(a[i+2]+a[i]*((ll)b[2]))%mod;
int ind=i,pw=2;
while(!(ind&1)) {
go(i-pw+1,i,pw+1,2*pw);
ind>>=1;
pw<<=1;
}
}
}
ll fact[1000005];
ll ifact[1000005];
int ncr(int n, int r) {
if(n<r||r<0)
return 0;
return (((fact[n]*ifact[r])%mod)*ifact[n-r])%mod;
}
void solve() {
int n=readIntSp(1,500000),k=readIntSp(1,1000000),l=readIntSp(1,k),r=readIntLn(l,k);
b.resize(n+1);
fr(i,1,n)
b[i]=(ncr(r,i)-ncr(l-1,i)+mod)%mod;
a={1};
vi bb=b;
for(int &i:bb)
i=(mod-i)%mod;
bb[0]=1;
vi c=inv(bb,n+1);
int iol=powm(k,mod-2);
ll pp=iol;
fr(i,1,n) {
cout<<(c[i]*pp)%mod<<" \n"[i==n];
pp=(pp*iol)%mod;
}
// online_fft(a,b);
// int iol=powm(k,mod-2);
// ll pp=iol;
// fr(i,1,n) {
// cout<<(a[i]*pp)%mod<<" \n"[i==n];
// pp=(pp*iol)%mod;
// }
// cout<<a[n]<<endl;
}
signed main() {
precompute();
fact[0]=1;
fr(i,1,1000000)
fact[i]=(fact[i-1]*i)%mod;
ifact[1000000]=powm(fact[1000000],mod-2);
for(int i=999999; i>=0; i--)
ifact[i]=(ifact[i+1]*(i+1))%mod;
ios_base::sync_with_stdio(0),cin.tie(0);
srand(chrono::high_resolution_clock::now().time_since_epoch().count());
cout<<fixed<<setprecision(10);
int t=readIntLn(1,100000);
// int t=1;
// cin>>t;
fr(i,1,t)
solve();
// assert(getchar()==EOF);
#ifdef rd
cerr<<endl<<endl<<endl<<"Time Elapsed: "<<((double)(clock()-clk))/CLOCKS_PER_SEC<<endl;
#endif
}