BININV - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Practice

Setter: Soumyadeep Pal
Tester: Aryan Choudhary
Editorialist: Taranpreet Singh

DIFFICULTY

Easy

PREREQUISITES

Observations.

PROBLEM

Given N binary strings, each of length M. Concatenate all N strings in some order in a single string T of length N*M, aiming to minimize the number of inversions in T.

QUICK EXPLANATION

  • The optimal order would be sorting the strings in nondecreasing order of number of ones present in S
  • We can first build the string and then calculate the number of inversions on the concatenated string.

EXPLANATION

Solving for N = 2

Let’s say we have two strings A and B, which we need to concatenate while minimizing the number of inversions. We can try both AB and BA and pick the one with fewer inversions.

Let’s denote C_{S,c} denote the number of occurrences of character c in S and the number of inversions in S by f(S).

We can denote f(A+B) = f(A)+f(B) + C_{A, 1}*C_{B, 0} and f(B+A) = f(B)+f(A) + C_{B, 1}*C_{A, 0}

Types of inversions

Let’s call inversion pair (x, y) if 1 appears at position x and 0 appears at position y and x \lt y.

We can divide inversions into two categories

  • Inversions within the same string
    This includes inversions where both x and y lie on the same string. Irrespective of where this string is concatenated, this pair shall always exist as an inversion. We cannot change the number of inversions within a string.
  • Inversions across strings
    When solving for N = 2, when we concatenated AB, there must be some pairs where 1 appeared in A and 0 appeared in B, forming an inversion in concatenated string. These inversions are dependent on the order of strings we choose.

Our aim is to reduce the inversions of the second type since the first type of inversion never changes with the order of strings.

Deciding the order of a pair of strings.

Let’s say we have already chosen the order of strings to be concatenated, and we are only allowed to swap adjacent strings. For some i, we need to decide whether swapping S_i and S_{i+1} is beneficial or not.

Observation: Only the number of inversion between strings S_i and S_{i+1} is affected by this swap.
Proof: Considering string S_j for j \lt i, both S_i and S_{i+1} appear to the right of string S_j, so no 0 or 1 from right to left of S_j, leaving the number of inversions arising from string S_j unaffected.

Similarly, if we have string S_j where j \gt i+1, then also, S_{j+1} is to the right of both S_i and S_{i+1}. So the number of inversions arising from S_j are also unaffected by the swap.

Hence, we can decide which one of S_i or S_{i+1} should come first sorely based on the number of inversions in string S_i + S_{i+1} and string S_{i+1}+S_i.

Observation: In the optimal order of strings, there doesn’t exist any beneficial swap, as beneficial swap reduces the inversions, but our string is already optimal.

Choosing the order of strings

Now, for an adjacent pair, we know whether the swap would be beneficial. So we can actually simulate bubble sort since bubble sort swaps elements until the array is sorted. We can sort the strings by defining a comparator function, accepting two strings A and B, and comparing f(A) + f(B) +C_{A, 1}*C_{B, 0} with f(A)+f(B) + C_{B, 1}*C_{A, 0} to decide which string should appear before which.

Since bubble sort is slow, we can use sort algorithms like merge sort to sort them efficiently.

Computing the number of inversions

Now that we have the computed binary string, we need to count inversions. Counting inversions in an array is a well-known problem, but for binary strings, it can be solved even faster.

Let’s say we iterate on string S from left to right. We have two variables, ans, and cnt_1, denoting the number of inversions found yet, and the number of 1.

  • If the current character is 0, this position shall form an inversion pair with all occurrences of 1 before the current position. It can be written as ans = ans + cnt_1
  • If the current character is 1, the number of $1$s should be incremented. which implies cnt_1 increases.

TIME COMPLEXITY

The time complexity is O(M*N*log(N)) per test case due to sorting.

SOLUTIONS

Setter's Solution
#include<bits/stdc++.h>
using namespace std;
void solve() {
  int n, m; 
  cin >> n >> m;
  vector<string> s(n);
  vector<pair<int, int>> v;
  for (int i = 0; i < n; i++) {
    cin >> s[i];
    int ones = count(s[i].begin(), s[i].end(), '1');
    v.push_back({ones, i});
  }
  sort(v.begin(), v.end());
  string cur;
  for (int i = 0; i < n; i++) {
    for (auto u : s[v[i].second]) {
      cur.push_back(u);
    }
  }
  int ones = 0;
  long long ans = 0;
  for (int i = 0; i < n * m; i++) {
    if (cur[i] == '1') ones++;
    else ans += ones;
  }
  cout << ans << '\n';
}

signed main() {
  int t = 1;
  cin >> t;
  for (int i = 1; i <= t; i++) solve();
  return 0;
}
Tester's Solution
/* in the name of Anton */

/*
  Compete against Yourself.
  Author - Aryan (@aryanc403)
  Atcoder library - https://atcoder.github.io/ac-library/production/document_en/
*/

#ifdef ARYANC403
    #include <header.h>
#else
    #pragma GCC optimize ("Ofast")
    #pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx")
    //#pragma GCC optimize ("-ffloat-store")
    #include<bits/stdc++.h>
    #define dbg(args...) 42;
#endif

using namespace std;
#define fo(i,n)   for(i=0;i<(n);++i)
#define repA(i,j,n)   for(i=(j);i<=(n);++i)
#define repD(i,j,n)   for(i=(j);i>=(n);--i)
#define all(x) begin(x), end(x)
#define sz(x) ((lli)(x).size())
#define pb push_back
#define mp make_pair
#define X first
#define Y second
#define endl "\n"

typedef long long int lli;
typedef long double mytype;
typedef pair<lli,lli> ii;
typedef vector<ii> vii;
typedef vector<lli> vi;

const auto start_time = std::chrono::high_resolution_clock::now();
void aryanc403()
{
#ifdef ARYANC403
auto end_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end_time-start_time;
    cerr<<"Time Taken : "<<diff.count()<<"\n";
#endif
}

long long readInt(long long l, long long r, char endd) {
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true) {
        char g=getchar();
        if(g=='-') {
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g&&g<='9') {
            x*=10;
            x+=g-'0';
            if(cnt==0) {
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd) {
            if(is_neg) {
                x=-x;
            }
            assert(l<=x&&x<=r);
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l, int r, char endd) {
    string ret="";
    int cnt=0;
    while(true) {
        char g=getchar();
        assert(g!=-1);
        if(g==endd) {
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt&&cnt<=r);
    return ret;
}
long long readIntSp(long long l, long long r) {
    return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
    return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
    return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
    return readString(l,r,' ');
}

void readEOF(){
    assert(getchar()==EOF);
}

void assertBinaryString(const string s){
    for(auto x:s)
        assert('0'<=x&&x<='1');
}

vi readVectorInt(int n,lli l,lli r){
    vi a(n);
    for(int i=0;i<n-1;++i)
        a[i]=readIntSp(l,r);
    a[n-1]=readIntLn(l,r);
    return a;
}

const lli INF = 0xFFFFFFFFFFFFFFFL;

lli seed;
mt19937 rng(seed=chrono::steady_clock::now().time_since_epoch().count());
inline lli rnd(lli l=0,lli r=INF)
{return uniform_int_distribution<lli>(l,r)(rng);}

class CMP
{public:
bool operator()(ii a , ii b) //For min priority_queue .
{    return ! ( a.X < b.X || ( a.X==b.X && a.Y <= b.Y ));   }};

void add( map<lli,lli> &m, lli x,lli cnt=1)
{
    auto jt=m.find(x);
    if(jt==m.end())         m.insert({x,cnt});
    else                    jt->Y+=cnt;
}

void del( map<lli,lli> &m, lli x,lli cnt=1)
{
    auto jt=m.find(x);
    if(jt->Y<=cnt)            m.erase(jt);
    else                      jt->Y-=cnt;
}

bool cmp(const ii &a,const ii &b)
{
    return a.X<b.X||(a.X==b.X&&a.Y<b.Y);
}

const lli mod = 1000000007L;
// const lli maxN = 1000000007L;

    lli T,n,i,j,k,in,cnt,l,r,u,v,x,y;
    lli m;
    string s;
    vi a;
    //priority_queue < ii , vector < ii > , CMP > pq;// min priority_queue .

int main(void) {
    ios_base::sync_with_stdio(false);cin.tie(NULL);
    // freopen("txt.in", "r", stdin);
    // freopen("txt.out", "w", stdout);
// cout<<std::fixed<<std::setprecision(35);
T=readIntLn(1,1e3);
lli sumNM = 1e6;
while(T--)
{

    const lli n=readIntSp(1,min(sumNM,100000LL)),m=readIntLn(1,min(sumNM/n,100000LL));
    sumNM-=n*m;
    vector<string> a(n);
    vii values;
    for(auto &s:a){
        s=readStringLn(m,m);
        assertBinaryString(s);
        ii cnt={0,0};
        for(auto x:s){
            if(x=='0')
                cnt.X++;
            else
                cnt.Y++;
        }
        values.pb(cnt);
    }

    vi b(n);
    iota(all(b),0);
    sort(all(b),[&](const int x,const int y){
        return values[x].Y*values[y].X<values[x].X*values[y].Y;
    });
    dbg(b);
    lli ans=0,cnt1=0;
    for(auto idx:b)
        for(auto x:a[idx]){
            if(x=='1')
                cnt1++;
            else
                ans+=cnt1;
        }
    cout<<ans<<endl;
}   aryanc403();
    readEOF();
    return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class BININV{
    //SOLUTION BEGIN
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        int N = ni(), M = ni();
        String[] S = new String[N];
        for(int i = 0; i< N; i++)S[i] = n();
        Arrays.sort(S, (String s1, String s2) -> {
            int[] c1 = count(s1), c2 = count(s2);
            long inv1 = c1[1]*(long)c2[0], inv2 = c1[0]*(long)c2[1];
            if(inv1 == inv2)return 0;
            if(inv1 < inv2)return -1;
            return 1;
        });
        long inv = 0, onesCount = 0;
        for(int i = 0; i< N*M; i++){
            char ch = S[i/M].charAt(i%M);
            if(ch == '0')inv += onesCount;
            else onesCount++;
        }
        pn(inv);
    }
    int[] count(String S){
        int[] c = new int[2];
        for(int i = 0; i< S.length(); i++)c[S.charAt(i)-'0']++;
        return c;
    }
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = true;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
        new BININV().run();
    }
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
                }
            }
            return st.nextToken();
        }

        String nextLine() throws Exception{
            String str = "";
            try{   
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            }  
            return str;
        }
    }
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

3 Likes
#include<bits/stdc++.h>
using namespace std;

bool compare(string s1,string s2){
    return s1<s2;
}

void solve(){
    int n,m;
    cin>>n>>m;
    vector<string> v;
    for(int i = 0 ; i < n ;i++){
        string s;
        cin>>s;
        v.push_back(s);
    }
    sort(v.begin(),v.end(),compare);
    string tt ="";

    for(auto it: v){
        tt+=it;
    }
    
    // counting inversions
    int cnt = 0;
    int arr[m*n] = {0};
   
    for(int i = m*n-1;i >=0 ;i--){
        if(tt[i] == '0'){
            cnt++;
        }
        arr[i] = cnt;
    }
    int ans = 0;
    for(int i = 0 ; i < m*n ;i++){
        if(tt[i] == '1'){
            ans+=arr[i];
        }
    }

    cout<<ans<<"\n";
    
}
int main(){
    int t;
    cin>>t;
    while(t--){
        solve();
    }
    return 0;
}

Can anyone please tell why I am getting a WA.

1 Like

01111
10000

your solution will sort them 01111 → 10000
whereas the best is achieved when we use the order 10000 > 01111

consider
2 3
100
011
what is your code output?

What is the issue with below code ?

#include <bits/stdc++.h>
#define all(c) (c).begin(), (c).end()
#define present(c, x) ((c).find(x) != (c).end())
#define cpresent(c, x) (find(all(c),x) != (c).end())
#define fi first
#define se second
#define print(arr) for(int num : (arr)) cout << num << ' '; cout << '\n';
using namespace std;

int inversions(string str){
    int count = 0;
    int res = 0;
    for(int i=0;i<str.size();++i){
        if(str[i] == '1') count++;
        else res += count;
    }
    return res;
}

bool comp(string &l, string &r){
    string temp1 = l+r;
    string temp2 = r+l;
    int a = inversions(temp1);
    int b = inversions(temp2);
    return a < b;
}

int main(){
    ios::sync_with_stdio(false); cin.tie(NULL);
    int t; cin >> t;
    while(t--){
        int n, m; cin >> n >> m;
        vector<string> vec(n);
        for(string &str : vec) cin >> str;
        sort(all(vec), comp);
        string str = "";
        for(string st : vec) str += st;
        cout << inversions(str) << '\n';
    }
    return 0;
}
#include <bits/stdc++.h>
using namespace std;
typedef vector<int> vi;
typedef pair<int, int> pii;
#define endl "\n"
#define all(v) v.begin(), v.end()
#define pb push_back
#define mp make_pair
#define FF first
#define SS second
#define ll long long
#define MOD 1000000007
#define clr(val) memset(val, 0, sizeof(val))
#define what_is(x) cerr << #x << " is " << x << endl;
#define OJ                            \
    freopen("input.txt", "r", stdin); \
    freopen("output.txt", "w", stdout);
#define FIO                           \
    ios_base::sync_with_stdio(false); \
    cin.tie(NULL);                    \
    cout.tie(NULL);
bool cmp(string &a,
         string &b)
{
    int count1 = count(all(a), '1');
    int count2 = count(all(b), '1');
    if (count1 == count2)
        return a < b;
    else
        return count1 < count2;
}
int count(string s)
{
    int count0 = 0;
    vi a(s.size());
    for (int i = s.size() - 1; i >= 0; i--)
    {
        if (s[i] == '0')
            count0++;
        a.pb(count0);
    }
    reverse(all(a));
    int inversion = 0;
    for (int i = 0; i < s.size(); i++)
    {
        if (s[i] == '1')
            inversion += a[i];
    }
    return inversion;
}
int main()
{
    int t;
    cin >> t;
    while (t--)
    {
        int n, m;
        cin >> n >> m;
        vector<string> a(n);
        for (int i = 0; i < n; i++)
        {
            cin >> a[i];
        }
        sort(all(a), cmp);
        string ans;
        for (auto i : a)
        {
            ans += i;
        }
        //cout << ans << endl;
        cout << count(ans) << endl;
    }
    return 0;
}

Can someone tell why this is showing WA

ig comparator function dont work well in CodeChef ide, I also did the same got Run time error

But i got WA, not runtime error

#include<bits/stdc++.h>
#include <unordered_map>
#define int long long
#define w(x) int x; cin>>x; while(x--)
int mod = (int)1e9+7;
#define maxv 9223372036854775807
#define minv -9223372036854775808
#define  add push_back
using namespace std;

int fun(string a)
{
    int d=0;
    for(char i:a)
    {
        if(i=='0')
        {
            d++;
        }
    }
    int z=0;
    for(char i:a)
    {
        if(i=='1')
        {
            z+=d;
        }
        else
        {
            d--;
        }
    }
    return z;
}
bool cmp(string a,string b)
{
    string s1="";
    s1+=a;
    s1+=b;
    string s2="";
    s2+=b;
    s2+=a;
    if(fun(s1)<=fun(s2))
    {
        return true;
    }
    return false;
}
void solve(int t) 
{
    int n,m;
    cin>>n>>m;
    string ar[n];
    for(int i=0;i<n;i++)
    {
        cin>>ar[i];
    }
    sort(ar,ar+n,cmp);
    int d=0;
    for(int i=0;i<n;i++)
    {
        for(int j=0;j<m;j++)
        {
            if(ar[i][j]=='0')
            {
                d++;
            }
        }
    }
    int z=0;
    for(int i=0;i<n;i++)
    {
        for(int j=0;j<m;j++)
        {
            if(ar[i][j]=='1')
            {
                z+=d;
            }
            else
            {
                d--;
            }
        }
    }
    cout<<z<<"\n";
}
int32_t main()
{
    #ifndef ONLINE_JUDGE
    freopen("Input.txt", "r", stdin);
    freopen("Output.txt", "w", stdout);
#endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    int t;
    cin >> t;
    for(int i=1;i<=t;i++)
    {
        solve(i);
    }
}
/*

  


*/

GIVING RUNTIME ERROR CAN ANYONE EXPLAIN??

looks like my solution works, i had to use long long

1 Like

Ooh ok. Got my mistake.
Thanks

My code would give output 6. But answer should be 0 . Got it !

Even I used comparator function and got runtime error :frowning:

2 Likes

Can anyone please tell the issue with my solution:

cook your dish here

t = int(input())
while(t>0):
t-=1
n,m = list(map(int,input().split(" ")))
nums = list()

while(n>0):
    n-=1 
    t1 = str(input())
    nums.append(t1)
    
nums.sort(key = lambda x: int(x,2))
a = "".join(nums)
inversion_count = 0
one_count = 0
for i in a:
    if(i == '1'):
        one_count += 1 
    else:
        inversion_count += one_count

print(inversion_count)

I too got RE while using a comparator. Removing it lead to AC . Seems like CC has some issues with comparators.

you have to sort the strings in increasing order of ones , but you are sorting it lexicographically.

Why is it giving WA?

#include <bits/stdc++.h>

using namespace std;

int main(){

int tc;

cin>>tc;

while(tc–){

int n,m;

cin>>n>>m;

vector s;

vector<vector> vec(n+1);

for(int i=0; i<n; i++){

string str;

cin>>str;

s.push_back(str);

 int cnt=0, cnt2 = 0;

 for(int j=0; j<m; j++){

  if(str[j]=='0'){

    cnt++;

    cnt2 = m - j;

  }

 

}

vec[i] = {cnt, cnt2, i};

}

sort(vec.begin(), vec.end());

reverse(vec.begin(), vec.end());

string str = “”;

for(int it = 0; it < n; it++){

auto v = vec[it];

str += s[v[2]];

}

int count=0, res=0;

for(int it = m*n -1; it >=0; it-- ){

  if(str[it]=='0')

  count++;

  else{

    res += count;

  }

}

cout<<res<<endl;

}

}

Hello Guys im new to cp i did a similar approach to the editorial can someone explain why my code wont work i sorted the strings and concatenated and checked for number of zeros to the right using a suffix array kinda thing attaching my code here

void solve(){
    ll n,m; cin>>n>>m;
    vector<string>s;
    while(n--){
        string str;
        cin>>str;
        s.pb(str);
    }
    sort(all(s));
    string ans = "";
    for(ll i=0;i<sz(s);i++)ans+=s[i];
    vi zero(sz(ans),0);
    for(ll i=0;i<sz(ans);i++){
        if(ans[i]=='0')zero[i]=1;
    }
    for(ll i=sz(ans)-2;i>=0;i--){
        zero[i] = zero[i]+zero[i+1];
    }
    ll cnt=0;
    for(ll i=0;i<sz(ans);i++){
        if(ans[i]=='1')cnt+=zero[i];
    }
    cout<<cnt<<nl;
}  

I sorted the string in decreasing order of ‘0’ and then used an array to store the count of 0 from the right side in the concatenated string. While iterating the optimal string I added to my answer if s[i] == ‘1’ then added count of ‘0’ till i + 1.
Can anyone suggest how this approach is wrong??

Pasted my code below

*#include <bits/stdc++.h>
using namespace std;

int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int t;
cin >> t;
while(t–){
int n, m;
cin >> n >> m;
vector v;
for(int i = 0; i < n; i++){
string s;
cin >> s;
v.push_back(s);
}
vector <pair <int, string>> vec;
for(int i = 0; i < v.size(); i++){
int cnt = 0;
for(int j = 0; j < v[i].size(); j++){
if(v[i][j] == ‘0’){
cnt++;
}
}
vec.push_back(make_pair(cnt, v[i]));
}
sort(vec.rbegin(), vec.rend());
string res = “”;
for(int i = 0; i < vec.size(); i++){
res += vec[i].second;
}
int size = res.length();
int arr[size];
int cnt = 0;
for(int i = size - 1; i >= 0; i–){
if(res[i] == ‘0’){
cnt++;
}
arr[i] = cnt;
}
int ans = 0;
for(int i = 0; i < size - 1; i++){
if(res[i] == ‘1’){
ans += arr[i + 1];
}
}
cout << ans << “\n”;
}
return 0;
}*

// Author: Harsh Bardolia 
// Language: C++14

#include<bits/stdc++.h>
using namespace std;

// Define shortcuts
typedef long long ll;
#define endl "\n"
#define sz size
#define fi first
#define se second
#define pb push_back
#define all(x) (x).begin(), (x).end()
#define out(x); { for (auto &to : x) cout << to << " "; cout << '\n'; }
#define fast ios_base::sync_with_stdio(false); cout.tie(NULL); cin.tie(NULL);

void solve() {
    int n, m;
    cin >> n >> m;

    string s = "";
    vector<pair<int, string>> vs;
    
    for (int i = 0; i < n; i++) {
        string s1;
        cin >> s1;

        int c = 0;
        for (int j = 0; j < m; j++) {
            if (s1[j] == '0')
                c++;
        }

        vs.pb({c, s1});
    }

    sort(all(vs));
    reverse(all(vs));

    for (auto x : vs)
        s += x.se;

    int res = 0;
    int cur = 0;
    for (int i = m * n - 1; i >= 0; i--) {
        if (s[i] == '0')
            cur++;
        if (s[i] == '1')
            res += cur;
    }    

    // cout << s << endl;
    cout << res << endl;
}

int main() {
    fast;
    ll t;
    cin >> t;
    while (t--) {
         solve();
    }
    // solve();
    return 0;
}

I don’t understand why am I getting WA.
can someone help me out figuring it?