PART - Editorial

PROBLEM LINK:

Practice
Contest: Division 1

Author: Kritagya Agarwal
Tester: Rahul Dugar
Editorialist: Aman Dwivedi

DIFFICULTY:

Medium-Hard

PREREQUISITES:

Dynamic Programming, NTT

PROBLEM:

A function P(A) for a sequence A is defined as the number of ways to divide A into contiguous subsequences such that each element of A belongs to exactly one of these subsequences and the sum of elements of each subsequence is between L and R inclusive.

You are given an integer N. For every integer n between 1 and N inclusive, you need to solve as:

Consider a random sequence A with length n where each element is an integer between 1 and K chosen uniformly randomly and independently. Find the expected value of P(A).

EXPLANATION:

We have a random sequence A, and we want expected value of P(A). First we will find the sum of number of valid partitions over all arrays of size n having element in range [1, K].

Since the number of ways for length of N, it will depend on the number for length smaller than length than N,

Let DP[n], denotes the corresponding answer for the sequence of length n. Then,

DP[n]=\displaystyle\sum_{j=0}^{n-1} DP[j] * f[n-j]

where f[x], denote the number of ways of filling x spaces with elements in the range [1,K], such that sum is in range [L,R].

Now, lets calculate f[x]:

The first observation that we can make is that K doesn’t matter much as L \le R \le K. Since if we pick a number greater than R, then sum will be always greater than R. Let’s re-frame as, we need to pick f[x], denote the number of ways of filling x spaces with elements in the range [1,inf], such that sum is in range [L,R].

Consider a subsequence of length m, such that:

a_1+a_2+a_3+........+a_m=X
where, L \le X \le R

The number of ways to get this subsequence will be \binom{X+m-1}{m-1}

Since all our numbers are greater than so we can do basic math by subtracting 1 from our element. Hence:

(a_1-1)+(a_2-1)+(a_3-1)+........+(a_m-1)=(X-m)
Hence, ways= \binom{X-1}{m-1}

Now, you can solve further and optimise the formula it more.

Now, the answer for length, i is given by DP[i]*K[i]. This can be calculated in O(N^2) complexity. We can solve this in O(N*log^2(N)) time using online FFT or NTT.

Let us try to optimize it further. For doing so in O(N*log(N)), time, we need to find the value of:

DP[i], where i denotes the sum of number of valid partitions over all arrays of size n.

DP[n]=\displaystyle\sum_{j=0}^{n-1} DP[j] * f[n-j]

where f[x], denote the number of ways of filling x spaces with elements in the range [1,K], such that sum is in range [L,R].

Consider a polynomial C(z) = f(1)z + f(2)z^2 + .... + f(k)z^k.

Value of dp[i]*K^i is the coefficient of z^i in expansion of : 1 / (1 - C(z)).

So the remaining question is: how to calculate the inverse of a power series 1 / (1 - C(z)).

To do so we can use NTT.

TIME COMPLEXITY:

O(N*log(N)), per testcase.

SOLUTIONS:

Setter
#include<bits/stdc++.h>
#define int long long
using namespace std;
 
int get() {
    int x = 0, f = 1; char c = getchar();
    while(!isdigit(c)) { if(c == '-') f = -1; c = getchar(); }
    while(isdigit(c)) { x = x * 10 + c - '0'; c = getchar(); }
    return x * f;
}
 
const int N = 2e6 + 10, P = 998244353, G = 3, iG = 332748118;
int n, k, l, rr;
int a[N], b[N], r[N];
int fact[N], invfact[N];
 
int qpow(int x, int y) {
    int res = 1;
    while(y) res = res * ((y & 1)? x : 1) % P, x = x * x % P, y >>= 1;
    return res;
}
 
void NTT(int *A, int lim, int type) {
    for(int i = 0; i < lim; i++) if(i < r[i]) swap(A[i], A[r[i]]);
    for(int mid = 1; mid < lim; mid <<= 1) {
        int Wn = qpow(type == 1? G : iG, (P - 1) / (mid << 1));
        for(int i = 0; i < lim; i += mid << 1) {
            int w = 1;
            for(int j = 0; j < mid; j++, w = w * Wn % P) {
                int x = A[i + j], y = w * A[i + mid + j] % P;
                A[i + j] = (x + y) % P, A[i + mid + j] = (x - y + P) % P;
            }
        }
    }
    if(type == -1) {
        int inv = qpow(lim, P - 2);
        for(int i = 0; i < lim; i++) A[i] = A[i] * inv % P;
    }
}
 
int tmp[N];
void GetInv(int deg, int *F, int *G) {
    if(deg == 1) { G[0] = qpow(F[0], P - 2); return; }
    GetInv((deg + 1) >> 1, F, G);
    int lim = 1, l = 0;
    while(lim < 2 * deg) lim <<= 1, l++;
    for(int i = 0; i < lim; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
    for(int i = 0; i < deg; i++) tmp[i] = F[i];
    for(int i = deg; i < lim; i++) tmp[i] = 0;
    NTT(tmp, lim, 1), NTT(G, lim, 1);
    for(int i = 0; i < lim; i++) G[i] = (G[i] * 2 % P - G[i] * G[i] % P * tmp[i] % P + P) % P;
    NTT(G, lim, -1);
    for(int i = deg; i < lim; i++) G[i] = 0;
}
 
int ncr(int n, int r)
{
   if(r > n or r < 0) return 0;
 
   int ans = fact[n];
   ans *= invfact[r];
   ans %= P;
 
   ans *= invfact[n-r];
   ans %= P;
 
   return ans;
}
 
signed main() {
    fact[0] = invfact[0] = 1;
    
    for(int i = 1 ; i < N ; i++)
    {
       fact[i] = fact[i-1]*i;
       fact[i] %= P;
    }
 
    invfact[N-1] = qpow(fact[N-1], P-2);
 
    for(int i = N - 2 ; i >= 1 ; i--)
    {
      invfact[i] = invfact[i+1]*(i+1);
      invfact[i] %= P;
    }
 
    int t;
    cin >> t;
 
    while(t--){
        cin >> n;
        cin >> k;
        cin >> l;
        cin >> rr;
        int d = n + 1;
        
 
        for(int i = 1 ; i <= n ; i++)
        {
           b[i] = (ncr(rr,i) - ncr(l-1,i) + P) % P;
           b[i] = (P - b[i]) % P;
        }
 
        memset(a, 0, sizeof(a));
        b[0] = 1;
        GetInv(d, b, a);
 
        for(int i = 1; i <= n; i++){
          a[i] *= qpow(qpow(k,i),P-2);
          a[i] %= P;
          printf("%lld ", a[i]);
        }
        printf("\n");
 
        for(int i = 0 ; i < N ; i++){
            tmp[i] = r[i] = 0;
        }
    }
    return 0;
}  
Tester
#include <bits/stdc++.h>
using namespace std;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/rope>
using namespace __gnu_pbds;
using namespace __gnu_cxx;
#ifndef rd
#define trace(...)
#define endl '\n'
#endif
#define pb push_back
#define fi first
#define se second
#define int long long
typedef long long ll;
typedef long double f80;
#define double long double
#define pii pair<int,int>
#define pll pair<ll,ll>
#define sz(x) ((long long)x.size())
#define fr(a,b,c) for(int a=b; a<=c; a++)
#define rep(a,b,c) for(int a=b; a<c; a++)
#define trav(a,x) for(auto &a:x)
#define all(con) con.begin(),con.end()
const ll infl=0x3f3f3f3f3f3f3f3fLL;
const int infi=0x3f3f3f3f;
const int mod=998244353;
//const int mod=1000000007;
typedef vector<int> vi;
typedef vector<ll> vl;
 
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> oset;
auto clk=clock();
mt19937_64 rang(chrono::high_resolution_clock::now().time_since_epoch().count());
int rng(int lim) {
	uniform_int_distribution<int> uid(0,lim-1);
	return uid(rang);
}
int powm(ll a, int b) {
	ll res=1;
	while(b) {
		if(b&1)
			res=(res*a)%mod;
		a=(a*a)%mod;
		b>>=1;
	}
	return res;
}
 
long long readInt(long long l, long long r, char endd) {
	long long x=0;
	int cnt=0;
	int fi=-1;
	bool is_neg=false;
	while(true) {
		char g=getchar();
		if(g=='-') {
			assert(fi==-1);
			is_neg=true;
			continue;
		}
		if('0'<=g&&g<='9') {
			x*=10;
			x+=g-'0';
			if(cnt==0) {
				fi=g-'0';
			}
			cnt++;
			assert(fi!=0 || cnt==1);
			assert(fi!=0 || is_neg==false);
 
			assert(!(cnt>19 || ( cnt==19 && fi>1) ));
		} else if(g==endd) {
			if(is_neg) {
				x=-x;
			}
			assert(l<=x&&x<=r);
			return x;
		} else {
			assert(false);
		}
	}
}
string readString(int l, int r, char endd) {
	string ret="";
	int cnt=0;
	while(true) {
		char g=getchar();
		assert(g!=-1);
		if(g==endd) {
			break;
		}
		cnt++;
		ret+=g;
	}
	assert(l<=cnt&&cnt<=r);
	return ret;
}
long long readIntSp(long long l, long long r) {
	return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
	return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
	return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
	return readString(l,r,' ');
}
 
const int MAXN = 20;
const int maxn = 1 << MAXN;
const int root = 3;
int A[maxn], B[maxn];
int W[maxn], iW[maxn], I[maxn];
int nn;
const int threshold = 100;
 
namespace modulo{
    const int MOD = 998244353;
    int add(const int &a,const int &b){
        int val = a + b;
        if(val >= MOD) val -= MOD;
        return val;
    }
    int sub(const int &a,const int &b){
        int val = a - b;
        if(val < 0) val += MOD;
        return val;
    }
    int mul(const int &a, const int &b){ return 1ll * a * b % MOD; }
}
using namespace modulo;
 
void ensureINV(int n) {
    if(n <= nn) return;
    if(!nn){
        I[1] = 1;
        nn = 1;
    }
    fr(i, nn + 1, n)
        I[i] = (mod - mul((mod / i), I[mod % i]));
    nn = n;
}
int pwr(int a,int b){
    int ans = 1;
    while(b){
        if(b & 1)
            ans = mul(ans, a);
        a = mul(a, a);
        b >>= 1;
    }
    return ans;
}
void precompute(){
    W[0] = iW[0] = 1;
    int g = pwr(root,(mod - 1) / maxn), ig = pwr(g, mod - 2);
    fr(i, 1, maxn / 2 - 1){
        W[i] = mul(W[i - 1], g);
        iW[i] = mul(iW[i - 1], ig);
    }
}
int rev(int i, int n){
    int irev = 0;
    n >>= 1;
    while(n){
        n >>= 1;
        irev = (irev << 1) | (i & 1);
        i >>= 1;
    }
    return irev;
}
void go(int a[], int n){
    fr(i, 0, n - 1){
        int r = rev(i, n);
        if(i < r)
            swap(a[i], a[r]);
    }
}
void fft(int a[], int n, bool inv = 0){
    go(a, n);
    int len, i, j, *p, *q, u, v, ind, add;
    for(len = 2; len <= n; len <<= 1){
        for(i = 0; i < n; i += len){
            ind = 0, add = maxn / len;
            p = &a[i], q = &a[i + len / 2];
            fr(j, 0, len / 2 - 1){
                v = mul((*q), (inv ? iW[ind] : W[ind]));
                (*q) = sub((*p), v);
                (*p) = ::add((*p), v);
                ind += add;
                p++, q++;
            }
        }
    }
    if(inv) {
        int p = pwr(n, mod - 2);
        fr(i, 0, n - 1)
            a[i] = mul(a[i], p);
    }
}
vi brute(const vi &a, const vi &b){ // brute multiplication
    vi c(a.size() + b.size() - 1, 0);
    for(int i = 0; i < a.size(); i++){
        for(int j = 0; j < b.size(); j++){
            c[i + j] = (c[i+j]+a[i]*b[j])%mod;
        }
    }
    return c;
}
vi mul(vi a, vi b){ // n = total size (power of 2)
    if(min(a.size(),b.size()) <= threshold)
        return brute(a, b);
    int n=1;
    while(n<sz(a)+sz(b)-1)
    	n<<=1;
    a.resize(n, 0);
    b.resize(n, 0);
    copy(all(a), A);
    fft(A, n);
    if(a == b)
        copy(A, A + n, B);
    else{
        copy(all(b), B);
        fft(B, n);
    }
    fr(i, 0, n - 1)
        A[i] = mul(A[i], B[i]);
    fft(A, n, 1);
    vi c(A, A + n);
    return c;
}
 
vector<int> v1,v2,a,b;
void go(int l1, int r1, int l2, int r2) {
	v1.assign(a.begin()+l1,a.begin()+r1+1);
	v2.assign(b.begin()+l2,b.begin()+r2+1);
	v1=mul(v1,v2);
	for(int i=0; i<v1.size()&&l1+l2+i<a.size(); i++) {
		a[l1+l2+i]+=v1[i];
		if(a[l1+l2+i]>=mod)
			a[l1+l2+i]-=mod;
	}
}
vi inv(vi a, int m){ // get m terms
    assert(a[0] != 0);
    int tot = 1;
    while(tot < m)
        tot <<= 1;
    swap(tot, m);
    a.resize(m, 0);
    vi ia(m, 0);
    ia[0] = pwr(a[0], mod - 2);
    for(int sz = 2; sz <= m; sz <<= 1){
        copy(ia.begin(), ia.begin() + sz / 2, A);
        copy(a.begin(), a.begin() + sz, B);
        fill(A + sz / 2, A + (sz << 1), 0);
        fill(B + sz, B + (sz << 1), 0);
        fft(A, sz << 1);
        fft(B, sz << 1);
        fr(j, 0, (sz << 1) - 1)
            A[j] = add(A[j], sub(A[j], mul(mul(A[j], A[j]), B[j])));
        fft(A, sz << 1, 1);
        copy(A, A + sz, ia.begin());
    }
    ia.resize(tot);
    return ia;
}
 
void online_fft(vector<int> &a, vector<int> &b) { // a and b are 1-indexed
	int n=1;
	while(n<b.size())
		n<<=1;
	a.resize(n+2,0),b.resize(n+1,0);
	for(int i=1; i<n; i++) {
		a[i]=(a[i]+b[i])%mod;;
		a[i+1]=(a[i+1]+((ll)a[i])*b[1])%mod,a[i+2]=(a[i+2]+a[i]*((ll)b[2]))%mod;
		int ind=i,pw=2;
		while(!(ind&1)) {
			go(i-pw+1,i,pw+1,2*pw);
			ind>>=1;
			pw<<=1;
		}
	}
}
 
ll fact[1000005];
ll ifact[1000005];
int ncr(int n, int r) {
	if(n<r||r<0)
		return 0;
	return (((fact[n]*ifact[r])%mod)*ifact[n-r])%mod;
}
void solve() {
	int n=readIntSp(1,500000),k=readIntSp(1,1000000),l=readIntSp(1,k),r=readIntLn(l,k);
	b.resize(n+1);
	fr(i,1,n)
		b[i]=(ncr(r,i)-ncr(l-1,i)+mod)%mod;
	a={1};
	vi bb=b;
	for(int &i:bb)
		i=(mod-i)%mod;
	bb[0]=1;
	vi c=inv(bb,n+1);
	int iol=powm(k,mod-2);
	ll pp=iol;
	fr(i,1,n) {
		cout<<(c[i]*pp)%mod<<" \n"[i==n];
		pp=(pp*iol)%mod;
	}
 
 
//	online_fft(a,b);
//	int iol=powm(k,mod-2);
//	ll pp=iol;
//	fr(i,1,n) {
//		cout<<(a[i]*pp)%mod<<" \n"[i==n];
//		pp=(pp*iol)%mod;
//	}
//	cout<<a[n]<<endl;
}
 
signed main() {
	precompute();
	fact[0]=1;
	fr(i,1,1000000)
		fact[i]=(fact[i-1]*i)%mod;
	ifact[1000000]=powm(fact[1000000],mod-2);
	for(int i=999999; i>=0; i--)
		ifact[i]=(ifact[i+1]*(i+1))%mod;
	ios_base::sync_with_stdio(0),cin.tie(0);
	srand(chrono::high_resolution_clock::now().time_since_epoch().count());
	cout<<fixed<<setprecision(10);
	int t=readIntLn(1,100000);
//	int t=1;
//	cin>>t;
	fr(i,1,t)
		solve();
//	assert(getchar()==EOF);
#ifdef rd
	cerr<<endl<<endl<<endl<<"Time Elapsed: "<<((double)(clock()-clk))/CLOCKS_PER_SEC<<endl;
#endif
}
 

VIDEO EDITORIAL:

1 Like