PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: sushil2006
Tester: sushil2006
Editorialist: iceknight1093
DIFFICULTY:
Medium
PREREQUISITES:
Segment trees, binary search
PROBLEM:
There’s a 2\times N binary grid.
On one operation, you can swap two adjacent elements within the same row.
Find the minimum number of operations needed to obtain a path from (1, 1) to (2, N) consisting of only ones, that only moves right or down at every step.
Also process Q point updates to the grid, each of which flips one character.
Compute the answer after each update.
EXPLANATION:
Let’s recap the solution to a single state of the grid: try each value of k from 1 to N, and compute the minimum cost of making the first k elements of the top row and last N+1-k elements of the bottom row all ones.
For a fixed k, this could be done greedily: we make the leftmost k ones in the top row end up at positions 1, 2, \ldots, k, and the rightmost N+1-k ones in the bottom row end up at positions k, k+1, \ldots, N.
This algorithm isn’t really conducive to supporting updates, so we need a different perspective.
First, for simplicity, let’s reverse the second row.
Now, our aim in both rows is to make some prefix all ones.
In particular, observe that as long as the total number of ones across the prefixes of the two rows is at least N+1, we’ll have a valid path in the original problem.
Now, define C_i to be the cost of moving the i-th one in the top row to position i (with C_i = \infty if there aren’t i ones).
An important observation here is that C_i \leq C_{i+1} for any i. This is because the (i+1)-th one is at least one position further away than the i-th one, but needs to end up exactly one position ahead of it; so it must travel at least the same distance.
Similarly, if we define D_i to be the cost of moving the i-th one in the bottom row to position i, we have D_i \leq D_{i+1}.
Using these arrays C and D, let’s compute the answer.
An initial candidate is to have a single 1 in the top row and N ones in the bottom row.
This has a cost of C_1 + (D_1 + D_2 + \ldots + D_N).
Alternately, we could have two ones in the top row and N-1 in the bottom row.
This incurs an additional cost of C_2 - D_N over the above, since the only change is including C_2 and excluding D_N.
Next, we could have three ones in the top row and N-2 in the bottom row; this would add a further cost of C_3 - D_{N-1}, and so on.
In general, the cost of having k ones in the top row and N+1-k in the bottom row is
The objective is to minimize this across all k.
Here, we can use the fact that the arrays C and D are monotonically increasing: observe that for any i, C_i - D_{N+2-i} \leq C_{i+1} - D_{N+1-i}, because C_i \leq C_{i+1} and D_{N+2-i} \geq D_{N+1-i}.
This means, if C_i - D_{N+2-i} \gt 0 for some i, it’s not optimal to have i ones in the top row; nor will it be optimal to have \gt i ones in the top row.
This means finding the optimal k is in fact quite easy: we only need to find the largest k such that C_k - D_{N+2-k} \leq 0.
Assuming we’re able to maintain the arrays C and D, finding this k quickly is not hard: since the function C_k - D_{N+2-k} is monotonic in k, we can simply binary search to find the breakpoint we’re looking for.
Once k is known, the answer is obtained as a couple of prefix sums of C and D.
Everything we did above hinges on being able to maintain information about the arrays C and D across updates.
Let’s look at exactly what we need and how to maintain it across updates.
First, we need to find the optimal value of k, i.e. the largest k for which C_k - D_{N+2-k} \leq 0.
For this, we’re binary searching on k, and only need to be able to look up individual values of C and D given their indices.
One way of doing that is as follows.
Let S_A = \{x_1, x_2, \ldots, x_m\} denote the indices of the ones in A.
Note that we have C_i = x_i - i, so looking up the k-th element of C is equivalent to looking up the k-th element of S_A and subtracting k from it.
Each update either inserts or deletes one element from S_A, so we need a data structure that supports quick insertion/deletion and looking up the k-th element.
This is a classical problem, and can be handled by for example a segment tree.
Specifically, suppose we build a segment tree on the string A itself, maintaining the sum in each node.
Then, finding the k-th element is equivalent to finding the first index with a prefix sum that’s \geq k (which can be done in \mathcal{O}(\log^2 N) with binary search, or \mathcal{O}(\log N) by “walking on the segment tree”).
Insertion and deletion are just point set updates and trivially handled in \mathcal{O}(\log N) by the segment tree.
The exact same thing can be done on B to compute values of D quickly.
With this structure, we’re able to find the optimal k in \mathcal{O}(\log^2 N) time.
It’s also possible to optimize this to \mathcal{O}(\log N) by walking on both segment trees simultaneously, though this shouldn’t be necessary to receive AC.
Once we know k, we want to find the answer, which equals
Again, we use the fact that C_i = x_i - i, to see that
We know k, so the second term on the right is a constant.
This means we only need to find x_1 + x_2 + \ldots + x_k, i.e. the sum of the smallest k indices containing ones.
It’s easy to see that this can also be done by just modifying the segment tree we had above: in each node of the segment tree, store also the sum of all “active” indices within it; after which we just want a single prefix sum.
As a side note, it’s also possible to maintain the arrays C and D themselves directly without having to resort to index manipulation like we did above, though doing so doesn’t really improve the complexity in any way.
It requires a bit more machinery: for example you can use a treap (that supports lazy propagation).
TIME COMPLEXITY:
\mathcal{O}(N + Q\log^2 N) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
const int N = 2e5 + 69;
struct node{
int sum, cnt;
};
int n;
string s[2];
node seg[2][4 * N];
node combine(node a, node b){
node c;
c.sum = a.sum + b.sum;
c.cnt = a.cnt + b.cnt;
return c;
}
node init(char ch, int pos){
node c;
if (ch == '1'){
c.sum = pos;
c.cnt = 1;
} else {
c.sum = c.cnt = 0;
}
return c;
}
void build(int l, int r, int pos, int i){
if (l == r){
seg[i][pos] = init(s[i][l], l);
return;
}
int mid = (l + r) / 2;
build(l, mid, pos * 2, i);
build(mid + 1, r, pos * 2 + 1, i);
seg[i][pos] = combine(seg[i][pos * 2], seg[i][pos * 2 + 1]);
}
void upd(int l, int r, int pos, int qp, int i){
if (l == r){
seg[i][pos] = init(s[i][l], l);
return;
}
int mid = (l + r) / 2;
if (qp <= mid) upd(l, mid, pos * 2, qp, i);
else upd(mid + 1, r, pos * 2 + 1, qp, i);
seg[i][pos] = combine(seg[i][pos * 2], seg[i][pos * 2 + 1]);
}
int find_kth(int l, int r, int pos, int k, int i){
if (l == r){
return l;
}
int mid = (l + r) / 2;
if (k <= seg[i][pos * 2].cnt){
return find_kth(l, mid, pos * 2, k, i);
} else {
return find_kth(mid + 1, r, pos * 2 + 1, k - seg[i][pos * 2].cnt, i);
}
}
node res;
void query(int l, int r, int pos, int ql, int qr, int i){
if (l >= ql && r <= qr){
res = combine(res, seg[i][pos]);
} else if (l > qr || r < ql){
} else {
int mid = (l + r) / 2;
query(l, mid, pos * 2, ql, qr, i);
query(mid + 1, r, pos * 2 + 1, ql, qr, i);
}
}
void Solve()
{
int q;
cin >> n >> q;
cin >> s[0] >> s[1];
reverse(s[1].begin(), s[1].end());
s[0] = "0" + s[0];
s[1] = "0" + s[1];
int c0 = 0, c1 = 0;
for (auto x : s[0]){
c0 += x == '1';
}
for (auto x : s[1]){
c1 += x == '1';
}
build(1, n, 1, 0);
build(1, n, 1, 1);
auto calc = [&](){
if (c0 + c1 <= n){
cout << -1 << "\n";
return;
}
int lo = n + 1 - c1;
int hi = c0;
while (lo != hi){
int mid = (lo + hi + 1) / 2;
int v1 = mid;
int v2 = n + 1 - mid;
int p1 = find_kth(1, n, 1, v1, 0) - v1;
int p2 = find_kth(1, n, 1, v2, 1) - v2;
if (p1 > p2){
hi = mid - 1;
} else {
lo = mid;
}
}
int ans = INF;
auto work = [&](int x){
int v1 = x;
int v2 = n + 1 - x;
int p1 = find_kth(1, n, 1, v1, 0);
int p2 = find_kth(1, n, 1, v2, 1);
int val = 0;
res.sum = res.cnt = 0;
query(1, n, 1, 1, p1, 0);
val += res.sum;
res.sum = res.cnt = 0;
query(1, n, 1, 1, p2, 1);
val += res.sum;
val -= v1 * (v1 + 1) / 2;
val -= v2 * (v2 + 1) / 2;
ans = min(ans, val);
};
work(lo);
if (lo + 1 <= c0){
work(lo + 1);
}
cout << ans << "\n";
};
calc();
while (q--){
int x, y; cin >> x >> y;
x--;
if (x == 1){
y = n + 1 - y;
}
if (x == 0){
c0 -= s[0][y] == '1';
c0 += !(s[0][y] == '1');
s[0][y] ^= '0' ^ '1';
upd(1, n, 1, y, 0);
} else {
c1 -= s[1][y] == '1';
c1 += !(s[1][y] == '1');
s[1][y] ^= '0' ^ '1';
upd(1, n, 1, y, 1);
}
calc();
}
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}