COUNTPART - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Satyam
Testers: Takuki Kurokawa, Utkarsh Gupta
Editorialist: Nishank Suresh

DIFFICULTY:

2605

PREREQUISITES:

Dynamic programming, next greater element

PROBLEM:

An array is said to be good if the prefix upto its maximum element is sorted.

Given a permutation P of length N, find the number of ways to partition it into good subarrays.

EXPLANATION:

This is a counting problem, and most of the time that means the solution is going to involve either combinatorics or dynamic programming.
Combinatorics doesn’t seem to be very helpful here, so let’s try DP.

Let P[L:R] denote the subarray of P starting at L and ending at R.

Let’s try to make a couple of observations. Consider a subarray P[L:R]. Then,

  • If L = R, then P[L:R] is a good subarray.
  • If P[L:R] is a good subarray and L\lt R, then P[L:R-1] is also a good subarray.
  • If P[L:R] is not a good subarray, then P[L:R+1] is also not a good subarray.

All three facts should be fairly obvious to see.
Notice that this gives us a useful piece of information: if we fix the left endpoint L, then the set of R such that P[L:R] is a good subarray form a contiguous range starting at L.
That is, the good subarrays starting at L are P[L:L], P[L:L+1], P[L:L+2], \ldots, P[L:L+K] for some K \geq 0.

This immediately gives us an idea for a (slow) dynamic programming solution.

Initial idea

Let dp_i denote the number of ways to partition the suffix starting from i into good subarrays. Our final answer is dp_1, and the base case is dp_N = 1.

By fixing the length of the good subarray starting at i, it’s easy to see that our transitions are simply

dp_i = dp_{i+1} + dp_{i+2} + \ldots + dp_{j+1}

where j is the largest integer such that P[i:j] is good.

However, there are two problems with this solution:

  • The first issue is that we need to find the right endpoint j (and find it quickly, at that)
  • The second is that this is too slow. Each index can have a potentially \mathcal{O}(N) transition, in which case this approach has a runtime of \mathcal{O}(N^2).

Let’s fix the problems individually.

Speeding up the DP

The speed issue is easy to fix using suffix sums.
Suppose we maintained the suffix sum array suf, where suf_i = dp_i + dp_{i+1} + \ldots + dp_N.
Then,

  • dp_i = dp_{i+1} + dp_{i+2} + \ldots + dp_{j+1} = suf_{i+1} - suf_{j+2}
  • suf_i = suf_{i+1} + dp_i

So, suf can be computed as we keep computing dp, and this allows us to optimize transitions to \mathcal{O}(1).

Finding the right endpoint

Suppose you know the value of j for a left endpoint i. Can you find its value for left endpoint i-1?
It turns out, we can! Here’s how:

  • Suppose P_{i-1} \lt P_i. Then, the exact same value of j works for i-1 as well.
  • Otherwise, P_{i-1} \gt P_i. In this case, let x \gt i-1 be the smallest integer such that P_x \gt P_{i-1}. Then, we choose j = x-1.
Proof

Both cases are not hard to prove.

  • If P_{i-1} \lt P_i, and good subarray starting at i-1 can be turned into a good subarray starting at i by deleting the first element. Conversely, any good subarray starting at i still remains good when we insert P_{i-1} in the beginning. This implies that they have the same value of j.
  • If P_{i-1} \gt P_i, the only sorted prefix is the one with length 1. So, the subarray starting here can only be good if P_{i-1} itself is the maximum element. So, the instant we hit something larger than P_{i-1}, the subarray is no longer good.

Dealing with the first case is trivial. However, dealing with the second needs us to find, for a given element, the first element to its right that is greater than it.

In fact, it is possible to find this next greater element for every index of the array in \mathcal{O}(N), using a stack. The method of doing this is linked in the prerequisites.

Putting together the computation to find the right endpoint for each index, along with the DP speedup using suffix sums, gives us a solution in \mathcal{O}(N).

TIME COMPLEXITY

\mathcal{O}(N) per test case.

CODE:

Setter's code (C++)
#pragma GCC optimize("O3")
#pragma GCC target("popcnt")
#pragma GCC target("avx,avx2,fma")
#pragma GCC optimize("Ofast,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
#define ll long long  
const ll INF_MUL=1e13;
const ll INF_ADD=1e18;  
#define pb push_back               
#define mp make_pair        
#define nline "\n"                           
#define f first                                          
#define s second                                             
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()   
#define vl vector<ll>         
#define vvl vector<vector<ll>>    
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);  
#endif     
void _print(ll x){cerr<<x;}  
void _print(int x){cerr<<x;} 
void _print(char x){cerr<<x;} 
void _print(string x){cerr<<x;}     
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count()); 
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=998244353;   
const ll MAX=100100; 
void solve(){     
    ll n; cin>>n;
    vector<ll> dp(n+5,1);  
    vector<ll> pref(n+5,1);
    vector<ll> a(n+5,0);  
    vector<ll> track;
    ll till=1;
    for(ll i=1;i<=n;i++){  
        cin>>a[i]; 
        while(!track.empty()){
            auto it=track.back();
            if(a[i]>a[it]){
                track.pop_back();
            }
            else{
                break;   
            }
        }
        if(a[i]<a[i-1]){
            till=i; 
        }
        dp[i]=pref[i-1];
        if(till!=1){
            dp[i]-=pref[till-2]; 
        }
        if(!track.empty()){
            dp[i]+=dp[track.back()];
        }
        track.push_back(i);
        dp[i]%=MOD;
        dp[i]=(dp[i]+MOD)%MOD;
        pref[i]=(pref[i-1]+dp[i])%MOD;
    }
    cout<<dp[n]<<nline;
    return; 
}                            
int main()                                                                            
{
    ios_base::sync_with_stdio(false);                           
    cin.tie(NULL);                         
    #ifndef ONLINE_JUDGE                 
    freopen("input.txt", "r", stdin);                                                
    freopen("output.txt", "w", stdout);  
    freopen("error.txt", "w", stderr);                        
    #endif         
    ll test_cases=1;                   
    cin>>test_cases;
    while(test_cases--){ 
        solve();       
    }
    cout<<fixed<<setprecision(10);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
} 
Tester's code (C++)
//Utkarsh.25dec
#include <bits/stdc++.h>
#define ll long long int
#define pb push_back
#define mp make_pair
#define mod 998244353
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
const int N=500023;
bool vis[N];
vector <int> adj[N];
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }

            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
void solve()
{
    int n=readInt(1,1000000,'\n');
    int A[n+2]={0};
    int vis[n+1]={0};
    for(int i=1;i<=n;i++)
    {
        if(i==n)
            A[i]=readInt(1,n,'\n');
        else
            A[i]=readInt(1,n,' ');
        assert(vis[A[i]]==0);
        vis[A[i]]=1;
    }
    int sortedtill[n+1]={0};
    sortedtill[n]=n;
    for(int i=n-1;i>=1;i--)
    {
        if(A[i]>A[i+1])
            sortedtill[i]=i;
        else
            sortedtill[i]=sortedtill[i+1];
    }
    stack <int> s;
    s.push(1);
    int NGE[n+1]={0};
    for(int i=2;i<=n;i++)
    {
        if(s.empty())
        {
            s.push(i);
            continue;
        }
        while(s.empty()==false && A[s.top()]<A[i])
        {
            NGE[s.top()]=i;
            s.pop();
        }
        s.push(i);
    }
    ll dp[n+1]={0};
    dp[n]=1;
    ll anssum[n+3]={0};
    anssum[n]=1;
    for(int i=1;i<=n;i++)
        if(NGE[i]==0)
            NGE[i]=n+1;
    for(int i=n-1;i>=1;i--)
    {
        int j=sortedtill[i];
        int validtill=NGE[j];
        // Sum of dp[i+1] to dp[validtill]
        dp[i]=anssum[i+1]+(mod-anssum[validtill+1]);
        if(validtill==n+1)
            dp[i]++;
        dp[i]%=mod;
        anssum[i]=anssum[i+1]+dp[i];
        anssum[i]%=mod;
    }
    cout<<dp[1]<<'\n';
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL),cout.tie(NULL);
    int T=readInt(1,100000,'\n');
    while(T--)
        solve();
    assert(getchar()==-1);
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Editorialist's code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

const int mod = 998244353;

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	int t; cin >> t;
	while (t--) {
		int n; cin >> n;
		vector<int> p(n);
		for (int &x : p) cin >> x;
		vector<int> dp(n), suf(n+2), nxt(n, n);
		stack<int> st; st.push(n-1);
		dp[n-1] = suf[n] = 1;
		suf[n-1] = 2;
		int peak = n-1;
		for (int i = n-2; i >= 0; --i) {
			while (!st.empty()) {
				if (p[i] > p[st.top()]) st.pop();
				else break;
			}
			if (!st.empty()) nxt[i] = st.top();
			st.push(i);

			if (p[i] > p[i+1]) peak = i;
			dp[i] = (suf[i+1] - suf[nxt[peak] + 1] + mod) % mod;
			suf[i] = (suf[i+1] + dp[i]) % mod;
		}
		cout << dp[0] << '\n';
	}
}
2 Likes

The editorial is quite great. Very nicely explained. Thanks for creating so high quality editorials @iceknight1093

1 Like

Yet another Broken Stiched question

Awesome Editorials… keep it up :slightly_smiling_face:

How come the time complexity is O(N) . You are find next greater element for each index in O(N) and You are iterating the whole array so it should be O(N^2) . Please correct me if i am wrong .

The next greater element for every index can be found at the same time in \mathcal{O}(N), we don’t do N separate calculations.
The prerequisites section has a link for next greater element, please go through that.

why
suf[n-1] = 2; ?

kudos to ur efforts
:wink:

suf_i = dp_i + dp_{i+1} + \ldots

dp_N = 1 because you have only one element A_N.
dp_{N+1} = 1 because there are no elements, and so vacuously there’s only one way to partition them.
Add them up and you get suf_N = 2 (the code uses 0-indexing, so this becomes suf[n-1] = 2).

1 Like