# BOXES - Editorial

Setter: Alex
Tester: Harris Leung
Editorialist: Trung Dang

3633

# PREREQUISITES:

FFT with Divide and Conquer

# PROBLEM:

There are N boxes. Initially, each box contains one ball.

For each K from 1 to N, solve the following problem:

Choose a subset of boxes of size K equiprobably at random, and place one additional ball in each chosen box.

Then again choose a subset of boxes of size K equiprobably at random, and place one additional ball in each chosen box. This second choice is independent of the first.

What is the expected value of the product of the number of balls over all boxes? It can be shown that the answer can be expressed as \frac{P}{Q} for some integers P, Q, where Q is not divisible by 998\,244\,353. Output such value 0 \le R < 998\,244\,353, that P \equiv QR \bmod 998\,244\,353.

# EXPLANATION:

Letâ€™s first see how we can come up with the formula for each choice of K. Instead of calculating the expected value, we simply calculate the sum of the product of all the boxes in every choice of 2 subsets, then divide this value by the number of ways to choose 2 subsets of size K (which is \binom{N}{K}^2).

We can calculate the sum as follow:

• There are \binom{N}{K} ways to choose the first subset.
• We iterate over the amount of overlaps between the two subsets. Suppose the overlap is i, then:
• There are \binom{K}{i} \cdot \binom{N - K}{K - i} ways to choose the second subset (choosing i overlaps and K - i non-overlaps).
• The product in this case is 2^{2K - 2i} \cdot 3^i = 4^K \cdot (\frac{3}{4})^i

Therefore, the sum of the product of all the boxes for each choice of K is \binom{N}{K} \cdot 4^K \cdot \sum_{i = 0}^K \binom{K}{i} \cdot \binom{N-K}{K - i} \cdot (\frac{3}{4})^i. The sum can be represented as (1 + 3/4x)^K(1 + x)^{N-K}[x^K], where [x^K] notation means the coefficient of the k-th power in the polynomial. Ignoring the terms \binom{N}{K} and 4^K which can be multiplied to at the end, we are simply tasked with calculating (1 + 3/4x)^K(1 + x)^{N-K}[x^K] for every K from 1 to N, which this blog by the author directly talks about how we can solve this exact problem form in O(N \log^2 N) time complexity.

# TIME COMPLEXITY:

Time complexity is O(N \log^2 N).

# SOLUTION:

Tester's Solution
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define fi first
#define se second
const ll mod=998244353;
typedef vector<ll> poly;
const int mb=19;//can change !!!!
ll roots[1<<mb];
int rev[1<<mb];
ll pw(ll x,ll y){
if(y==0) return 1;
if(y%2) return x*pw(x,y-1)%mod;
ll res=pw(x,y/2);
return res*res%mod;
}
void operator<<(ostream& out,poly y){
for(auto c:y) out << c << ' ';
}
poly operator+(poly x,poly y){
int n=max(x.size(),y.size());
x.resize(n);y.resize(n);
for(int i=0; i<n ;i++){
x[i]+=y[i];
if(x[i]>=mod) x[i]-=mod;
}
return x;
}
poly operator-(poly x,poly y){
int n=max(x.size(),y.size());
x.resize(n);y.resize(n);
for(int i=0; i<n ;i++){
x[i]+=mod-y[i];
if(x[i]>=mod) x[i]-=mod;
}
return x;
}

void pre(){
roots[0]=1;
roots[1]=pw(15311432,1<<(23-mb));
for(int i=1; i<(1<<mb) ;i++) roots[i]=roots[i-1]*roots[1]%mod;
}
void fft(poly &p){
int n=p.size();
roots[0]=1;
int m=(1<<mb)/n;
for(int i=1; i<n ;i++){
rev[i]=rev[i/2]/2+((i&1)*n/2);
if(i<rev[i]) swap(p[i],p[rev[i]]);
}
for(int k=1; k<n ;k*=2){
for(int i=0; i<n ;i+=2*k){
int cur=0,step=n/(2*k);
for(int j=0; j<k;j++,cur+=step){
ll x=p[i+j];
ll y=p[i+j+k]*roots[cur*m]%mod;
p[i+j]=(x+y>=mod?x+y-mod:x+y);
p[i+j+k]=(x>=y?x-y:x+mod-y);
}
}
}
}
poly operator*(poly x,poly y){
int n=1;
while(n<x.size()+y.size()-1) n*=2;
x.resize(n,0);y.resize(n,0);
fft(x);fft(y);
for(int i=0; i<n ;i++) x[i]=x[i]*y[i]%mod;
reverse(x.begin()+1,x.end());
fft(x);
ll inv=pw(n,mod-2);
for(int i=0; i<n ;i++) x[i]=x[i]*inv%mod;
while(x.size()>1 && x.back()==0) x.pop_back();
return x;
}/*
vector<ll>multiply2(vector<ll>x,vector<ll>y){
vector<ll>z;z.resize(x.size()+y.size()-1);
for(auto c:z) c=0;
for(int i=0; i<x.size() ;i++){
for(int j=0; j<y.size() ;j++){
z[i+j]=(z[i+j]+x[i]*y[j])%mod;
}
}
return z;
}*/

const int N=2.5e5+1;
ll f[N],inf[N];

//my code starts here
int n,m;
int main(){
ios::sync_with_stdio(false);cin.tie(0);
pre();/*
{
poly f={1,mod-1};
poly g={1,1};
f.resize(12);
auto h=inverse(f);
cout << h;
return 0;
}*/
cin >> n;
f[0]=1;
for(int i=1; i<=n ;i++) f[i]=f[i-1]*i%mod;
inf[n]=pw(f[n],mod-2);
for(int i=n; i>=1 ;i--) inf[i-1]=inf[i]*i%mod;
poly p(n+1),q(n+1);
for(int i=0; i<=n ;i++){
p[i]=f[n-i]*inf[i]%mod*pw(mod-4,n-i)%mod;
q[i]=inf[i]*inf[i]%mod;
}
p=p*q;
for(int k=1; k<=n ;k++){
ll cur=p[n-k];
ll num=pw(9,k)*f[n-k]%mod*f[n-k]%mod;
ll den=pw(4,k)*pw(mod-3,n)%mod*f[n]%mod;
ll ans=cur*num%mod*pw(den,mod-2)%mod;
cout << ans << ' ';
}
cout << '\n';
}

Editoralist's Solution
#include <bits/stdc++.h>
#include <atcoder/convolution> // https://atcoder.github.io/ac-library/production/document_en/
using namespace std;
using namespace atcoder;

using mint = modint998244353;

int n;
vector<mint> fct;
vector<mint> ans;
vector<mint> p = {mint(1), mint(3) / 4}, q = {mint(1), mint(1)};

mint C(int n, int k) {
return n < k || k < 0 ? mint(0) : fct[n] / fct[k] / fct[n - k];
}

vector<mint> pow(vector<mint> p, int x) {
// assuming p.size() = 2
vector<mint> ans(x + 1);
for (int i = 0; i <= x; i++) {
ans[i] = p[1].pow(i) * p[0].pow(x - i) * C(x, i);
}
return ans;
}

void solve(int l, int r, vector<mint> poly, int k) {
// poly is P^l * Q^(n - r), truncated to the needed range and shifted by x^k
if (l == r) {
ans[l] = poly[0];
} else {
int m = (l + r) / 2;
{
int l_focus = 2 * l - m, r_focus = m;
vector<mint> to_left = convolution(poly, pow(q, r - m));
int new_k = max(k, l_focus), new_sz = min((int)to_left.size(), r_focus - k + 1);
solve(l, m, vector<mint>(to_left.begin() + new_k - k, to_left.begin() + new_sz), new_k);
}
{
int l_focus = 2 * (m + 1) - r, r_focus = r;
vector<mint> to_right = convolution(poly, pow(p, m + 1 - l));
int new_k = max(k, l_focus), new_sz = min((int)to_right.size(), r_focus - k + 1);
solve(m + 1, r, vector<mint>(to_right.begin() + new_k - k, to_right.begin() + new_sz), new_k);
}
}
}

int main() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
int n; cin >> n;
fct.resize(n + 1); ans.resize(n + 1);
fct[0] = 1;
for (int i = 1; i <= n; i++) {
fct[i] = fct[i - 1] * i;
}
// ans_k = 2^(2k) * C(n, k) * sum(i = 0 -> k, (k choose i) * (n - k choose k - i) * (3/4)^i)
//       = 4^k * C(n, k) * ((1 + 3/4x)^k * (1 + x)^(n - k))[x^k]
solve(0, n, {mint(1)}, 0);
for (int i = 1; i <= n; i++) {
cout << (ans[i] * mint(4).pow(i) / C(n, i)).val() << " ";
}
}