RMVPK - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: practice_track
Tester: jay_1048576
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Stacks, dynamic programming

PROBLEM:

An element is an array is called a peak if it’s strictly greater than both its neighbors.
In particular, the border elements are never peaks.

You’re given a permutation P of \{1, 2, \ldots, N\}. You have an empty set S.
In one move, you can delete a peak from P and insert it into S.
Find the sum of elements across all distinct sets S that can be obtained via this process.

EXPLANATION:

Rather than thinking about which elements are removed (and inserted into S), let’s look at the elements which remain in P.

Observe that P_N will surely remain: after all, by definition, the last element isn’t a peak, so it can’t be removed no matter what.
Let P_i be the next-last remaining element - that is, i is the largest index \lt N that isn’t removed.
Observe that this is possible if and only if every element between P_i and P_N can be deleted.
It’s easy to see that this is equivalent to saying that every element at indices i+1 to N-1 should be strictly greater than both P_i and P_N.

Proof

If every element at those indices is greater than P_i and P_N, all of them can be deleted by simply choosing the maximum element each time, which is guaranteed to be greater than both its neighbors.

On the other hand, suppose some element is \lt \max(P_i, P_N).
Let M be the smallest element in this range.
Then, M can never be deleted - at any moment of time, its only possible neighbors are elements from \{P_i, P_{i+1}, \ldots, P_N\}, and it’s smaller than all (except maybe one) of them so it certainly can’t ever become a peak.

As for elements at indices \lt i, notice that since P_i is not deleted, P_N is never going to affect whether they’re peaks or not; so we essentially just reduced the problem to the prefix ending at i.

This gives us a rather straightforward dynamic programming solution.
Let dp_k denote the answer when considering only the first k elements.
Let ct_k denote the number of distinct sets of deleted elements, again considering only the first k elements.
Then, we have:

  • As a base case, dp_2 = 0 and ct_2 = 1, since nothing can be deleted.
  • For k\gt 2, let’s fix the previous non-deleted element to be i.
    Then, after checking that this is indeed valid (i.e, every element in between is larger than P_i and P_k),
    • ct_k increases by ct_i, since each subset ending at i is extended uniquely to end at k by deleting everything inbetween.
    • dp_k increases by dp_i + ct_i \cdot (P_{i+1} + P_{i+2} + \ldots + P_{k-1}), since the sum of elements inbetween is added once for each chosen subset of the first i elements.

This is easily implemented to run in \mathcal{O}(N^2), which solves subtask 1.
The final answer is just dp_N.


For the full solution, we clearly need to optimize something in our previous approach.
The number of states is already \mathcal{O}(N), so that can remain: only the transitions are slow, so let’s try to speed those up instead.

Let’s fix k, the index we’re computing the DP for.
Let all the good indices for k be i_1 \gt i_2 \gt \ldots \gt i_m. (good meaning they can be the previous undeleted element).
Notice that i_1 = k-1 will hold.

Now, observe that the values at these indices will be a bit special: they’ll be in descending order!
That is, P_{i_1} \gt P_{i_2} \gt\ldots \gt P_{i_m} must hold.

Proof

This should be fairly obvious: if P_{i_1} \lt P_{i_2}, index i_2 can’t be good because some value between it and index k is less than it; which isn’t allowed.
The same analysis applies to all the pairs of adjacent good indices.

In fact, we can make a slightly stronger statement.
As noted earlier, we have i_1 = k-1.
Then, for each j \gt 1, i_j will be the index of the closest smaller element to the left of P_{i_{j-1}}.
Further, P_{i_m} will either be less than P_k, or have no previous smaller element.

Why?

If i_j is a ‘good’ index, then it’s easy to see that the index of the previous smaller element to it will also be a good index: everything between it and i_j is larger than it by definition, and since i_j was good everything from there to k-1 is already greater than P_{i_j}.

The only exception to this is if P_{i_j} \lt P_k, in which case no previous index can be good (since the range will now include something smaller than P_k).

This nice structure allows us to optimize our DP to run in linear time!
Notice that the ‘good’ indices for k are exactly those indices popped from the stack when running the standard algorithm to compute previous smaller elements (along with the top of the stack just before index k is pushed onto it); and we know that algorithm runs in linear time.
So, simply compute the DP at the same time, and it becomes linear overall!

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Author's code (C++)
//Shortcuts

#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pi;
typedef pair<ll,ll> pll;
typedef vector<int> vi;
typedef vector<ll> vll;
#define pb push_back
#define fi first
#define se second
#define endl "\n"
#define all(x) x.begin(),x.end()
template <typename T>
using ordered_set = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;

//Macros
#ifndef ONLINE_JUDGE
#include "debug.h"
#define debug(x...) cerr << "[" << #x << "] = ["; _print(x)
#else
#define debug(x...)
#endif

//IO functions

void inp(vi& a)
{
    for(int i = 0; i < a.size(); i++)
        cin >> a[i];
}

void inp(vll& a)
{
    for(int i = 0; i < a.size(); i++)
        cin >> a[i];
}

ostream& operator << (ostream& s, pi& a)
{
    return s << a.fi << " " << a.se;
}

ostream& operator << (ostream& s, pll& a)
{
    return s << a.fi << " " << a.se;
}

ostream& operator << (ostream& s, vi& a)
{
    for(int i : a)
        s << i << " ";
    return s;
}

ostream& operator << (ostream& s, vll& a)
{
    for(ll i : a)
        s << i << " ";
    return s;
}

void yes(bool b)
{
    if(b)
        cout << "YES" << endl;
    else
        cout << "NO" << endl;
}

//EXTRA FUNCTIONS*******************************************************
const int MOD = 1e9+7;

struct mint 
{
    int val;
    mint(ll v = 0)
    {
        if(abs(v) < MOD)
            val = v;
        else
            val = v%MOD;
        if(val < 0)
            val += MOD;
    }
};

mint operator+(const mint& a, const mint& b)
{
    int res = a.val + b.val;
    if(res >= MOD)
        return res-MOD;
    else
        return res;
}
mint operator-(const mint& a, const mint& b)
{
    int res = a.val - b.val;
    if(res < 0)
        return res+MOD;
    else
        return res;
}
mint operator*(const mint& a, const mint& b)
{
    ll res = (ll)a.val * (ll)b.val;
    if(res >= MOD)
        return res%MOD;
    else
        return res;
}
mint operator^(const mint& a, ll b)
{
    ll res = 1;
    ll prod = a.val;
    while(b > 0)
    {
        if(b&1)
            res = (res*prod)%MOD;
        prod = (prod*prod)%MOD;
        b >>= 1;
    }
    return res;
}
mint operator/(const mint& a, const mint& b)
{
    return a*(b^(MOD-2));
}

void operator+=(mint& a, const mint& b)
{
    a = a+b;
}
void operator-(mint& a)
{
    a = 0-a;
}
void operator-=(mint& a, const mint& b)
{
    a = a-b;
}
void operator*=(mint& a, const mint& b)
{
    a = a*b;
}
void operator^=(mint& a, ll b)
{
    a = a^b;
}
void operator/=(mint& a, const mint& b)
{
    a = a/b;
}

bool operator==(const mint& a, const mint& b)
{
    return a.val == b.val;
}

ostream& operator << (ostream& s, mint a)
{
    return s << a.val;
}

//COMBINATORICS AND FACTORIAL

vector<mint> fact;
vector<mint> ifact;

void cfact(int n)
{
    fact.pb(1);
    ifact.pb(1);
    for(int i = 1; i <= n; i++)
    {
        fact.pb(fact.back()*i);
        ifact.pb(fact.back()^(MOD-2));
    }
}

mint comb(int n, int r)
{
    return fact[n] * ifact[n-r] * ifact[r];
}
//END OF EXTRA FUNCTIONS************************************************

//Write code here

void run(int test)
{
    int n; cin >> n;
    vector<int> a(n+1);
    for(int i = 1; i <= n; i++)
        cin >> a[i];
    vector<ll> pre(n+1, 0);
    for(int i = 1; i <= n; i++)
        pre[i] = pre[i-1] + a[i];
    vector<mint> cnt(n+1, 0), sum(n+1, 0);
    stack<int> st;
    st.push(1);
    cnt[1] = 1;
    for(int i = 2; i <= n; i++)
    {
        while(!st.empty())
        {
            int j = st.top();
            cnt[i] += cnt[j];
            sum[i] += sum[j] + cnt[j] * (pre[i-1] - pre[j]);
            if(a[j] < a[i])
                break;
            st.pop();
        }
        st.push(i);
    }
    cout << sum[n] << endl;
}

//Main function

int32_t main()
{
    //Fast IO
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int t = 1;
    cin >> t;
    for(int i = 1; i <= t; i++)
        run(i);
    return 0;
}
Tester's code (C++)
/*...................................................................*
 *............___..................___.....____...______......___....*
 *.../|....../...\........./|...../...\...|.............|..../...\...*
 *../.|...../.....\......./.|....|.....|..|.............|.../........*
 *....|....|.......|...../..|....|.....|..|............/...|.........*
 *....|....|.......|..../...|.....\___/...|___......../....|..___....*
 *....|....|.......|.../....|...../...\.......\....../.....|./...\...*
 *....|....|.......|../_____|__..|.....|.......|..../......|/.....\..*
 *....|.....\...../.........|....|.....|.......|.../........\...../..*
 *..__|__....\___/..........|.....\___/...\___/.../..........\___/...*
 *...................................................................*
 */
 
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define MOD 1000000007


void solve(int tc)
{
    int n;
    cin >> n;
    int p[n];
    for(int i=0;i<n;i++)
        cin >> p[i];
    int pre[n];
    pre[0]=p[0];
    for(int i=1;i<n;i++)
        pre[i]=pre[i-1]+p[i];
    int cnt[n],sum[n];
    memset(cnt,0,sizeof(cnt));
    memset(sum,0,sizeof(sum));
    cnt[0]=1;
    stack<int> s;
    s.push(0);
    for(int i=1;i<n;i++)
    {
        while(!s.empty() && p[s.top()]>p[i])
        {
            int j = s.top();
            s.pop();
            cnt[i] = (cnt[i]+cnt[j])%MOD;
            sum[i] = (sum[i]+sum[j]+cnt[j]*((pre[i-1]-pre[j])%MOD))%MOD;
        }
        if(!s.empty())
        {
            int j=s.top();
            cnt[i] = (cnt[i]+cnt[j])%MOD;
            sum[i] = (sum[i]+sum[j]+cnt[j]*((pre[i-1]-pre[j])%MOD))%MOD;
        }
        s.push(i);
    }
    cout << sum[n-1] << '\n';
}

int32_t main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    int tc=1;
    cin >> tc;
    for(int ttc=1;ttc<=tc;ttc++)
        solve(ttc);
    return 0;
}
Editorialist's code (Python)
mod = 10**9 + 7
for _ in range(int(input())):
    n = int(input())
    p = list(map(int, input().split()))
    pref = [0]*(n+1)
    for i in range(n): pref[i+1] = pref[i] + p[i]
    stack = [0]*n
    pos = -1
    
    dp = [0]*n
    ct = [0]*n
    ct[0] = 1
    for i in range(n):
        while pos != -1:
            j = stack[pos]
            if p[j] > p[i]:
                ct[i] = (ct[i] + ct[j]) % mod
                dp[i] = (dp[i] + dp[j] + ct[j]*(pref[i] - pref[j+1])) % mod
                pos -= 1
            else: break
        if pos != -1:
            j = stack[pos]
            ct[i] = (ct[i] + ct[j]) % mod
            dp[i] = (dp[i] + dp[j] + ct[j]*(pref[i] - pref[j+1])) % mod
        pos += 1
        stack[pos] = i
    print(dp[-1])