PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: raysh07
Tester: sushil2006
Editorialist: iceknight1093
DIFFICULTY:
Easy
PREREQUISITES:
Dynamic programming, binary search/two pointers
PROBLEM:
You’re given an array A containing integers between 1 and K.
For a starting index S, a sequence i_1, i_2, \ldots, i_{K+1} of K+1 indices is said to be good if A_{i_1} = S and A_{i_j} = j-1 for each j \gt 1.
The cost of a good sequence is defined to be \sum_{j=1}^K |i_j - i_{j+1}|.
For each starting index, compute the minimum possible cost of a good sequence.
EXPLANATION:
Let’s solve the problem for a single starting index S first.
Starting at S, we need to move to an occurrence of 1, then an occurrence of 2, then a 3, and so on till we end up at a K.
Which occurrence of 1 should we choose?
It’s not hard to see that there are only two valid options: either move to the closest 1 to the left of S, or the closest 1 to the right of S.
Though it’s not immediately obvious which one of these to choose, observe that once we’ve made the choice, we once again have two choices: move to the nearest 2 that’s to either the left or the right of the chosen index.
In general, if we’ve moved to value x at index i, the next move is going to be to move to the nearest occurrence of x+1 that’s either to the left or the right of index i.
Proof of optimality
Suppose A_i = x, and in an optimal sequence, the next move is to index j. Without loss of generality, j \lt i.
Let k be the nearest index to the left of i with value x+1; which in particular means j \leq k.
Our claim is that moving to k is always not worse than moving to j.To see why, we consider what happens to the move after j.
Suppose the move after j is j \to x.
Then,
- If x \lt j, then the sequences i \to k \to x and i \to j \to x have the same cost (that being i-x) since we’re moving leftward the whole time.
- If x \geq k, the sequence i \to k \to x is cheaper than i \to j \to x, in particular the reduction in cost is exactly 2\cdot (k-j).
- If j \leq x \lt k, once again i \to k \to x is cheaper: note in this case, i \to j covers the entirety of i \to k \to x and then an extra x \to j, and then must again add the length of j\to x, so i\to k\to x is cheaper by 2\cdot (x-j).
So, in every case it’s optimal to move to k instead.
This proves that if the optimal move is to the left, then it’s best to choose the nearest element.
The same applies to moving right.
Since there are at most two options, and once we’ve made our choice all previous choices no longer matter, we can now solve the problem using dynamic programming.
Define dp_i to be the minimum cost of reaching value K, assuming we’re starting from index i with value A_i.
To compute dp_i, first find indices L_i and R_i, where L_i \lt i is the nearest index to the left of i containing an occurrence of A_i + 1, and R_i \gt i is the same but to the right.
We then have dp_i = \min(i - L_i + dp_{L_i}, R_i - i + dp_{R_i}) because we can choose to move either left or right with the appropriate cost, and then solve recursively from there.
The base case is, of course, dp_i = 0 if A_i = K.
To make this run quickly, either implement it recursively and add memoization, or just process indices in decreasing order of value.
The only part of the above algorithm that needs optimization is actually computing the indices L_i and R_i.
To do this quickly, create a sequence of K lists \text{pos}, where \text{pos}[x] is a sorted list of all indices containing the value x.
Then, L_i and R_i are simply the closest elements of \text{pos}[A_i + 1] to i, and can be found quickly using binary search (or a two-pointer algorithm, if you process all indices corresponding to a value from left to right).
Now that we know all the dp_i values, we need to compute the final answer.
This is just more of the same though: for starting index S, the first move is going to be to move either left or right to a 1 and then go from there.
So, if L and R are the nearest occurrences of 1 to the left/right respectively, the answer is simply
Of course, finding L and R can be done quickly just as before, by looking at \text{pos}[1].
TIME COMPLEXITY:
\mathcal{O}(N) or \mathcal{O}(N\log N) per testcase.
CODE:
Editorialist's code (PyPy3, binary search)
import bisect
for _ in range(int(input())):
n, k = map(int, input().split())
a = list(map(int, input().split()))
pos = [ [] for _ in range(k+1) ]
for i in range(n):
pos[a[i]].append(i)
dp = [n**2]*n
for i in pos[k]: dp[i] = 0
for x in reversed(range(1, k)):
for i in pos[x]:
loc = bisect.bisect_left(pos[x+1], i)
if loc != len(pos[x+1]):
j = pos[x+1][loc]
dp[i] = dp[j] + abs(i-j)
if loc != 0:
j = pos[x+1][loc-1]
dp[i] = min(dp[i], dp[j] + abs(i-j))
ans = [n**2]*n
for i in range(n):
loc = bisect.bisect_left(pos[1], i)
if loc != len(pos[1]):
j = pos[1][loc]
ans[i] = dp[j] + abs(i-j)
if loc != 0:
j = pos[1][loc-1]
ans[i] = min(ans[i], dp[j] + abs(i-j))
print(*ans)
Author's code (C++, two pointers)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
void Solve()
{
int n, k; cin >> n >> k;
vector <int> a(n + 1);
for (int i = 1; i <= n; i++){
cin >> a[i];
assert(1 <= a[i] && a[i] <= k);
}
vector<vector<int>> pos(n + 1);
vector <int> dp(n + 1), ans(n + 1);
for (int i = 1; i <= n; i++){
pos[a[i]].push_back(i);
pos[0].push_back(i);
}
for (int i = 1; i <= k; i++){
assert(pos[i].size());
}
for (int i = k - 1; i >= 0; i--){
int ptr = 0;
for (int x : pos[i]){
while (ptr + 1 < pos[i + 1].size() && pos[i + 1][ptr + 1] < x){
ptr++;
}
int y = pos[i + 1][ptr];
int res = dp[y] + abs(y - x);
if (ptr + 1 < pos[i + 1].size()){
y = pos[i + 1][ptr + 1];
res = min(res, dp[y] + abs(y - x));
}
if (i == 0){
ans[x] = res;
} else {
dp[x] = res;
}
}
}
for (int i = 1; i <= n; i++){
cout << ans[i] << " \n"[i == n];
}
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}