CNTSTILLFUN - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: apoorv_me
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

Easy

PREREQUISITES:

Elementary combinatorics

PROBLEM:

You are given an array A of length N.
Count the number of pairs of permutations (P, Q) such that A_i = \max(P_i, Q_i) for every i.

EXPLANATION:

A_i = \max(P_i, Q_i) means that either P_i and Q_i are both equal to A_i, or one of them is equal to A_i and the other is strictly smaller than A_i.

Let’s iterate over the elements of A from smallest to largest.
Suppose we’re currently at x, and there are f_x occurrences of x in A. Then,

  • If f_x \gt 2, the answer is immediately 0: P and Q are permutations, so each will contain exactly one copy of x. This means x can be the maximum at no more than two indices.
  • Suppose f_x = 2, and the occurrences of x are at indices i and j.
    Then, either we have P_i = Q_j = x with P_j, Q_i being \lt x, or vice versa.
    • Note that it doesn’t really matter what the smaller elements are: anything valid will work (since we’re iterating from small to large, anything that’s smaller now will remain smaller in the future).
  • If f_x = 1, we have a couple of options (say x appears at index i):
    • We can have P_i = Q_i = x; or
    • We can have P_i = x and Q_i \lt x (or vice versa).
      Once again, note that it doesn’t matter what the smaller element is: just that it’s smaller.
  • If f_x = 0, we can ignore x (though it does become available as a smaller element for future values).

To count valid permutation pairs, the only thing we really need to know is how many smaller elements are available whenever we’re processing x.
So, let’s just store that count.
Let \text{ans} denote the answer (initially, this is 1), and s denote the number of smaller elements available (initially, 0).
Then, for each x from 1 to N in order:

  • If f_x \gt 2, set \text{ans} to 0.
  • If f_x = 0, increase s by 1.
  • If f_x = 2, multiply \text{ans} by 2\cdot s\cdot s, and then reduce s by 1.
    • The factor of 2 comes from choosing which of P_i and P_j should contain x, and once this is fixed, we have s choices in each of P and Q for the smaller elements.
    • This uses up one smaller element in each array, and doesn’t add any more; so s reduces by 1.
  • If f_x = 1, multiply \text{ans} by 2\cdot s + 1, and don’t change s.
    • 2\cdot s choices for placing x at either P_i or Q_i and placing a smaller element at the other one.
      1 choice for setting P_i = Q_i.
    • In both cases, the number of smaller elements available doesn’t change: if P_i = Q_i this is obvious, and if P_i = x, Q_i \lt x then one smaller element gets used up in Q, but as compensation x becomes available as a smaller element instead.

The frequency array f can be computed in linear time, and then the counting part is also linear time once f is known - all we do is a few multiplications at each step.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Author's code (C++)
#include<bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...)
#endif

#ifdef LOCAL
struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};
#else

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
      string X; cin >> X;
      return X;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res;  cin >> res;
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res;  cin >> res;
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
    }

    void readEoln() {
    }

    void readEof() {
    }
};
#endif

namespace mint_ns {
template<auto P>
struct Modular {
    using value_type = decltype(P);
    value_type value;
 
    Modular(long long k = 0) : value(norm(k)) {}
 
    friend Modular<P>& operator += (      Modular<P>& n, const Modular<P>& m) { n.value += m.value; if (n.value >= P) n.value -= P; return n; }
    friend Modular<P>  operator +  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r += m; }
 
    friend Modular<P>& operator -= (      Modular<P>& n, const Modular<P>& m) { n.value -= m.value; if (n.value < 0)  n.value += P; return n; }
    friend Modular<P>  operator -  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r -= m; }
    friend Modular<P>  operator -  (const Modular<P>& n)                      { return Modular<P>(-n.value); }
 
    friend Modular<P>& operator *= (      Modular<P>& n, const Modular<P>& m) { n.value = n.value * 1ll * m.value % P; return n; }
    friend Modular<P>  operator *  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r *= m; }
 
    friend Modular<P>& operator /= (      Modular<P>& n, const Modular<P>& m) { return n *= m.inv(); }
    friend Modular<P>  operator /  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r /= m; }
 
    Modular<P>& operator ++ (   ) { return *this += 1; }
    Modular<P>& operator -- (   ) { return *this -= 1; }
    Modular<P>  operator ++ (int) { Modular<P> r = *this; *this += 1; return r; }
    Modular<P>  operator -- (int) { Modular<P> r = *this; *this -= 1; return r; }
 
    friend bool operator == (const Modular<P>& n, const Modular<P>& m) { return n.value == m.value; }
    friend bool operator != (const Modular<P>& n, const Modular<P>& m) { return n.value != m.value; }
 
    explicit    operator       int() const { return value; }
    explicit    operator      bool() const { return value; }
    explicit    operator long long() const { return value; }
 
    constexpr static value_type mod()      { return     P; }
 
    value_type norm(long long k) {
        if (!(-P <= k && k < P)) k %= P;
        if (k < 0) k += P;
        return k;
    }
 
    Modular<P> inv() const {
        value_type a = value, b = P, x = 0, y = 1;
        while (a != 0) { value_type k = b / a; b -= k * a; x -= k * y; swap(a, b); swap(x, y); }
        return Modular<P>(x);
    }
 
    friend void __print(Modular<P> v) {
        cerr << v.value;
    }
};
template<auto P> Modular<P> pow(Modular<P> m, long long p) {
    Modular<P> r(1);
    while (p) {
        if (p & 1) r *= m;
        m *= m;
        p >>= 1;
    }
    return r;
}
 
template<auto P> ostream& operator << (ostream& o, const Modular<P>& m) { return o << m.value; }
template<auto P> istream& operator >> (istream& i,       Modular<P>& m) { long long k; i >> k; m.value = m.norm(k); return i; }
template<auto P> string   to_string(const Modular<P>& m) { return to_string(m.value); }
 
}
constexpr int mod = (int)1e9 + 7;
using mod_int = mint_ns::Modular<mod>;

int32_t main() {
  ios_base::sync_with_stdio(0);
  cin.tie(0);

  input_checker inp;
  int T = inp.readInt(1, (int)1e5), NN = 0; inp.readEoln();
  while(T-- > 0) {
    int N = inp.readInt(1, (int)3e5); inp.readEoln();
    NN += N;
    vector<int> A(N), B(N + 1);
    for(auto &i: A)
      cin >> i, B[i]++;

    mod_int m = 0;
    mod_int res = 1;
    for(int i = N ; i >= 1 ; --i) {
      if(B[i] == 0) {
        res *= m * m; --m; continue;
      }
      if(B[i] == 1) {
        res *= (2 * m + 1); continue;
      }
      if(B[i] == 2) {
        res *= 2; ++m; continue;
      }
      res = 0;
      break;
    }
    cout << res << '\n';
  }
  assert(NN <= (int)3e5);
  inp.readEof();
  
  return 0;
}

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

void Solve() 
{
    int n; cin >> n;
    
    vector <int> f(n);
    for (int i = 0; i < n; i++){
        int x; cin >> x;
        f[x - 1]++;
    }
    
    int ans = 1;
    const int mod = 1e9 + 7;
    int fr = 0;
    
    for (int i = n - 1; i >= 0; i--){
        if (f[i] == 0){
            ans = ans * fr % mod * fr % mod;
            fr--;
        } else if (f[i] == 1){
            ans = ans * (2 * fr + 1) % mod;
        } else if (f[i] == 2){
            fr++;
            ans *= 2;
            ans %= mod;
        } else {
            ans = 0;
        }
    }
    
    cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Editorialist's code (Python)
mod = 10**9 + 7
for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    freq = [0]*(n+1)
    for x in a: freq[x] += 1
    
    ans, free = 1, 0
    for x in range(1, n+1):
        if freq[x] > 2: ans = 0
        elif freq[x] == 2:
            ans = (ans * 2 * free * free) % mod
            free -= 1
        elif freq[x] == 1: ans = (ans * (2*free + 1)) % mod
        else: free += 1
    print(ans)

this exact problem ig was asked in google OA.