PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: raysh07
Tester: apoorv_me
Editorialist: iceknight1093
DIFFICULTY:
Easy-Medium
PREREQUISITES:
Binary lifting
PROBLEM:
You’re given an array A of length N.
Answer Q queries on it:
- Given X and Y, find the minimum number of moves need to reach Y if you start at X.
In one move, you can move from index i to any index j such that i \lt j \leq i + A_i.
EXPLANATION:
Let’s try to answer a single query (X, Y) first.
If Y \leq X + A_X, then we can reach Y with a single jump, which is clearly optimal.
Otherwise, it’s not hard to see that it’s optimal to jump to an index i such that (i + A_i) is maximal (and then repeat the process starting from i, till we either reach Y or fail to do so.)
Proof
Suppose we make more than one jump.
Let the first two jumps in an optimal sequence be X \to u \to v.
Let k be the index such that (k + A_k) is maximum across all indices from X+1 to X+A_X.
Then,
- If v \leq X + A_X, we could’ve jumped directly to v on the first move, reducing the length of the sequence by 1 (meaning the sequence we started with couldn’t have been optimal in the first place).
- Otherwise, we can always replace X \to u \to v with X \to k \to v, since if u can reach v then so can k.
It’s thus not worse to make the first jump be from X to k; now repeat this argument for the rest of the sequence (starting from k this time).
In performing this process, it’s easy to see that there are only \leq N “important” edges - namely the ones from each index to the position in its corresponding range that it is optimal to jump to.
Let’s find all these important edges: this is a fairly standard data structure task, and reduces to computing a range maximum quickly.
Use a sparse table/segment tree for this.
Let \text{link}[i] denote the position we end up at if we follow the important edge from i.
Now, to answer the query (X, Y), we want to follow the path
X \to \text{link}[X] \to \text{link}[\text{link}[X]] \to \ldots
till we reach some index v such that v + A_v \leq Y, and then make the last jump be v\to Y.
Finding the number of steps in such a path is, yet again, a classical task: and can be computed quickly using binary lifting.
That is, precompute \text{jump}[i][j] to be where you end up if you start at index i and jump 2^j times. In particular, \text{jump}[i][0] = \text{link}[i].
With this jump table known, you can find the maximum number of jumps that can be made before you’re forced to cross Y, by considering decreasing powers of 2 and jumping only when possible.
Precomputation takes \mathcal{O}(N\log N) time, after which each query is answered in \mathcal{O}(\log N) time.
TIME COMPLEXITY:
\mathcal{O}((N+Q)\log N) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
#define f first
#define s second
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
const int N = 5e5 + 69;
int n, q;
int a[N], lift[N][21];
pair <int, int> seg[4 * N];
void Build(int l, int r, int pos){
if (l == r){
seg[pos] = {a[l] + l, l};
return;
}
int mid = (l + r)/2;
Build(l, mid, pos*2);
Build(mid + 1, r, pos*2 + 1);
seg[pos] = max(seg[pos * 2], seg[pos * 2 + 1]);
}
pair<int, int> query(int l, int r, int pos, int ql, int qr){
if (l >= ql && r <= qr) return seg[pos];
else if (l > qr || r < ql) return {0, -1};
int mid = (l + r)/2;
return max(query(l, mid, pos*2, ql, qr), query(mid + 1, r, pos*2 + 1, ql, qr));
}
void Solve()
{
cin >> n >> q;
for (int i = 1; i <= n; i++){
for (int j = 0; j <= 20; j++){
lift[i][j] = 0;
}
}
for (int i = 0; i <= 4 * n; i++){
seg[i] = {0, 0};
}
for (int i = 1; i <= n; i++) cin >> a[i];
Build(1, n, 1);
for (int i = 1; i <= n; i++){
auto get = query(1, n, 1, i, i + a[i]);
lift[i][0] = get.s;
}
for (int j = 1; j <= 20; j++){
for (int i = 1; i <= n; i++){
lift[i][j] = lift[lift[i][j - 1]][j - 1];
}
}
while (q--){
int l, r; cin >> l >> r;
if (r < l){
cout << -1 << "\n";
continue;
} else if (l == r) {
cout << 0 << "\n";
continue;
} else if (a[l] + l >= r){
cout << 1 << "\n";
continue;
}
int ans = 0;
for (int i = 20; i >= 0; i--){
if (lift[lift[l][i]][0] < r){
// cout << l << " ";
l = lift[l][i];
// cout << (1 << i ) << " ";
// cout << l << "\n";
ans += 1 << i;
}
}
if (l + a[l] >= r) cout << ans + 1 << "\n";
else if (lift[l][0] + a[lift[l][0]] >= r) cout << ans + 2 << "\n";
else cout << -1 << "\n";
}
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}
Tester's code (C++)
#include<bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...)
#endif
#ifdef LOCAL
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}
int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
now++;
}
return now;
}
string readOne() {
assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
return res;
}
string readString(int minl, int maxl, const string &pattern = "") {
assert(minl <= maxl);
string res = readOne();
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int minv, int maxv) {
assert(minv <= maxv);
int res = stoi(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
long long res = stoll(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
auto readInts(int n, int minv, int maxv) {
assert(n >= 0);
vector<int> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readInt(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
auto readLongs(int n, long long minv, long long maxv) {
assert(n >= 0);
vector<long long> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readLong(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
};
#else
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
}
int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
now++;
}
return now;
}
string readOne() {
assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
return res;
}
string readString(int minl, int maxl, const string &pattern = "") {
string X; cin >> X;
return X;
}
int readInt(int minv, int maxv) {
assert(minv <= maxv);
int res; cin >> res;
assert(minv <= res);
assert(res <= maxv);
return res;
}
long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
long long res; cin >> res;
assert(minv <= res);
assert(res <= maxv);
return res;
}
auto readInts(int n, int minv, int maxv) {
assert(n >= 0);
vector<int> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readInt(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
auto readLongs(int n, long long minv, long long maxv) {
assert(n >= 0);
vector<long long> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readLong(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
void readSpace() {
}
void readEoln() {
}
void readEof() {
}
};
#endif
template<class T>
struct RMQ{
int n, logn;
vector<vector<int>> b;
vector<T> A;
void build(const vector<T> &a) {
A = a, n = (int)a.size();
logn = 32 - __builtin_clz(n);
b.resize(logn, vector<int>(n));
iota(b[0].begin(), b[0].end(), 0);
for(int i = 1; i < logn ; i++){
for(int j = 0; j < n ; j++){
b[i][j] = b[i - 1][j];
if(j + (1 << (i - 1)) < n && A[b[i - 1][j + (1 << (i - 1))]] >= A[b[i][j]])
b[i][j] = b[i - 1][j + (1 << (i - 1))];
}
}
}
int rangeMin(int x, int y){
int k = 31 - __builtin_clz(y - x + 1);
return max(A[b[k][x]], A[b[k][y - (1 << k) + 1]]);
}
int minIndx(int x, int y){
int k = 31 - __builtin_clz(y - x + 1);
return A[b[k][x]] > A[b[k][y - (1 << k) + 1]] ? b[k][x] : b[k][y - (1 << k) + 1];
}
};
int32_t main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
input_checker inp;
int T = inp.readInt(1, (int)1e4), NN = 0, NQ = 0; inp.readEoln();
while(T-- > 0) {
int N = inp.readInt(1, (int)5e5); inp.readSpace();
int Q = inp.readInt(1, (int)5e5); inp.readEoln();
NN += N, NQ += Q;
vector<int> B(N), A = inp.readInts(N, 0, N); inp.readEoln();
for(int i = 0 ; i < N ; ++i) {
assert(i + A[i] < N);
B[i] = i + A[i];
}
RMQ<int> rmq; rmq.build(B);
vector<vector<pair<int, int>>> Query(N);
for(int i = 0 ; i < Q ; ++i) {
int l, r; cin >> l >> r;
Query[l - 1].emplace_back(r - 1, i);
}
vector<int> par(N, -1);
vector<vector<int>> adj(N);
for(int i = N - 1 ; i >= 0 ; --i) {
if(A[i] == 0) continue;
int parent = rmq.minIndx(i + 1, i + A[i]);
if(B[parent] > i + A[i]) {
par[i] = parent;
adj[parent].push_back(i);
}
}
vector<bool> vis(N);
vector<int> sol, ans(Q, -1);
auto dfs = [&](auto &&dfs, int node) -> void {
vis[node] = 1;
sol.push_back(-B[node]);
for(auto &[r, in]: Query[node]) {
if(r > -sol.front()) continue;
ans[in] = 1 + (sol.end() - upper_bound(sol.begin(), sol.end(), -r));
}
for(auto &u: adj[node]) if(!vis[u]) {
dfs(dfs, u);
}
sol.pop_back();
};
for(int i = N - 1 ; i >= 0 ; --i) if(!vis[i]) {
dfs(dfs, i);
}
for(int i = 0 ; i < Q ; ++i)
cout << ans[i] << "\n";
}
assert(max(NN, NQ) <= (int)5e5);
inp.readEof();
return 0;
}
Editorialist's code (Python)
import sys
input = sys.stdin.readline
for _ in range(int(input())):
n, q = map(int, input().split())
a = list(map(int, input().split())) + [0]
stk = [0]*(n+1)
jump = [0]*(n+1)
stk[0], ptr = n, 1
jump[n] = n
for i in reversed(range(n)):
if a[i] == 0:
jump[i] = n
else:
lo, hi = 0, ptr-1
while lo < hi:
mid = (lo + hi)//2
if stk[mid] <= i + a[i]: hi = mid
else: lo = mid + 1
if i + a[i] >= stk[lo] + a[stk[lo]]: jump[i] = n
else: jump[i] = stk[lo]
while ptr > 1:
x = stk[ptr-1]
if x + a[x] <= i + a[i]: ptr -= 1
else: break
stk[ptr] = i
ptr += 1
lift = [ [0 for _ in range(20)] for _ in range(n+1)]
for i in range(n+1): lift[i][0] = jump[i]
for i in reversed(range(n+1)):
for j in range(1, 20):
lift[i][j] = lift[lift[i][j-1]][j-1]
for i in range(q):
x, y = map(int, input().split())
x, y = x-1, y-1
if y <= x + a[x]:
print(1)
continue
ans = 0
for k in reversed(range(20)):
if x + a[x] >= y: break
u = lift[x][k]
if y > u + a[u]:
ans += 2**k
x = u
x = jump[x]
if x <= y <= x + a[x]: print(ans + 1 + (x < y))
else: print(-1)