[ Counting Inversions Revisited : INVYCNT] [Beginner] I didn't understand the logic behind this

Hi everyone,

I am working as software developer with 4+ years of experience, I’m very new to competitive programming. But doing it whenever i get time…

Coming to the question, today i participated in October Lunchtime, so i started with this problem . I could be able to do the brute force, but didn’t get any thought to optimize it.

Then, after the contest i went through the submissions and found this code.

I have gone through the code, traced it on the paper to get the intuition or logic behind it, but sadly i couldn’t understand the logic.

Can someone please explain the intuition behind this ??

#include<bits/stdc++.h>
using namespace std;
#define ll long long int
int main()
{
	ll i,j,k,m,n;
	ll a,b,c,t;
	cin>>t;
	while(t--)
	{
		cin>>n>>k;
		ll ar[n];
		for(i=0;i<n;i++)
		cin>>ar[i];
		ll ctr=0;
		for(i=0;i<n;i++)
		{
			for(j=0;j<n;j++)
			if(ar[i]>ar[j])
			ctr++;
		}
		ll ans=ctr*(k*(k-1))/2;
		ctr=0;
		for(i=0;i<n;i++)
		{
			for(j=i+1;j<n;j++)
			if(ar[i]>ar[j])
			ctr++;
		}
		cout<<ctr*k+ans<<endl;
	}
}

I’ll give you an example, hope it will help.
Since its given array is extended to arr + arr× (k-1) and k being so large, gives hint about combination

Now you need to count inversion from both side, can be using O(n^2) approach since array len can be as high as 100.

Now coming to actual example, I’ll recommend you to write these on a notebook and see combination pattern yourself.

3 3
2 1 3

New array
2, 1, 3, 2, 1, 3, 2 , 1, 3

Now why both inversion from left to right and right to left matter?
As we get the new array we’ll have smaller element before 3 come after it
And the ones after 2 will multiply,
Now there will be 3 ones after 1st occurrence of 2, 2 after the 2nd “2” and one 1 after third “2”.

I’ll recommend you to come up with the formula given below yourself.
Just see how inversions will multiply as the value of k increases.

1 +2 + 3 +… k = (k*(k +1))/2

For left we’ll have (k*(k+1)/2) * inversion count
For right (k*(k-1))/2 * inversion count.

1 Like
int main(){
fastIO
lli t=1;
cin>>t;
while(t--){
lli n,k;
cin>>n>>k;
lli a[n];
scanarr(a,n);
lli sum=0;
for(lli i=0;i<n;i++){
    lli age=0,piche=0;
    for(lli j=i-1;j>=0;j--){
        if(a[i]>a[j])
            piche++;
    }
    for(lli j=i+1;j<n;j++){
        if(a[i]>a[j])
            age++;
    }
    sum+=(age*((k*(k+1))/2))+(piche*((k*(k-1))/2));
}
cout<<sum<<endl;
}
return 0;
}
2 Likes

@ayush4 Bro can you please explain me that how answer for
3 3
2 1 3 is coming 12
I think i am not getting how exactly inversion is counted :sweat_smile:

cocatenate 2,1,3 three times , means write array 2,1,3 " three times"
2 1 3 2 1 3 2 1 3
now inversions are :
for first 2 there are 3 values lesser
for second 2 there are 2 values lesser
for third 2 there is 1 value lesser

3+2+1 = 6

now
for first 3 there are 4 values smaller
for second 3 there are 2 values smaller

4+2 = 6

total 6+6 = 12

2 Likes

I’ll just number the new array numbers as 1, 2, …so on.

21, 11, 31, 22, 12, 3 2, 23, 13, 33

Given inversion definition

A pair (i,j), where 1≤i<j≤N, is an inversion if Xi > Xj.

Initially, we have 1 inversion from left to right, 2 > 1
from right we have to inversions 3 > 1, 2

Now total inversions in new array
count = 0
21 > 11, 12, 13 (count += 3)
22 > 12, 13 (count += 2)
23 > 13 (count += 1)

count = 6

Formula(from left to right) = (k * (k + 1))/2 * inversion pair of current array element.

31 > 22, 12, 23, 13 (count += 4)
32 > 23, 13 (count += 2)

count = 6 + 4 + 2 = 12

formula here = (k - 1) * (inversion count for current element) - right to left.
we’ll keep multiplying until (k - 1) > 0

initially (k - 1) * (2) + (k -2) * (2)+ …+(1) * (2) => 2 * ((k - 1) + (k - 2) + (k - 3) + …1)
==> invCnt * (k * (k - 1) )/ 2

Write different examples in notebook, with different frequency(repeat elements) you’ll understand formula clearly.

2 Likes

This is a very poorly written problem with no proper explanation.

:+1: thank you very much for the explanation @ayush4 I really appreciate this… thank you :slight_smile:

@ssrivastava990 :+1: nice… well written… :slight_smile:

1 Like

I am trying to solve this question with the same approach but doing it a little differently. My WA code.
I am storing the frequency of elements in an unordered map and multiplying it to those numbers which are lesser than this given element. I think I am missing something. Can anyone help me find out what am I missing in my logic? It would be of great help to me.

@ayush4 I tried to solve this a little differently. I first found out the sum for the whole sequence and then subtracted the count which was lesser using the first sequence. Here is my code: code

Can you tell me why this logic won’t work? It would be of great help to me.

Multiplication k*(k + 1) /2 causes integer overflow, k could be as large as 10 ^ 6, I just changed data type of k and arr(might not be necessary) to long long

https://www.codechef.com/viewsolution/30000436

1 Like

thanks buddy.

can anyone please tell me where i am wrong in this inversion count problem . it is giving right answer to the given test cases in the question but after submitting it says wrong answer . i’am attaching my code here

#include
#define lli long long
using namespace std;

int main()
{
int t;
cin>>t;
while(t–)
{
lli n ; lli k;
cin>>n>>k;
lli arr[n];
for(lli i = 0; i<n ; i++)
{
cin>>arr[i];
}
int s = 0 ;
for(lli i = 0 ; i < n ; i ++)
{

       for(lli j = 0 ; j < n ; j ++)
       {
           if(i==j)
            continue;
          else if(arr[i]>arr[j])
           {
               if(i<j)

                    s = s  + (k*(k+1)/2);
                else
                    s  = s + (k*(k-1)/2);
           }
       }


        }
        cout<<s<<endl;
   }

}

can anyone please tell me where i am wrong in this inversion count problem . it is giving right answer to the given test cases in the question but after submitting it says wrong answer . i’am attaching my code here

#include
#define lli long long
using namespace std;

int main()
{
int t;
cin>>t;
while(t–)
{
lli n ; lli k;
cin>>n>>k;
lli arr[n];
for(lli i = 0; i<n ; i++)
{
cin>>arr[i];
}
int s = 0 ;
for(lli i = 0 ; i < n ; i ++)
{

       for(lli j = 0 ; j < n ; j ++)
       {
           if(i==j)
            continue;
          else if(arr[i]>arr[j])
           {
               if(i<j)

                    s = s  + (k*(k+1)/2);
                else
                    s  = s + (k*(k-1)/2);
           }
       }


        }
        cout<<s<<endl;
   }

}