GRAND - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Practice

Setter: Eugene
Tester & Editorialist: Taranpreet Singh

DIFFICULTY

Medium-Hard

PREREQUISITES

Dynamic Programming, Segment Tree, Randomization

PROBLEM

Given a Randomly generated Permutation of length N and an integer K, find the maximum length of the longest subsequence of given permutation such that there are at most K pairs of neighboring elements such that left element is greater than right element. (Referring to these as special pairs)

QUICK EXPLANATION

  • It is easy to do a trivial O(K*N^2) DP by maintaining state (i, k) storing the length of longest subsequence ending at i, which has at most x pairs where the first element is greater than second
  • We don’t need to check transitions from position j to position i where i \lt j if
    • P_i \lt P_j and there exists p such that i \lt p \lt j and P_j \lt P_p \lt P_i
    • P_j \gt P_i and there exists p such that i \lt p \lt j and either P_j \lt P_p or P_p \lt P_i
      In both cases, we can include the element at position p to obtain a longer subsequence with the same number of special pairs. With random permutation, the number of transitions per state is reduced to O(log(N)), leading to complexity O(K*N*log(N))
  • If the whole permutation contains up to K pairs, then the whole permutation is the required subsequence, otherwise, we expect the special pairs to be spread almost evenly, implying the number of special pairs for subsequence ending at position i must be i*K/N
  • So, let’s maintain a wide diagonal of the DP table instead of maintaining the whole DP table since the probability of the special pairs being concentrated towards ends is exponentially low.

EXPLANATION

Slow solution

Let us use a simple Dynamic Programming approach here, where state (i, k) denotes the length of the longest subsequence of first i elements in a given permutation, such that there are exactly K special pairs.

Transitions are relatively straight forward, we can write \displaystyle f(i, k) = 1+\max_{j < i} (1+f(j, k')) where k' = k if P_j \lt P_i and k' = k-1 if P_j \gt P_i. This recurrence correspond to trying to extend subsequences ending at each position j \lt i to include element at position i.

The longest subsequence is the maximum in the whole DP table. The initial state can be defined by filling the DP table with 1 s

Computing the answer for each state requires O(N) transitions, and there are O(N*K) states. This solution has time complexity O(K*N^2) and space complexity O(K*N), both of which are too high for given constraints

Reducing the transitions

Let us try to reduce the number of transitions required per state. While computing values for states at position i, we don’t need to check states at position j where j \lt i if either of the following is satisfied

  • P_j \lt P_i and there exists a p such that j \lt p \lt i where P_j \lt P_p \lt P_i
    OR
  • P_j \gt P_i and there exists a p such that j \lt p \lt i where P_j \lt P_p or P_p \lt P_i

In both cases, it is beneficial to include element p when extending subsequence ending at position J to include the element at position i since we can include the element at position p in subsequence to obtain a longer subsequence with the same number of special pairs, so for position i, checking position p is sufficient.

From first condition, the set of positions checked having P_j \lt P_i are a subsequence i_1 \lt i_2 \lt \ldots \lt i_l such that P_{i_j} \gt P_{i_{j+1}} is satisfied for 1 \leq j \lt l.

From second condition, let’s find largest p such that P_p \lt P_i. All positions j \lt p don’t need to be checked. Let’s say positions p \lt i_1 \lt i_2 \lt \ldots i_l \lt i are the positions to be considered. Then again, P_{i_j} \gt P_{i_{j+1}} is satisfied for 1 \leq j \lt l.

In both cases, the number of transitions from state i is bounded by number of suffix maximums for subarray P_{1, i-1}, where a suffix maximum of an array is position x such that P_x \gt P_y for all y \gt x.

Let’s visualize reversing this subarray, and the number of transitions is bounded by the number of prefix maximums of a given permutation.

This question is already answered, and by the proof mentioned here, we can see that the expected number of transitions to be processed from a given position becomes log(N).

Hence, the solution now has time complexity O(K*N*log(N)) and space complexity O(K*N)

Reducing the size of DP table

Let us compute the number of special pairs in a given permutation. Say there are C such pairs. If we have C \leq K, then the whole permutation is a valid subsequence.

Otherwise, we can greedily prove that it is beneficial to pick a subsequence with exactly K such pairs, since selecting fewer pairs is not optimal.

Core Observation
The crucial observation here is that since the permutation is random, we can expect the K pairs to be evenly distributed in the range from 1 to N.

So, based on the above observation, the optimal subsequence must have approximately i*K/N special pairs for elements up to position i.

We attempt to maintain DP table values which we expect to be relevant. For position i, we expect states (i, x) to be relevant as |x-i*K/N| is smaller. So, let’s choose a bound B and discard states (i, x) if |x-i*K/N| > B. We can visualize it as preserving a B width diagonal from the bottom left corner to the top right corner.

For example, considering permutation 2 7 6 3 4 1 8 5 and K = 5, the DP table is generated as follows

1 2 2 2 3 1 4 4 
1 2 3 3 4 4 5 5 
1 2 3 4 5 5 6 6 
1 2 3 4 5 6 7 7 
1 2 3 4 5 6 7 8 
1 2 3 4 5 6 7 8

By using bound B = 2, we keep the following table

1 2 2 2 - - - - 
1 2 3 3 4 3 - - 
1 2 3 4 5 5 6 - 
- 1 3 4 5 6 7 7 
- - - 4 5 6 7 8 
- - - - 1 6 7 8 

This way, we intend to store only O(B) states per position, reducing time complexity to O(B*N*log(N)) and space complexity to O(B*N)

Bound B, if chosen too high, shall lead to TLE or MLE, and if chosen too tight, you might actually miss a permutation P depending upon the values in the DP table we discarded. The probability of this happening with a sufficient bound is exponentially low, and becomes even lower as N increases, by a proof similar to Chernoff Bound.

For this problem, choosing B = 70 was sufficient, and TL was kept wide enough to even allow B = 300 to get accepted. During test generation, Some random tests requiring B = 138 were found as well.

Implementation

Dynamic Programming is straightforward. Computing the list of transitions for each state can be done by building a segment tree, where position x stores p if P_p = x and p < i and we are computing transitions for position i.

We can repeatedly query the rightmost position of an element containing value in range [lo, P_i-1] where lo = 0 initially, and we update lo = P_q +1 if q is the position returned. This process runs time proportional to the number of transitions, and thus is executed O(N*log(N)) times.

TIME COMPLEXITY

To generate list of transitions, there are O(N*log(N)) transitions to be found, each taking O(log(N)) operations, leading to O(N*log^2(N)).

The DP takes time O(B*N*log(N)) for chosen B.

Hence, the time complexity is O(N*log^2(N)+B*N*log(N)) and the space complexity is O(N*log(N)+B*N)

SOLUTIONS

Setter's Solution
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#include <vector>
#include <set>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <queue>
#include <ctime>
#include <cassert>
#include <complex>
#include <string>
#include <cstring>
#include <chrono>
#include <random>
#include <bitset>
using namespace std;

#ifdef LOCAL
    #define eprintf(...) fprintf(stderr, __VA_ARGS__);fflush(stderr);
#else
    #define eprintf(...) 42
#endif

using ll = long long;
using ld = long double;
using uint = unsigned int;
using ull = unsigned long long;
template<typename T>
using pair2 = pair<T, T>;
using pii = pair<int, int>;
using pli = pair<ll, int>;
using pll = pair<ll, ll>;
mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
ll myRand(ll B) {
    return (ull)rng() % B;
}

#define pb push_back
#define mp make_pair
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second

clock_t startTime;
double getCurrentTime() {
    return (double)(clock() - startTime) / CLOCKS_PER_SEC;
}

const int M = (1 << 16);
struct Node {
    int val;
    int l, r;

    Node() : val(-1), l(), r() {}
    Node(int _l, int _r) : val(-1), l(_l), r(_r) {}
};
Node tree[2 * M + 5];
void build() {
    for (int i = 0; i < M; i++)
	    tree[M + i] = Node(i, i + 1);
    for (int i = M - 1; i > 0; i--)
	    tree[i] = Node(tree[2 * i].l, tree[2 * i + 1].r);
}

void setPoint(int v, int x) {
    v += M;
    tree[v].val = x;
    while(v > 1) {
	    v >>= 1;
	    tree[v].val = max(tree[2 * v].val, tree[2 * v + 1].val);
    }
}
int getMax(int v, int l, int r) {
    if (l <= tree[v].l && tree[v].r <= r) return tree[v].val;
    if (l >= tree[v].r || tree[v].l >= r) return -1;
    return max(getMax(2 * v, l, r), getMax(2 * v + 1, l, r));
}

const int N = 50100;
const int L = 100;
int n, k;
int a[N];
int b[N];
int dp[N][2 * L + 1];

void makeTrans(int v, int u) {
    for (int i = 0; i <= 2 * L; i++) {
	    int x = i + b[v] - b[u];
	    x += (int)(a[v] > a[u]);
	    if (x < 0 || x > 2 * L) continue;
	    dp[u][x] = max(dp[u][x], dp[v][i] + 1);
    }
}

int main()
{
    startTime = clock();
//	freopen("input.txt", "r", stdin);
//	freopen("output.txt", "w", stdout);

    build();

    scanf("%d%d", &n, &k);
    for (int i = 0; i < n; i++)
	    scanf("%d", &a[i]);
    int tot = 0;
    for (int i = 1; i < n; i++) {
	    tot += (int)(a[i] < a[i - 1]);
    }
    if (tot <= k) {
	    printf("%d\n", n);
	    return 0;
    }
    for (int i = 0; i < n; i++) {
	    b[i] = (ll)((i + 1) * k) / n;
	    b[i] = min(b[i], k - L);
	    b[i] = max(b[i], L);
	    for (int j = 0; j <= 2 * L; j++)
		    dp[i][j] = 1;
    }
    for (int i = 0; i < n; i++) {
	    setPoint(a[i], i);
	    int l = 0;
	    while(true) {
		    int p = getMax(1, l + 1, a[i]);
		    if (p < 0) break;
		    l = a[p];
		    makeTrans(p, i);
	    }
	    int q = getMax(1, 1, a[i]);
	    l = a[i];
	    while(true) {
		    int p = getMax(1, l + 1, M);
		    if (p <= q) break;
		    l = a[p];
		    makeTrans(p, i);
	    }
    }
    int ans = 1;
    for (int i = 0; i < n; i++)
	    for (int j = 0; j <= 2 * L; j++) {
		    if (b[i] - L + j > k) break;
		    ans = max(ans, dp[i][j]);
	    }
    printf("%d\n", ans);

    return 0;
}
Tester's Solution
import java.util.*;
import java.io.*;
class GRAND{
    //SOLUTION BEGIN
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        int N = ni(), K = ni();
        int[] A = new int[N];
        for(int i = 0; i< N; i++)A[i] = ni()-1;
        int count = 0;
        for(int i = 1; i< N; i++)if(A[i] < A[i-1])count++;
        if(count <= K){
            pn(N);
            return;
        }
        
        SegmentTree st = new SegmentTree(N);
        List<Integer>[] choices = new ArrayList[N];
        for(int i = 0; i< N; i++){
            choices[i] = new ArrayList<>();
            st.update(A[i], i);
            //A[j] < A[i]
            int lo = 0;
            while(true){
                int p = st.query(lo, A[i]-1);
                if(p < 0)break;
                choices[i].add(p);
                lo = A[p]+1;
            }
            //A[j] > A[i]
            int x = st.query(0, A[i]-1);
            lo = A[i]+1;
            while(true){
                int p = st.query(lo, N-1);
                if(p <= x)break;
                choices[i].add(p);
                lo = A[p]+1;
            }
        }
        
        int ans = 1;
        int B = 100;//Width of Diagonal
        int[][] DP = new int[N][];
        int[] lo = new int[N], hi = new int[N];
        for(int i = 0; i< N; i++){
            int x = (int)(((i+1)*(long)K)/N);
            lo[i] = Math.max(0, x-B);
            hi[i] = Math.min(K, x+B);
            DP[i] = new int[hi[i]-lo[i]+1];
            Arrays.fill(DP[i], 1);
        }
        
        for(int i = 1; i< N; i++){
            for(int j:choices[i]){
                int op = A[j] > A[i]?1:0;
                for(int o2 = 0; o2 < DP[i].length; o2++){
                    int o1 = o2+lo[i]-op-lo[j];
                    if(o1 < 0 || o1 > hi[j]-lo[j])continue;
                    DP[i][o2] = Math.max(DP[i][o2], 1+DP[j][o1]);
                    ans = Math.max(ans, DP[i][o2]);
                }
            }
        }
        pn(ans);
//        int[][] DP = new int[N][1+K];
//        for(int i = 0; i< N; i++)Arrays.fill(DP[i], 1);
//        int ans = 0;
//        for(int i = 1; i< N; i++){
//            for(int x:choices[i]){
//                int d = A[x] > A[i]?1:0;
//                for(int k = d; k <= K; k++){
//                    DP[i][k] = Math.max(DP[i][k], 1+DP[x][k-d]);
//                    ans = Math.max(ans, DP[i][k]);
//                }
//            }
//        }
//        pn(ans);
    }
    void dbg(Object... o){System.err.println(Arrays.deepToString(o));}
    class SegmentTree{
        private int initValue(){return -1;}
        private int u(int oldValue, int newValue){return newValue;}
        private int merge(int le, int ri){return Math.max(le, ri);}
        private int initQuery(){return -1;}

        private int m= 1;
        private int[] t;
        public SegmentTree(int n){
            while(m<n)m<<=1;
            t = new int[m<<1];
            Arrays.fill(t, initValue());
        }
        public void update(int i, int val){
            t[i += m]  = u(t[i], val);
            for(i>>=1;i>0;i>>=1)t[i] = merge(t[i<<1], t[i<<1|1]);
        }
        public int query(int l, int r){
            int lans = initQuery(), rans = initQuery();
            for(l+=m,r+=m+1;l<r;l>>=1,r>>=1){
                if((l&1)==1)lans = merge(lans, t[l++]);
                if((r&1)==1)rans = merge(t[--r], rans);
            }
            return merge(lans, rans);
        }
    }
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = false;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
        new GRAND().run();
    }
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
                }
            }
            return st.nextToken();
        }

        String nextLine() throws Exception{
            String str = "";
            try{   
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            }  
            return str;
        }
    }
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

what is the problem in this code…
i am getting wrong answer…
what is the case I am getting wrong
plz help