FNLPRBLM - Editorial

PROBLEM LINK:

Practice
Contest

Author and Editorialist: still_me
Tester: Jay_1048576

DIFFICULTY:

EASY-MEDIUM

PREREQUISITES:

Combinatorics , FFT/NTT

PROBLEM:

Given an array of size N , and two integer X and Y . You have to answer Q queries, for each query you have to find how many ways to form a subsequence of size K of the original array such that the number of elements in the subsequence which are divisible by X is divisible by Y.

QUICK EXPLANATION:

Let’s count the number of elements that are divisible by X, and call this number d.
Then there will be N-d elements that are not divisible by X.

Now for a given K:

let S be set of all the numbers less than or equal to \leq K which are divisible by X.

for example if X=7 and K=35
X = {0 , 7 , 14 , 21 , 28 , 35}

We need to find the number of ways to simultaneously select X_i distinct elements from d distinct elements and K-X_i distinct elements from N-d distinct elements, and take sum for all the values in the set X.

It will be of the form:

^dC_{X_1} . ^{N-d}C_{k-{X_1}} + ^dC_{X_2} . ^{N-d}C_{k-{X_2}} +…

which in turn is coefficient of x^k in the product of polynomials

1. ^dC_{X_1} . x^{X_1} + ^dC_{X_2} . x^{X_2} + . . . + ^dC_{X_i}.x^{X_i} + . . .

2. ^{N-d}C_0 . x^0 + ^{N-d}C_1 . x^1 + . . . + ^{N-d}C_{N-d}

Brute Force way to multiply polynomial takes O(n^2) time, but we can use FFT to do the polynomial multiplication in O(nlogn) time.
Then we can answer each query in O(1) time.

SOLUTIONS:

Setter and Editorialist's Solution
//	Code by Sahil Tiwari (still_me)

#include <bits/stdc++.h>
#define ll long long
#define ull unsigned ll
#define For(i, j, k) for (int i = (int)(j); i <= (int)(k); i++)
#define Rep(i, j, k) for (int i = (int)(j); i >= (int)(k); i--)
using namespace std;
const int mo = 998244353;
const int FFTN = 1 << 18;
#define poly vector<int>
int X , Y;
namespace FFT
{
    int w[FFTN + 5], W[FFTN + 5], R[FFTN + 5];
    int power(int x, int y)
    {
        int s = 1;
        for (; y; y /= 2, x = 1ll * x * x % mo)
            if (y & 1)
                s = 1ll * s * x % mo;
        return s;
    }
    void FFTinit()
    {
        W[0] = 1;
        W[1] = power(3, (mo - 1) / FFTN);
        For(i, 2, FFTN) W[i] = 1ll * W[i - 1] * W[1] % mo;
    }
    int FFTinit(int n)
    {
        int L = 1;
        for (; L <= n; L <<= 1)
            ;
        For(i, 0, L - 1) R[i] = (R[i >> 1] >> 1) | ((i & 1) ? (L >> 1) : 0);
        return L;
    }
    int A[FFTN + 5], B[FFTN + 5];
    ull p[FFTN + 5];
    void DFT(int *a, int n)
    {
        For(i, 0, n - 1) p[R[i]] = a[i];
        for (int d = 1; d < n; d <<= 1)
        {
            int len = FFTN / (d << 1);
            for (int i = 0, j = 0; i < d; i++, j += len)
                w[i] = W[j];
            for (int i = 0; i < n; i += (d << 1))
                for (int j = 0; j < d; j++)
                {
                    int y = p[i + j + d] * w[j] % mo;
                    p[i + j + d] = p[i + j] + mo - y;
                    p[i + j] += y;
                }
            if (d == 1 << 15)
                For(i, 0, n - 1) p[i] %= mo;
        }
        For(i, 0, n - 1) a[i] = p[i] % mo;
    }
    void IDFT(int *a, int n)
    {
        For(i, 0, n - 1) p[R[i]] = a[i];
        for (int d = 1; d < n; d <<= 1)
        {
            int len = FFTN / (d << 1);
            for (int i = 0, j = FFTN; i < d; i++, j -= len)
                w[i] = W[j];
            for (int i = 0; i < n; i += (d << 1))
                for (int j = 0; j < d; j++)
                {
                    int y = p[i + j + d] * w[j] % mo;
                    p[i + j + d] = p[i + j] + mo - y;
                    p[i + j] += y;
                }
            if (d == 1 << 15)
                For(i, 0, n - 1) p[i] %= mo;
        }
        int val = power(n, mo - 2);
        For(i, 0, n - 1) a[i] = p[i] * val % mo;
    }
    poly Mul(const poly &a, const poly &b)
    {
        int sza = a.size() - 1, szb = b.size() - 1;
        poly ans(sza + szb + 1);
        if (sza <= 30 || szb <= 30)
        {
            For(i, 0, sza) For(j, 0, szb)
                ans[i + j] = (ans[i + j] + 1ll * a[i] * b[j]) % mo;
            return ans;
        }
        int L = FFTinit(sza + szb);
        For(i, 0, L - 1) A[i] = (i <= sza ? a[i] : 0);
        For(i, 0, L - 1) B[i] = (i <= szb ? b[i] : 0);
        DFT(A, L);
        DFT(B, L);
        For(i, 0, L - 1) A[i] = 1ll * A[i] * B[i] % mo;
        IDFT(A, L);
        For(i, 0, sza + szb) ans[i] = A[i];
        return ans;
    }
}
using FFT::Mul;
using FFT::power;
#define mxn 200005
int fact[mxn];
int inv[mxn];

void pre()
{
    fact[0] = 1;
    int i;
    For(i, 1, mxn - 1)
        fact[i] = fact[i - 1] * 1ll * i % mo;
    inv[mxn - 1] = power(fact[mxn - 1], mo - 2);
    Rep(i, mxn - 2, 0)
        inv[i] = inv[i + 1] * 1ll * (i + 1) % mo;
}
int nCr(int n, int r)
{
    if (n < 0 || r < 0 || n < r || (n == 0 && r != 0))
        return 0;
    int ans = fact[n] * 1ll * inv[r] % mo;
    ans = ans * 1ll * inv[n - r] % mo;
    return ans;
}
// a) dC0 x^0 + dC7 x^7 + ..... + dC(d-d%7)) x^(d-d%7)
// b) (N-d)C0 x^0 + (N-d)C1 x^1 + .... + (N-d)C(N-d) x^(N-d)
poly compute(int n, poly &arr)
{
    int d = 0;
    for (int i = 0; i < n; i++)
        if (arr[i] % X == 0)
            d++;
    poly a(n + 1);
    poly b(n + 1);
    for (int i = 0; i <= n; i++)
    {
        if (i % Y == 0)
            a[i] = nCr(d, i);
        b[i] = nCr(n - d, i);
    }
    return Mul(a, b);
}
int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(0);

    FFT::FFTinit();
    pre();
    int t;
    cin >> t;
    while (t--)
    {
        int n;
        cin >> n >>X>>Y;
        poly arr(n);
        For(i, 0, n - 1)
                cin >>
            arr[i];
        poly ans = compute(n, arr);
        int q;
        cin >> q;
        while (q--)
        {
            int k;
            cin >> k;
            cout << ans[k] << '\n';
        }
    }
    cerr << "time taken : " << (float)clock() / CLOCKS_PER_SEC << " secs" << endl;
    return 0;
}
Tester's Solution
/*...................................................................*
 *............___..................___.....____...______......___....*
 *.../|....../...\........./|...../...\...|.............|..../...\...*
 *../.|...../.....\......./.|....|.....|..|.............|.../........*
 *....|....|.......|...../..|....|.....|..|............/...|.........*
 *....|....|.......|..../...|.....\___/...|___......../....|..___....*
 *....|....|.......|.../....|...../...\.......\....../.....|./...\...*
 *....|....|.......|../_____|__..|.....|.......|..../......|/.....\..*
 *....|.....\...../.........|....|.....|.......|.../........\...../..*
 *..__|__....\___/..........|.....\___/...\___/.../..........\___/...*
 *...................................................................*
 */
 
#include <bits/stdc++.h>
using namespace std;
#define int long long

const int MOD = 998244353;
const int N = 200005;
int f[N+1],invf[N+1];
vector<int> a,b,v;

int power(int a,int b)
{
	if(b==0)
		return 1;
	else
	{
		int z=power(a,b/2);
		int ans=(z*z)%MOD;
		if(b%2)
			ans=(ans*a)%MOD;
		return ans;
	}
}

int inverse(int a)
{
	return power(a,MOD-2);
}

int nCr(int n,int r)
{
	if(n<r)
		return 0;
	int ans=f[n];
	ans=(ans*invf[r])%MOD;
	ans=(ans*invf[n-r])%MOD;
	return ans;
}

void fft(vector<int> &a,bool invert) 
{
	int root = 31;
	int root_pw = 1<<23;
	int root_1 = inverse(root);
    int n = a.size();
    for(int i=1,j=0;i<n;i++) 
    {
        int bit=n>>1;
        for(;j&bit;bit>>=1)
            j^=bit;
        j^=bit;
        if(i<j)
            swap(a[i],a[j]);
    }
    for(int len=2;len<=n;len<<=1) 
    {
        int wlen = invert? root_1:root;
        for(int i=len;i<root_pw;i<<=1)
            wlen = (wlen*wlen)%MOD;
        for(int i=0;i<n;i+=len) 
        {
            int w = 1;
            for(int j=0;j<len/2;j++) 
            {
                int u=a[i+j], v= (a[i+j+len/2]*w)%MOD;
                a[i+j] = u+v<MOD? u+v:u+v-MOD;
                a[i+j+len/2] = u-v>=0? u-v:u-v+MOD;
                w = (w*wlen)%MOD;
            }
        }
    }
    if(invert) 
    {
        int n_1=inverse(n);
        for(int &x:a)
            x = (x*n_1)%MOD;
    }
}

void polyMult()
{
    int n=1;
    while(n<(a.size()+b.size()))
    { 
        n<<=1;
    }
    vector<int> fa(a.begin(),a.end());
    fa.resize(n,0);
    vector<int> fb(b.begin(),b.end());
    fb.resize(n,0);
    fft(fa,false);
    fft(fb,false);
    v.resize(n,0);
    for(int i=0;i<n;i++)
    {
        v[i]=(fa[i]*fb[i])%MOD;
    }
    fft(v,true);
    return;
}

int32_t main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    f[0]=1;
	for(int i=1;i<=N;i++)
		f[i]=(f[i-1]*i)%MOD;
	for(int i=0;i<=N;i++)
		invf[i]=inverse(f[i]);
    int tt=1;
    cin >> tt;
    while(tt--)
    {
        int n,x,y;
        cin >> n >> x >> y;
        int arr[n];
        for(int i=0;i<n;i++)
            cin >> arr[i];
        int p = 0;
        for(int i=0;i<n;i++)
        {
            if(arr[i]%x==0)
                p++;
        }
        a.clear();
        b.clear();
        v.clear();
        for(int i=0;i<=p;i++)
        {
            if(i%y==0)
                a.push_back(nCr(p,i));
            else
                a.push_back(0);
        }
        for(int i=0;i<=n-p;i++)
        {
            b.push_back(nCr(n-p,i));
        }
		polyMult();
		int q;
		cin >> q;
		for(int i=0;i<q;i++)
		{
		    int k;
		    cin >> k;
		    cout << v[k] << '\n';
		}
    }
    return 0;
}
2 Likes