PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: still_me
Testers: the_hyp0cr1t3, rivalq
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Observation, computing binomial coefficients
PROBLEM:
An array B is called good if B_i is either the maximum or the minimum of B[1\ldots i] for each i.
Given an array A, count the number of its rearrangements that are good.
EXPLANATION:
First, let’s sort A since it won’t change the answer. From now on, A_i \leq A_{i+1}.
Let B be a good rearrangement of A. Let’s analyze the structure of B.
Let S_L = [B_{i_1}, B_{i_2}, \ldots, B_{i_r}] be the subsequence of elements that are prefix minimums, and S_H = [B_{j_1}, B_{j_2}, \ldots, B_{j_s}] be the subsequence of elements that are prefix maximums. If an element can be both (for example, B_1) then we treat it as a minimum.
S_L and S_H partition B, so r+s = N.
For example, if B = [2, 3, 3, 1, 4, 1] then S_L = [2, 1, 1] and S_H = [3, 3, 4].
Note that the following conditions must hold:
- B_{i_1} \geq B_{i_2} \geq \ldots \geq B_{i_r} (the minimums decrease)
- B_{j_1} \leq B_{j_2} \leq \ldots \leq B_{j_s} (the maximums increase)
- B_{j_1} \geq B_{i_1}
Notice that this is an ordering that encompasses every element!
In particular, this means that S_L contains the smallest r elements of A, and S_H contains the largest s elements of A.
So, let’s try to fix the value of r and see how many rearrangements exist such that S_L contains the smallest r elements.
Of course, the first element must be A_r. After that, there are two cases:
- If A_{r+1} \gt A_r (or r = N), we can choose any r-1 positions out of the remaining N-1 and place the elements A_1, A_2, \ldots, A_{r-1} in these positions (in descending order) and the other elements in the remaining positions (in ascending order). So, there are \binom{N-1}{r-1} ways for this to happen.
- If A_{r+1} = A_r, we need to be a bit more careful. We cannot arbitrarily choose r-1 positions to place the first r-1 elements, because we need to ensure that A_{r+1} is strictly the maximum when we place it. This means we need to place at least one element less than A_{r+1} before placing it.
So, let x \lt r be the largest index such that A_x \lt A_r. We first place the elements A_r, A_{r-1}, \ldots, A_x in the first r-x+1 positions, then we can choose x-1 positions from the remaining N-(r-x+1) to place the other elements, giving us \binom{N-r+x-1}{x-1} possible ways.
x can be maintained as you iterate through the array: update it whenever you move from one value to the next.
Adding this up across all positions gives the final answer.
Each binomial coefficient needs to be computed in \mathcal{O}(\log{MOD}) or \mathcal{O}(1). If you don’t know how to do this, read through this article.
TIME COMPLEXITY:
\mathcal{O}(N\log N) per testcase.
CODE:
Setter's code (C++)
// Code by Sahil Tiwari (still_me)
#include<bits/stdc++.h>
#define still_me main
#define endl "\n"
#define int long long int
#define all(a) (a).begin() , (a).end()
#define print(a) for(auto TEMPORARY: a) cout<<TEMPORARY<<" ";cout<<endl;
#define tt int TESTCASE;cin>>TESTCASE;while(TESTCASE--)
#define arrin(a,n) for(int INPUT=0;INPUT<n;INPUT++)cin>>a[INPUT]
using namespace std;
const int mod = 1e9+7;
const int inf = 1e18;
const int N = 1e6;
int fact[N+1];
void factorial(){
fact[0] = fact[1] = 1;
for(int i=2;i<=N;i++){
fact[i] = (fact[i-1] * i) % mod;
}
}
long long power(long long a , long long b , long long mod){
if(b==0)
return 1;
long long res = power(a , b/2 , mod);
res = res*res%mod;
if(b%2)
res = res*a % mod;
return res;
}
int inverse(int a){
return power(a , mod-2 , mod);
}
int nCr(int n , int r){
if(r>n)
return 0;
if(r < 0)
return 0;
return fact[n] * (inverse(fact[r]) * inverse(fact[n-r]) % mod) % mod;
}
void chal_bsdk() {
int n;
cin>>n;
vector<int> a(n);
arrin(a , n);
sort(all(a));
int l = 0;
int ans = power(2 , n-1 , mod);
while(l < n) {
int r = upper_bound(all(a) , a[l]) - a.begin() - 1;
// cout<<r<<endl;
if(r == l) {
l++;
continue;
}
{
int k = l + n - r - 1;
int x = (nCr(k , l) * (power(2 , r-l , mod) - 1 + mod)) % mod;
ans = (ans - x + mod) % mod;
}
for(int i=2;i<(r-l+1);i++) {
int L = l;
int R = n-r-1;
int g = L+R + (r-l-i);
// Left side
int x = nCr(g , L-1) * (power(2 , i-1 , mod) - 1 + mod) % mod;
ans = (ans - x + mod) % mod;
// Right side
x = nCr(g , R-1) * (power(2 , i-1 , mod) - 1 + mod) % mod;
ans = (ans - x + mod) % mod;
}
l = r+1;
}
cout<<ans<<endl;
}
signed still_me()
{
ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
// freopen("4.in" , "r" , stdin);
// freopen("4.out" , "w" , stdout);
factorial();
tt{
chal_bsdk();
}
return 0;
}
Tester's code (C++)
/**
* the_hyp0cr1t3
* 02.01.2023 23:38:00
**/
#ifdef W
#include <k_II.h>
#else
#include <bits/stdc++.h>
using namespace std;
#endif
// -------------------- Input Validator Start --------------------
#define read_int_sp(x, L, R) val.read_int(x, L, R, ' ', __LINE__, #x)
#define read_int_ln(x, L, R) val.read_int(x, L, R, '\n', __LINE__, #x)
#define read_vec(vec, N, L, R) val.read_vector(vec, N, L, R, __LINE__, #vec)
#define read_str_sp(x, L, R, chset) val.read_string(x, L, R, chset, ' ', __LINE__, #x)
#define read_str_ln(x, L, R, chset) val.read_string(x, L, R, chset, '\n', __LINE__, #x)
constexpr int max_digits = 19;
enum test_type { single_test, multi_tests };
enum char_set { alpha, binary, digit, gridwalls };
template <test_type T = single_test> class validator {
int tests, current_test {0}, input_line_no {1}, input_col_no {0};
public:
template <test_type U = T,
std::enable_if_t<
std::is_same<validator<U>, validator<single_test>>::value> * = nullptr>
validator() : tests {1} {}
template <test_type U = T,
std::enable_if_t<
std::is_same<validator<U>, validator<multi_tests>>::value> * = nullptr>
validator(int tests_lb, int tests_ub) {
read_int(tests, tests_lb, tests_ub, '\n', -1, "tests");
}
#define FAIL(cond, msg) \
if (cond) { \
std::cerr << msg "while reading\n" \
<< "> symbol \"" << label << "\" (line " << line << ")\n" \
<< "> in test " << current_test << '\n' \
<< "> at pos " << input_line_no << ':' << input_col_no << '\n'; \
abort(); \
}
template <typename U = int, typename = std::enable_if_t<std::is_integral<U>::value>>
void read_int(
U &x, int64_t L, int64_t R, char delim, int line = -1, const char *label = "") {
int64_t res = 0;
int len = 0, leading = -1;
bool is_negative = false;
while (true) {
char c = std::getchar();
++input_col_no;
if (c == '-') {
FAIL(len > 0, "error: found invalid symbol \'-\'\n")
is_negative = true;
} else if ('0' <= c and c <= '9') {
res = res * 10 + c - '0';
if (++len == 1)
leading = c - '0';
FAIL(leading == 0 and len > 1, "error: found leading zeroes\n")
FAIL(leading == 0 and is_negative, "error: found negative zero\n")
FAIL(len > max_digits or len == max_digits and leading > 1,
"error: value will overflow in64_t\n")
} else if (c == delim) {
if (is_negative)
res *= -1;
if (res < L or R < res) {
std::cerr << "error: found value " << res
<< " expected to be in range [" << L << ", " << R << "]\n";
FAIL(true, "")
}
x = res;
if (delim == '\n') ++input_line_no, input_col_no = 0;
return;
} else {
const char *cc = c == '\n' ? "\n" : &c;
std::cerr << "error: found invalid character \'" << cc << "\'\n";
FAIL(true, "")
}
}
}
int read_string(std::string &s,
int L, int R, char_set chset, char delim, int line = -1, const char *label = "") {
std::string res;
char c;
while (res.size() <= R) {
c = std::getchar();
if (c == EOF or c == delim)
break;
res += c;
if (chset == binary) {
if (c != '0' and c != '1') {
const char *cc = c == '\n' ? "\n" : &c;
std::cerr << "error: found invalid character \'" << cc << "\'\n";
FAIL(true, "")
}
} else if (chset == alpha) {
if (c < 'a' or 'z' < c) {
const char *cc = c == '\n' ? "\n" : &c;
std::cerr << "error: found invalid character \'" << cc << "\'\n";
FAIL(true, "")
}
} else if (chset == digit) {
if (c < '0' or '9' < c) {
const char *cc = c == '\n' ? "\n" : &c;
std::cerr << "error: found invalid character \'" << cc << "\'\n";
FAIL(true, "")
}
} else if (chset == gridwalls) {
if (c != '.' and c != '#') {
const char *cc = c == '\n' ? "\n" : &c;
std::cerr << "error: found invalid character \'" << cc << "\'\n";
FAIL(true, "")
}
}
}
FAIL(c == EOF, "Unexpected EOF\n")
if (res.length() < L or R < res.length()) {
std::cerr << "error: found string of length " << res.length()
<< " expected to be in range [" << L << ", " << R << "]\n";
FAIL(true, "")
}
s = res;
return res.length();
}
template <typename U = int, typename = std::enable_if_t<std::is_integral<U>::value>>
void read_vector(
std::vector<U> &vec, int N, int L, int R, int line = -1, const char *label = "") {
vec.resize(N);
for (int i = 0; i < N - 1; i++)
read_int(vec[i], L, R, ' ', line, label);
read_int(vec[N - 1], L, R, '\n', line, label);
}
bool do_test() { return ++current_test <= tests; }
~validator() {
#ifndef W
if (std::getchar() != EOF) {
std::cerr << "error: expected EOF\n";
abort();
}
#endif
}
};
// -------------------- Input Validator End --------------------
template<int MOD>
struct Modint {
using T = typename decay<decltype(MOD)>::type; T v;
Modint(): v(0) {}
template<typename U, typename = enable_if_t<is_integral<U>::value>>
Modint(U v) { if(v < 0) v = v % MOD + MOD; if(v >= MOD) v %= MOD; this->v = static_cast<T>(v); }
template<typename U, typename = enable_if_t<is_integral<U>::value>>
explicit operator U() const { return static_cast<U>(v); }
friend istream& operator>>(istream& in, Modint& m) { int64_t v_; in >> v_; m = Modint(v_); return in; }
friend ostream& operator<<(ostream& os, const Modint& m) { return os << m.v; }
static T inv(T a, T m) {
T g = m, x = 0, y = 1;
while(a != 0) {
T q = g / a;
g %= a; swap(g, a);
x -= q * y; swap(x, y);
} return x < 0? x + m : x;
}
static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
#if !defined(_WIN32) || defined(_WIN64)
return unsigned(x % m);
#endif // x must be less than 2^32 * m
unsigned x_high = unsigned(x >> 32), x_low = unsigned(x), quot, rem;
asm("divl %4\n" : "=a" (quot), "=d" (rem) : "d" (x_high), "a" (x_low), "r" (m));
return rem;
}
Modint inv() const { return Modint(inv(v, MOD)); }
Modint operator-() const { return Modint(v? MOD-v : 0); }
Modint& operator++() { v++; if(v == MOD) v = 0; return *this; }
Modint& operator--() { if(v == 0) v = MOD; v--; return *this; }
Modint operator++(int) { Modint a = *this; ++*this; return a; }
Modint operator--(int) { Modint a = *this; --*this; return a; }
Modint& operator+=(const Modint& o) { v += o.v; if (v >= MOD) v -= MOD; return *this; }
Modint& operator-=(const Modint& o) { v -= o.v; if (v < 0) v += MOD; return *this; }
Modint& operator*=(const Modint& o) { v = fast_mod(uint64_t(v) * o.v); return *this; }
Modint& operator/=(const Modint& o) { return *this *= o.inv(); }
friend Modint operator+(const Modint& a, const Modint& b) { Modint res = a; res += b; return res; }
friend Modint operator-(const Modint& a, const Modint& b) { Modint res = a; res -= b; return res; }
friend Modint operator*(const Modint& a, const Modint& b) { Modint res = a; res *= b; return res; }
friend Modint operator/(const Modint& a, const Modint& b) { Modint res = a; res /= b; return res; }
friend bool operator==(const Modint& a, const Modint& b) { return a.v == b.v; }
friend bool operator!=(const Modint& a, const Modint& b) { return a.v != b.v; }
friend bool operator<(const Modint& a, const Modint& b) { return a.v < b.v; }
friend bool operator>(const Modint& a, const Modint& b) { return a.v > b.v; }
friend bool operator<=(const Modint& a, const Modint& b) { return a.v <= b.v; }
friend bool operator>=(const Modint& a, const Modint& b) { return a.v >= b.v; }
Modint operator^(int64_t p) {
if(p < 0) return inv() ^ -p;
Modint a = *this, res{1}; while(p > 0) {
if(p & 1) res *= a;
p >>= 1; if(p > 0) a *= a;
} return res;
}
};
int main() {
#if __cplusplus > 201703L
namespace R = ranges;
#endif
ios_base::sync_with_stdio(false), cin.tie(nullptr);
constexpr int MOD = 1'000'000'000 + 7;
using mint = Modint<MOD>;
int64_t sum_n = 0, sum_n2 = 0;
const int N = 1e6 + 5;
static vector<mint> fact{1, 1}, factinv{1, 1}, inv{0, 1};
[](int N) {
fact.reserve(N); factinv.reserve(N); inv.reserve(N);
for(int z = fact.size(); z < N; z++) {
inv.push_back(inv[MOD % z] * (MOD - MOD / z));
fact.push_back(z * fact[z-1]);
factinv.push_back(inv[z] * factinv[z-1]);
}
}(N);
auto nCr = [](int n, int r) {
return r < 0 or r > n? 0 : fact[n] * factinv[r] * factinv[n-r];
};
validator<multi_tests> val(1, 1000);
while (val.do_test()) {
int n, m, i, j, k;
read_int_ln(n, 1, 1e6);
vector<int> a;
read_vec(a, n, 0, 1e9);
sum_n += n;
sum_n2 += 1LL * n * n;
mint ans = 0;
sort(a.begin(), a.end());
for(i = 0, j = 0; i < n; i = j) {
while(j < n and a[j] == a[i]) j++;
ans += nCr(n - 1, i); // end on the left
ans += nCr(n - 1, n - j); // end on the right
ans -= nCr(n - j + i, i); // counted twice
}
cout << ans << '\n';
}
assert(sum_n <= 1e6);
cerr << "Sum N: " << sum_n << '\n';
cerr << "Sum N^2: " << sum_n2 << '\n';
} // ~W
Editorialist's code (Python)
mod = 10**9 + 7
maxn = 2*10**6 + 10
fac, ifac = [1], [1]*(maxn)
for i in range(1, maxn):
fac.append(fac[-1] * i % mod)
ifac[-1] = pow(fac[-1], mod-2, mod)
for i in reversed(range(1, maxn-1)):
ifac[i] = (i+1) * ifac[i+1] % mod
def C(n, r):
if n < r or r < 0: return 0
return fac[n] * ifac[r] * ifac[n-r] % mod
for _ in range(int(input())):
n = int(input())
a = sorted(list(map(int, input().split())))
ans = i = 0
while i < n:
j = i
while j < n and a[i] == a[j]: j += 1
for k in range(j-i):
if i > 0:
# next element is smaller than a[i]
# n - k - 2 positions remaining, choose i-1 of them for the minimums
ans += C(n - k - 2, i-1)
if j < n:
# next element is larger than a[i]
# n - k - 2 positions remaining, choose n - j - 1 of them for maximums
ans += C(n - k - 2, n - j - 1)
if i == 0 and j == n: ans = 1
i = j
print(ans % mod)