GARRANGE - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: still_me
Testers: the_hyp0cr1t3, rivalq
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Observation, computing binomial coefficients

PROBLEM:

An array B is called good if B_i is either the maximum or the minimum of B[1\ldots i] for each i.

Given an array A, count the number of its rearrangements that are good.

EXPLANATION:

First, let’s sort A since it won’t change the answer. From now on, A_i \leq A_{i+1}.

Let B be a good rearrangement of A. Let’s analyze the structure of B.

Let S_L = [B_{i_1}, B_{i_2}, \ldots, B_{i_r}] be the subsequence of elements that are prefix minimums, and S_H = [B_{j_1}, B_{j_2}, \ldots, B_{j_s}] be the subsequence of elements that are prefix maximums. If an element can be both (for example, B_1) then we treat it as a minimum.
S_L and S_H partition B, so r+s = N.

For example, if B = [2, 3, 3, 1, 4, 1] then S_L = [2, 1, 1] and S_H = [3, 3, 4].

Note that the following conditions must hold:

  • B_{i_1} \geq B_{i_2} \geq \ldots \geq B_{i_r} (the minimums decrease)
  • B_{j_1} \leq B_{j_2} \leq \ldots \leq B_{j_s} (the maximums increase)
  • B_{j_1} \geq B_{i_1}

Notice that this is an ordering that encompasses every element!
In particular, this means that S_L contains the smallest r elements of A, and S_H contains the largest s elements of A.

So, let’s try to fix the value of r and see how many rearrangements exist such that S_L contains the smallest r elements.

Of course, the first element must be A_r. After that, there are two cases:

  • If A_{r+1} \gt A_r (or r = N), we can choose any r-1 positions out of the remaining N-1 and place the elements A_1, A_2, \ldots, A_{r-1} in these positions (in descending order) and the other elements in the remaining positions (in ascending order). So, there are \binom{N-1}{r-1} ways for this to happen.
  • If A_{r+1} = A_r, we need to be a bit more careful. We cannot arbitrarily choose r-1 positions to place the first r-1 elements, because we need to ensure that A_{r+1} is strictly the maximum when we place it. This means we need to place at least one element less than A_{r+1} before placing it.
    So, let x \lt r be the largest index such that A_x \lt A_r. We first place the elements A_r, A_{r-1}, \ldots, A_x in the first r-x+1 positions, then we can choose x-1 positions from the remaining N-(r-x+1) to place the other elements, giving us \binom{N-r+x-1}{x-1} possible ways.
    x can be maintained as you iterate through the array: update it whenever you move from one value to the next.

Adding this up across all positions gives the final answer.

Each binomial coefficient needs to be computed in \mathcal{O}(\log{MOD}) or \mathcal{O}(1). If you don’t know how to do this, read through this article.

TIME COMPLEXITY:

\mathcal{O}(N\log N) per testcase.

CODE:

Setter's code (C++)
//	Code by Sahil Tiwari (still_me)

#include<bits/stdc++.h>
#define still_me main
#define endl "\n"
#define int long long int
#define all(a) (a).begin() , (a).end()
#define print(a) for(auto TEMPORARY: a) cout<<TEMPORARY<<" ";cout<<endl;
#define tt int TESTCASE;cin>>TESTCASE;while(TESTCASE--)
#define arrin(a,n) for(int INPUT=0;INPUT<n;INPUT++)cin>>a[INPUT]

using namespace std;
const int mod = 1e9+7;
const int inf = 1e18;

const int N = 1e6;
int fact[N+1];
void factorial(){
    fact[0] = fact[1] = 1;
    for(int i=2;i<=N;i++){
        fact[i] = (fact[i-1] * i) % mod;
    }
}

long long power(long long a , long long b , long long mod){
    if(b==0)
        return 1;
    long long res = power(a , b/2 , mod);
    res = res*res%mod;
    if(b%2)
        res = res*a % mod;
    return res;
}

int inverse(int a){
    return power(a , mod-2 , mod);
}

int nCr(int n , int r){
    if(r>n)
        return 0;
    if(r < 0)
        return 0;
    return fact[n] * (inverse(fact[r]) * inverse(fact[n-r]) % mod) % mod;
}


void chal_bsdk() {
    int n;
    cin>>n;
    vector<int> a(n);
    arrin(a , n);
    sort(all(a));
    
    int l = 0;
    int ans = power(2 , n-1 , mod);
    while(l < n) {
        int r = upper_bound(all(a) , a[l]) - a.begin() - 1;
        // cout<<r<<endl;
        if(r == l) {
            l++;
            continue;
        }
        {
            int k = l + n - r - 1;
            int x = (nCr(k , l) * (power(2 , r-l , mod) - 1 + mod)) % mod;
            ans = (ans - x + mod) % mod;
        }
        for(int i=2;i<(r-l+1);i++) {
            int L = l;
            int R = n-r-1;
            int g = L+R + (r-l-i);

            // Left side
            int x = nCr(g , L-1) * (power(2 , i-1 , mod) - 1 + mod) % mod;
            ans = (ans - x + mod) % mod;

            // Right side
            x = nCr(g , R-1) * (power(2 , i-1 , mod) - 1 + mod) % mod;
            ans = (ans - x + mod) % mod;
        }
        l = r+1;
    }
    cout<<ans<<endl;
}

signed still_me()
{
    ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);

    // freopen("4.in" , "r" , stdin);
    // freopen("4.out" , "w" , stdout);
    factorial();
    tt{
        chal_bsdk();
    }
    return 0;
}
Tester's code (C++)
/**
 * the_hyp0cr1t3
 * 02.01.2023 23:38:00
**/
#ifdef W
    #include <k_II.h>
#else
    #include <bits/stdc++.h>
    using namespace std;
#endif

// -------------------- Input Validator Start --------------------

#define read_int_sp(x, L, R) val.read_int(x, L, R, ' ', __LINE__, #x)
#define read_int_ln(x, L, R) val.read_int(x, L, R, '\n', __LINE__, #x)

#define read_vec(vec, N, L, R) val.read_vector(vec, N, L, R, __LINE__, #vec)

#define read_str_sp(x, L, R, chset) val.read_string(x, L, R, chset, ' ', __LINE__, #x)
#define read_str_ln(x, L, R, chset) val.read_string(x, L, R, chset, '\n', __LINE__, #x)

constexpr int max_digits = 19;

enum test_type { single_test, multi_tests };
enum char_set { alpha, binary, digit, gridwalls };

template <test_type T = single_test> class validator {
    int tests, current_test {0}, input_line_no {1}, input_col_no {0};

   public:
    template <test_type U = T,
      std::enable_if_t<
        std::is_same<validator<U>, validator<single_test>>::value> * = nullptr>
    validator() : tests {1} {}

    template <test_type U = T,
      std::enable_if_t<
        std::is_same<validator<U>, validator<multi_tests>>::value> * = nullptr>
    validator(int tests_lb, int tests_ub) {
        read_int(tests, tests_lb, tests_ub, '\n', -1, "tests");
    }

#define FAIL(cond, msg)                                                             \
    if (cond) {                                                                     \
        std::cerr << msg "while reading\n"                                          \
                  << "> symbol \"" << label << "\" (line " << line << ")\n"         \
                  << "> in test " << current_test << '\n'                           \
                  << "> at pos " << input_line_no << ':' << input_col_no << '\n';   \
        abort();                                                                    \
    }

    template <typename U = int, typename = std::enable_if_t<std::is_integral<U>::value>>
    void read_int(
      U &x, int64_t L, int64_t R, char delim, int line = -1, const char *label = "") {
        int64_t res = 0;
        int len = 0, leading = -1;
        bool is_negative = false;
        while (true) {
            char c = std::getchar();
            ++input_col_no;

            if (c == '-') {
                FAIL(len > 0, "error: found invalid symbol \'-\'\n")
                is_negative = true;

            } else if ('0' <= c and c <= '9') {
                res = res * 10 + c - '0';
                if (++len == 1)
                    leading = c - '0';

                FAIL(leading == 0 and len > 1, "error: found leading zeroes\n")
                FAIL(leading == 0 and is_negative, "error: found negative zero\n")
                FAIL(len > max_digits or len == max_digits and leading > 1,
                  "error: value will overflow in64_t\n")
            } else if (c == delim) {
                if (is_negative)
                    res *= -1;

                if (res < L or R < res) {
                    std::cerr << "error: found value " << res
                              << " expected to be in range [" << L << ", " << R << "]\n";
                    FAIL(true, "")
                }

                x = res;

                if (delim == '\n') ++input_line_no, input_col_no = 0;

                return;

            } else {
                const char *cc = c == '\n' ? "\n" : &c;
                std::cerr << "error: found invalid character \'" << cc << "\'\n";
                FAIL(true, "")
            }
        }
    }

    int read_string(std::string &s,
      int L, int R, char_set chset, char delim, int line = -1, const char *label = "") {
        std::string res;
        char c;
        while (res.size() <= R) {
            c = std::getchar();
            if (c == EOF or c == delim)
                break;
            res += c;
            if (chset == binary) {
                if (c != '0' and c != '1') {
                    const char *cc = c == '\n' ? "\n" : &c;
                    std::cerr << "error: found invalid character \'" << cc << "\'\n";
                    FAIL(true, "")
                }
            } else if (chset == alpha) {
                if (c < 'a' or 'z' < c) {
                    const char *cc = c == '\n' ? "\n" : &c;
                    std::cerr << "error: found invalid character \'" << cc << "\'\n";
                    FAIL(true, "")
                }
            } else if (chset == digit) {
                if (c < '0' or '9' < c) {
                    const char *cc = c == '\n' ? "\n" : &c;
                    std::cerr << "error: found invalid character \'" << cc << "\'\n";
                    FAIL(true, "")
                }
            } else if (chset == gridwalls) {
                if (c != '.' and c != '#') {
                    const char *cc = c == '\n' ? "\n" : &c;
                    std::cerr << "error: found invalid character \'" << cc << "\'\n";
                    FAIL(true, "")
                }
            }
        }

        FAIL(c == EOF, "Unexpected EOF\n")

        if (res.length() < L or R < res.length()) {
            std::cerr << "error: found string of length " << res.length()
                      << " expected to be in range [" << L << ", " << R << "]\n";
            FAIL(true, "")
        }

        s = res;
        return res.length();
    }

    template <typename U = int, typename = std::enable_if_t<std::is_integral<U>::value>>
    void read_vector(
      std::vector<U> &vec, int N, int L, int R, int line = -1, const char *label = "") {
        vec.resize(N);
        for (int i = 0; i < N - 1; i++)
            read_int(vec[i], L, R, ' ', line, label);
        read_int(vec[N - 1], L, R, '\n', line, label);
    }

    bool do_test() { return ++current_test <= tests; }

    ~validator() {
#ifndef W
        if (std::getchar() != EOF) {
            std::cerr << "error: expected EOF\n";
            abort();
        }
#endif
    }
};

// -------------------- Input Validator End --------------------

template<int MOD>
struct Modint {
    using T = typename decay<decltype(MOD)>::type; T v;
    Modint(): v(0) {}
    template<typename U, typename = enable_if_t<is_integral<U>::value>>
    Modint(U v) { if(v < 0) v = v % MOD + MOD; if(v >= MOD) v %= MOD; this->v = static_cast<T>(v); }
    template<typename U, typename = enable_if_t<is_integral<U>::value>>
    explicit operator U() const { return static_cast<U>(v); }
    friend istream& operator>>(istream& in, Modint& m) { int64_t v_; in >> v_; m = Modint(v_); return in; }
    friend ostream& operator<<(ostream& os, const Modint& m) { return os << m.v; }

    static T inv(T a, T m) {
        T g = m, x = 0, y = 1;
        while(a != 0) {
            T q = g / a;
            g %= a; swap(g, a);
            x -= q * y; swap(x, y);
        } return x < 0? x + m : x;
    }

    static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
#if !defined(_WIN32) || defined(_WIN64)
        return unsigned(x % m);
#endif // x must be less than 2^32 * m
        unsigned x_high = unsigned(x >> 32), x_low = unsigned(x), quot, rem;
        asm("divl %4\n" : "=a" (quot), "=d" (rem) : "d" (x_high), "a" (x_low), "r" (m));
        return rem;
    }

    Modint inv() const { return Modint(inv(v, MOD)); }
    Modint operator-() const { return Modint(v? MOD-v : 0); }
    Modint& operator++() { v++; if(v == MOD) v = 0; return *this; }
    Modint& operator--() { if(v == 0) v = MOD; v--; return *this; }
    Modint operator++(int) { Modint a = *this; ++*this; return a; }
    Modint operator--(int) { Modint a = *this; --*this; return a; }
    Modint& operator+=(const Modint& o) { v += o.v; if (v >= MOD) v -= MOD; return *this; }
    Modint& operator-=(const Modint& o) { v -= o.v; if (v < 0) v += MOD; return *this; }
    Modint& operator*=(const Modint& o) { v = fast_mod(uint64_t(v) * o.v); return *this; }
    Modint& operator/=(const Modint& o) { return *this *= o.inv(); }
    friend Modint operator+(const Modint& a, const Modint& b) { Modint res = a; res += b; return res; }
    friend Modint operator-(const Modint& a, const Modint& b) { Modint res = a; res -= b; return res; }
    friend Modint operator*(const Modint& a, const Modint& b) { Modint res = a; res *= b; return res; }
    friend Modint operator/(const Modint& a, const Modint& b) { Modint res = a; res /= b; return res; }
    friend bool operator==(const Modint& a, const Modint& b) { return a.v == b.v; }
    friend bool operator!=(const Modint& a, const Modint& b) { return a.v != b.v; }
    friend bool operator<(const Modint& a, const Modint& b) { return a.v < b.v; }
    friend bool operator>(const Modint& a, const Modint& b) { return a.v > b.v; }
    friend bool operator<=(const Modint& a, const Modint& b) { return a.v <= b.v; }
    friend bool operator>=(const Modint& a, const Modint& b) { return a.v >= b.v; }
    Modint operator^(int64_t p) {
        if(p < 0) return inv() ^ -p;
        Modint a = *this, res{1}; while(p > 0) {
            if(p & 1) res *= a;
            p >>= 1; if(p > 0) a *= a;
        } return res;
    }
};

int main() {
#if __cplusplus > 201703L
    namespace R = ranges;
#endif
    ios_base::sync_with_stdio(false), cin.tie(nullptr);
    constexpr int MOD = 1'000'000'000 + 7;
    using mint = Modint<MOD>;
    int64_t sum_n = 0, sum_n2 = 0;

    const int N = 1e6 + 5;
    static vector<mint> fact{1, 1}, factinv{1, 1}, inv{0, 1};
    [](int N) {
        fact.reserve(N); factinv.reserve(N); inv.reserve(N);
        for(int z = fact.size(); z < N; z++) {
            inv.push_back(inv[MOD % z] * (MOD - MOD / z));
            fact.push_back(z * fact[z-1]);
            factinv.push_back(inv[z] * factinv[z-1]);
        }
    }(N);

    auto nCr = [](int n, int r) {
        return r < 0 or r > n? 0 : fact[n] * factinv[r] * factinv[n-r];
    };

    validator<multi_tests> val(1, 1000);

    while (val.do_test()) {
        int n, m, i, j, k;
        read_int_ln(n, 1, 1e6);

        vector<int> a;
        read_vec(a, n, 0, 1e9);

        sum_n += n;
        sum_n2 += 1LL * n * n;

        mint ans = 0;
        sort(a.begin(), a.end());
        for(i = 0, j = 0; i < n; i = j) {
            while(j < n and a[j] == a[i]) j++;
            ans += nCr(n - 1, i);       // end on the left
            ans += nCr(n - 1, n - j);   // end on the right
            ans -= nCr(n - j + i, i);   // counted twice
        }

        cout << ans << '\n';
    }

    assert(sum_n <= 1e6);
    cerr << "Sum N: " << sum_n << '\n';
    cerr << "Sum N^2: " << sum_n2 << '\n';
} // ~W
Editorialist's code (Python)
mod = 10**9 + 7
maxn = 2*10**6 + 10
fac, ifac = [1], [1]*(maxn)
for i in range(1, maxn):
	fac.append(fac[-1] * i % mod)

ifac[-1] = pow(fac[-1], mod-2, mod)
for i in reversed(range(1, maxn-1)):
    ifac[i] = (i+1) * ifac[i+1] % mod

def C(n, r):
	if n < r or r < 0: return 0
	return fac[n] * ifac[r] * ifac[n-r] % mod

for _ in range(int(input())):
	n = int(input())
	a = sorted(list(map(int, input().split())))
	ans = i = 0
	while i < n:
		j = i
		while j < n and a[i] == a[j]: j += 1
		
		for k in range(j-i):
			if i > 0:
				# next element is smaller than a[i]
				# n - k - 2 positions remaining, choose i-1 of them for the minimums
				ans += C(n - k - 2, i-1)
			if j < n:
				# next element is larger than a[i]
				# n - k - 2 positions remaining, choose n - j - 1 of them for maximums
				ans += C(n - k - 2, n - j - 1)
		if i == 0 and j == n: ans = 1
		i = j
	print(ans % mod)
1 Like

Hey, firstly thanks for this beautiful problem. I was not able to solve it yet, but I guess I am on the right track.

Seeing the tester’s code, it seems like I am also doing the exact thing. If anyone can help me to find the error, it would be helpful.

My submission.

Edit: Got it corrected now.

I need help understanding a TLE veredict. My approach was counting from a different way but using binomial coefficients as well. In some point I needed to count the frequencies of each element. In order to do this I used a map in O(N\log N) which gives TLE and then change it to sort the array and iterate the array to count in O(N \log N) but this gives AC with a lot of difference.

You can notice that the code is the same with difference only from line 56

Any idea why this happened?

TLE
AC