PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: kugo
Tester: yash_daga
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Dynamic programming, range-max queries
PROBLEM:
You’re given an array S. For Q queries of (L, R), answer the following:
- For a fixed value of D, you start from L with a score of D and move towards R. At each index, S_i \leq D must hold. Further, if S_i = D, reduce D by 1.
Find the smallest value of D that allows you to reach R from L.
EXPLANATION:
Let M = \max(S_L, S_{L+1}, \ldots, S_R).
Clearly, the answer to query (L, R) must be at least M, since we will encounter an M along our way and need to be \geq M when we do.
Further, it’s also easy to see that D = M+1 always allows us to reach R, since it’s strictly larger than anything in this range and hence will never reduce.
So, our task reduces to checking whether D = M allows us to reach R or not.
Let’s see how this process would go.
- Let i_1 \geq L be the first index such that S_{i_1} = M. We will definitely reach this index, after which we’d be left with D = M-1.
- From here, we move right till we
- Reach R, at which point we’re done
- Reach an occurrence of something \gt D, at which point we can’t continue
- Reach another occurrence of D, at which point we move to D-1; and then this process repeats.
This is easy to simulate in \mathcal{O}(N), but we need to do better.
Notice that once we reach the first position where S_i = M, the process is rather uniquely defined, because we can pretend we start there with D = M.
So, let’s find for each index i, the furthest we can travel if we start from i with value S_i.
Let this be dp_i.
dp_i can be computed as follows:
- Let j \gt i be the first index such that S_j \geq S_i
- Let k \gt i be the first index such that S_k = S_i - 1.
- Then, dp_i = \min(j-1, dp_k), because we can’t cross index j no matter what; and if we reach k we’re essentially starting from there instead, so we can’t go beyond dp_k either.
dp_i can be precomputed quite easily for all indices by iterating i from N down to 1. The only things that need to be taken care of are quickly computing indices j and k.
- Computing k is easy. Since we’re iterating in reverse anyway, simply keep a record (using, say, a map) of the last occurrence of every element. Let this be \text{last}[x]; then k = \text{last}[S_i - 1].
Note that after processing index i, you should set \text{last}[S_i] = i. - Computing j needs a bit more effort, but in the case of this specific problem, can be actually done easier.
Note that rather than finding the next element that’s \geq A_i, it suffices to find the next element that’s equal to S_i, i.e, \text{last}[S_i] itself. Do you see why?
At any rate, by maintaining the \text{last} map, all the dp values can be computed quickly.
After this, for a query (L, R) we need to:
- Find the maximum M on the range [L, R], which is a standard data structure task (use a segment tree/sparse table/whatever).
- Find the leftmost occurrence of this maximum, which is also not too hard: for example, the segment tree/sparse table merge function can be modified slightly to return this, by maintaining pairs of (value, index).
- Let i be the index we found. Then, if dp_i \geq R the answer is M, otherwise the answer is M+1.
TIME COMPLEXITY
\mathcal{O}((N + Q)\log N) per test case.
CODE:
Setter's code (C++)
#include "bits/stdc++.h"
using namespace std;
typedef long long lol;
typedef std::pair<int,int> pii;
#define pb push_back
#define ub upper_bound
#define lb lower_bound
#define fo(i,l,r,d) for (auto i=(l); (d)<0?i>(r):((d)>0?i<(r):0); i+=(d))
#define all(x) x.begin(), x.end()
#define ff first
#define ss second
std::mt19937 rng (std::chrono::high_resolution_clock::now().time_since_epoch().count());
template <typename A, typename B> std::ostream& operator<< (std::ostream &cout, const std::pair<A, B> &p) { return cout << p.first << ' ' << p.second; } template <typename A, size_t n> std::ostream& operator<< (std::ostream &cout, const std::array<A, n> &v) { for (int i = 0; i < n - 1; ++i) cout << v[i] << ' '; return (n ? cout << v.back(): cout << '\n'); } template <typename A> std::ostream& operator<< (std::ostream &cout, const std::vector<A> &v) { for (int i = 0; i < v.size() - 1; ++i) cout << v[i] << ' '; return (v.size() ? cout << v.back(): cout << '\n'); }
template <typename A, typename B> std::istream& operator>> (std::istream &cin, std::pair<A, B> &p) { cin >> p.first; return cin >> p.second; } template <typename A, size_t n> std::istream& operator>> (std::istream &cin, std::array<A, n> &v) { assert(n); for (int i = 0; i < n - 1; i++) cin >> v[i]; return cin >> v.back(); } template <typename A> std::istream& operator>> (std::istream &cin, std::vector<A> &v) { assert(v.size()); for (int i = 0; i < v.size() - 1; i++) cin >> v[i]; return cin >> v.back(); }
template <typename A, typename B> auto amax (A &a, const B b){ if (b > a) a = b ; return a; }
template <typename A, typename B> auto amin (A &a, const B b){ if (b < a) a = b ; return a; }
template <
class Node,
class Calc,
bool kNearestPowOf2 = false
>
class Segtree {
public:
explicit Segtree (const int n, const Node id, const Calc& F)
: sz(n), N(kNearestPowOf2 ? 1 << 32 - __builtin_clz(std::max(1, sz - 1)) : sz), a(N << 1, id), id(id), F(F)
{
}
explicit Segtree (const std::vector<Node>& x, const Node id, const Calc& F)
: sz(x.size()), N(kNearestPowOf2 ? 1 << 32 - __builtin_clz(std::max(1, sz - 1)) : sz), id(id), F(F)
{
a.resize(N << 1, id);
std::copy(x.begin(), x.end(), a.begin() + N);
for (int i = N; --i; )
a[i] = F(a[i << 1], a[i << 1 | 1]);
}
void set (int i, const Node x) {
// assert(0 <= i and i < sz);
for (a[i += N] = x; i >>= 1; )
a[i] = F(a[i << 1], a[i << 1 | 1]);
}
Node qu (int l, int r) const {
// assert(0 <= l and l <= r and r <= sz);
Node x = id, y = id;
for (l += N, r += N; l < r; l >>= 1, r >>= 1) {
if (l & 1) x = F(x, a[l++]);
if (r & 1) y = F(a[--r], y);
}
return F(x, y);
}
// First j in [l, N] such that pred(F[l, j)) is FALSE, if pred is monotonic
template<class Predicate>
int max_right (int l, const Predicate& pred) const {
// assert(0 <= l and l <= N and pred(id));
if (l == N) return l;
Node prev = id, t = id;
l += N;
do {
l >>= __builtin_ctz(l);
if (!pred(F(prev, a[l]))) {
while (l < N)
if (pred(t = F(prev, a[l <<= 1])))
prev = t, l++;
return l - N;
}
prev = F(prev, a[l++]);
} while ((l & -l) != l);
return N;
}
// First j in [0, r] such that pred(F[j, r)) is TRUE, if pred is monotonic
template<class Predicate>
int min_left (int r, const Predicate& pred) const {
// assert(r > -1 and r <= N and pred(id));
if(r == 0) return r;
Node last = id, t = id;
r += N;
do {
r--, r >>= __builtin_ctz(~r);
if (r == 0) r = 1;
if (!pred(F(a[r], last))){
while (r < N)
if (pred(t = F(a[(r <<= 1) += 1], last)))
last = t, r--;
return r + 1 - N;
}
last = F(a[r], last);
} while((r & -r) != r);
return 0;
}
private:
const int sz;
const int N;
std::vector<Node> a;
const Node id;
const Calc F;
};
void darling (const int kase) {
int n, q; cin >> n >> q;
vector a(n, 0); cin >> a;
a.pb(-1);
stack<pair<int, int>> s;
vector geq(n, -1);
s.push(pair(0, a[0]));
for (int i = 1; i < n; i++) {
while (s.size() and s.top().ss < a[i])
s.pop();
if (s.size())
geq[i] = s.top().ff;
s.push(pair(i, a[i]));
}
vector dp(n, n);
fo(i,n-1,-1,-1) {
int j = geq[i];
if (a[j] == a[i])
dp[j] = i;
else if (a[j] == a[i] + 1)
amin(dp[j], dp[i]);
}
// cout << a << '\n' << dp << '\n';
vector it(n, 0);
iota(all(it), 0);
Segtree mx(it, n, [&](int x, int y){
if (a[x] > a[y])
return x;
else if (a[x] < a[y])
return y;
else
return min(x, y);
});
while (q--) {
int l, r; cin >> l >> r, l--;
auto j = mx.qu(l, r);
if (dp[j] < r)
cout << a[j] + 1 << '\n';
else
cout << a[j] << '\n';
}
}
int main () {
ios_base::sync_with_stdio(0), cin.tie(0);
int t; cin >> t, assert(t >= 0);
for (int i = 0; t--; )
darling(++i);
}
Tester's code (C++)
#include <bits/stdc++.h>
#define IOS std::ios::sync_with_stdio(false); cin.tie(NULL);cout.tie(NULL);
#define pii pair<int, int>
#define ll long long
#define ff first
#define ss second
#define rep(i,x,y) for(int i=x; i<y; i++)
using namespace std;
const long long N=500005, INF=2000000000000000000;
int a[N];
pii st[4*N];
void build(int v, int l, int r)
{
if(l==r)
{
st[v]={a[l], -l};
return;
}
int m=(l+r)/2;
build(v*2, l, m);
build((v*2)+1, m+1, r);
st[v]=max(st[v*2], st[(v*2)+1]);
//cout<<l<<" "<<r<<" "<<t[v].o<<" "<<t[v].c<<" "<<t[v].ans<<"\n";
}
pii query(int v, int tl, int tr, int l, int r)
{
if(l>r)
return {-INF, 0};
if(tl==l&&tr==r)
return st[v];
int tm=(tl+tr)/2;
return max(query((2*v), tl, tm, l, min(tm, r)), query((2*v)+1, tm+1, tr, max(tm+1, l), r));
}
int32_t main()
{
IOS;
int t;
cin>>t;
while(t--)
{
int n, q;
cin>>n>>q;
rep(i,0,n)
cin>>a[i];
build(1, 0, n-1);
int dp[n];
stack <int> s;
for(int i=n-1;i>=0;i--)
{
dp[i]=n;
while(!s.empty() && a[s.top()]<a[i]-1)
s.pop();
if(!s.empty() && a[s.top()]==a[i]-1)
{
dp[i]=min(dp[i], dp[s.top()]);
s.pop();
}
if(!s.empty())
dp[i]=min(dp[i], s.top());
s.push(i);
}
while(q--)
{
int l, r;
cin>>l>>r;
l--, r--;
pii p=query(1, 0, n-1, l, r);
if(dp[-p.ss]<=r)
cout<<p.ff+1<<"\n";
else
cout<<p.ff<<"\n";
}
}
}
Editorialist's code (Python)
class RangeQuery:
def __init__(self, data, func=max):
self.func = func
self._data = _data = [list(data)]
i, n = 1, len(_data[0])
while 2 * i <= n:
prev = _data[-1]
_data.append([func(prev[j], prev[j + i]) for j in range(n - 2 * i + 1)])
i <<= 1
def query(self, start, stop):
"""func of data[start, stop)"""
depth = (stop - start).bit_length() - 1
return self.func(self._data[depth][start], self._data[depth][stop - (1 << depth)])
def __getitem__(self, idx):
return self._data[0][idx]
for _ in range(int(input())):
n, q = map(int, input().split())
s = list(map(int, input().split()))
queries = [ [] for _ in range(n)]
for i in range(q):
l, r = map(int, input().split())
queries[l-1].append((r-1, i))
RMQ = RangeQuery(s)
ans = [0]*q
dp = [0]*n
pos = {}
for i in reversed(range(n)):
dp[i] = n-1
if s[i] in pos: dp[i] = min(dp[i], pos[s[i]] - 1)
if s[i]-1 in pos: dp[i] = min(dp[i], dp[pos[s[i] - 1]])
pos[s[i]] = i
for r, id in queries[i]:
mx = RMQ.query(i, r+1)
ans[id] = mx+1
where = pos[mx]
if dp[where] >= r: ans[id] = mx
print(*ans, sep = '\n')