COUNTISFUN10 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: satyam_343
Tester: mridulahi
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Elementary combinatorics, working with disjoint segments (for e.g with a set or DSU)

PROBLEM:

For a permutation P of \{1, 2, \ldots, N\}, you can perform the following operation exactly once:

  • Partition P into some number of subarrays, say K subarrays.
    Let M_i denote the maximum element of the i-th subarray.
  • Delete the \left\lfloor \frac{K}{2} \right\rfloor subarrays with smallest M_i values.
    In particular, if K = 1 nothing gets deleted.

Find the number of distinct resulting arrays.

EXPLANATION:

Let’s count the number of subsets of elements that can be deleted via the above process - this is equivalent to counting the number of final arrays.

Let M denote the largest deleted element, meaning we can only delete elements \leq M.
Further, since our process only allows us to delete subarrays, we can only delete some disjoint subarrays of P that all contain only elements \leq M.

Let the positions of elements \gt M in P be i_1 \lt i_2 \lt i_3 \lt\ldots\lt i_{N-M}.
Note that:

  • Any deleted subarray must lie fully between i_j+1 and i_{j+1}-1 for some j, or be a prefix/suffix of P.
  • Further, from a fixed range [i_j+1, i_{j+1}-1], at most one subarray can be chosen for deletion.

So, we can pick upto N-M+1 subarrays to be deleted - one from each maximal segment between i_j's, and one prefix and suffix.
Observe that if we pick N-M+1 subarrays to be deleted, it isn’t actually possible to delete them all: each element \gt M can be the maximum of at most one non-deleted subarray, meaning we have at most N-M non-deleted subarrays; so deleting N-M+1 isn’t possible.

On the other hand, if we pick K\leq N-M subarrays, it’s always possible to find a partition that deletes exactly them: it’s not hard to see how (start with N-M separate subarrays, each containing one of the elements \gt M, after which some of them can be merged if necessary).

So, all we need to do is find the number of ways of picking \leq N-M non-empty subarrays, with at most one from each maximal segment between i_j's, and with M included in one of them.

To do that, we’ll compute all ways of choosing N-M+1 subarrays (some of which are possibly empty), and from them subtract the ways of choosing N-M+1 non-empty subarrays.

Counting all ways

Let L_j = i_j - i_{j-1} - 1 be the length of the j-th maximal segment, along with L_1 = i_1 - 1 and L_{N-M+1} = N - i_{N-M}. Then,

  • As noted earlier, from L_1 we can only choose some prefix.
    This gives us L_1 + 1 options (recall that we’re allowing empty choices here).
  • Similarly, we have L_{N-M+1} + 1 choices for the suffix.
  • Then, for every other L_j, we have \frac{L_j\cdot (L_j + 1)}{2} + 1 ways: all possible non-empty subarrays, and the empty subarray.

There is one exception: one of these segments will contain M, and we need to ensure that the subarray chosen in this segment includes M.
That can be done with simple combinatorics, though do remember to account for the prefix/suffix cases appropriately.

The total number of choices is then just the product of all these numbers.

Counting non-empty ways

Exactly the same as above, just that you don’t include the +1 for choice of empty subarray anymore.
That is,

  • L_1 choices for the prefix.
  • L_{N-M+1} choices for the suffix.
  • \frac{L_j\cdot (L_j + 1)}{2} choices for everything else.

The computation for subarrays including M doesn’t change, since that’s non-empty anyway.

Once these numbers are known, say X_M and Y_M respectively, simply add (X_M - Y_M) to the answer.


The above discussion gives us a solution in \mathcal{O}(N^2): fix a value of M, then find X_M and Y_M by computing the lengths of empty segments and multiplying appropriate numbers.

To speed it up, consider what happens when M is processed in descending order.
When moving from M to M-1, the segments stay largely the same: the only change is the segment with M-1, which will break into two smaller segments.

Since the values we’re looking for are products depending only on the lengths of the segments, they don’t change much either: you can simply divide out the old value of this segment, then multiply in the values of the new segments.
There are various ways of actually maintaining the segments. For instance:

  • Keep the segments themselves in a set, and find which one to break with binary search; or
  • Note that when a segment (say with maximum M) breaks, the new segments depend only on the position of M, and the next/previous greater elements of M in P, which can be precomputed.

Alternately, you can process M in increasing order, merging segments as you go - for instance with a DSU (or, again, a set).

TIME COMPLEXITY:

\mathcal{O}(N\log{MOD}) per testcase.

CODE:

Author's code (C++)
#pragma GCC optimize("O3,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;   
using namespace std;
#define ll long long
const ll INF_ADD=1e18;
#define pb push_back                  
#define mp make_pair          
#define nline "\n"                            
#define f first                                            
#define s second                                             
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()     
#define vl vector<ll>           
#define vvl vector<vector<ll>>      
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);  
#endif     
void _print(ll x){cerr<<x;}  
void _print(char x){cerr<<x;}   
void _print(string x){cerr<<x;}    
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());   
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------  
const ll MOD=998244353;
const ll MAX=500500; 
template<const int &MOD>
struct _m_int {
    int val;
 
    _m_int(int64_t v = 0) {
        if (v < 0) v = v % MOD + MOD;
        if (v >= MOD) v %= MOD;
        val = int(v);
    }
 
    _m_int(uint64_t v) {
        if (v >= MOD) v %= MOD;
        val = int(v);
    }
    _m_int(ll v) : _m_int(int64_t(v)) {}
    _m_int(int v) : _m_int(int64_t(v)) {}
    _m_int(unsigned v) : _m_int(uint64_t(v)) {}
 
    explicit operator int() const { return val; }
    explicit operator unsigned() const { return val; }
    explicit operator int64_t() const { return val; }
    explicit operator uint64_t() const { return val; }
    explicit operator double() const { return val; }
    explicit operator long double() const { return val; }
 
    _m_int& operator+=(const _m_int &other) {
        val -= MOD - other.val;
        if (val < 0) val += MOD;
        return *this;
    }
 
    _m_int& operator-=(const _m_int &other) {
        val -= other.val;
        if (val < 0) val += MOD;
        return *this;
    }
 
    static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
#if !defined(_WIN32) || defined(_WIN64)
        return unsigned(x % m);
#endif
        // Optimod_intzed mod for Codeforces 32-bit machines.
        // x must be less than 2^32 * m for this to work, so that x / m fits in an unsigned 32-bit int.
        unsigned x_high = unsigned(x >> 32), x_low = unsigned(x);
        unsigned quot, rem;
        asm("divl %4\n"
            : "=a" (quot), "=d" (rem)
            : "d" (x_high), "a" (x_low), "r" (m));
        return rem;
    }
 
    _m_int& operator*=(const _m_int &other) {
        val = fast_mod(uint64_t(val) * other.val);
        return *this;
    }
 
    _m_int& operator/=(const _m_int &other) {
        return *this *= other.inv();
    }
 
    friend _m_int operator+(const _m_int &a, const _m_int &b) { return _m_int(a) += b; }
    friend _m_int operator-(const _m_int &a, const _m_int &b) { return _m_int(a) -= b; }
    friend _m_int operator*(const _m_int &a, const _m_int &b) { return _m_int(a) *= b; }
    friend _m_int operator/(const _m_int &a, const _m_int &b) { return _m_int(a) /= b; }
 
    _m_int& operator++() {
        val = val == MOD - 1 ? 0 : val + 1;
        return *this;
    }
 
    _m_int& operator--() {
        val = val == 0 ? MOD - 1 : val - 1;
        return *this;
    }
 
    _m_int operator++(int) { _m_int before = *this; ++*this; return before; }
    _m_int operator--(int) { _m_int before = *this; --*this; return before; }
 
    _m_int operator-() const {
        return val == 0 ? 0 : MOD - val;
    }
 
    friend bool operator==(const _m_int &a, const _m_int &b) { return a.val == b.val; }
    friend bool operator!=(const _m_int &a, const _m_int &b) { return a.val != b.val; }
    friend bool operator<(const _m_int &a, const _m_int &b) { return a.val < b.val; }
    friend bool operator>(const _m_int &a, const _m_int &b) { return a.val > b.val; }
    friend bool operator<=(const _m_int &a, const _m_int &b) { return a.val <= b.val; }
    friend bool operator>=(const _m_int &a, const _m_int &b) { return a.val >= b.val; }
 
    static const int SAVE_INV = int(1e6) + 5;
    static _m_int save_inv[SAVE_INV];
 
    static void prepare_inv() {
        // Ensures that MOD is prime, which is necessary for the inverse algorithm below.
        for (int64_t p = 2; p * p <= MOD; p += p % 2 + 1)
            assert(MOD % p != 0);
 
        save_inv[0] = 0;
        save_inv[1] = 1;
 
        for (int i = 2; i < SAVE_INV; i++)
            save_inv[i] = save_inv[MOD % i] * (MOD - MOD / i);
    }
 
    _m_int inv() const {
        if (save_inv[1] == 0)
            prepare_inv();
 
        if (val < SAVE_INV)
            return save_inv[val];
 
        _m_int product = 1;
        int v = val;
 
        do {
            product *= MOD - MOD / v;
            v = MOD % v;
        } while (v >= SAVE_INV);
 
        return product * save_inv[v];
    }
 
    _m_int pow(int64_t p) const {
        if (p < 0)
            return inv().pow(-p);
 
        _m_int a = *this, result = 1;
 
        while (p > 0) {
            if (p & 1)
                result *= a;
 
            p >>= 1;
 
            if (p > 0)
                a *= a;
        }
 
        return result;
    }
 
    friend ostream& operator<<(ostream &os, const _m_int &m) {
        return os << m.val;
    }
};
 
template<const int &MOD> _m_int<MOD> _m_int<MOD>::save_inv[_m_int<MOD>::SAVE_INV];
const int MOD_INT=998244353;
using mod_int = _m_int<MOD_INT>;
void solve(){
    ll n; cin>>n;
    vector<ll> a(n);
    for(auto &it:a){
        cin>>it;
    }
    vector<ll> track(n);
    iota(all(track),0);
    sort(all(track),[&](ll l,ll r){
        return a[l]>a[r];
    });  
    set<ll> found;
    found.insert(-1); found.insert(n);
    auto getv=[&](ll l,ll r){
        mod_int len=r-l+1;
        if(l==0 or r==n-1){
            return len;
        }
        return len+(len*(len-1))/2;
    };
    mod_int ans=1,total=getv(0,n-1)+1,bad=getv(0,n-1);
    for(auto it:track){
        ll pos=it;
        found.insert(pos);
        ll l=*(--found.lower_bound(pos))+1;
        ll r=*(found.upper_bound(pos))-1;
        total/=(getv(l,r)+1);
        bad/=getv(l,r);
        mod_int mul=1;
        if(l!=0){
            mul=pos-l+1;
        }
        if(r!=n-1){
            mul*=(r-pos+1);
        } 
        ans+=(total-bad)*mul;  
        total*=(getv(l,pos-1)+1)*(getv(pos+1,r)+1);
        bad*=getv(l,pos-1)*getv(pos+1,r);
    }
    cout<<ans<<nline;
    return;       
}                                       
int main()                                                                                 
{       
    ios_base::sync_with_stdio(false);                           
    cin.tie(NULL);                               
    #ifndef ONLINE_JUDGE                 
    freopen("input.txt", "r", stdin);                                           
    freopen("output.txt", "w", stdout);      
    freopen("error.txt", "w", stderr);                        
    #endif     
    ll test_cases=1;                 
    cin>>test_cases;
    while(test_cases--){
        solve();
    }
    cout<<fixed<<setprecision(10);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
}  

Tester's code (C++)
#include<bits/stdc++.h>

using namespace std;

#define all(x) begin(x), end(x)
#define sz(x) static_cast<int>((x).size())
#define int long long

int mod = 998244353;

int norm (int x) {
        if (x < 0) {
                x += mod;
        }
        if (x >= mod) {
                x -= mod;
        }
        return x;
}
template<class T>
T power(T a, int b) {
        T res = 1;
        for (; b; b /= 2, a *= a) {
                if (b % 2) {
                res *= a;
                }
        }
        return res;
}
struct Z {
        int x;
        Z(int x = 0) : x(norm(x)) {}
        int val() const {
                return x;
        }
        Z operator-() const {
                return Z(norm(mod - x));
        }
        Z inv() const {
                assert(x != 0);
                return power(*this, mod - 2);
        }
        Z &operator*=(const Z &rhs) {
                x = x * rhs.x % mod;
                return *this;
        }
        Z &operator+=(const Z &rhs) {
                x = norm(x + rhs.x);
                return *this;
        }
        Z &operator-=(const Z &rhs) {
                x = norm(x - rhs.x);
                return *this;
        }
        Z &operator/=(const Z &rhs) {
                return *this *= rhs.inv();
        }
        friend Z operator*(const Z &lhs, const Z &rhs) {
                Z res = lhs;
                res *= rhs;
                return res;
        }
        friend Z operator+(const Z &lhs, const Z &rhs) {
                Z res = lhs;
                res += rhs;
                return res;
        }
        friend Z operator-(const Z &lhs, const Z &rhs) {
                Z res = lhs;
                res -= rhs;
                return res;
        }
        friend Z operator/(const Z &lhs, const Z &rhs) {
                Z res = lhs;
                res /= rhs;
                return res;
        }
        friend std::istream &operator>>(std::istream &is, Z &a) {
                int v;
                is >> v;
                a = Z(v);
                return is;
        }
        friend std::ostream &operator<<(std::ostream &os, const Z &a) {
                return os << a.val();
        }
};



signed main() {

        ios::sync_with_stdio(0);
        cin.tie(0);
        cout.tie(0);


        int t;
        cin >> t;

        while (t--) {

                int n;
                cin >> n;
                int a[n];
                for (auto &x : a) cin >> x;

                set<int> s;
                s.insert(-1);
                s.insert(n);
                int pos[n];
                for (int i = 0; i < n; i++) pos[a[i] - 1] = i;

                s.insert(pos[n - 1]);
                Z v1 = ((pos[n - 1] + 1) * (n - pos[n - 1]) % mod);
                Z v2 = (pos[n - 1] * (n - pos[n - 1] - 1) % mod);
                Z ans = 1;

                for (int i = n - 2; i >= 0; i--) {

                        int j = pos[i];
                        auto it = s.lower_bound(j);
                        int r = *it;
                        it--;
                        int l = *it;
                        int x = r - l - 1;
                        
                        if (r == n || l == -1) {
                                v1 /= x + 1;
                                v2 /= x;
                        }
                        else {
                                v1 /= (x * (x + 1) / 2 + 1) % mod;
                                v2 /= (x * (x + 1) / 2) % mod; 
                        }

                        int x1 = j - l - 1, x2 = r - j - 1;

                        if (r == n) {
                                ans += (v1 - v2) * (j - l);
                                v1 *= ((x1 * (x1 + 1) / 2 + 1) % mod);
                                v1 *= x2 + 1;
                                v2 *= ((x1 * (x1 + 1) / 2) % mod); 
                                v2 *= x2;
                        }
                        else if (l == -1) {
                                ans += (v1 - v2) * (r - j);
                                v1 *= ((x2 * (x2 + 1) / 2 + 1) % mod);
                                v1 *= x1 + 1;
                                v2 *= ((x2 * (x2 + 1) / 2) % mod); 
                                v2 *= x1;
                        }
                        else {
                                ans += (v1 - v2) * (r - j) * (j - l);
                                v1 *= ((x1 * (x1 + 1) / 2 + 1) % mod);
                                v2 *= ((x1 * (x1 + 1) / 2) % mod); 
                                v1 *= ((x2 * (x2 + 1) / 2 + 1) % mod);
                                v2 *= ((x2 * (x2 + 1) / 2) % mod); 
                        }
                        s.insert(j);

                }

                cout << ans << "\n";


        }

        
}
Editorialist's code (Python)
mod = 998244353
for _ in range(int(input())):
    n = int(input())
    p = list(map(int, input().split()))

    pos = [0]*(n+1)
    left, right = [-1]*n, [n]*n
    st = []
    for i in range(n):
        while len(st):
            u = st[-1]
            if p[u] > p[i]: break
            st.pop()
        if len(st): left[i] = st[-1]
        st.append(i)
        pos[p[i]] = i
    st = []
    for i in reversed(range(n)):
        while len(st):
            u = st[-1]
            if p[u] > p[i]: break
            st.pop()
        if len(st): right[i] = st[-1]
        st.append(i)
    
    def get(L, R):
        if L > R: return 1
        if L == 0 and R == n-1: return n*(n+1)//2 + 1
        if L == 0: return R+2
        if R == n-1: return n-L+1
        
        lt = R-L+1
        return lt*(lt+1)//2 + 1
    
    ans = 0
    ways1 = n*(n+1)//2 + 1 # allowing empty
    ways2 = n*(n+1)//2 # empty not allowed
    for i in reversed(range(1, n+1)):
        u = pos[i]
        L, R = left[u]+1, right[u]-1

        ways1 = ways1*pow(get(L, R), mod-2, mod) % mod
        ways2 = ways2*pow(get(L, R) - 1, mod-2, mod) % mod
        mul = 1
        if L > 0: mul = mul*(u-L+1) % mod
        if R+1 < n: mul = mul*(R-u+1) % mod
        ans += mul * (ways1 - ways2) % mod
        ans %= mod

        ways1 = ways1 * get(L, u-1) % mod * get(u+1, R) % mod
        ways2 = ways2 * (get(L, u-1) - 1) % mod * (get(u+1, R) - 1) % mod

    print(ans + 1)

Editorialist code is Giving RE (python code which is provided)

submit in python3, for some unknown reason, pypy doesnt like this code