FLIPINV-Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4

Setter: Jeevan Jyot Singh
Tester: Aryan, Satyam
Editorialist: Devendra Singh

DIFFICULTY:

2692

PREREQUISITES:

knowledge of greedy algorithm paradigm

PROBLEM:

JJ has a binary string S of length N. JJ can perform the following operation on S:

  • Select an i such that 1 \le i \le N, and flip S_i (i.e. change 0 to 1 and 1 to 0)

JJ wants to minimize the number of inversions in S by performing the above operation at most K times. Can you help JJ do so?

Recall that a pair of indices (i, j) in S is called an inversion if i \lt j and S_i \gt S_j.

EXPLANATION:

In a binary string S, all inversions are of the form (i,j) where i \lt j and S_i=1 and S_j=0.

For a flip of type 1 to 0 we should choose the smallest index i such that it is S_i=1 as the number of zeroes after i is greater than or equal to the number of zeroes after any other index j such that S_j=1. Thus the number of inversions reduced by flipping index i is greater than or equal to flipping any other index j such that S_j=1
Similarly, for a flip of type 0 to 1, we should choose the greatest index i such that S_i =0.

If the count of zeroes in S is less than or equal to k or the count of ones is less than or equal to k the answer is zero as we can change all zeroes to ones or ones to zeroes in the cases respectively.
Otherwise, let x be the number of flips of type 1 to 0, then the maximum number of flips of type 0 to 1 is k-x. Let left be the index of x^{th} 1 from the beginning and right be the index of (k-x)^{th} 0 from the end.
After flipping these indices the string S looks something like this 000000..S^{'}..1111111 where S^{'} is the substring S[left+1,Right-1] of the string S[1,N].There are zeroes up to the index left from the start and ones upto the index right from the end. The inversions in this string is equal to the inversions in S^{'}.

Thus if left is greater than or equal to right then the answer is 0 otherwise take the minimum of inversions in all such strings formed by flipping some ones from the beginning and some zeroes from the end
To calculate inversions in the strings formed, keep the prefix count of zeroes, ones and inversions.
Then, inversions in any string formed = Inversions till index (right-1) - count of zeroes from index (left+1) to (right -1) * count of ones till index left

Then output the minimum of these values.

TIME COMPLEXITY:

O(N) for each test case.

SOLUTION:

Setter's Solution
#ifdef WTSH
    #include <wtsh.h>
#else
    #include <bits/stdc++.h>
    using namespace std;
    #define dbg(...)
#endif

#define int long long
#define endl "\n"
#define sz(w) (int)(w.size())
using pii = pair<int, int>;

const long long INF = 1e18;

const int N = 1e6 + 5; 

void solve()
{
    int n, k; cin >> n >> k;
    string s; cin >> s;
    vector<int> ones, zeros, prefones(n), prefzeros(n), prefinv(n);
    for(int i = 0; i < n; i++)
    {
        if(s[i] == '0')
        {
            zeros.push_back(i), prefzeros[i]++;
            prefinv[i] += sz(ones);
        }
        else
        {
            ones.push_back(i), prefones[i]++;
        }
        if(i > 0)
        {
            prefones[i] += prefones[i - 1];
            prefzeros[i] += prefzeros[i - 1];
            prefinv[i] += prefinv[i - 1];
        }
    }
    reverse(zeros.begin(), zeros.end());
    if(sz(zeros) <= k or sz(ones) <= k)
    {
        cout << 0 << endl;
        return;
    }

    auto inversions = [&](int l, int r)
    {
        if(r <= l)
            return 0LL;
        int res = prefinv[r];
        if(l > 0) res -= prefinv[l - 1];
        int cnt0 = prefzeros[r], cnt1 = 0;
        if(l > 0)
            cnt1 += prefones[l - 1], cnt0 -= prefzeros[l - 1];
        res -= cnt1 * cnt0;
        return res;
    };

    int ans = prefinv.back();
    for(int take0 = 0; take0 <= k; take0++)
    {
        int take1 = k - take0;
        int L = -1, R = n;
        if(take1 != 0)
            L = ones[take1 - 1];
        if(take0 != 0)
            R = zeros[take0 - 1];
        ans = min(ans, inversions(L + 1, R - 1));
    }
    cout << ans << endl;
}

int32_t main()
{
    ios::sync_with_stdio(0); 
    cin.tie(0);
    int T; cin >> T;
    for(int tc = 1; tc <= T; tc++)
    {
        // cout << "Case #" << tc << ": ";
        solve();
    }
    return 0;
}
Tester-1's Solution
/* in the name of Anton */

/*
  Compete against Yourself.
  Author - Aryan (@aryanc403)
  Atcoder library - https://atcoder.github.io/ac-library/production/document_en/
*/

#include <algorithm>
#include <cassert>
#include <vector>


#ifdef _MSC_VER
#include <intrin.h>
#endif

namespace atcoder {

namespace internal {

int ceil_pow2(int n) {
    int x = 0;
    while ((1U << x) < (unsigned int)(n)) x++;
    return x;
}

int bsf(unsigned int n) {
#ifdef _MSC_VER
    unsigned long index;
    _BitScanForward(&index, n);
    return index;
#else
    return __builtin_ctz(n);
#endif
}

}  // namespace internal

}  // namespace atcoder


namespace atcoder {

template <class S, S (*op)(S, S), S (*e)()> struct segtree {
  public:
    segtree() : segtree(0) {}
    explicit segtree(int n) : segtree(std::vector<S>(n, e())) {}
    explicit segtree(const std::vector<S>& v) : _n(int(v.size())) {
        log = internal::ceil_pow2(_n);
        size = 1 << log;
        d = std::vector<S>(2 * size, e());
        for (int i = 0; i < _n; i++) d[size + i] = v[i];
        for (int i = size - 1; i >= 1; i--) {
            update(i);
        }
    }

    void set(int p, S x) {
        assert(0 <= p && p < _n);
        p += size;
        d[p] = x;
        for (int i = 1; i <= log; i++) update(p >> i);
    }

    S get(int p) const {
        assert(0 <= p && p < _n);
        return d[p + size];
    }

    S prod(int l, int r) const {
        assert(0 <= l && l <= r && r <= _n);
        S sml = e(), smr = e();
        l += size;
        r += size;

        while (l < r) {
            if (l & 1) sml = op(sml, d[l++]);
            if (r & 1) smr = op(d[--r], smr);
            l >>= 1;
            r >>= 1;
        }
        return op(sml, smr);
    }

    S all_prod() const { return d[1]; }

    template <bool (*f)(S)> int max_right(int l) const {
        return max_right(l, [](S x) { return f(x); });
    }
    template <class F> int max_right(int l, F f) const {
        assert(0 <= l && l <= _n);
        assert(f(e()));
        if (l == _n) return _n;
        l += size;
        S sm = e();
        do {
            while (l % 2 == 0) l >>= 1;
            if (!f(op(sm, d[l]))) {
                while (l < size) {
                    l = (2 * l);
                    if (f(op(sm, d[l]))) {
                        sm = op(sm, d[l]);
                        l++;
                    }
                }
                return l - size;
            }
            sm = op(sm, d[l]);
            l++;
        } while ((l & -l) != l);
        return _n;
    }

    template <bool (*f)(S)> int min_left(int r) const {
        return min_left(r, [](S x) { return f(x); });
    }
    template <class F> int min_left(int r, F f) const {
        assert(0 <= r && r <= _n);
        assert(f(e()));
        if (r == 0) return 0;
        r += size;
        S sm = e();
        do {
            r--;
            while (r > 1 && (r % 2)) r >>= 1;
            if (!f(op(d[r], sm))) {
                while (r < size) {
                    r = (2 * r + 1);
                    if (f(op(d[r], sm))) {
                        sm = op(d[r], sm);
                        r--;
                    }
                }
                return r + 1 - size;
            }
            sm = op(d[r], sm);
        } while ((r & -r) != r);
        return 0;
    }

  private:
    int _n, size, log;
    std::vector<S> d;

    void update(int k) { d[k] = op(d[2 * k], d[2 * k + 1]); }
};

}  // namespace atcoder


#ifdef ARYANC403
    #include <header.h>
#else
    #pragma GCC optimize ("Ofast")
    #pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx")
    //#pragma GCC optimize ("-ffloat-store")
    #include<bits/stdc++.h>
    #define dbg(args...) 42;
#endif

// y_combinator from @neal template https://codeforces.com/contest/1553/submission/123849801
// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2016/p0200r0.html
template<class Fun> class y_combinator_result {
    Fun fun_;
public:
    template<class T> explicit y_combinator_result(T &&fun): fun_(std::forward<T>(fun)) {}
    template<class ...Args> decltype(auto) operator()(Args &&...args) { return fun_(std::ref(*this), std::forward<Args>(args)...); }
};
template<class Fun> decltype(auto) y_combinator(Fun &&fun) { return y_combinator_result<std::decay_t<Fun>>(std::forward<Fun>(fun)); }

using namespace std;
#define fo(i,n)   for(i=0;i<(n);++i)
#define repA(i,j,n)   for(i=(j);i<=(n);++i)
#define repD(i,j,n)   for(i=(j);i>=(n);--i)
#define all(x) begin(x), end(x)
#define sz(x) ((lli)(x).size())
#define pb push_back
#define mp make_pair
#define X first
#define Y second
#define endl "\n"

typedef long long int lli;
typedef long double mytype;
typedef pair<lli,lli> ii;
typedef vector<ii> vii;
typedef vector<lli> vi;

const auto start_time = std::chrono::high_resolution_clock::now();
void aryanc403()
{
#ifdef ARYANC403
auto end_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end_time-start_time;
    cerr<<"Time Taken : "<<diff.count()<<"\n";
#endif
}

long long readInt(long long l, long long r, char endd) {
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true) {
        char g=getchar();
        if(g=='-') {
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g&&g<='9') {
            x*=10;
            x+=g-'0';
            if(cnt==0) {
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd) {
            if(is_neg) {
                x=-x;
            }
            assert(l<=x&&x<=r);
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l, int r, char endd) {
    string ret="";
    int cnt=0;
    while(true) {
        char g=getchar();
        assert(g!=-1);
        if(g==endd) {
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt&&cnt<=r);
    return ret;
}
long long readIntSp(long long l, long long r) {
    return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
    return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
    return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
    return readString(l,r,' ');
}

void readEOF(){
    assert(getchar()==EOF);
}

vi readVectorInt(int n,lli l,lli r){
    vi a(n);
    for(int i=0;i<n-1;++i)
        a[i]=readIntSp(l,r);
    a[n-1]=readIntLn(l,r);
    return a;
}

bool isBinaryString(const string s){
    for(auto x:s){
        if('0'<=x&&x<='1')
            continue;
        return false;
    }
    return true;
}

// #include<atcoder/dsu>
// vector<vi> readTree(const int n){
//     vector<vi> e(n);
//     atcoder::dsu d(n);
//     for(lli i=1;i<n;++i){
//         const lli u=readIntSp(1,n)-1;
//         const lli v=readIntLn(1,n)-1;
//         e[u].pb(v);
//         e[v].pb(u);
//         d.merge(u,v);
//     }
//     assert(d.size(0)==n);
//     return e;
// }

const lli INF = 0xFFFFFFFFFFFFFFFL;

lli seed;
mt19937 rng(seed=chrono::steady_clock::now().time_since_epoch().count());
inline lli rnd(lli l=0,lli r=INF)
{return uniform_int_distribution<lli>(l,r)(rng);}

class CMP
{public:
bool operator()(ii a , ii b) //For min priority_queue .
{    return ! ( a.X < b.X || ( a.X==b.X && a.Y <= b.Y ));   }};

void add( map<lli,lli> &m, lli x,lli cnt=1)
{
    auto jt=m.find(x);
    if(jt==m.end())         m.insert({x,cnt});
    else                    jt->Y+=cnt;
}

void del( map<lli,lli> &m, lli x,lli cnt=1)
{
    auto jt=m.find(x);
    if(jt->Y<=cnt)            m.erase(jt);
    else                      jt->Y-=cnt;
}

bool cmp(const ii &a,const ii &b)
{
    return a.X<b.X||(a.X==b.X&&a.Y<b.Y);
}

const lli mod = 1000000007L;
// const lli maxN = 1000000007L;

struct S {
    // # of 0 / # of 1 / inversion number
    long long zero, one, inversion;
};

S op(S l, S r) {
    return S{
        l.zero + r.zero,
        l.one + r.one,
        l.inversion + r.inversion + l.one * r.zero,
    };
}

S e() { return S{0, 0, 0}; }

    lli T,n,i,j,k,in,cnt,l,r,u,v,x,y;
    lli m;
    string s;
    vi a;
    //priority_queue < ii , vector < ii > , CMP > pq;// min priority_queue .

int main(void) {
    ios_base::sync_with_stdio(false);cin.tie(NULL);
    // freopen("txt.in", "r", stdin);
    // freopen("txt.out", "w", stdout);
// cout<<std::fixed<<std::setprecision(35);
T=readIntLn(1,1e5);
lli sumN = 2e5;
while(T--)
{

    n=readIntSp(1,min(100000LL,sumN));
    sumN-=n;
    k=readIntLn(0,n);
    auto s = readStringLn(n,n);
    assert(isBinaryString(s));
    atcoder::segtree<S, op, e> seg(vector<S>(n,S{1,0,0}));

    vi one,zero;
    for(int i=0;i<n;i++){
        if(s[i]=='1'){
            one.pb(i);
            seg.set(i,S{0,1,0});
        }
        else
            zero.pb(i);
    }

    dbg(one,zero);
    if(sz(one)<=k||sz(zero)<=k){
        cout<<0<<endl;
        continue;
    }

    one.resize(k);
    reverse(all(zero));
    zero.resize(k);
    lli ans=seg.all_prod().inversion;
    reverse(all(zero));
    dbg(one,zero);
    for(lli i=0;i<=k;i++){
        // change first i ones to zero.
        const lli l=(i==0?0:one[i-1]+1);
        // change last k-i zeros to one
        const lli r=(i==k?n-1:zero[i]-1);
        dbg(i,l,r);
        if(l>=r)
            ans=0;
        else
            ans=min(ans,seg.prod(l,r+1).inversion);
    }
    cout<<ans<<endl;
}   aryanc403();
    readEOF();
    return 0;
}
Tester-2's Solution
#include <bits/stdc++.h>
using namespace std;
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);  
#endif 
#define ll long long
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());  
 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;  
    while(true){
        char g=getchar();
        if(g=='-'){  
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/
ll MAX=100000;
ll tes_sum=0;
vector<string> YS={"YES","yes","yES","YeS","yEs","YEs","Yes","yeS"};
vector<string> NO={"NO","no","No","nO"}; 
void solve(){  
    ll n=readIntSp(1,100000);
    ll k=readIntLn(0,n);
    string s=readStringLn(n,n);
    for(auto it:s){
        assert((it=='0')||(it=='1')); 
    }
    vector<ll> freq(n+5,0);
    ll r=n;
    for(ll i=n-1;i>=0;i--){
        freq[i]=freq[i+1]+(s[i]=='0'); 
        if(freq[i]<=k){
            r=i;
        }
    }
    ll one=0,zero=0,now=0;
    for(ll i=0;i<r;i++){
        if(s[i]=='0'){
            now+=one; zero++;
        }  
        else{
            one++;
        }
    }
    ll ans=now;
    for(ll i=0;i<n;i++){
        if(s[i]=='1'){
            now-=zero; 
            k--,one--,zero++;
        }
        if(k<0){
            break; 
        }
        while(freq[r]>k){
            if(s[r]=='0'){
                now+=one; zero++; 
            }
            else{
                one++;
            }
            r++;
        }
        zero--; ans=min(ans,now);
    }
    cout<<ans<<"\n";
    return;
}
int main(){
    ios_base::sync_with_stdio(false);                         
    cin.tie(NULL);                              
    #ifndef ONLINE_JUDGE                 
    freopen("input.txt", "r", stdin);                                              
    freopen("output.txt", "w", stdout);  
    freopen("error.txt", "w", stderr);                          
    #endif         
    ll test_cases=readIntLn(1,MAX); 
    while(test_cases--){
        solve();
    }
    assert(getchar()==-1);
    assert(tes_sum<=200000);
    return 0;
}

Editorialist's Solution
#include "bits/stdc++.h"
using namespace std;
#define ll long long
#define pb push_back
#define all(_obj) _obj.begin(),_obj.end()
#define F first
#define S second
#define pll pair<ll, ll> 
#define vll vector<ll>
#define INF 1e18
const int N=1e5+11,mod=1e9+7;
ll max(ll a,ll b) {return ((a>b)?a:b);}
ll min(ll a,ll b) {return ((a>b)?b:a);}
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
void sol(void)
{
int n,k;
cin>>n>>k;
string s;
cin>>s;
ll pzero[n+1],pones[n+1],pinv[n+1],ans=INF;
memset(pzero,0,sizeof(pzero));
memset(pones,0, sizeof(pones));
memset(pinv,0, sizeof(pinv));
vll zero,ones;
for(int i=0;i<n;i++)
{
    if(s[i]=='1')
    ones.pb(i),pones[i]++;
    else
    zero.pb(i),pzero[i]++,pinv[i]+=ones.size();
    if(i)
    {
        pones[i]+=pones[i-1];
        pzero[i]+=pzero[i-1];
        pinv[i]+=pinv[i-1];
    }
}
if(ones.size()<=k || zero.size()<=k)
{
    cout<<0<<'\n';
    return ;
}
for(ll makeOneToZero=0;makeOneToZero<=k;makeOneToZero++)
{
    int left=-1,right=n;
    if(makeOneToZero)
    left=ones[makeOneToZero-1];
    if(makeOneToZero!=k)
    right=zero[zero.size()-(k-makeOneToZero)];
    if(left>=right)
    {ans=0;break;}
    ll calc = pinv[right-1];
    if(left>=0) calc -= pinv[left];
    ll cnt0 = pzero[right-1], cnt1 = 0;
    if(left >= 0)
    cnt1 += pones[left], cnt0 -= pzero[left];
    calc -= cnt1 * cnt0;
    ans=min(ans,calc);
}
cout<<ans<<'\n';
return ;
}
int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(NULL),cout.tie(NULL);
    int test=1;
    cin>>test;
    while(test--) sol();
}
1 Like

why is this solution not valid? which test case is not passing?
in one copy of input string i’m changing K leftmost '1’s into '0’s, in the other copy i’m changing K rightmost '0’s into '1’s, then calculating number of inversions in each and choosing min of the two
thanks in advance

static int CountInversions(StringBuilder sb)
{
int oneCnt = 0, invCnt = 0; ;
for (int i = 0; i < sb.Length; i++)
{
if (sb[i] == ‘1’)
oneCnt++;
else
invCnt += oneCnt;
}
return invCnt;
}

    static void Main(string[] args)
    {
        int T = int.Parse(Console.ReadLine());

        for (int t = 0; t < T; t++)
        {
            int[] nk = Array.ConvertAll(Console.ReadLine().Split(), int.Parse);
            int N = nk[0];
            int K1 = nk[1];
            int K2 = K1;
            StringBuilder sb1 = new StringBuilder(Console.ReadLine());
            StringBuilder sb2 = new StringBuilder(sb1.ToString());
            
            int ind = 0;
            while (K1 > 0 && ind < N)
            {
                if (sb1[ind] == '1')
                {
                    sb1[ind] = '0';
                    K1--;
                }
                ind++;
            }

            ind = N - 1;
            while (K2 > 0 && ind >= 0)
            {
                if (sb2[ind] == '0')
                {
                    sb2[ind] = '1';
                    K2--;
                }
                ind--;
            }
            Console.WriteLine(Math.Min(CountInversions(sb1), CountInversions(sb2)));
        }
    }

K=2 and 1000011110. The correct string after modification is 0000011111 which has zero inversions.

5 Likes

how do you guys think like that different tests.

Can Someone provide a test case where my sol63564428 gives wrong output?
Thanks in advance :slight_smile:

1 Like

I used the same logic as @babyabh and only 2(first and last) testcases were passing. I am still clueless as to why editorial and other answer are not greedily choosing the rightmost 0 or leftmost 1 which has not been flipped yet and reducing the most inversions. I could not find any test case that gives wrong answer.
Can we see the testcase inputs that is used ?

edit

I found why my solution does not work.
and the testcase is:

s = 0 0 0 1 1 0 0 1 0
k = 2

since I was greedy while picking leftmost 1s or rightmost 0s depending upon which one reduces the most inversion. this idea does not work if both flips is equally good (both reduces the same amount of inversions).

here in the testcase both leftmost 1 and right most 0 are equally good so we would have chose anyone of them and proceeded. which yeilds wrong answer as one of the way would give fewer inversions at the end and the other dont

wrong way

0 0 0 1 1 0 0 1 0 ->  0 0 0 1 1 0 0 1 1 ->  0 0 0 0 1 0 0 1 1  # again, both were equally good

# final inversions: 2

right way:

0 0 0 1 1 0 0 1 0 ->  0 0 0 0 1 0 0 1 0  ->  0 0 0 0 0 0 0 1 0    

# final inversion: 1
3 Likes

Submission
I tried to implement the logic given in editorial , this code is giving WA.
Can someone pls explain what’s going wrong ?
Thanks in advance

Hey @mwolf :wave: ,
Your logic is incorrect because you are only thinking of one side at a time(first k leftmost 1s , rightmost k 0’s) but you sometimes need to split some K into x and k-x and do first atmost(x) leftmost 1’s and atmost(k-x) rightmost 0’s.
here is an example.
1
8 2
10011110
we convert first 1 to 0 and last 0 to 1 and output is 0.

1 Like