CNTARRAY - Editorial

Cnt the Array:

PROBLEM LINK:

Practice

Division 1
Division 2
Division 3

Author: Nitin Gupta

Testers: Lavish Gupta, Takuki Kurokawa

DIFFICULTY:

EASY-MEDIUM

PREREQUISITES:

Inclusion-Exclusion

PROBLEM:

You are given two integer N and M, along with an array B of length N containing unique integers. You have to
find number of array A of length N with three restriction.

1. 1 \leq A[i] \leq M.

2. A[i] \neq A[j] where i \lt j.

3. For all i , A[i]!=B[i].

QUICK EXPLANATION:

So we have two main restriction ( 2nd,3rd ) , we will follow 2nd instruction and apply inclusion-exclusion on 3rd restriction.

EXPLANATION:

According to problem we have to follow two restriction mainly, 2nd one and 3rd one.
For this we can fix one condition and apply inclusion-exclusion on 3rd condition, we can do this because B contain unique elements otherwise inclusion-exclusion will not work.
Lets first find the total number of array which follow 1st and 2nd restriction. So we have to select n number from among m choices and have to permute them, which is \binom{m}{n}n\,!.

Now lets find the number of array where 1st , 2nd restriction are followed but 3rd restriction is not followed at single position. For that we need number of position where B[i] \leq M, lets assume this is P,now we will select the position where we are going to violate and now fill n-1 position among m-1 choices, so the number of array are \binom{P}{1}*\binom{m-1}{n-1}*(n-1)!.

Now lets find the number of ways where 1st and 2nd restriction are followed but 3rd restriciton is not followed at x position, so we have to choose x position where we are going to violate the condition and choose n-x number from m-x choices and then permute them, so the number or array are \binom{P}{x}*\binom{m-x}{n-x}*(n-x)!.

Now by applying inclusion-exclusion principle, total number of array are:-

\binom{m}{n}n\,! - \binom{P}{1}*\binom{m-1}{n-1}*(n-1)! + \binom{P}{2}*\binom{m-2}{n-2}*(n-2)! + ... +(-1)^P*\binom{P}{P}*\binom{m-P}{n-P}*(n-P)!.
This can be implemented in O(N).

SOLUTIONS:

Setter’s Solution
    #include <bits/stdc++.h>
    
    #define int long long
    
    #define endl           "\n"
    #define mod            1000000007
    #define nitin          ios_base::sync_with_stdio(false); cin.tie(nullptr)
    using namespace std;
    
    
    const int N = 400001;
    int fact[N];
    int fact_inv[N];
    
    int power(int x, int y, int p) {
        int res = 1;
        x = x % p;
        while (y > 0) {
            if (y & 1)
                res = (res * x) % p;
            y = y >> 1;
            x = (x * x) % p;
        }
        return res;
    }
    
    int modi(int a, int m) {
        return power(a, m - 2, m);
    }
    
    void pre() {
        fact[0] = 1;
        for (int i = 1; i < N; i++) {
            fact[i] = (fact[i - 1] * i) % mod;
        }
        fact_inv[N - 1] = modi(fact[N - 1], mod);
        for (int i = N - 2; i >= 0; i--) {
            fact_inv[i] = (fact_inv[i + 1] * (i + 1)) % mod;
        }
    }
    
    int ncr(int n, int r) {
        if (r > n)
            return 0;
        return (fact[n] * ((fact_inv[r] * fact_inv[n - r]) % mod)) % mod;
    }
    void solve() {
        int n,m;
        cin>>n>>m;
        assert(n>=1 && n<=200000);
        assert(m>=1 && m<=300000);
        vector<int>b;
        int p=0;
        for(int i=0;i<n;i++){
            int val;
            cin>>val;
            assert(val>=1 && val<=1000000000);
            b.push_back(val);
            if(val>=1 && val<=m){
                p++;
            }
        }
        if(m<n){
            cout<<0<<endl;
            return;
        }
        int ans=ncr(m,n)*fact[n];
        ans%=mod;
        for(int j=1;j<=p;j++){
            int val=ncr(p,j);
            val*=ncr(m-j,n-j);
            val%=mod;
            val*=fact[n-j];
            val%=mod;
            if(j&1){
                ans-=val;
            }else{
                ans+=val;
            }
            ans%=mod;
            ans+=mod;
            ans%=mod;
        }
        cout<<ans<<endl;
    }
    
    int32_t main() {
        nitin;
        pre();
        solve();
    }
Tester's Solution
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

template <long long mod>
struct modular {
    long long value;
    modular(long long x = 0) {
        value = x % mod;
        if (value < 0) value += mod;
    }
    modular& operator+=(const modular& other) {
        if ((value += other.value) >= mod) value -= mod;
        return *this;
    }
    modular& operator-=(const modular& other) {
        if ((value -= other.value) < 0) value += mod;
        return *this;
    }
    modular& operator*=(const modular& other) {
        value = value * other.value % mod;
        return *this;
    }
    modular& operator/=(const modular& other) {
        long long a = 0, b = 1, c = other.value, m = mod;
        while (c != 0) {
            long long t = m / c;
            m -= t * c;
            swap(c, m);
            a -= t * b;
            swap(a, b);
        }
        a %= mod;
        if (a < 0) a += mod;
        value = value * a % mod;
        return *this;
    }
    friend modular operator+(const modular& lhs, const modular& rhs) { return modular(lhs) += rhs; }
    friend modular operator-(const modular& lhs, const modular& rhs) { return modular(lhs) -= rhs; }
    friend modular operator*(const modular& lhs, const modular& rhs) { return modular(lhs) *= rhs; }
    friend modular operator/(const modular& lhs, const modular& rhs) { return modular(lhs) /= rhs; }
    modular& operator++() { return *this += 1; }
    modular& operator--() { return *this -= 1; }
    modular operator++(int) {
        modular res(*this);
        *this += 1;
        return res;
    }
    modular operator--(int) {
        modular res(*this);
        *this -= 1;
        return res;
    }
    modular operator-() const { return modular(-value); }
    bool operator==(const modular& rhs) const { return value == rhs.value; }
    bool operator!=(const modular& rhs) const { return value != rhs.value; }
    bool operator<(const modular& rhs) const { return value < rhs.value; }
};
template <long long mod>
string to_string(const modular<mod>& x) {
    return to_string(x.value);
}
template <long long mod>
ostream& operator<<(ostream& stream, const modular<mod>& x) {
    return stream << x.value;
}
template <long long mod>
istream& operator>>(istream& stream, modular<mod>& x) {
    stream >> x.value;
    x.value %= mod;
    if (x.value < 0) x.value += mod;
    return stream;
}

constexpr long long mod = (long long) 1e9 + 7;
using mint = modular<mod>;

mint power(mint a, long long n) {
    mint res = 1;
    while (n > 0) {
        if (n & 1) {
            res *= a;
        }
        a *= a;
        n >>= 1;
    }
    return res;
}

vector<mint> fact(1, 1);
vector<mint> finv(1, 1);

mint C(int n, int k) {
    if (n < k || k < 0) {
        return mint(0);
    }
    while ((int) fact.size() < n + 1) {
        fact.emplace_back(fact.back() * (int) fact.size());
        finv.emplace_back(mint(1) / fact.back());
    }
    return fact[n] * finv[k] * finv[n - k];
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    int n, m;
    cin >> n >> m;
    vector<int> b(n);
    for (int i = 0; i < n; i++) {
        cin >> b[i];
    }
    if (m < n) {
        cout << 0 << '\n';
        return 0;
    }
    sort(b.begin(), b.end());
    while (!b.empty() && b.back() > m) {
        b.pop_back();
    }
    int k = (int) b.size();
    mint ans = 0;
    C(n + m, 0);
    for (int i = 0; i <= k; i++) {
        mint t = C(k, i) * C(m - i, n - i) * fact[n - i];
        if (i % 2 == 0) {
            ans += t;
        } else {
            ans -= t;
        }
    }
    cout << ans << '\n';
    return 0;
}

6 Likes

Easy to Understand! Thanks:)

2 Likes

Can you please explain this line in more detail?

From where to learn inclusion - exclusion for c++?

https://cp-algorithms.com/combinatorics/inclusion-exclusion.html

I have a doubt on why and how exactly do we use inclusion-exclusion here?

Because there are intersections among the sets.

It will be hard to explain inclusion-exclusion in a comment, i would recommend you to go through this.

1 Like

When you want to calculate union of sets and they overlap then we apply inclusion-exclusion. It would be more clear if you read this.

1 Like

Hey all,

Can somebody point out what wrong here, I have implemented the same logic, but still somehow. I manage to get wrong ans.
Thanks in advance.
Only subtask 7 fails, rest all are passed. Now I have no clue how to debug the code. If someone can help, it would be great.

link = Solution: 55701341 | CodeChef