COUNTDIST - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: satyam_343
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Monotonic stack

PROBLEM:

You have a permutation P of \{1, 2, 3, \ldots, N\}.
For each prefix of this permutation, solve the following problem:

  • You can choose three indices i \lt j \lt k such that A_i\gt A_j \lt A_k, and delete A_j.
    Find the number of distinct arrays that can be reached by performing this operation any number of times.

EXPLANATION:

Let’s use dynamic programming — let \text{dp}_i denote the answer for the prefix of length i.

Clearly, \text{dp}_1 = \text{dp}_2 = 1, since no operations can be performed on these prefixes.
Let’s now look at how to solve for prefix i, assuming we’ve solved for 1, 2, 3, \ldots, i-1 already.

Let M denote the index such that P_M = \max(P_1, P_2, \ldots, P_i).
Observe that this index can never be deleted.
There are two possibilities: M = i, or M \lt i. We look at them both separately.

Case 1: M < i

Note that the prefix 1,2\ldots,M can be thought of as an independent subproblem; since any deletions done in it can always utilize M as the third index.
So, we can combine any of the \text{dp}_M arrays available in that prefix, with whatever arrays are possible from indices M to i.
The former is a value we already know, since index M was processed earlier.

To compute the latter, we can use a similar observation.
Let M_2 denote the index of the maximum element among M+1, M+2, \ldots, i.
Note that M_2 can also never be deleted by our operation, since it’s impossible to find two elements surrounding it that are both larger than it (after all, nothing to its right is greater than it).
Further, every index between M and M_2 can be freely deleted or not by utilizing M and M_2; for 2^{M_2-M-1} possibilities.

This process then continues: if M_3 denotes the index of the maximum after M_2, we obtain another 2^{M_3-M_2-1} possibilities, and so on.

More generally, let S = [M, M_2, M_3, M_4, \ldots, M_k=i] denote the sequence of suffix maximums of the i-th prefix.
Then, the number of ways is exactly

\text{dp}_M \times \left(2^{M_2-M-1}\cdot 2^{M_3-M_2-1}\cdot\ldots\cdot 2^{i- M_{k-1}-1}\right)

The product of the powers of 2 simply reduces to 2^{i-M-k+1}.

In other words, all we really need to know are the values M and k: the position of the prefix maximum of the first i elements, and the number of suffix maximums of the first i elements.

Note that the latter quantity can be thought of in reverse: we start at index i, and keep moving to the closest element to its left that’s greater than it, till we can’t anymore.
This “previous greater element” is exactly what a monotonic stack computes during its process!

In short, keep a stack of all suffix maximums, sorted in descending order from bottom to top.
Before pushing A_i to it, pop all elements of the stack that are less than A_i.
This ensures that the stack is kept sorted, and the top of the stack just before pushing A_i into it is the nearest element to the left of i that’s greater than it.

The value k is then just the size of this stack, and now that we know both M and k (and hence \text{dp}_M, dp_i can be computed as

\text{dp}_i = \text{dp}_M \cdot 2^{i-M-k+1}

as mentioned earlier.

Case 2: M = i

When M = i, the earlier analysis fails since \text{dp}_M isn’t computed yet.
However, we can do something quite similar.

Let M_2 denote the index of the second maximum of P_1, P_2, \ldots, P_i; that is, the maximum of P_1, P_2, \ldots, P_{i-1}.
Again, observe that M_2 can’t be deleted (since nothing to its left is larger than it).
Further, everything between M and M_2 can be freely deleted (or not) by utilizing these two indices; and once again everything before M_2 is an independent subproblem so we just get

\text{dp}_i = \text{dp}_{M_2}\cdot 2^{M-M_2-1}

M_2 is easily known, since it’s just the prefix maximum of the first i-1 elements.

Together, both cases solve the problem in \mathcal{O}(N) time.


You might observe that we don’t even really need dynamic programming here: the same logic we used of suffix maximums being undeletable shows that prefix maximums are undeletable; and everything else is deletable.
So, if we know the sequence of suffix maximums till i (suppose it’s of length k_1), and the sequence of prefix maximums till i (suppose of length k_2), the number of sequences is just 2^{i+2-k_1-k_2}.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Author's code (C++)
#pragma GCC optimize("O3,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;   
using namespace std;
#define ll long long
#define pb push_back                  
#define mp make_pair          
#define nline "\n"                            
#define f first                                            
#define s second                                             
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()     
#define vl vector<ll>           
#define vvl vector<vector<ll>>      
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);  
#endif     
void _print(ll x){cerr<<x;}  
void _print(char x){cerr<<x;}   
void _print(string x){cerr<<x;}    
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());   
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------  
const ll MOD=998244353;
const ll MAX=2000200;
vector<ll> fact(MAX+2,1),inv_fact(MAX+2,1);
ll binpow(ll a,ll b,ll MOD){
    ll ans=1;
    a%=MOD;  
    while(b){
        if(b&1)
            ans=(ans*a)%MOD;
        b/=2;
        a=(a*a)%MOD;
    }
    return ans;
}
ll inverse(ll a,ll MOD){
    return binpow(a,MOD-2,MOD);
} 
void precompute(ll MOD){
    for(ll i=2;i<MAX;i++){
        fact[i]=(fact[i-1]*i)%MOD;
    }
    inv_fact[MAX-1]=inverse(fact[MAX-1],MOD);
    for(ll i=MAX-2;i>=0;i--){
        inv_fact[i]=(inv_fact[i+1]*(i+1))%MOD;
    }
}
ll nCr(ll a,ll b,ll MOD){
    if(a==b){
        return 1;
    }
    if((a<0)||(a<b)||(b<0))
        return 0;   
    ll denom=(inv_fact[b]*inv_fact[a-b])%MOD;
    return (denom*fact[a])%MOD;  
} 
ll sum=0;
void solve(){  
    ll n; cin>>n;
    vector<ll> track,p(n+5); 
    ll nax=0,last=-1;
    sum+=n;
    for(ll i=1;i<=n;i++){
        cin>>p[i];
        nax=max(nax,p[i]);
        if(nax!=p[i]){   
            while(p[track.back()] < p[i]){
                track.pop_back();
            }
        }
        else{
            while(!track.empty()){
                auto it=track.back();
                if(p[it]!=last){
                    track.pop_back();
                }
                else{
                    break;
                }
            }
        }
        track.push_back(i);
        last=nax;
        cout<<binpow(2,i-track.size(),MOD)<<" ";
    }
    cout<<nline;
    return;    
}                                       
int main()                                                                               
{       
    ios_base::sync_with_stdio(false);                          
    cin.tie(NULL);                               
    #ifndef ONLINE_JUDGE                 
    freopen("input.txt", "r", stdin);                                           
    freopen("output.txt", "w", stdout);      
    freopen("error.txt", "w", stderr);                        
    #endif     
    ll test_cases=1;                   
    cin>>test_cases;
    precompute(MOD);
    while(test_cases--){
        solve();
    }
    debug(sum);
    cout<<fixed<<setprecision(10);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
}  
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

#define IGNORE_CR

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
#ifdef IGNORE_CR
            if (buffer[pos] == '\r') {
                pos++;
                continue;
            }
#endif
            assert(!isspace(buffer[pos]));
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

template <long long mod>
struct modular {
    long long value;
    modular(long long x = 0) {
        value = x % mod;
        if (value < 0) value += mod;
    }
    modular& operator+=(const modular& other) {
        if ((value += other.value) >= mod) value -= mod;
        return *this;
    }
    modular& operator-=(const modular& other) {
        if ((value -= other.value) < 0) value += mod;
        return *this;
    }
    modular& operator*=(const modular& other) {
        value = value * other.value % mod;
        return *this;
    }
    modular& operator/=(const modular& other) {
        long long a = 0, b = 1, c = other.value, m = mod;
        while (c != 0) {
            long long t = m / c;
            m -= t * c;
            swap(c, m);
            a -= t * b;
            swap(a, b);
        }
        a %= mod;
        if (a < 0) a += mod;
        value = value * a % mod;
        return *this;
    }
    friend modular operator+(const modular& lhs, const modular& rhs) { return modular(lhs) += rhs; }
    friend modular operator-(const modular& lhs, const modular& rhs) { return modular(lhs) -= rhs; }
    friend modular operator*(const modular& lhs, const modular& rhs) { return modular(lhs) *= rhs; }
    friend modular operator/(const modular& lhs, const modular& rhs) { return modular(lhs) /= rhs; }
    modular& operator++() { return *this += 1; }
    modular& operator--() { return *this -= 1; }
    modular operator++(int) {
        modular res(*this);
        *this += 1;
        return res;
    }
    modular operator--(int) {
        modular res(*this);
        *this -= 1;
        return res;
    }
    modular operator-() const { return modular(-value); }
    bool operator==(const modular& rhs) const { return value == rhs.value; }
    bool operator!=(const modular& rhs) const { return value != rhs.value; }
    bool operator<(const modular& rhs) const { return value < rhs.value; }
};
template <long long mod>
string to_string(const modular<mod>& x) {
    return to_string(x.value);
}
template <long long mod>
ostream& operator<<(ostream& stream, const modular<mod>& x) {
    return stream << x.value;
}
template <long long mod>
istream& operator>>(istream& stream, modular<mod>& x) {
    stream >> x.value;
    x.value %= mod;
    if (x.value < 0) x.value += mod;
    return stream;
}

constexpr long long mod = 998244353;
using mint = modular<mod>;

mint power(mint a, long long n) {
    mint res = 1;
    while (n > 0) {
        if (n & 1) {
            res *= a;
        }
        a *= a;
        n >>= 1;
    }
    return res;
}

vector<mint> fact(1, 1);
vector<mint> finv(1, 1);

mint C(int n, int k) {
    if (n < k || k < 0) {
        return mint(0);
    }
    while ((int) fact.size() < n + 1) {
        fact.emplace_back(fact.back() * (int) fact.size());
        finv.emplace_back(mint(1) / fact.back());
    }
    return fact[n] * finv[k] * finv[n - k];
}

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e5);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 5e5);
        in.readEoln();
        sn += n;
        vector<mint> p2(n + 1);
        p2[0] = 1;
        for (int i = 1; i <= n; i++) {
            p2[i] = p2[i - 1] + p2[i - 1];
        }
        auto p = in.readInts(n, 1, n);
        in.readEoln();
        {
            auto q = p;
            sort(q.begin(), q.end());
            q.resize(unique(q.begin(), q.end()) - q.begin());
            assert((int) q.size() == n);
            assert(q[0] == 1);
            assert(q[n - 1] == n);
        }
        vector<int> st;
        vector<int> que;
        for (int i = 0; i < n; i++) {
            while (!st.empty() && st.back() < p[i]) {
                st.pop_back();
            }
            st.emplace_back(p[i]);
            if (que.empty() || que.back() < p[i]) {
                que.emplace_back(p[i]);
            }
            int t = (i + 2) - (int) (st.size() + que.size());
            cout << p2[t] << " \n"[i == n - 1];
        }
    }
    assert(sn <= 5e5);
    in.readEof();
    return 0;
}
Editorialist's code (Python)
mod = 998244353
for _ in range(int(input())):
    n = int(input())
    p = list(map(int, input().split()))
    pmax = [0]*n
    pmax[0] = 0
    for i in range(1, n):
        pmax[i] = i
        if p[pmax[i]] < p[pmax[i-1]]: pmax[i] = pmax[i-1]
    
    pow2 = [1]*n
    for i in range(1, n): pow2[i] = 2 * pow2[i-1] % mod
    inv2 = [pow(x, mod-2, mod) for x in pow2]
    
    dp = [0]*n
    stk = []
    cur = 1
    for i in range(n):
        while len(stk) > 0:
            u = stk[-1]
            if p[u] > p[i]: break
            if len(stk) > 1:
                v = stk[-2]
                cur = cur * inv2[u-v-1] % mod
            stk.pop()
        if len(stk) > 0:
            cur = cur * pow2[i - stk[-1] - 1] % mod
        else:
            if i > 0:
                what = pmax[i-1]
                cur = dp[what] * pow2[i-what-1] % mod
        dp[i] = cur
        stk.append(i)
    print(*dp)