FIZZBUZZ2311 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Authors: naisheel, jalp1428
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

2833

PREREQUISITES:

Sieve of Eratosthenes, FFT (specifically, NTT)

PROBLEM:

Given N and M, find the number of arrays of length N containing distinct integers between 1 and M, such that if i \lt j, then A_i shouldn’t have more distinct prime factors than A_j.

EXPLANATION:

Let’s first compute the number of distinct prime factors for each integer between 1 and M.
This can be easily done with a sieve in \mathcal{O}(M\log\log M).

Let \text{freq}[x] be the number of integers with x distinct prime factors.

Suppose we fix which N elements out of [1, M] are going to be placed in A.
Then, since we want \text{pr}(A_i) \leq \text{pr}(A_j) whenever i \lt j, the order of these elements is almost fixed.
The only leeway we have is when two elements have the same prime count, in which case they can appear in any order.
More generally, we can order the elements with a fixed prime count however we like, but that’s also the only freedom we have.

This leads to a somewhat natural dynamic programming solution.
Let \text{dp}(i, j) denote the number of valid arrays of length j such that we’ve only used elements with prime counts \leq i.
Then, we have

\text{dp}(i, j) = \sum_{k=0}^{\text{freq}[i]} \text{dp}(i-1, j-k) \cdot \binom{\text{freq}[i]}{k}\cdot k!

That is,

  • Fix k, the number of elements with prime count = i that our array contains.
  • Fix which k integers with this prime count we’re choosing, and their order: \binom{\text{freq}[x]}{k}\cdot k! possibilities.
  • Finally, the remaining array has j-k elements with prime count \lt i, which there are \text{dp}(i-1, j-k) ways to form.

Now, since M \leq 10^5 any integer has at most 6 distinct prime factors, meaning we’re only concerned with i \leq 6.
This makes our dp have 7N states, with an \mathcal{O}(N) transition from each one, for \mathcal{O}(N^2) overall.

To optimize this, observe that the transition structure is somewhat special: it’s the sum of a product of terms that are based on j-k and k (meaning they sum up to j).
This should remind you of polynomial multiplication!

That is, if we had the polynomials

p(x) = \sum_{k=0}^N \text{dp}(i-1, k) x^k \\ q(x) = \sum_{k=0}^N\left(\binom{\text{freq}[i]}{k}\cdot k! \right)x^k

Then the j-th coefficient of their product would be exactly \text{dp}(i, j).

The product of two N-degree polynomials can be computed in \mathcal{O}(N\log N) with the help of the Fast Fourier Transform.
In our case, we want the coefficients modulo 998244353, which is achieved by a slight modification of FFT called the Number Theoretic Transform.
You may read about them here.

So, simply replacing the dp transition computation with NTT speeds up our algorithm to \mathcal{O}(7N\log N), which is fast enough to get AC.

TIME COMPLEXITY

\mathcal{O}(M\log^2 M) per testcase.

CODE:

Author's code (C++)
#include<bits/stdc++.h>
using namespace std;

// -------------------- Input Checker Start --------------------

// This function reads a long long, character by character, and returns it as a whole long long. It makes sure that it lies in the range [l, r], and the character after the long long is endd. l and r should be in [-1e18, 1e18].
long long readInt(long long l, long long r, char endd)
{
    long long x = 0;
    int cnt = 0, fi = -1;
    bool is_neg = false;
    while(true)
    {
        char g = getchar();
        if(g == '-')
        {
            if(!(fi == -1))
                cerr << "- in between integer\n";
            assert(fi == -1);
            is_neg = true; // It's a negative integer
            continue;
        }
        if('0' <= g && g <= '9')
        {
            x *= 10;
            x += g - '0';
            if(cnt == 0)
                fi = g - '0'; // fi is the first digit
            cnt++;
            
            // There shouldn't be leading zeroes. eg. "02" is not valid and assert will fail here.
            if(!(fi != 0 || cnt == 1))
                cerr << "Leading zeroes found\n";
            assert(fi != 0 || cnt == 1); 
            
            // "-0" is invalid
            if(!(fi != 0 || is_neg == false))
                cerr << "-0 found\n";
            assert(fi != 0 || is_neg == false); 
            
            // The maximum number of digits should be 19, and if it is 19 digits long, then the first digit should be a '1'.
            if(!(!(cnt > 19 || (cnt == 19 && fi > 1))))
                cerr << "Value greater than 1e18 found\n";
            assert(!(cnt > 19 || (cnt == 19 && fi > 1))); 
        }
        else if(g == endd)
        {
            if(is_neg)
                x = -x;
            if(!(l <= x && x <= r))
            {
                // We've reached the end, but the long long isn't in the right range.
                cerr << "Constraint violated: Lower Bound = " << l << " Upper Bound = " << r << " Violating Value = " << x << '\n'; 
                assert(false); 
            }
            return x;
        }
        else if((g == ' ') && (endd == '\n'))
        {
            cerr << "Extra space found. It should instead have been a new line.\n";
            assert(false);
        }
        else if((g == '\n') && (endd == ' '))
        {
            cerr << "A new line found where it should have been a space.\n";
            assert(false);
        }
        else
        {
            cerr << "Something weird has happened.\n";
            assert(false);
        }
    }
}

string readString(int l, int r, char endd)
{
    string ret = "";
    int cnt = 0;
    while(true)
    {
        char g = getchar();
        
        assert(g != -1);
        if(g == endd)
            break;
        cnt++;
        ret += g;
    }
    if(!(l <= cnt && cnt <= r))
        cerr << "String length not within constraints\n";
    assert(l <= cnt && cnt <= r);
    return ret;
}

long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
void readEOF() 
{ 
    char g = getchar();
    if(g != EOF)
    {
        if(g == ' ')
            cerr << "Extra space found where the file shold have ended\n";
        if(g == '\n')
            cerr << "Extra newline found where the file shold have ended\n";
        else
            cerr << "File didn't end where expected\n";
    }
    assert(g == EOF); 
}

vector<int> readVectorInt(int n, long long l, long long r)
{
    vector<int> a(n);
    for(int i = 0; i < n - 1; i++)
        a[i] = readIntSp(l, r);
    a[n - 1] = readIntLn(l, r);
    return a;
}

bool checkStringContents(string &s, char l, char r) {
    for(char x: s) {
        if (x < l || x > r) {
            cerr << "String is not valid\n";
            return false;
        }
    }
    return true;
}

bool isStringBinary(string &s) {
    return checkStringContents(s, '0', '1');
}

bool isStringLowerCase(string &s) {
    return checkStringContents(s, 'a', 'z');
}
bool isStringUpperCase(string &s) {
    return checkStringContents(s, 'A', 'Z');
}

bool isArrayDistinct(vector<int> a) {
    sort(a.begin(), a.end());
    for(int i = 1 ; i < a.size() ; ++i) {
        if (a[i] == a[i-1])
        return false;
    }
    return 1;
}

bool isPermutation(vector<int> &a) {
    int n = a.size();
    vector<int> done(n);
    for(int x: a) {
      if (x <= 0 || x > n || done[x-1]) {
        cerr << "Not a valid permutation\n";
        return false;
      }
      done[x-1]=1;
    }
    return true;
}

// -------------------- Input Checker End --------------------

typedef long long int ll;
const int MOD=998244353;
ll modpower(ll n,ll a,ll p){ ll res=1; while(a){ if(a%2) res= ((res*n)%p) ,a--; else n=((n*n)%p),a/=2;} return res;}


#define sz(x) int(x.size())
typedef pair<int, int> ii;
typedef vector<int> vi;
typedef long double ld;
const ld PI = acos((ld)-1);
namespace FFT {
	struct com {
		ld x, y;
 
		com(ld _x = 0, ld _y = 0) : x(_x), y(_y) {}
 
		inline com operator + (const com &c) const {
			return com(x + c.x, y + c.y);
		}
		inline com operator - (const com &c) const {
			return com(x - c.x, y - c.y);
		}
		inline com operator * (const com &c) const {
			return com(x * c.x - y * c.y, x * c.y + y * c.x);
		}
		inline com conj() const {
			return com(x, -y);
		}
	};
 
	const static int maxk = 19, maxn = (1 << maxk) + 1;
	com ws[maxn];
	int dp[maxn];
	com rs[maxn];
	int n, k;
	int lastk = -1;
 
	void fft(com *a, bool torev = 0) {
		if (lastk != k) {
			lastk = k;
			dp[0] = 0;
 
			for (int i = 1, g = -1; i < n; ++i) {
				if (!(i & (i - 1))) {
					++g;
				}
				dp[i] = dp[i ^ (1 << g)] ^ (1 << (k - 1 - g));
			}
 
			ws[1] = com(1, 0);
			for (int two = 0; two < k - 1; ++two) {
				ld alf = PI / n * (1 << (k - 1 - two));
				com cur = com(cos(alf), sin(alf));
 
				int p2 = (1 << two), p3 = p2 * 2;
				for (int j = p2; j < p3; ++j) {
					ws[j * 2 + 1] = (ws[j * 2] = ws[j]) * cur;
				}
			}
		}
		for (int i = 0; i < n; ++i) {
			if (i < dp[i]) {
				swap(a[i], a[dp[i]]);
			}
		}
		if (torev) {
			for (int i = 0; i < n; ++i) {
				a[i].y = -a[i].y;
			}
		}
		for (int len = 1; len < n; len <<= 1) {
			for (int i = 0; i < n; i += len) {
				int wit = len;
				for (int it = 0, j = i + len; it < len; ++it, ++i, ++j) {
					com tmp = a[j] * ws[wit++];
					a[j] = a[i] - tmp;
					a[i] = a[i] + tmp;
				}
			}
		}
	}
 
	com a[maxn];
	vector<ll> mult(vector<ll> &_a, vector<ll> &_b) {
		int na = sz(_a), nb = sz(_b);
		
		for (k = 0, n = 1; n < na + nb - 1; n <<= 1, ++k);
		//assert(n < maxn);
		for (int i = 0; i < n; ++i) {
			a[i] = com(i < na ? _a[i] : 0, i < nb ? _b[i] : 0);
		}
		fft(a);
		a[n] = a[0];
		for (int i = 0; i <= n - i; ++i) {
			a[i] = (a[i] * a[i] - (a[n - i] * a[n - i]).conj()) * com(0, (ld)-1 / n / 4);
			a[n - i] = a[i].conj();
		}
		fft(a, 1);
		int res = 0;
 
		vector<ll> ans(n);
		for (int i = 0; i < n; ++i) {
			ll val = (ll) round(a[i].x);
			ans[i] = val;//para mutiplicar polinomios
		}
		return ans;
	}
};

const int N=1e5+1;
bool isp[N];
int divsr[N];
int diffdiv[N];
vector<int> pr;
void build_primes() // 79000 primes for N=1e6 in O(N)
{
    divsr[1]=1;
    for(int i=2;i<N;i++)
    {
        if(divsr[i]==0)
        {
            divsr[i]=i;
            isp[i]=1;
            pr.push_back(i);
        }
        for (int j=0;j<(int)pr.size() && pr[j]<=divsr[i] && i*pr[j]<N;j++)
        divsr[i*pr[j]]=pr[j];
    }
    diffdiv[0]=0;
    diffdiv[1]=0;
    for(int i=2;i<N;i++){
        set<int> st;
        int tmp=i;
        while(tmp!=1){
            st.insert(divsr[tmp]);
            tmp/=divsr[tmp];
        }
        diffdiv[i]=st.size();
    }
}

ll fact[N];
ll ifact[N];
void build_fact()
{
    fact[0]=1;
    for(ll i=1;i<N;i++)   fact[i]=(fact[i-1]*i)%MOD;
    
    ifact[N-1]=modpower(fact[N-1],MOD-2,MOD);
    for(ll i=N-2;i>=0;i--) ifact[i]=(ifact[i+1]*(i+1))%MOD;
}
ll npr(ll n,ll r)
{
    if(r>n) return 0;
    if(n<0 || r<0) return 0;
    
    ll res=fact[n];
    res=(res*ifact[n-r])%MOD;
    
    return res;
}

void solve(){
    int n,m;
    n=readIntSp(1,1e5);
    m=readIntLn(n,1e5);
    int cnt[8]={0};
    for(int i=1;i<=m;i++){
        cnt[diffdiv[i]]++;
    }
    vector<ll> ans(n+1,0);
    ans[0]=1;
    for(int i=0;i<8;i++){
        vector<ll> mul(n+1,0);
        for(int j=0;j<=n;j++){
            mul[j]=npr(cnt[i],j);
        }
        ans=FFT::mult(mul,ans);
        ans.resize(n+1);
        for(int i=0;i<=n;i++){
            ans[i]%=MOD;
        }
    }
    cout<<ans[n]<<endl;
}

int main(){
    int tc;
    tc=readIntLn(1,1000);
    build_primes();
    build_fact();
    while(tc--){
        solve();
    }
    readEOF();
    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            res += buffer[pos];
            assert(!isspace(buffer[pos]));
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

template <long long mod>
struct modular {
    long long value;
    modular(long long x = 0) {
        value = x % mod;
        if (value < 0) value += mod;
    }
    modular& operator+=(const modular& other) {
        if ((value += other.value) >= mod) value -= mod;
        return *this;
    }
    modular& operator-=(const modular& other) {
        if ((value -= other.value) < 0) value += mod;
        return *this;
    }
    modular& operator*=(const modular& other) {
        value = value * other.value % mod;
        return *this;
    }
    modular& operator/=(const modular& other) {
        long long a = 0, b = 1, c = other.value, m = mod;
        while (c != 0) {
            long long t = m / c;
            m -= t * c;
            swap(c, m);
            a -= t * b;
            swap(a, b);
        }
        a %= mod;
        if (a < 0) a += mod;
        value = value * a % mod;
        return *this;
    }
    friend modular operator+(const modular& lhs, const modular& rhs) { return modular(lhs) += rhs; }
    friend modular operator-(const modular& lhs, const modular& rhs) { return modular(lhs) -= rhs; }
    friend modular operator*(const modular& lhs, const modular& rhs) { return modular(lhs) *= rhs; }
    friend modular operator/(const modular& lhs, const modular& rhs) { return modular(lhs) /= rhs; }
    modular& operator++() { return *this += 1; }
    modular& operator--() { return *this -= 1; }
    modular operator++(int) {
        modular res(*this);
        *this += 1;
        return res;
    }
    modular operator--(int) {
        modular res(*this);
        *this -= 1;
        return res;
    }
    modular operator-() const { return modular(-value); }
    bool operator==(const modular& rhs) const { return value == rhs.value; }
    bool operator!=(const modular& rhs) const { return value != rhs.value; }
    bool operator<(const modular& rhs) const { return value < rhs.value; }
};
template <long long mod>
string to_string(const modular<mod>& x) {
    return to_string(x.value);
}
template <long long mod>
ostream& operator<<(ostream& stream, const modular<mod>& x) {
    return stream << x.value;
}
template <long long mod>
istream& operator>>(istream& stream, modular<mod>& x) {
    stream >> x.value;
    x.value %= mod;
    if (x.value < 0) x.value += mod;
    return stream;
}

constexpr long long mod = 998244353;
using mint = modular<mod>;

mint power(mint a, long long n) {
    mint res = 1;
    while (n > 0) {
        if (n & 1) {
            res *= a;
        }
        a *= a;
        n >>= 1;
    }
    return res;
}

vector<mint> fact(1, 1);
vector<mint> finv(1, 1);

mint C(int n, int k) {
    if (n < k || k < 0) {
        return mint(0);
    }
    while ((int) fact.size() < n + 1) {
        fact.emplace_back(fact.back() * (int) fact.size());
        finv.emplace_back(mint(1) / fact.back());
    }
    return fact[n] * finv[k] * finv[n - k];
}

namespace NTT {
mint root;
int base;
int max_base;
vector<mint> roots;
vector<int> rev;

void ensure_base(int nbase) {
    if (roots.empty()) {
        auto tmp = mod - 1;
        max_base = 0;
        while (tmp % 2 == 0) {
            tmp /= 2;
            max_base++;
        }
        root = 2;
        while (power(root, (mod - 1) >> 1) == 1) {
            root += 1;
        }
        root = power(root, (mod - 1) >> max_base);
        base = 1;
        rev = {0, 1};
        roots = {0, 1};
    }
    if (nbase <= base) {
        return;
    }
    rev.resize(1 << nbase);
    for (int i = 0; i < (1 << nbase); i++) {
        rev[i] = (rev[i >> 1] >> 1) + ((i & 1) << (nbase - 1));
    }
    roots.resize(1 << nbase);
    while (base < nbase) {
        mint z = power(root, 1 << (max_base - 1 - base));
        for (int i = 1 << (base - 1); i < (1 << base); i++) {
            roots[i << 1] = roots[i];
            roots[(i << 1) + 1] = roots[i] * z;
        }
        base++;
    }
}

void ntt(vector<mint>& a) {
    int n = (int) a.size();
    int zeros = __builtin_ctz(n);
    ensure_base(zeros);
    int shift = base - zeros;
    for (int i = 0; i < n; i++) {
        if (i < (rev[i] >> shift)) {
            swap(a[i], a[rev[i] >> shift]);
        }
    }
    for (int k = 1; k < n; k <<= 1) {
        for (int i = 0; i < n; i += 2 * k) {
            for (int j = 0; j < k; j++) {
                mint x = a[i + j];
                mint y = a[i + j + k] * roots[j + k];
                a[i + j] = x + y;
                a[i + j + k] = x - y;
            }
        }
    }
}

vector<mint> multiply(vector<mint> a, vector<mint> b) {
    int need = (int) a.size() + (int) b.size() - 1;
    int nbase = 0;
    while ((1 << nbase) < need) {
        nbase++;
    }
    ensure_base(nbase);
    int sz = 1 << nbase;
    a.resize(sz);
    b.resize(sz);
    ntt(a);
    ntt(b);
    mint inv = mint(1) / mint(sz);
    for (int i = 0; i < sz; i++) {
        a[i] *= b[i] * inv;
    }
    reverse(a.begin() + 1, a.end());
    ntt(a);
    a.resize(need);
    return a;
}
}  // namespace NTT

vector<mint> operator*(const vector<mint>& a, const vector<mint>& b) {
    if (a.empty() || b.empty()) {
        return {};
    } else if (min(a.size(), b.size()) < 150) {
        vector<mint> c(a.size() + b.size() - 1);
        for (int i = 0; i < (int) a.size(); i++) {
            for (int j = 0; j < (int) b.size(); j++) {
                c[i + j] += a[i] * b[j];
            }
        }
        return c;
    } else {
        return NTT::multiply(a, b);
    }
}

vector<mint>& operator*=(vector<mint>& a, const vector<mint>& b) {
    return a = a * b;
}

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e3);
    in.readEoln();
    int sm = 0;
    while (tt--) {
        int n = in.readInt(1, 1e5);
        in.readSpace();
        int m = in.readInt(n, 1e5);
        in.readEoln();
        sm += m;
        vector<int> p(m + 1);
        for (int i = 2; i <= m; i++) {
            if (p[i] != 0) {
                continue;
            }
            for (int j = i; j <= m; j += i) {
                p[j]++;
            }
        }
        map<int, int> cnt;
        for (int i = 1; i <= m; i++) {
            cnt[p[i]]++;
        }
        vector<vector<mint>> a;
        for (auto [_, k] : cnt) {
            vector<mint> b(k + 1);
            for (int i = 0; i <= k; i++) {
                mint c = C(k, i);
                c *= fact[i];
                b[i] = c;
            }
            a.emplace_back(b);
        }
        while (a.size() > 1) {
            vector<vector<mint>> new_a;
            int k = (int) a.size();
            for (int i = 0; i < k; i += 2) {
                if (i + 1 < k) {
                    auto c = a[i] * a[i + 1];
                    new_a.emplace_back(c);
                } else {
                    new_a.emplace_back(a[i]);
                }
            }
            swap(a, new_a);
        }
        cout << a[0][n] << '\n';
    }
    assert(sm <= 1e5);
    in.readEof();
    return 0;
}
Editorialist's code (Python)
from collections import Counter
# NTT implementation based on https://codeforces.com/blog/entry/117947

# NTT prime
MOD = (119 << 23) + 1

non_quad_res = 2
while pow(non_quad_res, MOD//2, MOD) != MOD - 1:
    non_quad_res += 1
rt = [1]

def ntt(P):
    n = len(P)
    P = list(P)
    assert n and (n - 1) & n == 0
    
    while 2 * len(rt) < n:
        # 4*len(rt)-th root of unity
        root = pow(non_quad_res, MOD // (4*len(rt)), MOD)
        rt.extend([r * root % MOD for r in rt])

    k = n
    while k > 1:
        for i in range(n//k):
            r = rt[i]
            for j1 in range(i*k, i*k + k//2):
                j2 = j1 + k//2
                z = r * P[j2]
                P[j2] = (P[j1] - z) % MOD
                P[j1] = (P[j1] + z) % MOD
        k //= 2
    
    rev = [0] * n
    for i in range(1, n):
        rev[i] = rev[i // 2] // 2 + (i & 1) * n // 2
    return [P[r] for r in rev]

def intt(P):
    n = len(P)
    ninv = pow(n, MOD - 2, MOD)
    return ntt([P[-i] * ninv % MOD for i in range(n)])

def ntt_conv(P, Q):
    m = len(P) + len(Q) - 1
    n = 1 << m.bit_length()

    P = P + [0] * (n - len(P))
    Q = Q + [0] * (n - len(Q))
    P, Q = ntt(P), ntt(Q)

    return intt([p * q % MOD for p,q in zip(P, Q)])[:m]

for _ in range(int(input())):
    n, m = map(int, input().split())
    prms = [0]*(m+1)
    for i in range(2, m+1):
        if prms[i] == 0:
            for j in range(i, m+1, i): prms[j] += 1
    freq = Counter(prms[1:])
    ans = [1]
    for x in freq:
        have = freq[x]
        poly = [1]*(have+1)
        for i in range(1, have+1): poly[i] = poly[i-1] * (have+1-i) % MOD
        ans = ntt_conv(ans, poly)
    print(ans[n])
1 Like