MEANIDIAN - Editorial

PROBLEM LINK:

Practice
Div1
Div2
Div3

Setter: Utkarsh Gupta
Testers: Lavish Gupta, Tejas Pandey
Editorialist: Ajit Sharma Kasturi

DIFFICULTY:

EASY-MEDIUM

PREREQUISITES:

Binary Search

PROBLEM:

There is an array A of size N. In one operation, we can increase the value of any of the element by 1. We need to find the minimum number of operations to perform on the array to make mean of the array equal to its median.

Median is defined as follows: Let B be the sorted array of A. if N is even, median is B_{\frac{N}{2}} else the median is B_{\frac{N+1}{2}} .

QUICK EXPLANATION:

  • Try binary searching on the final mean/median value of the array.

  • For a given value, check how many operations are required to ensure mean = median = value.

EXPLANATION:

  • Without loss of generality, let us assume N is odd.

  • First let us sort the array A.

  • Let us assume that in the optimal case, the final mean/median value be x. Let us say this requires y operations.

  • We can clearly observe that it is always possible to achieve the final mean/median value as x+1 by adding 1 to each element of the array thereby requiring y+N operations.

  • If x can be achieved, then x+1, x+2, \dots can also be achieved with the number of operations increasing by N each time.

  • Therefore, we can binary search on the final mean/median value and output the number of operations required to achieve that value.

  • Now the only question remains is to given an x, find the number of operations to achieve
    mean = median = x.

  • Let the sum of all the elements of the array A be sum.

  • Let us find the number of operations to achieve mean = x. Now according to the conditions of mean, we have \frac{sum + extra}{N} = x \implies extra = N \cdot x - sum where extra is the total number of operations that should be performed.

  • Let us also find the minimum number of operations required to achieve median = x. We can find this easily by ensuring the numbers in the indices from 1 to \frac{N-1}{2} \leq x and the numbers in indices \frac{N+1}{2} to N \geq x. Let this count be y.

  • If y \gt extra, we cannot make mean equal to median since median requires atleast y operations. If y \leq extra, we can do extra - y operations on A_N after making median equal to x. Since we are performing the remaining operations on the last element, without changing the median value, we are bringing the mean value to x.

  • In this way, we can go on binary searching on the value of x and finally print the value of extra at the minimum possible value of x for which we can make mean = median = x.

TIME COMPLEXITY:

O(N \log N) for each test case.

SOLUTION:

Editorialist's solution

#include <bits/stdc++.h>
#define ll long long int
using namespace std;

ll getMinOperations(vector<int> &a, int median_or_mean)
{

      int n = a.size();
      ll sum = 0;

      for (int i = 0; i < a.size(); i++)
      {
            sum += a[i];
      }

      ll cost_for_mean = 1ll * n * median_or_mean - sum;
      if (cost_for_mean < 0)
      {
            return -1;
      }

      // Exactly ceil(n/2) + 1 elements must be >= median

      ll min_cost_for_median = 0;
      int start = (n / 2);
      if (n % 2 == 0)
            start--;

      for (int i = start; i < n; i++)
      {
            min_cost_for_median += max(median_or_mean - a[i], 0);
      }

      if (min_cost_for_median > cost_for_mean)
      {
            return -1;
      }

      // Else we can add (cost_for_mean - min_cost_for_median) operations
      // on the greatest element to accommodate for mean

      return cost_for_mean;
}

int main()
{
      int tests;
      cin >> tests;
      while (tests--)
      {
            int n;
            cin >> n;
            vector<int> a(n);
            for (int i = 0; i < n; i++)
                  cin >> a[i];

            sort(a.begin(), a.end());

            int median_ind = (n / 2);
            if (n % 2 == 0)
                  median_ind--;

            int median_or_mean = 1e9;
            ll min_operations = -1;

            // Performing binary search on final median/mean value
            int l = a[median_ind], r = 1e9;

            while (l <= r)
            {
                  int mid = (l + r) / 2;
                  ll val = getMinOperations(a, mid);
                  if (val != -1)
                  {
                        median_or_mean = mid;
                        min_operations = val;
                        r = mid - 1;
                  }
                  else
                  {
                        l = mid + 1;
                  }
            }

            cout << min_operations << endl;
      }
      return 0;
}

Setter's solution
//Utkarsh.25dec
#include <bits/stdc++.h>
#include <chrono>
#include <random>
#define ll long long int
#define ull unsigned long long int
#define pb push_back
#define mp make_pair
#define mod 1000000007
#define rep(i,n) for(ll i=0;i<n;i++)
#define loop(i,a,b) for(ll i=a;i<=b;i++)
#define vi vector <int>
#define vs vector <string>
#define vc vector <char>
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
#define max3(a,b,c) max(max(a,b),c)
#define min3(a,b,c) min(min(a,b),c)
#define deb(x) cerr<<#x<<' '<<'='<<' '<<x<<'\n'
using namespace std;
#include <ext/pb_ds/assoc_container.hpp> 
#include <ext/pb_ds/tree_policy.hpp> 
using namespace __gnu_pbds; 
#define ordered_set tree<int, null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>
// ordered_set s ; s.order_of_key(val)  no. of elements strictly less than val
// s.find_by_order(i)  itertor to ith element (0 indexed)
typedef vector<vector<ll>> matrix;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
const int N=500023;
bool vis[N];
vector <int> adj[N];
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
ll sumN=0;
void solve()
{
    int n=readInt(1,300000,'\n');
    sumN+=n;
    assert(sumN<=300000);
    vl v;
    v.pb(0);
    ll sum=0;
    for(int i=1;i<=n;i++)
    {
        ll c;
        if(i<n)
            c=readInt(1,1000000000,' ');
        else
            c=readInt(1,1000000000,'\n');
        sum+=c;
        v.pb(c);
    }
    sort(all(v));
    int med=(n+1)/2;
    ll l=sum/n;
    ll currmed=v[med];
    if((sum%n)!=0)
        l++;
    ll r=1e9;
    while(l<=r)
    {
        ll mid=(l+r)/2;
        ll can=(mid*n-sum);
        ll req=0;
        for(int i=med;i<=n;i++)
        {
            req+=(max((ll)0,mid-v[i]));
        }
        if(currmed>mid)
        {
            l=mid+1;
            continue;
        }
        if(req<=can)
            r=mid-1;
        else
            l=mid+1;
    }
    cout<<(l*n-sum)<<'\n';
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    int T=readInt(1,10000,'\n');
    int t=0;
    while(t++<T)
    {
        //cout<<"Case #"<<t<<":"<<' ';
        solve();
        //cout<<'\n';
    }
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Tester's solution
#include <bits/stdc++.h>
using namespace std;
#define ll long long
 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/
 
const int MAX_T = 10000;
const int MAX_N = 300000;
const int MAX_A = 1000000000;
const int MAX_val = 1000000000;
const int MAX_SUM_N = 300000;
 
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
 
int sum_n = 0;
int max_n = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;
ll p = 1000000007;
ll sum_nk = 0 ;


void solve()
{   
    int n = readIntLn(2 , MAX_N) ;
    ll arr[n] ;
    for(int i = 0 ; i < n-1 ; i++)
        arr[i] = readIntSp(1 , MAX_A) ;
    arr[n-1] = readIntLn(1 , MAX_A) ;

    sort(arr , arr+n) ;
    ll sum = 0 ;
    for(int i = 0 ; i < n ; i++)
        sum += arr[i] ;

    int flag = 0 ;
    ll median = arr[(n-1)/2] ;

    if(median*n == sum)
    {
        cout << 0 << '\n' ;
        cerr << "flag = " << flag << endl ;
        return ;
    }

    if(median*n > sum)
    {
        flag = 1 ;
        cerr << "flag = " << flag << endl ;
        ll req = (median*n - sum) ;
        cout << req << '\n' ;
        return ;
    }

    cerr << "flag = 2" << endl ;

    ll l = 0 , r = arr[n-1] ;
    ll med = 0 , ind = (n-1)/2 ;
    while(l <= r)
    {
        ll mid = (l+r)/2 ;
        ll add_sum = 0 ;
        for(ll i = ind ; i < n ; i++)
        {
            if(arr[i] >= mid)
                break ;
            add_sum += (mid - arr[i]) ;
        }
        ll new_sum = sum + add_sum ;
        if(new_sum <= mid * n)
        {
            med = mid ;
            r = mid-1 ;
        }
        else
            l = mid+1 ;
    }

    ll fin_ans = (med * n - sum) ;
    cout << fin_ans << endl ;
    return ;

}
 
signed main()
{
    //fast;
    #ifndef ONLINE_JUDGE
    freopen("inputf.txt" , "r" , stdin) ;
    freopen("outputf.txt" , "w" , stdout) ;
    freopen("error.txt" , "w" , stderr) ;
    #endif
    
    int t = 1;
    
    t = readIntLn(1,MAX_T);

    for(int i=1;i<=t;i++)
    {    
        solve() ;
    }
    
    assert(getchar() == -1);
 
    cerr<<"SUCCESS\n";
    cerr<<"Tests : " << t << '\n';
    // cerr<<"Sum of lengths : " << sum_n << '\n';
    // cerr<<"Maximum length : " << max_n << '\n';
    // cerr << "Sum o f product : " << sum_nk << '\n' ;
    // cerr<<"Total operations : " << total_ops << '\n';
    // cerr<<"Answered yes : " << yess << '\n';
    // cerr<<"Answered no : " << nos << '\n';
}

Please comment below if you have any questions, alternate solutions, or suggestions. :slight_smile:

8 Likes

Anygood resource to Practice Binary Search of CP lvl ??

8 Likes

I think we can have even a more optimised solution. Rather iterating from i=(n-1)/2 to n each time, we can simply use a prefix array and then find the lower bound of k (mean) in original sorted array, suppose it’s R, then medianop is simply prefix [(n-1)/2,R], thus total efficiency of O(log(log(max(A[i]))).

3 Likes

I keep reminding myself that some minimisation or maximisation problems can be solved using Binary Search. Still, I fail to notice binary search based problems. :man_facepalming:

11 Likes

This problem taught me that boundaries of binary search is not always (0,1e9). Here the lower boundary must be the median of the current array. Anyways, nice problem !!

3 Likes

I solved this problem using Ternary search.
My Solution

2 Likes

Editorialist solution was awesome

2 Likes

I solved this problem by:
Case 1: When mean > median

Incrementing median element till it reaches next greater element. Then incrementing both of them till they reach next greater element, and so on… till mean <= median

Case 2: When mean <= median
Increment the largest element of array till mean = median

Is this linear time complexity except for the initial sorting of array?

Solution: 55903656 | CodeChef

2 Likes

yha binary search nhi ban rha aur log ternary search kare pade hai :joy::joy::joy::joy:

1 Like

#include<bits/stdc++.h>
using namespace std;
#define endl “\n”
#define ll long long
#define vc vector
#define vp vector<pair<ll,ll>>
#define pb(a) push_back(a)
#define mp(a,b) make_pair(a,b)
#define ist(a) insert(a)
#define fr(i,a,n) for(int i=a;i<n;i++)
#define fr2(i,n,a) for(int i=n;i>=a;i–)
#define db1(x) cerr <<#x<<"="<<x<<’\n’
#define db2(x,y) cerr <<#x<<"="<<x<<" , “<<#y<<”="<<y<<’\n’
#define db3(x,y,z) cerr <<#x<<"="<<x<<" , “<<#y<<”="<<y<<" , “<<#z<<”="<<z<<’\n’
ll pwr(ll a, ll b, ll mod = 1000000007)
{
if (b == 0)
{
return 1;
}
ll ans = pwr(a, b / 2, mod);
ans *= ans;
ans %= mod;
if (b % 2)
{
ans *= a;
}
return ans % mod;
}
void solve()
{
ll n;
cin >> n;
ll sum = 0;
vc a(n);
fr(i, 0, n)
{
cin >> a[i];
sum += a[i];
}
if (sum <= n * (a[(n - 1) / 2]))
{
cout << 0 << endl;
return;
}
ll k = sum - n * (a[(n - 1) / 2]);
ll u = 0;
if (k < 0)u = abs(k);
fr(i, u, 1e18)
{
ll y = i, x = (k + i) / (n - 1);
if ((k + i) % (n - 1) == 0)
{
cout << (k + i) / (n - 1) + i << endl;
return;
}
}
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
#ifndef ONLINE_JUDGE
freopen(“inputf.in”, “r”, stdin);
freopen(“outputf.in”, “w”, stdout);
#endif
ll t;
cin >> t;
while (t–)
{
solve();
}
}
can anyone tell on which test case my code fails.

@strcmp why is applying binary search in the space (1, 1e9) not working… but, applying it in the range (median, 1e9) works??

1 Like

This is quite a useful approach to the question with small implementation. I used greedy approach to this question… where at each step i was comparing mean & median of the array & incrementing the array elements in each case accordingly… My Solution

Yes I did same. Is this approach linear except for the initial sorting?

If the initial sorting is excluded, the rest of the algorithm is definitely linear time. Here the outer while loop which checks current mean & median runs atmost 3 times for any case, the linear time is because of the for loop inside the condition where mean is greater than median.

1 Like
Short Code
ll isItPossible(vector<ll> &v, ll mean, ll sum) {
    ll n = v.size();
    ll ops = 0;
    if(v[(n - 1) /2] > mean) {
        return 0;
    }
    for(ll i = (n - 1) / 2; i < n; i ++) {
        if(v[i] < mean) {
            ops += abs(mean - v[i]);
        }
    }
    return n * mean - sum >= ops; 
}
void solve() {
    ll n, k;
    cin >> n;
    vector<ll> v(n);
    ll sum = 0;
    for(ll i = 0; i < n; i ++) {
        cin >> v[i];
        sum += v[i];
    }
    sov(v);
    ll start = 0, end = 1e9, ans = -1;
    while(start <= end) {
        ll mid = (start + end) / 2;
        if(isItPossible(v, mid, sum)) {
            ans = mid;
            end = mid -1;
        }
        else start = mid + 1;
    }
    cout << n * ans - sum << endl;
}
1 Like

tell me too if you have found