Cnt the Array:
PROBLEM LINK:
Division 1
Division 2
Division 3
Author: Nitin Gupta
Testers: Lavish Gupta, Takuki Kurokawa
DIFFICULTY:
EASY-MEDIUM
PREREQUISITES:
Inclusion-Exclusion
PROBLEM:
You are given two integer N and M, along with an array B of length N containing unique integers. You have to
find number of array A of length N with three restriction.
1. 1 \leq A[i] \leq M.
2. A[i] \neq A[j] where i \lt j.
3. For all i , A[i]!=B[i].
QUICK EXPLANATION:
So we have two main restriction ( 2nd,3rd ) , we will follow 2nd instruction and apply inclusion-exclusion on 3rd restriction.
EXPLANATION:
According to problem we have to follow two restriction mainly, 2nd one and 3rd one.
For this we can fix one condition and apply inclusion-exclusion on 3rd condition, we can do this because B contain unique elements otherwise inclusion-exclusion will not work.
Lets first find the total number of array which follow 1st and 2nd restriction. So we have to select n number from among m choices and have to permute them, which is \binom{m}{n}n\,!.
Now lets find the number of array where 1st , 2nd restriction are followed but 3rd restriction is not followed at single position. For that we need number of position where B[i] \leq M, lets assume this is P,now we will select the position where we are going to violate and now fill n-1 position among m-1 choices, so the number of array are \binom{P}{1}*\binom{m-1}{n-1}*(n-1)!.
Now lets find the number of ways where 1st and 2nd restriction are followed but 3rd restriciton is not followed at x position, so we have to choose x position where we are going to violate the condition and choose n-x number from m-x choices and then permute them, so the number or array are \binom{P}{x}*\binom{m-x}{n-x}*(n-x)!.
Now by applying inclusion-exclusion principle, total number of array are:-
\binom{m}{n}n\,! - \binom{P}{1}*\binom{m-1}{n-1}*(n-1)! + \binom{P}{2}*\binom{m-2}{n-2}*(n-2)! + ... +(-1)^P*\binom{P}{P}*\binom{m-P}{n-P}*(n-P)!.
This can be implemented in O(N).
SOLUTIONS:
Setterβs Solution
#include <bits/stdc++.h>
#define int long long
#define endl "\n"
#define mod 1000000007
#define nitin ios_base::sync_with_stdio(false); cin.tie(nullptr)
using namespace std;
const int N = 400001;
int fact[N];
int fact_inv[N];
int power(int x, int y, int p) {
int res = 1;
x = x % p;
while (y > 0) {
if (y & 1)
res = (res * x) % p;
y = y >> 1;
x = (x * x) % p;
}
return res;
}
int modi(int a, int m) {
return power(a, m - 2, m);
}
void pre() {
fact[0] = 1;
for (int i = 1; i < N; i++) {
fact[i] = (fact[i - 1] * i) % mod;
}
fact_inv[N - 1] = modi(fact[N - 1], mod);
for (int i = N - 2; i >= 0; i--) {
fact_inv[i] = (fact_inv[i + 1] * (i + 1)) % mod;
}
}
int ncr(int n, int r) {
if (r > n)
return 0;
return (fact[n] * ((fact_inv[r] * fact_inv[n - r]) % mod)) % mod;
}
void solve() {
int n,m;
cin>>n>>m;
assert(n>=1 && n<=200000);
assert(m>=1 && m<=300000);
vector<int>b;
int p=0;
for(int i=0;i<n;i++){
int val;
cin>>val;
assert(val>=1 && val<=1000000000);
b.push_back(val);
if(val>=1 && val<=m){
p++;
}
}
if(m<n){
cout<<0<<endl;
return;
}
int ans=ncr(m,n)*fact[n];
ans%=mod;
for(int j=1;j<=p;j++){
int val=ncr(p,j);
val*=ncr(m-j,n-j);
val%=mod;
val*=fact[n-j];
val%=mod;
if(j&1){
ans-=val;
}else{
ans+=val;
}
ans%=mod;
ans+=mod;
ans%=mod;
}
cout<<ans<<endl;
}
int32_t main() {
nitin;
pre();
solve();
}
Tester's Solution
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif
template <long long mod>
struct modular {
long long value;
modular(long long x = 0) {
value = x % mod;
if (value < 0) value += mod;
}
modular& operator+=(const modular& other) {
if ((value += other.value) >= mod) value -= mod;
return *this;
}
modular& operator-=(const modular& other) {
if ((value -= other.value) < 0) value += mod;
return *this;
}
modular& operator*=(const modular& other) {
value = value * other.value % mod;
return *this;
}
modular& operator/=(const modular& other) {
long long a = 0, b = 1, c = other.value, m = mod;
while (c != 0) {
long long t = m / c;
m -= t * c;
swap(c, m);
a -= t * b;
swap(a, b);
}
a %= mod;
if (a < 0) a += mod;
value = value * a % mod;
return *this;
}
friend modular operator+(const modular& lhs, const modular& rhs) { return modular(lhs) += rhs; }
friend modular operator-(const modular& lhs, const modular& rhs) { return modular(lhs) -= rhs; }
friend modular operator*(const modular& lhs, const modular& rhs) { return modular(lhs) *= rhs; }
friend modular operator/(const modular& lhs, const modular& rhs) { return modular(lhs) /= rhs; }
modular& operator++() { return *this += 1; }
modular& operator--() { return *this -= 1; }
modular operator++(int) {
modular res(*this);
*this += 1;
return res;
}
modular operator--(int) {
modular res(*this);
*this -= 1;
return res;
}
modular operator-() const { return modular(-value); }
bool operator==(const modular& rhs) const { return value == rhs.value; }
bool operator!=(const modular& rhs) const { return value != rhs.value; }
bool operator<(const modular& rhs) const { return value < rhs.value; }
};
template <long long mod>
string to_string(const modular<mod>& x) {
return to_string(x.value);
}
template <long long mod>
ostream& operator<<(ostream& stream, const modular<mod>& x) {
return stream << x.value;
}
template <long long mod>
istream& operator>>(istream& stream, modular<mod>& x) {
stream >> x.value;
x.value %= mod;
if (x.value < 0) x.value += mod;
return stream;
}
constexpr long long mod = (long long) 1e9 + 7;
using mint = modular<mod>;
mint power(mint a, long long n) {
mint res = 1;
while (n > 0) {
if (n & 1) {
res *= a;
}
a *= a;
n >>= 1;
}
return res;
}
vector<mint> fact(1, 1);
vector<mint> finv(1, 1);
mint C(int n, int k) {
if (n < k || k < 0) {
return mint(0);
}
while ((int) fact.size() < n + 1) {
fact.emplace_back(fact.back() * (int) fact.size());
finv.emplace_back(mint(1) / fact.back());
}
return fact[n] * finv[k] * finv[n - k];
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
int n, m;
cin >> n >> m;
vector<int> b(n);
for (int i = 0; i < n; i++) {
cin >> b[i];
}
if (m < n) {
cout << 0 << '\n';
return 0;
}
sort(b.begin(), b.end());
while (!b.empty() && b.back() > m) {
b.pop_back();
}
int k = (int) b.size();
mint ans = 0;
C(n + m, 0);
for (int i = 0; i <= k; i++) {
mint t = C(k, i) * C(m - i, n - i) * fact[n - i];
if (i % 2 == 0) {
ans += t;
} else {
ans -= t;
}
}
cout << ans << '\n';
return 0;
}