# PROBLEM LINK:

Author and Editorialist: still_me
Tester: Jay_1048576

EASY-MEDIUM

# PREREQUISITES:

Combinatorics , FFT/NTT

# PROBLEM:

Given an array of size N , and two integer X and Y . You have to answer Q queries, for each query you have to find how many ways to form a subsequence of size K of the original array such that the number of elements in the subsequence which are divisible by X is divisible by Y.

# QUICK EXPLANATION:

Let’s count the number of elements that are divisible by X, and call this number d.
Then there will be N-d elements that are not divisible by X.

Now for a given K:

let S be set of all the numbers less than or equal to \leq K which are divisible by X.

for example if X=7 and K=35
X = {0 , 7 , 14 , 21 , 28 , 35}

We need to find the number of ways to simultaneously select X_i distinct elements from d distinct elements and K-X_i distinct elements from N-d distinct elements, and take sum for all the values in the set X.

It will be of the form:

^dC_{X_1} . ^{N-d}C_{k-{X_1}} + ^dC_{X_2} . ^{N-d}C_{k-{X_2}} +…

which in turn is coefficient of x^k in the product of polynomials

1. ^dC_{X_1} . x^{X_1} + ^dC_{X_2} . x^{X_2} + . . . + ^dC_{X_i}.x^{X_i} + . . .

2. ^{N-d}C_0 . x^0 + ^{N-d}C_1 . x^1 + . . . + ^{N-d}C_{N-d}

Brute Force way to multiply polynomial takes O(n^2) time, but we can use FFT to do the polynomial multiplication in O(nlogn) time.
Then we can answer each query in O(1) time.

# SOLUTIONS:

Setter and Editorialist's Solution
//	Code by Sahil Tiwari (still_me)

#include <bits/stdc++.h>
#define ll long long
#define ull unsigned ll
#define For(i, j, k) for (int i = (int)(j); i <= (int)(k); i++)
#define Rep(i, j, k) for (int i = (int)(j); i >= (int)(k); i--)
using namespace std;
const int mo = 998244353;
const int FFTN = 1 << 18;
#define poly vector<int>
int X , Y;
namespace FFT
{
int w[FFTN + 5], W[FFTN + 5], R[FFTN + 5];
int power(int x, int y)
{
int s = 1;
for (; y; y /= 2, x = 1ll * x * x % mo)
if (y & 1)
s = 1ll * s * x % mo;
return s;
}
void FFTinit()
{
W[0] = 1;
W[1] = power(3, (mo - 1) / FFTN);
For(i, 2, FFTN) W[i] = 1ll * W[i - 1] * W[1] % mo;
}
int FFTinit(int n)
{
int L = 1;
for (; L <= n; L <<= 1)
;
For(i, 0, L - 1) R[i] = (R[i >> 1] >> 1) | ((i & 1) ? (L >> 1) : 0);
return L;
}
int A[FFTN + 5], B[FFTN + 5];
ull p[FFTN + 5];
void DFT(int *a, int n)
{
For(i, 0, n - 1) p[R[i]] = a[i];
for (int d = 1; d < n; d <<= 1)
{
int len = FFTN / (d << 1);
for (int i = 0, j = 0; i < d; i++, j += len)
w[i] = W[j];
for (int i = 0; i < n; i += (d << 1))
for (int j = 0; j < d; j++)
{
int y = p[i + j + d] * w[j] % mo;
p[i + j + d] = p[i + j] + mo - y;
p[i + j] += y;
}
if (d == 1 << 15)
For(i, 0, n - 1) p[i] %= mo;
}
For(i, 0, n - 1) a[i] = p[i] % mo;
}
void IDFT(int *a, int n)
{
For(i, 0, n - 1) p[R[i]] = a[i];
for (int d = 1; d < n; d <<= 1)
{
int len = FFTN / (d << 1);
for (int i = 0, j = FFTN; i < d; i++, j -= len)
w[i] = W[j];
for (int i = 0; i < n; i += (d << 1))
for (int j = 0; j < d; j++)
{
int y = p[i + j + d] * w[j] % mo;
p[i + j + d] = p[i + j] + mo - y;
p[i + j] += y;
}
if (d == 1 << 15)
For(i, 0, n - 1) p[i] %= mo;
}
int val = power(n, mo - 2);
For(i, 0, n - 1) a[i] = p[i] * val % mo;
}
poly Mul(const poly &a, const poly &b)
{
int sza = a.size() - 1, szb = b.size() - 1;
poly ans(sza + szb + 1);
if (sza <= 30 || szb <= 30)
{
For(i, 0, sza) For(j, 0, szb)
ans[i + j] = (ans[i + j] + 1ll * a[i] * b[j]) % mo;
return ans;
}
int L = FFTinit(sza + szb);
For(i, 0, L - 1) A[i] = (i <= sza ? a[i] : 0);
For(i, 0, L - 1) B[i] = (i <= szb ? b[i] : 0);
DFT(A, L);
DFT(B, L);
For(i, 0, L - 1) A[i] = 1ll * A[i] * B[i] % mo;
IDFT(A, L);
For(i, 0, sza + szb) ans[i] = A[i];
return ans;
}
}
using FFT::Mul;
using FFT::power;
#define mxn 200005
int fact[mxn];
int inv[mxn];

void pre()
{
fact[0] = 1;
int i;
For(i, 1, mxn - 1)
fact[i] = fact[i - 1] * 1ll * i % mo;
inv[mxn - 1] = power(fact[mxn - 1], mo - 2);
Rep(i, mxn - 2, 0)
inv[i] = inv[i + 1] * 1ll * (i + 1) % mo;
}
int nCr(int n, int r)
{
if (n < 0 || r < 0 || n < r || (n == 0 && r != 0))
return 0;
int ans = fact[n] * 1ll * inv[r] % mo;
ans = ans * 1ll * inv[n - r] % mo;
return ans;
}
// a) dC0 x^0 + dC7 x^7 + ..... + dC(d-d%7)) x^(d-d%7)
// b) (N-d)C0 x^0 + (N-d)C1 x^1 + .... + (N-d)C(N-d) x^(N-d)
poly compute(int n, poly &arr)
{
int d = 0;
for (int i = 0; i < n; i++)
if (arr[i] % X == 0)
d++;
poly a(n + 1);
poly b(n + 1);
for (int i = 0; i <= n; i++)
{
if (i % Y == 0)
a[i] = nCr(d, i);
b[i] = nCr(n - d, i);
}
return Mul(a, b);
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(0);

FFT::FFTinit();
pre();
int t;
cin >> t;
while (t--)
{
int n;
cin >> n >>X>>Y;
poly arr(n);
For(i, 0, n - 1)
cin >>
arr[i];
poly ans = compute(n, arr);
int q;
cin >> q;
while (q--)
{
int k;
cin >> k;
cout << ans[k] << '\n';
}
}
cerr << "time taken : " << (float)clock() / CLOCKS_PER_SEC << " secs" << endl;
return 0;
}

Tester's Solution
/*...................................................................*
*............___..................___.....____...______......___....*
*.../|....../...\........./|...../...\...|.............|..../...\...*
*../.|...../.....\......./.|....|.....|..|.............|.../........*
*....|....|.......|...../..|....|.....|..|............/...|.........*
*....|....|.......|..../...|.....\___/...|___......../....|..___....*
*....|....|.......|.../....|...../...\.......\....../.....|./...\...*
*....|....|.......|../_____|__..|.....|.......|..../......|/.....\..*
*....|.....\...../.........|....|.....|.......|.../........\...../..*
*..__|__....\___/..........|.....\___/...\___/.../..........\___/...*
*...................................................................*
*/

#include <bits/stdc++.h>
using namespace std;
#define int long long

const int MOD = 998244353;
const int N = 200005;
int f[N+1],invf[N+1];
vector<int> a,b,v;

int power(int a,int b)
{
if(b==0)
return 1;
else
{
int z=power(a,b/2);
int ans=(z*z)%MOD;
if(b%2)
ans=(ans*a)%MOD;
return ans;
}
}

int inverse(int a)
{
return power(a,MOD-2);
}

int nCr(int n,int r)
{
if(n<r)
return 0;
int ans=f[n];
ans=(ans*invf[r])%MOD;
ans=(ans*invf[n-r])%MOD;
return ans;
}

void fft(vector<int> &a,bool invert)
{
int root = 31;
int root_pw = 1<<23;
int root_1 = inverse(root);
int n = a.size();
for(int i=1,j=0;i<n;i++)
{
int bit=n>>1;
for(;j&bit;bit>>=1)
j^=bit;
j^=bit;
if(i<j)
swap(a[i],a[j]);
}
for(int len=2;len<=n;len<<=1)
{
int wlen = invert? root_1:root;
for(int i=len;i<root_pw;i<<=1)
wlen = (wlen*wlen)%MOD;
for(int i=0;i<n;i+=len)
{
int w = 1;
for(int j=0;j<len/2;j++)
{
int u=a[i+j], v= (a[i+j+len/2]*w)%MOD;
a[i+j] = u+v<MOD? u+v:u+v-MOD;
a[i+j+len/2] = u-v>=0? u-v:u-v+MOD;
w = (w*wlen)%MOD;
}
}
}
if(invert)
{
int n_1=inverse(n);
for(int &x:a)
x = (x*n_1)%MOD;
}
}

void polyMult()
{
int n=1;
while(n<(a.size()+b.size()))
{
n<<=1;
}
vector<int> fa(a.begin(),a.end());
fa.resize(n,0);
vector<int> fb(b.begin(),b.end());
fb.resize(n,0);
fft(fa,false);
fft(fb,false);
v.resize(n,0);
for(int i=0;i<n;i++)
{
v[i]=(fa[i]*fb[i])%MOD;
}
fft(v,true);
return;
}

int32_t main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
f[0]=1;
for(int i=1;i<=N;i++)
f[i]=(f[i-1]*i)%MOD;
for(int i=0;i<=N;i++)
invf[i]=inverse(f[i]);
int tt=1;
cin >> tt;
while(tt--)
{
int n,x,y;
cin >> n >> x >> y;
int arr[n];
for(int i=0;i<n;i++)
cin >> arr[i];
int p = 0;
for(int i=0;i<n;i++)
{
if(arr[i]%x==0)
p++;
}
a.clear();
b.clear();
v.clear();
for(int i=0;i<=p;i++)
{
if(i%y==0)
a.push_back(nCr(p,i));
else
a.push_back(0);
}
for(int i=0;i<=n-p;i++)
{
b.push_back(nCr(n-p,i));
}
polyMult();
int q;
cin >> q;
for(int i=0;i<q;i++)
{
int k;
cin >> k;
cout << v[k] << '\n';
}
}
return 0;
}

2 Likes