COUNT_PERM - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: applepi216
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

2532

PREREQUISITES:

Combinatorics

PROBLEM:

Given an increasing array A consisting of integers from 1 to N, count the number of permutations of \{1, 2, \ldots, N\} whose prefix maximums are exactly A.

EXPLANATION:

There’s a couple of different ways to solve this, but a proper model of what’s going on makes it quite simple.

Let’s try to create a permutation P by placing the integers 1, 2, \ldots, N in order.

First, note that the first element of P must be A_1, since it’s the first maximum.
So, we have the following:

  • 1 can’t be placed at the first position, but anything else is ok since it can’t affect the prefix maximums. So, there are N-1 options.
  • 2 can’t be placed at the first position or where 1 is, but again anywhere else is fine. N-2 options.
  • Similarly, 3 has N-3 positions where it can be placed.
    \vdots
  • A_1 - 1 has N - (A_1 - 1) choices.

Once these have been placed, we place A_1 at position 1, as noted earlier.

After this, notice that A_2 must be placed in the leftmost empty position for it to be the second maximum.
This fixes A_1+1 elements, so

  • Element A_1+1 has N - (A_1+1) options.
  • Element A_1+2 has N-(A_1+2) options.
    \vdots
  • Element A_2-1 has N-(A_2-1) options.

It’s easy to see that this process continues till we place A_K = N.

That is, for every integer x that’s not one of the A_i, it has (N-x) options.

So, the final answer is simply the product of all of these, i.e

\prod_{\substack{x=1 \\ x \neq A_i}}^N (N-x)

which is easily computed in \mathcal{O}(N).

TIME COMPLEXITY

\mathcal{O}(N) per test case.

CODE:

Author's code (C++)


#include <bits/stdc++.h>

using namespace std;

using ll = long long;
const ll MOD = 998244353;

int main() {
    cin.tie(0)->sync_with_stdio(0);

    int t; cin >> t;
    while (t--) {
        int n, k; cin >> n >> k;

        vector<bool> skip(n + 1);
        skip[n] = true;
        for (int i = 0; i < k; i++) {
            int a; cin >> a;
            skip[n - a] = true;
        }

        ll ans = 1;
        for (int i = 1; i <= n; i++) {
            if (!skip[i]) ans = (ans * i) % MOD;
        }
        cout << ans << "\n";
    }
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

template <long long mod>
struct modular {
    long long value;
    modular(long long x = 0) {
        value = x % mod;
        if (value < 0) value += mod;
    }
    modular& operator+=(const modular& other) {
        if ((value += other.value) >= mod) value -= mod;
        return *this;
    }
    modular& operator-=(const modular& other) {
        if ((value -= other.value) < 0) value += mod;
        return *this;
    }
    modular& operator*=(const modular& other) {
        value = value * other.value % mod;
        return *this;
    }
    modular& operator/=(const modular& other) {
        long long a = 0, b = 1, c = other.value, m = mod;
        while (c != 0) {
            long long t = m / c;
            m -= t * c;
            swap(c, m);
            a -= t * b;
            swap(a, b);
        }
        a %= mod;
        if (a < 0) a += mod;
        value = value * a % mod;
        return *this;
    }
    friend modular operator+(const modular& lhs, const modular& rhs) { return modular(lhs) += rhs; }
    friend modular operator-(const modular& lhs, const modular& rhs) { return modular(lhs) -= rhs; }
    friend modular operator*(const modular& lhs, const modular& rhs) { return modular(lhs) *= rhs; }
    friend modular operator/(const modular& lhs, const modular& rhs) { return modular(lhs) /= rhs; }
    modular& operator++() { return *this += 1; }
    modular& operator--() { return *this -= 1; }
    modular operator++(int) {
        modular res(*this);
        *this += 1;
        return res;
    }
    modular operator--(int) {
        modular res(*this);
        *this -= 1;
        return res;
    }
    modular operator-() const { return modular(-value); }
    bool operator==(const modular& rhs) const { return value == rhs.value; }
    bool operator!=(const modular& rhs) const { return value != rhs.value; }
    bool operator<(const modular& rhs) const { return value < rhs.value; }
};
template <long long mod>
string to_string(const modular<mod>& x) {
    return to_string(x.value);
}
template <long long mod>
ostream& operator<<(ostream& stream, const modular<mod>& x) {
    return stream << x.value;
}
template <long long mod>
istream& operator>>(istream& stream, modular<mod>& x) {
    stream >> x.value;
    x.value %= mod;
    if (x.value < 0) x.value += mod;
    return stream;
}

constexpr long long mod = 998244353;
using mint = modular<mod>;

mint power(mint a, long long n) {
    mint res = 1;
    while (n > 0) {
        if (n & 1) {
            res *= a;
        }
        a *= a;
        n >>= 1;
    }
    return res;
}

vector<mint> fact(1, 1);
vector<mint> finv(1, 1);

mint C(int n, int k) {
    if (n < k || k < 0) {
        return mint(0);
    }
    while ((int) fact.size() < n + 1) {
        fact.emplace_back(fact.back() * (int) fact.size());
        finv.emplace_back(mint(1) / fact.back());
    }
    return fact[n] * finv[k] * finv[n - k];
}

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e4);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 1e6);
        in.readSpace();
        int k = in.readInt(1, n);
        in.readEoln();
        sn += n;
        auto a = in.readInts(k, 1, n);
        in.readEoln();
        for (int i = 1; i < k; i++) {
            assert(a[i - 1] < a[i]);
        }
        assert(a[k - 1] == n);
        vector<int> b(n, -1);
        for (int i = 0; i < k; i++) {
            a[i]--;
            b[a[i]] = i;
        }
        mint ans = 1;
        for (int i = n - 1; i >= 0; i--) {
            if (b[i] == -1) {
                ans *= n - 1 - i;
            }
        }
        cout << ans << '\n';
    }
    assert(sn <= 1e6);
    in.readEof();
    return 0;
}
Editorialist's code (Python)
mod = 998244353
for _ in range(int(input())):
    n, k = map(int, input().split())
    a = list(map(int, input().split())) + [-1]
    ans, ptr = 1, 0
    for x in range(1, n+1):
        if x == a[ptr]:
            ptr += 1
            continue
        ans *= n - x
        ans %= mod
    print(ans)
1 Like

really nice solution!!