PIARQ - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: apoorv_me
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

2806

PREREQUISITES:

Stacks, segment trees, answering queries offline

PROBLEM:

Given an array A of N elements, answer Q queries on it:

  • Given L and R, find the longest subsequence of [A_L, A_{L+1}, \ldots, A_R] that’s partially increasing.

A partially increasing sequence is one that has three consecutive non-decreasing elements.

EXPLANATION:

First, we need to figure out how to answer a single query.
That is, given an array A, find its longest partially increasing subsequence.

To do that, note that the subsequence needs to include three indices i, j, k such that i \lt j \lt k and A_i \leq A_j \leq A_k.
Let’s call such a triplet (i, j, k) a good triplet.
Further, the subsequence shouldn’t have anything between i and j; and anything between j and k. However, anything to the left of i and anything to the right of k can be included, so we might as well include them.
This means, for a fixed good triplet (i, j, k), the resulting partially increasing subsequence has length N - (k - i) + 2; because we discard only everything between i+1 and k-1 except j.

So, our task is really to find a good triplet (i, j, k) such that (k-i) is as small as possible, only then will the length be maximum.
To do this, notice that if j is fixed, then i and k are also uniquely fixed:

  • k is the smallest index \gt j such that A_j \leq A_k.
  • i is the largest index \lt j such that A_i \leq A_j.

So, if we find for each index the closest element to its right that’s \geq it, and the closest element to its left that’s \leq it, we can answer a single query in \mathcal{O}(N) time by just iterating across all j.
Finding these quantities in linear time is a well-known problem and can be done using a stack.


Now that we have a solution in \mathcal{O}(N) per query, let’s try to optimize it.
Clearly, the next greater/previous smaller information is important, so let’s precompute that in \mathcal{O}(N) time and store it in arrays \text{next} and \text{prev}.

Now, when answering a single query (L, R), we’d like to do the following:

  • For each L \lt j \lt R, look at \text{prev}[j] and \text{next}[j].
  • if \text{prev}[j] \lt L or \text{next}[j] \gt R, then j can’t be the ‘midpoint’ of a valid triplet at all, so ignore it.
  • Otherwise, L \leq \text{prev}[j] \lt \text{next}[j] \leq R, and as discussed earlier, we maximize our answer with (R-L+1 - (\text{next}[j] - \text{prev}[j]) + 2).

In particular, for range [L, R], we only care about indices j such that \text{prev}[j] \geq L and \text{next}[j] \leq R.

Answering these queries online is hard, so let’s answer them offline instead.
We’ll iterate the right endpoint R from 1 to N. For a fixed R, let’s try to answer all queries (L, R).
That can be done as follows:

  • First, “activate” all indices j for which \text{next}[j] = R.
    Since we’re going in increasing order of R, after this step we’ve only activated all indices whose \text{next}[j] value is \leq R; everything else is inactive.
    This takes care of one constraint. We only need to figure out how to actually do the activation, in order for the other constraints to also be satisfied when querying.
  • Look at some index j that we’re activating - in particular, look at x = \text{prev}[j].
    This index is only valid for queries (L, R) such that L \leq x, as per our constraints on \text{prev}.
    Also, its value is \text{next}[j] - \text{prev}[j], since that’s the only part that actually depends on j.
    So, let’s make an array B of length N, such that B_i = \infty initially.
    Then, when activating j, we set B_{\text{prev}[j]} to \text{next}[j] - \text{prev}[j] (if it hasn’t been set yet).
  • Finally, to answer a query (L, R) ending at R, we just query for the minimum element of B in the range [L, R]; say \text{mn}.
    The answer to this query is just (R-L+1) - \text{mn} + 2.
    If \text{mn} is \infty, the answer is -1.

So, we need a data structure that supports point updates and range-max queries, which can be handled using a segment tree.
Each element creates one segment tree update, and each query requires one segment tree query, so our overall complexity is \mathcal{O}((N+Q)\log N).

TIME COMPLEXITY

\mathcal{O}((N+Q)\log N) per testcase.

CODE:

Author's code (C++)
#ifndef LOCAL
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,avx2,sse,sse2,sse3,sse4,popcnt,fma")
#endif

#include <bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../../debug.h"
#else
#define dbg(...) "11-111"
#endif

#define sz(a) (int)(a).size()

struct node{
    int sum;
    node(int x = 0) {
        sum = x;
    }
    void add(int x) {
        sum = min(x, sum);
    }
    friend node merge(const node &a, const node &b) {
        node tmp;
        tmp.sum = min(a.sum, b.sum);
        return tmp;
    }
};

struct segtree{
    int n;
    vector<node> tree;
    node __default;
    segtree(const vector<int> &a) {
        __default = 1e9 + 2;
        this -> n = sz(a);
        tree.resize(sz(a) << 2);
        build(a, 0, sz(a) - 1, 1);
    }

    void build(const vector<int> &a, int l, int r, int v) {
        if(l == r) {
            tree[v] = a[l];
            return;
        }
        int m = l + r >> 1;
        build(a, l, m, 2 * v);
        build(a, m + 1, r, 2 * v + 1);
        tree[v] = merge(tree[2 * v], tree[2 * v + 1]);
    }

    void update(int ind, int val, int v = 1, int l = 0, int r = -1) {
        if(r < 0)   r += n;
        if(l == r) {
            assert(l == ind);
            tree[v].add(val);
            return;
        }
        int m = l + r >> 1;
        if(ind > m)
            update(ind, val, 2 * v + 1, m + 1, r);
        else
            update(ind, val, 2 * v, l, m);

        tree[v] = merge(tree[2 * v], tree[2 * v + 1]);
    }

    node query(int l, int r, int v = 1, int lmost = 0, int rmost = -1) {
        if(rmost < 0)   rmost += n;
        if(r < lmost || l > rmost) {
            return __default;
        }
        if(l <= lmost && r >= rmost)
            return tree[v];
        int mid = lmost + rmost >> 1;
        return merge(query(l, r, 2 * v, lmost, mid), query(l, r, 2 * v + 1, mid + 1, rmost));
    }
};

int32_t main() {
    ios_base::sync_with_stdio(0);   cin.tie(0);

    
    auto __solve_testcase = [&](int testcase) -> void {
        int n, nq;  cin >> n >> nq;
        vector<int> a(n);
        for(auto &i: a)     cin >> i;

        vector<int> nx(n, n), pv(n, -1), stk;
        for(int i = 0 ; i < n ; i++) {
            while(!stk.empty() && a[stk.back()] > a[i])
                stk.pop_back();

            if(!stk.empty())    pv[i] = stk.back();
            stk.push_back(i);
        }
        stk.clear();
        for(int i = n - 1 ; i >= 0 ; i--) {
            while(!stk.empty() && a[stk.back()] < a[i])
                stk.pop_back();
            if(!stk.empty())    nx[i] = stk.back();
            stk.push_back(i);
        }

        vector<int> upd(n, n);
        for(int i = 0 ; i < n ; i++)    if(pv[i] != -1 && nx[i] < n) {
            upd[pv[i]] = min(upd[pv[i]], nx[i]);
        }

        vector<vector<pair<int, int>>> que(n);
        vector<int> ans(nq);
        for(int i = 0 ; i < nq ; i++) {
            int l, r;   cin >> l >> r;
            que[l - 1].emplace_back(r - 1, i);
        }

        segtree sg(vector<int>(n, 1e9));

        for(int i = n - 1 ; i >= 0 ; i--) {
            if(upd[i] < n) {
                sg.update(upd[i], upd[i] - i - 2);
            }
            for(auto &[r, ind] : que[i]) {
                ans[ind] = max(-1, (r - i + 1) - (sg.query(i, r).sum));
            }
        }

        for(int i = 0 ; i < nq ; i++)   cout << ans[i] << '\n';
    };
    
    int no_of_tests;   cin >> no_of_tests;
    for(int test_no = 1 ; test_no <= no_of_tests ; test_no++)
        __solve_testcase(test_no);
    

    return 0;
}
Editorialist's code (Python)
class SegmentTree:
    def __init__(self, data, default=0, func=max):
        """initialize the segment tree with data"""
        self._default = default
        self._func = func
        self._len = len(data)
        self._size = _size = 1 << (self._len - 1).bit_length()

        self.data = [default] * (2 * _size)
        self.data[_size:_size + self._len] = data
        for i in reversed(range(_size)):
            self.data[i] = func(self.data[i + i], self.data[i + i + 1])

    def __delitem__(self, idx):
        self[idx] = self._default

    def __getitem__(self, idx):
        return self.data[idx + self._size]

    def __setitem__(self, idx, value):
        idx += self._size
        self.data[idx] = value
        idx >>= 1
        while idx:
            self.data[idx] = self._func(self.data[2 * idx], self.data[2 * idx + 1])
            idx >>= 1

    def __len__(self):
        return self._len

    def query(self, start, stop):
        """func of data[start, stop)"""
        start += self._size
        stop += self._size

        res_left = res_right = self._default
        while start < stop:
            if start & 1:
                res_left = self._func(res_left, self.data[start])
                start += 1
            if stop & 1:
                stop -= 1
                res_right = self._func(self.data[stop], res_right)
            start >>= 1
            stop >>= 1

        return self._func(res_left, res_right)

    def __repr__(self):
        return "SegmentTree({0})".format(self.data)

sumn, sumq = 0, 0
tests = int(input())
if tests > 50000: exit(0)

for _ in range(tests):
    n, q = map(int, input().split())
    sumn += n
    sumq += q
    if n < 3 or sumn > 200000 or sumq > 200000: exit(0)
    if n <= 0 or q <= 0: exit(0)
    
    a = [0] + list(map(int, input().split())) + [10**9]
    if min(a[1:]) < 1 or max(a[1:n+1]) > 10**9: exit(0)
    
    left, right = [0]*(n+2), [n+1]*(n+2)
    stk = []
    queries = [ [] for _ in range(n+2) ]
    updates = [ [] for _ in range(n+2) ]
    for i in range(1, n+1):
        while stk:
            if a[stk[-1]] > a[i]: stk.pop()
            else: break
        if stk: left[i] = stk[-1]
        stk.append(i)
    stk = []
    for i in reversed(range(1, n+1)):
        while stk:
            if a[stk[-1]] < a[i]: stk.pop()
            else: break
        if stk: right[i] = stk[-1]
        stk.append(i)
        if left[i] >= 1 and right[i] <= n:
            updates[right[i]].append((left[i], left[i] - right[i]))
    
    for i in range(q):
        L, R = map(int, input().split())
        if L <= 0 or R <= 0 or L > n or R > n or L > R: exit(0)
        queries[R].append((L, i))
    ans = [-1]*q
    seg = SegmentTree([-10**9]*(n+2), -10**9)

    for i in range(1, n+1):
        for pos, val in updates[i]:
            if seg[pos] < val: seg[pos] = val
        for L, id in queries[i]:
            mx = seg.query(L, i+1)
            if mx < -10**6: continue
            ans[id] = i-L+3 + mx
    for x in ans: print(x)

Sample test case if anyone is getting WA:
1
12 7
19 25 24 23 21 2 3 3 14 21 1 23
4 7
8 10
9 9
7 10
9 10
10 10
11 11

I solve it online using a merge-sort tree if anyone wants to see it : source code