MAKEALLEQUAL - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: satyam_343
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Prefix sums

PROBLEM:

You have an array A with N elements.
In one move, you can choose at most M of them and increase all of them by 1.
Find the minimum number of moves needed to make everything equal.

EXPLANATION:

Let \text{mx} denote the maximum element of the array.
Of course, to make everything equal our best choice is to make them all equal to \text{mx} (since we’re allowed to choose less than M elements at a time).

First, let’s find some lower bounds on the answer.
For any A_i in the array, it needs to be increased exactly \text{mx}-A_i times to reach \text{mx}.
So, if S = \sum_{i=1}^N (\text{mx} - A_i), we need S increases among our operations.
Each operation gives us M increases, so we definitely need at least \left\lceil \frac{S}{M} \right\rceil operations to reach our target.
Here, \left\lceil \ \ \right\rceil denotes the ceiling function.

Further. since each operation can change a given element at most once, if \text{mn} denotes the minimum element of A we’ll definitely need at least \text{mx} - \text{mn} operations to bring it up to \text{mx}.
Taking both cases into consideration, we need at least \max(\text{mx} - \text{mn}, \left\lceil \frac{S}{M} \right\rceil) operations.

This lower bound is strict, i.e, it’s always possible to use this many operations and achieve our goal.

Proof

Consider the following setup:
We have an integer array B = [B_1, B_2, \ldots, B_N] with us. Each B_i is \geq 0.
In one move, we can subtract 1 from at most M different indices of B, and we’d like to find the minimum number of moves to reduce all elements to 0.
It’s clear that this is equivalent to our situation.

Without loss of generality, let B be sorted, i.e B_i \leq B_{i+1}.
Let S = B_1 + B_2 + \ldots + B_N.
Then, our claim is that the number of moves needed is exactly \max(B_N, \left\lceil \frac{S}{M}\right\rceil).

We consider three cases.
Case 1: B has less than M non-zero elements.
In this case, clearly B_N operations is both necessary and sufficient, by just choosing all non-zero elements at each stage.

Case 2: B_1 + B_2 + \ldots + B_{N-1} \lt (M-1)\cdot B_N.
In other words, B_N is so large that even if we choose it on every operation, it’ll not become 0 before all the other elements become 0.
Notice that this is the same as saying S \lt M\cdot B_N, so \max(B_N, \left\lceil \frac{S}{M}\right\rceil) = B_N.
In such a case, it’s always possible to use B_N operations and make everything 0.

How?

On every operation, choose B_N and then M-1 of the largest remaining elements.
The inequality B_1 + B_2 + \ldots + B_{N-1} \lt (M-1)\cdot B_N is maintained since both sides change by the same quantity (they both reduce by M-1)
Further, B_N reduces by 1 at each stage, and will remain the maximum.

This process can stop when there are \lt M non-zero elements in B, at which point we move to case 1.


Case 3: B_1 + B_2 + \ldots + B_{N-1} \geq (M-1)\cdot B_N.
The only remaining case.
This means S \geq M\cdot B_N, so \left\lceil \frac{S}{M} \right\rceil is our lower bound on the number of operations.

For this case, we use a classical idea.
Consider a M\times \left\lceil \frac{S}{M} \right\rceil grid, which we’ll fill from top to bottom, left to right.
Write 1, B_1 times. Then 2, B_2 times. 3, B_3 times, and so on.
Once this is done, simply use the columns of the grid as the operations - they’ll all contain \leq M distinct values and index i will be used exactly B_i times, as required!

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Author's code (C++)
#pragma GCC optimize("O3,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;   
using namespace std;
#define ll long long
#define pb push_back                  
#define mp make_pair          
#define nline "\n"                            
#define f first                                            
#define s second                                             
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()     
#define vl vector<ll>           
#define vvl vector<vector<ll>>      
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);  
#endif     
void _print(ll x){cerr<<x;}  
void _print(char x){cerr<<x;}   
void _print(string x){cerr<<x;}    
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());   
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------  
const ll MOD=998244353;
const ll MAX=2000200;
vector<ll> fact(MAX+2,1),inv_fact(MAX+2,1);
ll binpow(ll a,ll b,ll MOD){
    ll ans=1;
    a%=MOD;  
    while(b){
        if(b&1)
            ans=(ans*a)%MOD;
        b/=2;
        a=(a*a)%MOD;
    }
    return ans;
}
ll inverse(ll a,ll MOD){
    return binpow(a,MOD-2,MOD);
} 
void precompute(ll MOD){
    for(ll i=2;i<MAX;i++){
        fact[i]=(fact[i-1]*i)%MOD;
    }
    inv_fact[MAX-1]=inverse(fact[MAX-1],MOD);
    for(ll i=MAX-2;i>=0;i--){
        inv_fact[i]=(inv_fact[i+1]*(i+1))%MOD;
    }
}
ll nCr(ll a,ll b,ll MOD){
    if(a==b){
        return 1;
    }
    if((a<0)||(a<b)||(b<0))
        return 0;   
    ll denom=(inv_fact[b]*inv_fact[a-b])%MOD;
    return (denom*fact[a])%MOD;  
} 
void solve(){  
    ll n,m; cin>>n>>m;
    vector<ll> a(n);
    ll sum=0;
    for(auto &it:a){
        cin>>it;
        sum+=it;
    }
    sort(all(a));
    ll ans=max(a[n-1]-a[0],(a[n-1]*n-sum+m-1)/m);
    cout<<ans<<nline;
    return;    
}                                       
int main()                                                                               
{       
    ios_base::sync_with_stdio(false);                          
    cin.tie(NULL);                               
    #ifndef ONLINE_JUDGE                 
    freopen("input.txt", "r", stdin);                                           
    freopen("output.txt", "w", stdout);      
    freopen("error.txt", "w", stderr);                        
    #endif     
    ll test_cases=1;                   
    cin>>test_cases;
    precompute(MOD);
    while(test_cases--){
        solve();
    }
    cout<<fixed<<setprecision(10);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
}   
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

#define IGNORE_CR

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
#ifdef IGNORE_CR
            if (buffer[pos] == '\r') {
                pos++;
                continue;
            }
#endif
            assert(!isspace(buffer[pos]));
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

template <long long mod>
struct modular {
    long long value;
    modular(long long x = 0) {
        value = x % mod;
        if (value < 0) value += mod;
    }
    modular& operator+=(const modular& other) {
        if ((value += other.value) >= mod) value -= mod;
        return *this;
    }
    modular& operator-=(const modular& other) {
        if ((value -= other.value) < 0) value += mod;
        return *this;
    }
    modular& operator*=(const modular& other) {
        value = value * other.value % mod;
        return *this;
    }
    modular& operator/=(const modular& other) {
        long long a = 0, b = 1, c = other.value, m = mod;
        while (c != 0) {
            long long t = m / c;
            m -= t * c;
            swap(c, m);
            a -= t * b;
            swap(a, b);
        }
        a %= mod;
        if (a < 0) a += mod;
        value = value * a % mod;
        return *this;
    }
    friend modular operator+(const modular& lhs, const modular& rhs) { return modular(lhs) += rhs; }
    friend modular operator-(const modular& lhs, const modular& rhs) { return modular(lhs) -= rhs; }
    friend modular operator*(const modular& lhs, const modular& rhs) { return modular(lhs) *= rhs; }
    friend modular operator/(const modular& lhs, const modular& rhs) { return modular(lhs) /= rhs; }
    modular& operator++() { return *this += 1; }
    modular& operator--() { return *this -= 1; }
    modular operator++(int) {
        modular res(*this);
        *this += 1;
        return res;
    }
    modular operator--(int) {
        modular res(*this);
        *this -= 1;
        return res;
    }
    modular operator-() const { return modular(-value); }
    bool operator==(const modular& rhs) const { return value == rhs.value; }
    bool operator!=(const modular& rhs) const { return value != rhs.value; }
    bool operator<(const modular& rhs) const { return value < rhs.value; }
};
template <long long mod>
string to_string(const modular<mod>& x) {
    return to_string(x.value);
}
template <long long mod>
ostream& operator<<(ostream& stream, const modular<mod>& x) {
    return stream << x.value;
}
template <long long mod>
istream& operator>>(istream& stream, modular<mod>& x) {
    stream >> x.value;
    x.value %= mod;
    if (x.value < 0) x.value += mod;
    return stream;
}

constexpr long long mod = 998244353;
using mint = modular<mod>;

mint power(mint a, long long n) {
    mint res = 1;
    while (n > 0) {
        if (n & 1) {
            res *= a;
        }
        a *= a;
        n >>= 1;
    }
    return res;
}

vector<mint> fact(1, 1);
vector<mint> finv(1, 1);

mint C(int n, int k) {
    if (n < k || k < 0) {
        return mint(0);
    }
    while ((int) fact.size() < n + 1) {
        fact.emplace_back(fact.back() * (int) fact.size());
        finv.emplace_back(mint(1) / fact.back());
    }
    return fact[n] * finv[k] * finv[n - k];
}

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e5);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 5e5);
        in.readSpace();
        int m = in.readInt(1, n);
        in.readEoln();
        auto a = in.readInts(n, 1, n);
        in.readEoln();
        long long mx = *max_element(a.begin(), a.end());
        long long mn = *min_element(a.begin(), a.end());
        long long sum = accumulate(a.begin(), a.end(), 0LL);
        cout << max(mx - mn, (mx * n - sum + m - 1) / m) << '\n';
    }
    assert(sn <= 5e5);
    in.readEof();
    return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
    n, m = map(int, input().split())
    a = list(map(int, input().split()))
    
    mn, mx = min(a), max(a)
    sm = sum(mx-x for x in a)
    ans = mx - mn
    ans = max(ans, (sm + m - 1) // m)
    print(ans)
4 Likes

please anyone provide the test case where my code is going wrong…

i make an array brr which contains (mx - ai) (for those ai != mx)
and sort brr
and try to simulate and calculate operations…
anyone please…
link - CodeChef: Practical coding for everyone

1 Like

Hey i have used the same approach why it is not working for the last test case? please help me out , link -CodeChef: Practical coding for everyone

I have used heap of size M to maintain the the most efficient way differences can be reduced to 0. Why is this approach wrong? Here is the python code for it:
https://www.codechef.com/viewsolution/1036009864

1 Like

use long long data type instead of int

can anyone tell me why this approach is wrong… i tried using two pointer i got the difference array and always reduced lower count from the max window…

https://www.codechef.com/viewsolution/1036016631

Can someone help me understand the below testcase -

10 4
1 1 2 5 6 6 6 7 9 10

According to the editorial, the answer will be 12. But I am getting 13, and I am not able to find any sequence of operations for the answer.

10 4
1 1 2 5 6 6 6 7 9 10

Answer should be 12 according to the editorial. Your code is giving 14.

1 Like
10 4
1 1 2 5 6 6 6 7 9 10

Answer should be 12 according to the editorial. Your code is giving 14.

1 Like

This is why proofs are important! :slightly_smiling_face:
In the proof section of the editorial, I’ve also given a construction on how exactly to use operations to achieve the provided answer.

The case you’ve given falls under “Case 3” of the proof, where I mentioned:

B_i here denotes the number of operations needed on index i.
In your case, the required operations to make everything reach 10 are [9, 9, 8, 5, 4, 4, 4, 3, 1, 0].
Let’s write them in a 4\times 12 grid:

1 1 1 1 1 1 1 1 1 2 2 2
2 2 2 2 2 2 3 3 3 3 3 3
3 3 4 4 4 4 4 5 5 5 5 6
6 6 6 7 7 7 7 8 8 8 9

The columns of this grid now give you the operations you want: each of them contains \leq 4 distinct elements.

5 Likes

Please provide me the test case where My code is wrong.
Following is my submission link: Submission

Since I can take at most M element in each operation, I am greedily trying to take M element if possible and increase there value by one. Instead of increasing 1 at a time, what I have done is that I have sorted the given elements and greedily starting from the second largest element element. My aim is to make all array element equal to the largest element. In one step, number of operation required will be equal to the number of operation required to make the largest element in the current subarray equal to target (i.e the largest value in the array).
I am using the approach similar to the one which is used to find the point which is contained in maximum number of intervals.

For eg:
Let take test case
3 2
1 2 3

initally, my suffix array will have all elements to be zero. My target value is 3.
Now from i=2 (1-based indexing), I will take all the elements till max(0,i-m)=0
So my range will include the updation of all elements from index 1 to 2. The number of operation performed will be equal to target-v[i]= 3-2=1, and suffix array will look like [-1,0,1,0]
When i=1, suffix[i]= suffix[i]+suffix[i+1] = 1, and v[i] will get updated to v[i]+suffix[i]= 2 which is because of the last operation we performed. Now again take the subarray of length m and the number of operations will be equal to target-v[1]= 1,
Therefore, the answer should be equal to 2.

Please help me in finding the mistake in my code. Since I cannot see the test case where it is failing, I am not able to find bug in my code or logic.

Beautiful problem.
A tough proof packed into an easy design.
Can you tell me where you read about this classical idea ?

1 Like

Your code fails on the test seen in multiple comments above,

10 4
1 1 2 5 6 6 6 7 9 10

Why do you think your greedy choice is correct?

Super, Thanks for helping me find out the series of possible operations and clarifying my doubts. :boom:

1 Like

Thank you.
I got it . Greedy won’t work since in greedy I am taking consecutive elements whereas I can distribute an operation optimally among other elements too.

1 Like

can you tell some problems or some resource from where I can read about this idea. Want to know in which type of problems it is used or how to identify it?

1 Like

Can we solve this using binary search? I don’t know how to implement the predicate function which will return true/false whether K operations are sufficient or not, for making all elements equal