PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: satyam_343
Testers: apoorv_me, tabr
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
None
PROBLEM:
You’re given a binary string S of length N, of which you can flip at most K elements.
Minimize the number of inversions in the final string.
EXPLANATION:
Suppose we choose some elements to flip, of which x are initially 1's and y are initially 0's. (Of course, x+y \leq K.)
Then, it can be observed that:
- It’s optimal to choose the leftmost x 1's in S.
- It’s optimal to choose the rightmost y 0's in S.
- Every chosen 1 should appear before every chosen 0.
Proof
Suppose a 1 at position i is flipped and a 1 at position j isn’t flipped, but i \gt j.
Then, choosing to flip the 1 at index j instead will strictly reduce the number of inversions in the resulting binary string.
This can be verified by looking at which pairs of inversions i and j contribute to before and after the flip - in particular, the pair (i, j) was itself an inversion initially and won’t be now; while no pair that was initially a non-inversion turns into an inversion.
This proves that it’s optimal to choose the first x 1's in S to flip.
A similar proof shows that flipping the last y 0's in S is optimal.
Finally, note that if we flip a 0 at index i and a 1 at index j, but i \lt j, then we’d have fewer inversions by simply not performing the flip at index i (which is allowed, since we can perform \leq K flips and not exactly K).
This tells us that there are only really only \mathcal{O}(K) candidate strings to check.
These candidate strings are obtained by flipping the first x 1's in S (for each 0 \leq x \leq K), and then flipping the last (K-x) 0's in S (or every 0 after the x'th 1, if there are less than (K-x) of them.)
Further, each candidate can be checked in linear time - after all, the inversion count of a binary string can be computed in linear time.
How?
Since the string is binary, the number of inversions is just the number of times \texttt{10} appears as a subsequence in it.
This can be computed with a simple loop: for each i from 1 to N, if S_i = 0, add to the answer the number of ones before index i.
We have \mathcal{O}(K) candidates, and each is checked in \mathcal{O}(N) time, giving us an \mathcal{O}(N\cdot K) algorithm.
Since K \leq N and the constraints guarantee that the sum of N^2 is bounded, this is fast enough.
TIME COMPLEXITY:
\mathcal{O}(N\cdot K) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define nline "\n"
#define all(x) x.begin(),x.end()
const ll MOD=998244353;
const ll MAX=500500;
ll inv(string s,ll n){
ll consider=0,ans=0;
for(ll i=1;i<=n;i++){
if(s[i]=='0'){
ans+=consider;
}
else{
consider++;
}
}
return ans;
}
ll sum_n=0;
void solve(){
ll n,k; cin>>n>>k;
string s; cin>>s;
s=" "+s;
sum_n += n;
assert(sum_n <= (ll)2e5);
vector<ll> track[2];
for(ll i=1;i<=n;i++){
track[s[i]-'0'].push_back(i);
}
ll ans=MOD*MOD;
reverse(all(track[0]));
auto do_op=[&](ll l,ll rem,string t){
for(auto it:track[0]){
if(it < l){
break;
}
if(rem==0){
break;
}
t[it]='1';
rem--;
}
ans=min(ans,inv(t,n));
};
do_op(1,k,s);
for(auto it:track[1]){
s[it]='0';
k--;
if(k<=-1){
break;
}
do_op(it,k,s);
}
cout<<ans<<nline;
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
ll test_cases=1;
cin>>test_cases;
while(test_cases--){
solve();
}
cout<<fixed<<setprecision(10);
cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n";
}
Tester's code (apoorv_me, C++)
#include<bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...)
#endif
#ifndef LOCAL
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}
int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
now++;
}
return now;
}
string readOne() {
assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
return res;
}
string readString(int minl, int maxl, const string &pattern = "") {
assert(minl <= maxl);
string res = readOne();
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int minv, int maxv) {
assert(minv <= maxv);
int res = stoi(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
long long res = stoll(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
auto readInts(int n, int minv, int maxv) {
assert(n >= 0);
vector<int> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readInt(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
auto readLongs(int n, long long minv, long long maxv) {
assert(n >= 0);
vector<long long> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readLong(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
};
#else
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
}
int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
now++;
}
return now;
}
string readOne() {
assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
return res;
}
string readString(int minl, int maxl, const string &pattern = "") {
string X; cin >> X;
return X;
}
int readInt(int minv, int maxv) {
assert(minv <= maxv);
int res; cin >> res;
assert(minv <= res);
assert(res <= maxv);
return res;
}
long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
long long res; cin >> res;
assert(minv <= res);
assert(res <= maxv);
return res;
}
auto readInts(int n, int minv, int maxv) {
assert(n >= 0);
vector<int> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readInt(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
auto readLongs(int n, long long minv, long long maxv) {
assert(n >= 0);
vector<long long> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readLong(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
void readSpace() {
}
void readEoln() {
}
void readEof() {
}
};
#endif
int32_t main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
auto __solve_testcase = [&](int test) {
int N, K; cin >> N >> K;
string S; cin >> S;
auto doit = [&](int x, int y, string S) {
dbg(x, y, S);
for(int i = 0 ; i < N ; ++i) if(x > 0 && S[i] == '1') {
S[i] = '0'; --x;
}
for(int i = N - 1 ; i >= 0 ; --i) if(y > 0 && S[i] == '0') {
--y; S[i] = '1';
}
dbg(S);
int inv = 0, one = 0;
for(int i = 0 ; i < N ; ++i) {
if(S[i] == '1') one += 1;
else inv += one;
dbg(S[i], one, inv);
}
return inv;
};
int res = doit(0, 0, S);
for(int x = 0 ; x <= K ; ++x) {
res = min(res, doit(x, K - x, S));
}
cout << res << '\n';
};
int NumTest = 1;
cin >> NumTest;
for(int testno = 1; testno <= NumTest ; ++testno) {
__solve_testcase(testno);
}
return 0;
}
Tester's code (tabr, C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif
long long inverse(string &s) {
long long res = 0, sum = 0;
for (char c : s) {
if (c == '0') {
res += sum;
} else {
sum++;
}
}
return res;
}
void solve(int n, int k, string &s) {
auto ans = inverse(s);
for (int x = 0; x <= k; x++) {
string t = s;
for (int i = 0, j = 0; i < n; i++) {
if (j < x && t[i] == '1') {
t[i] = '0';
j++;
}
}
for (int i = n - 1, j = 0; i >= 0; i--) {
if (j < k - x && t[i] == '0') {
t[i] = '1';
j++;
}
}
ans = min(ans, inverse(t));
}
cout << ans << "\n";
}
////////////////////////////////////////
#define IGNORE_CR
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
#ifdef IGNORE_CR
if (c == '\r') {
continue;
}
#endif
buffer.push_back((char) c);
}
}
string readOne() {
assert(pos < (int) buffer.size());
string res;
while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
assert(!isspace(buffer[pos]));
res += buffer[pos];
pos++;
}
return res;
}
string readString(int min_len, int max_len, const string& pattern = "") {
assert(min_len <= max_len);
string res = readOne();
assert(min_len <= (int) res.size());
assert((int) res.size() <= max_len);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int min_val, int max_val) {
assert(min_val <= max_val);
int res = stoi(readOne());
assert(min_val <= res);
assert(res <= max_val);
return res;
}
long long readLong(long long min_val, long long max_val) {
assert(min_val <= max_val);
long long res = stoll(readOne());
assert(min_val <= res);
assert(res <= max_val);
return res;
}
vector<int> readInts(int size, int min_val, int max_val) {
assert(min_val <= max_val);
vector<int> res(size);
for (int i = 0; i < size; i++) {
res[i] = readInt(min_val, max_val);
if (i != size - 1) {
readSpace();
}
}
return res;
}
vector<long long> readLongs(int size, long long min_val, long long max_val) {
assert(min_val <= max_val);
vector<long long> res(size);
for (int i = 0; i < size; i++) {
res[i] = readLong(min_val, max_val);
if (i != size - 1) {
readSpace();
}
}
return res;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
};
int main() {
input_checker in;
int tt = in.readInt(1, 1e4);
in.readEoln();
cerr << tt << endl;
int sn = 0, sn2 = 0;
while (tt--) {
int n = in.readInt(1, 5e5);
in.readSpace();
int k = in.readInt(0, n);
in.readEoln();
auto s = in.readString(n, n, "01");
in.readEoln();;
sn += n;
sn2 += n * n;
solve(n, k, s);
}
cerr << sn << " " << sn2 << endl;
assert(sn <= 2e5);
assert(sn2 <= 5e7);
in.readEof();
return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
n, k = map(int, input().split())
s = input()
ans = n*(n-1)//2
for i in range(k+1):
cur = list(s)
rem, last = i, -1
for j in range(n):
if cur[j] == '1' and rem > 0:
cur[j] = '0'
rem -= 1
last = j
rem = k - i
for j in reversed(range(last+1, n)):
if cur[j] == '0' and rem > 0:
cur[j] = '1'
rem -= 1
invs, ones = 0, 0
for j in range(n):
if cur[j] == '0': invs += ones
else: ones += 1
ans = min(ans, invs)
print(ans)