Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: satyam_343
Tester: tabr
Editorialist: iceknight1093
Prefix sums
You have an array A with N elements.
In one move, you can choose at most M of them and increase all of them by 1.
Find the minimum number of moves needed to make everything equal.
Let \text{mx} denote the maximum element of the array.
Of course, to make everything equal our best choice is to make them all equal to \text{mx} (since we’re allowed to choose less than M elements at a time).
First, let’s find some lower bounds on the answer.
For any A_i in the array, it needs to be increased exactly \text{mx}-A_i times to reach \text{mx}.
So, if S = \sum_{i=1}^N (\text{mx} - A_i), we need S increases among our operations.
Each operation gives us M increases, so we definitely need at least \left\lceil \frac{S}{M} \right\rceil operations to reach our target.
Here, \left\lceil \ \ \right\rceil denotes the ceiling function.
Further. since each operation can change a given element at most once, if \text{mn} denotes the minimum element of A we’ll definitely need at least \text{mx} - \text{mn} operations to bring it up to \text{mx}.
Taking both cases into consideration, we need at least \max(\text{mx} - \text{mn}, \left\lceil \frac{S}{M} \right\rceil) operations.
This lower bound is strict, i.e, it’s always possible to use this many operations and achieve our goal.
Consider the following setup:
We have an integer array B = [B_1, B_2, \ldots, B_N] with us. Each B_i is \geq 0.
In one move, we can subtract 1 from at most M different indices of B, and we’d like to find the minimum number of moves to reduce all elements to 0.
It’s clear that this is equivalent to our situation.
Without loss of generality, let B be sorted, i.e B_i \leq B_{i+1}.
Let S = B_1 + B_2 + \ldots + B_N.
Then, our claim is that the number of moves needed is exactly \max(B_N, \left\lceil \frac{S}{M}\right\rceil).
We consider three cases.
Case 1: B has less than M non-zero elements.
In this case, clearly B_N operations is both necessary and sufficient, by just choosing all non-zero elements at each stage.
Case 2: B_1 + B_2 + \ldots + B_{N-1} \lt (M-1)\cdot B_N.
In other words, B_N is so large that even if we choose it on every operation, it’ll not become 0 before all the other elements become 0.
Notice that this is the same as saying S \lt M\cdot B_N, so \max(B_N, \left\lceil \frac{S}{M}\right\rceil) = B_N.
In such a case, it’s always possible to use B_N operations and make everything 0.
On every operation, choose B_N and then M-1 of the largest remaining elements.
The inequality B_1 + B_2 + \ldots + B_{N-1} \lt (M-1)\cdot B_N is maintained since both sides change by the same quantity (they both reduce by M-1)
Further, B_N reduces by 1 at each stage, and will remain the maximum.
This process can stop when there are \lt M non-zero elements in B, at which point we move to case 1.
Case 3: B_1 + B_2 + \ldots + B_{N-1} \geq (M-1)\cdot B_N.
The only remaining case.
This means S \geq M\cdot B_N, so \left\lceil \frac{S}{M} \right\rceil is our lower bound on the number of operations.
For this case, we use a classical idea.
Consider a M\times \left\lceil \frac{S}{M} \right\rceil grid, which we’ll fill from top to bottom, left to right.
Write 1, B_1 times. Then 2, B_2 times. 3, B_3 times, and so on.
Once this is done, simply use the columns of the grid as the operations - they’ll all contain \leq M distinct values and index i will be used exactly B_i times, as required!
\mathcal{O}(N) per testcase.
Author's code (C++)
#pragma GCC optimize("O3,unroll-loops")
#include <bits/stdc++.h>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
#define ll long long
#define pb push_back
#define mp make_pair
#define nline "\n"
#define f first
#define s second
#define pll pair<ll,ll>
#define all(x) x.begin(),x.end()
#define vl vector<ll>
#define vvl vector<vector<ll>>
#define vvvl vector<vector<vector<ll>>>
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#define debug(x);
void _print(ll x){cerr<<x;}
void _print(char x){cerr<<x;}
void _print(string x){cerr<<x;}
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";}
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
const ll MOD=998244353;
const ll MAX=2000200;
vector<ll> fact(MAX+2,1),inv_fact(MAX+2,1);
ll binpow(ll a,ll b,ll MOD){
ll ans=1;
return ans;
ll inverse(ll a,ll MOD){
return binpow(a,MOD-2,MOD);
void precompute(ll MOD){
for(ll i=2;i<MAX;i++){
for(ll i=MAX-2;i>=0;i--){
ll nCr(ll a,ll b,ll MOD){
return 1;
return 0;
ll denom=(inv_fact[b]*inv_fact[a-b])%MOD;
return (denom*fact[a])%MOD;
void solve(){
ll n,m; cin>>n>>m;
vector<ll> a(n);
ll sum=0;
for(auto &it:a){
ll ans=max(a[n-1]-a[0],(a[n-1]*n-sum+m-1)/m);
int main()
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
freopen("error.txt", "w", stderr);
ll test_cases=1;
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#define debug(...)
#define IGNORE_CR
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
buffer.push_back((char) c);
string readOne() {
assert(pos < (int) buffer.size());
string res;
while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
#ifdef IGNORE_CR
if (buffer[pos] == '\r') {
res += buffer[pos];
return res;
string readString(int min_len, int max_len, const string& pattern = "") {
assert(min_len <= max_len);
string res = readOne();
assert(min_len <= (int) res.size());
assert((int) res.size() <= max_len);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
return res;
int readInt(int min_val, int max_val) {
assert(min_val <= max_val);
int res = stoi(readOne());
assert(min_val <= res);
assert(res <= max_val);
return res;
long long readLong(long long min_val, long long max_val) {
assert(min_val <= max_val);
long long res = stoll(readOne());
assert(min_val <= res);
assert(res <= max_val);
return res;
vector<int> readInts(int size, int min_val, int max_val) {
assert(min_val <= max_val);
vector<int> res(size);
for (int i = 0; i < size; i++) {
res[i] = readInt(min_val, max_val);
if (i != size - 1) {
return res;
vector<long long> readLongs(int size, long long min_val, long long max_val) {
assert(min_val <= max_val);
vector<long long> res(size);
for (int i = 0; i < size; i++) {
res[i] = readLong(min_val, max_val);
if (i != size - 1) {
return res;
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
void readEof() {
assert((int) buffer.size() == pos);
template <long long mod>
struct modular {
long long value;
modular(long long x = 0) {
value = x % mod;
if (value < 0) value += mod;
modular& operator+=(const modular& other) {
if ((value += other.value) >= mod) value -= mod;
return *this;
modular& operator-=(const modular& other) {
if ((value -= other.value) < 0) value += mod;
return *this;
modular& operator*=(const modular& other) {
value = value * other.value % mod;
return *this;
modular& operator/=(const modular& other) {
long long a = 0, b = 1, c = other.value, m = mod;
while (c != 0) {
long long t = m / c;
m -= t * c;
swap(c, m);
a -= t * b;
swap(a, b);
a %= mod;
if (a < 0) a += mod;
value = value * a % mod;
return *this;
friend modular operator+(const modular& lhs, const modular& rhs) { return modular(lhs) += rhs; }
friend modular operator-(const modular& lhs, const modular& rhs) { return modular(lhs) -= rhs; }
friend modular operator*(const modular& lhs, const modular& rhs) { return modular(lhs) *= rhs; }
friend modular operator/(const modular& lhs, const modular& rhs) { return modular(lhs) /= rhs; }
modular& operator++() { return *this += 1; }
modular& operator--() { return *this -= 1; }
modular operator++(int) {
modular res(*this);
*this += 1;
return res;
modular operator--(int) {
modular res(*this);
*this -= 1;
return res;
modular operator-() const { return modular(-value); }
bool operator==(const modular& rhs) const { return value == rhs.value; }
bool operator!=(const modular& rhs) const { return value != rhs.value; }
bool operator<(const modular& rhs) const { return value < rhs.value; }
template <long long mod>
string to_string(const modular<mod>& x) {
return to_string(x.value);
template <long long mod>
ostream& operator<<(ostream& stream, const modular<mod>& x) {
return stream << x.value;
template <long long mod>
istream& operator>>(istream& stream, modular<mod>& x) {
stream >> x.value;
x.value %= mod;
if (x.value < 0) x.value += mod;
return stream;
constexpr long long mod = 998244353;
using mint = modular<mod>;
mint power(mint a, long long n) {
mint res = 1;
while (n > 0) {
if (n & 1) {
res *= a;
a *= a;
n >>= 1;
return res;
vector<mint> fact(1, 1);
vector<mint> finv(1, 1);
mint C(int n, int k) {
if (n < k || k < 0) {
return mint(0);
while ((int) fact.size() < n + 1) {
fact.emplace_back(fact.back() * (int) fact.size());
finv.emplace_back(mint(1) / fact.back());
return fact[n] * finv[k] * finv[n - k];
int main() {
input_checker in;
int tt = in.readInt(1, 1e5);
int sn = 0;
while (tt--) {
int n = in.readInt(1, 5e5);
int m = in.readInt(1, n);
auto a = in.readInts(n, 1, n);
long long mx = *max_element(a.begin(), a.end());
long long mn = *min_element(a.begin(), a.end());
long long sum = accumulate(a.begin(), a.end(), 0LL);
cout << max(mx - mn, (mx * n - sum + m - 1) / m) << '\n';
assert(sn <= 5e5);
return 0;
Editorialist's code (Python)
for _ in range(int(input())):
n, m = map(int, input().split())
a = list(map(int, input().split()))
mn, mx = min(a), max(a)
sm = sum(mx-x for x in a)
ans = mx - mn
ans = max(ans, (sm + m - 1) // m)