DATA101 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Binary lifting

PROBLEM:

You’re given an array A of length N.
Answer Q queries on it:

  • Given X and Y, find the minimum number of moves need to reach Y if you start at X.
    In one move, you can move from index i to any index j such that i \lt j \leq i + A_i.

EXPLANATION:

Let’s try to answer a single query (X, Y) first.

If Y \leq X + A_X, then we can reach Y with a single jump, which is clearly optimal.
Otherwise, it’s not hard to see that it’s optimal to jump to an index i such that (i + A_i) is maximal (and then repeat the process starting from i, till we either reach Y or fail to do so.)

Proof

Suppose we make more than one jump.
Let the first two jumps in an optimal sequence be X \to u \to v.
Let k be the index such that (k + A_k) is maximum across all indices from X+1 to X+A_X.

Then,

  • If v \leq X + A_X, we could’ve jumped directly to v on the first move, reducing the length of the sequence by 1 (meaning the sequence we started with couldn’t have been optimal in the first place).
  • Otherwise, we can always replace X \to u \to v with X \to k \to v, since if u can reach v then so can k.

It’s thus not worse to make the first jump be from X to k; now repeat this argument for the rest of the sequence (starting from k this time).

In performing this process, it’s easy to see that there are only \leq N “important” edges - namely the ones from each index to the position in its corresponding range that it is optimal to jump to.

Let’s find all these important edges: this is a fairly standard data structure task, and reduces to computing a range maximum quickly.
Use a sparse table/segment tree for this.

Let \text{link}[i] denote the position we end up at if we follow the important edge from i.

Now, to answer the query (X, Y), we want to follow the path
X \to \text{link}[X] \to \text{link}[\text{link}[X]] \to \ldots
till we reach some index v such that v + A_v \leq Y, and then make the last jump be v\to Y.

Finding the number of steps in such a path is, yet again, a classical task: and can be computed quickly using binary lifting.
That is, precompute \text{jump}[i][j] to be where you end up if you start at index i and jump 2^j times. In particular, \text{jump}[i][0] = \text{link}[i].

With this jump table known, you can find the maximum number of jumps that can be made before you’re forced to cross Y, by considering decreasing powers of 2 and jumping only when possible.

Precomputation takes \mathcal{O}(N\log N) time, after which each query is answered in \mathcal{O}(\log N) time.

TIME COMPLEXITY:

\mathcal{O}((N+Q)\log N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
#define f first
#define s second

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
const int N = 5e5 + 69;
int n, q;
int a[N], lift[N][21];
pair <int, int> seg[4 * N];

void Build(int l, int r, int pos){
    if (l == r){
        seg[pos] = {a[l] + l, l};
        return;
    }
    
    int mid = (l + r)/2;
    Build(l, mid, pos*2);
    Build(mid + 1, r, pos*2 + 1);
    
    seg[pos] = max(seg[pos * 2], seg[pos * 2 + 1]);
}

pair<int, int> query(int l, int r, int pos, int ql, int qr){
    if (l >= ql && r <= qr) return seg[pos];
    else if (l > qr || r < ql) return {0, -1};
    
    int mid = (l + r)/2;
    return max(query(l, mid, pos*2, ql, qr), query(mid + 1, r, pos*2 + 1, ql, qr));
}

void Solve() 
{
    cin >> n >> q;

    for (int i = 1; i <= n; i++){
        for (int j = 0; j <= 20; j++){
            lift[i][j] = 0;
        }
    }

    for (int i = 0; i <= 4 * n; i++){
        seg[i] = {0, 0};
    }
    
    for (int i = 1; i <= n; i++) cin >> a[i];
    
    Build(1, n, 1);
    
    for (int i = 1; i <= n; i++){
        auto get = query(1, n, 1, i, i + a[i]);
        lift[i][0] = get.s;
    }
    
    for (int j = 1; j <= 20; j++){
        for (int i = 1; i <= n; i++){
            lift[i][j] = lift[lift[i][j - 1]][j - 1];
        }
    }
    
    while (q--){
        int l, r; cin >> l >> r;
        
        if (r < l){
            cout << -1 << "\n";
            continue;
        } else if (l == r) {
            cout << 0 << "\n";
            continue;
        } else if (a[l] + l >= r){
            cout << 1 << "\n";
            continue;
        }
        
        int ans = 0;
        for (int i = 20; i >= 0; i--){
            if (lift[lift[l][i]][0] < r){
               // cout << l << " ";
                l = lift[l][i];
              //  cout << (1 << i ) << " ";
              //  cout << l << "\n";
                ans += 1 << i;
            }
        }
        
        if (l + a[l] >= r) cout << ans + 1 << "\n";
        else if (lift[l][0] + a[lift[l][0]] >= r) cout << ans + 2 << "\n";
        else cout << -1 << "\n";
    }
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
   cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++)
#include<bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...)
#endif

#ifdef LOCAL
struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};
#else

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
      string X; cin >> X;
      return X;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res;  cin >> res;
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res;  cin >> res;
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
    }

    void readEoln() {
    }

    void readEof() {
    }
};
#endif

template<class T>
struct RMQ{
    int n, logn;
    vector<vector<int>> b;
    vector<T> A;
    void build(const vector<T> &a) {
        A = a, n = (int)a.size();
        logn = 32 - __builtin_clz(n);
        b.resize(logn, vector<int>(n));
        iota(b[0].begin(), b[0].end(), 0);
        for(int i = 1; i < logn ; i++){
            for(int j = 0; j < n ; j++){
                b[i][j] = b[i - 1][j];
                if(j + (1 << (i - 1)) < n && A[b[i - 1][j + (1 << (i - 1))]] >= A[b[i][j]])
                    b[i][j] = b[i - 1][j + (1 << (i - 1))];
            }
        }
    }
    int rangeMin(int x, int y){
        int k = 31 - __builtin_clz(y - x + 1);
        return max(A[b[k][x]], A[b[k][y - (1 << k) + 1]]);
    }
    int minIndx(int x, int y){
        int k = 31 - __builtin_clz(y - x + 1);
        return A[b[k][x]] > A[b[k][y - (1 << k) + 1]] ? b[k][x] : b[k][y - (1 << k) + 1];
    }
};

int32_t main() {
  ios_base::sync_with_stdio(0);
  cin.tie(0);

  input_checker inp;
  int T = inp.readInt(1, (int)1e4), NN = 0, NQ = 0; inp.readEoln();
  while(T-- > 0) {
    int N = inp.readInt(1, (int)5e5); inp.readSpace();
    int Q = inp.readInt(1, (int)5e5); inp.readEoln();
    NN += N, NQ += Q;

    vector<int> B(N), A = inp.readInts(N, 0, N);  inp.readEoln();
    for(int i = 0 ; i < N ; ++i) {
      assert(i + A[i] < N);
      B[i] = i + A[i];
    }
    RMQ<int> rmq; rmq.build(B);
    vector<vector<pair<int, int>>> Query(N);
    for(int i = 0 ; i < Q ; ++i) {
      int l, r; cin >> l >> r;
      Query[l - 1].emplace_back(r - 1, i);
    }

    vector<int> par(N, -1);
    vector<vector<int>> adj(N);
    for(int i = N - 1 ; i >= 0 ; --i) {
      if(A[i] == 0)  continue;
      int parent = rmq.minIndx(i + 1, i + A[i]);
      if(B[parent] > i + A[i]) {
        par[i] = parent;
        adj[parent].push_back(i);
      }
    }
    vector<bool> vis(N);
    vector<int> sol, ans(Q, -1);
    auto dfs = [&](auto &&dfs, int node) -> void {
      vis[node] = 1;
      sol.push_back(-B[node]);
      for(auto &[r, in]: Query[node]) {
        if(r > -sol.front()) continue;
        ans[in] = 1 + (sol.end() - upper_bound(sol.begin(), sol.end(), -r));
      }
      for(auto &u: adj[node]) if(!vis[u]) {
        dfs(dfs, u);
      }
      sol.pop_back();
    };
    for(int i = N - 1 ; i >= 0 ; --i) if(!vis[i]) {
      dfs(dfs, i);
    }

    for(int i = 0 ; i < Q ; ++i)
      cout << ans[i] << "\n";
  }
  assert(max(NN, NQ) <= (int)5e5);
  inp.readEof();
  
  return 0;
}

Editorialist's code (Python)
import sys
input = sys.stdin.readline
for _ in range(int(input())):
    n, q = map(int, input().split())
    a = list(map(int, input().split())) + [0]

    stk = [0]*(n+1)
    jump = [0]*(n+1)
    stk[0], ptr = n, 1
    jump[n] = n
    for i in reversed(range(n)):
        if a[i] == 0:
            jump[i] = n
        else:
            lo, hi = 0, ptr-1
            while lo < hi:
                mid = (lo + hi)//2
                if stk[mid] <= i + a[i]: hi = mid
                else: lo = mid + 1
            if i + a[i] >= stk[lo] + a[stk[lo]]: jump[i] = n
            else: jump[i] = stk[lo]
        while ptr > 1:
            x = stk[ptr-1]
            if x + a[x] <= i + a[i]: ptr -= 1
            else: break
        stk[ptr] = i
        ptr += 1
    
    lift = [ [0 for _ in range(20)] for _ in range(n+1)]
    for i in range(n+1): lift[i][0] = jump[i]
    for i in reversed(range(n+1)):
        for j in range(1, 20):
            lift[i][j] = lift[lift[i][j-1]][j-1]
    
    for i in range(q):
        x, y = map(int, input().split())
        x, y = x-1, y-1
        if y <= x + a[x]:
            print(1)
            continue

        ans = 0
        for k in reversed(range(20)):
            if x + a[x] >= y: break
            
            u = lift[x][k]
            if y > u + a[u]:
                ans += 2**k
                x = u
        x = jump[x]
        if x <= y <= x + a[x]: print(ans + 1 + (x < y))
        else: print(-1)

Never mind, I think I misunderstood what link was, it’s clear now.

Where is the proof ? Please add the proof for the following claim : " Otherwise, it’s not hard to see that it’s optimal to jump to an index iii such that (i+Ai)(i + A_i)(i+Ai​) is maximal (and then repeat the process starting from iii, till we either reach YYY or fail to do so.)"

Oops, didn’t realize I missed writing that - I’ve added it now.
It’s a really simple proof though.

@iceknight1093
i was doing this problem with the same approach starting from X and finding maximum J such that it is in reach and j+arr[j] is maximum
I m using segtree to find range maximum index
here is my code but its flailing I don’t know where is the error can you tell me a testcase where this logic will not work.

#include<bits/stdc++.h>
using namespace std;

int segtree[2000005];
int reach[500002];

void build(int l,int r ,int ind)
{
	if(l==r)
	{
		segtree[ind]=l;
		return;
	}

	int mid=(l+r)/2;
	build(l,mid,(ind*2)+1);
	build(mid+1,r,(ind*2)+2);

	if(reach[segtree[(2*ind)+1]]>reach[segtree[(2*ind)+2]])
		segtree[ind]=segtree[(2*ind)+1];
	else
		segtree[ind]=segtree[(2*ind)+2];

	return;

}

int query(int l,int r,int ind,int st,int en)
{
	if(r<st || l>en)
	{
		return 500001;
	}

	if(l>=st && r<=en)
	{
		return segtree[ind];
	}

	int mid=(l+r)/2;
	int lans=query(l,mid,(ind*2)+1,st,en);
	int rans=query(mid+1,r,(ind*2)+2,st,en);
	if(reach[lans]>reach[rans])
		return lans;
	else
		return rans;

	// return max(lans,rans);
}

int main() 
{
    #ifndef ONLINE_JUDGE
        freopen("input.txt","r",stdin);
        freopen("output.txt","w",stdout);
    #endif
    
    int t;
    cin>>t;
    while(t--)
    {
    	int n,q;
    	cin>>n>>q;
    	int arr[n];
    	for(int i=0;i<n;i++)
    	{
    		cin>>arr[i];
    	}

    	for(int i=0;i<n;i++)
    	{
    		reach[i]=i+arr[i];
    	}
		reach[500001]=0;
    	build(0,n-1,0);
    	

    	// cout<<query(0,n-1,0,2,5);

    	while(q--)
    	{
    		int x,y;
    		cin>>x>>y;

    		int ans=1;
    		x--;y--;
    		while(x<y && reach[x]<y)
    		{
    			ans++;

    			int nextx=query(0,n-1,0,x,reach[x]);
    			if(x<nextx)
    				x=nextx;
    			else
    			{
    				ans=-1;
    				break;
    			}
    		}
    		cout<<ans<<"\n";

    	}
    }
    return 0;
}

Your code fails on

1
2
0 0
1 2

Once you fix the WA, you can also expect it to receive TLE - manually performing every jump for every query is not fast enough.

2 Likes

This is my submission below for this problem.
Shortest Distance 101
I am getting TLE again and again even after optimizing to every feasible manner I know of.
I don’t know what I am missing out.
Please help me with this