BEAUTIFULARR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: shikhar2307
Tester & Editorialist: iceknight1093

DIFFICULTY:

2363

PREREQUISITES:

Greedy algorithms, sorting (or binary search)

PROBLEM:

The beauty of an array equals the sum of pairwise products of its elements.
You’re given an array A. At most K times, you can increase one of its elements by 1.
What’s the maximum possible beauty of the final array?

EXPLANATION:

There are three parts to this problem:

  1. Computing the beauty of an array faster than \mathcal{O}(N^2).
  2. Figuring out which elements to apply the operations to.
  3. Actually applying these operations quickly, since \mathcal{O}(K) would be too slow.

Let’s go over these one at a time.

Computing beauty

We want to quickly compute the sum of pairwise products of elements.
Let S = \sum_{i=1}^N \sum_{j=1}^N (A_i \cdot A_j) be this quantity. Then,
Note that

\begin{align*} (A_1 + A_2 + \ldots + A_N)^2 &= (A_1^2 + A_2^2 + \ldots + A_N^2) + 2\cdot \sum_{1 \leq i \lt j \leq N} A_i\cdot A_j \\ &= (A_1^2 + \ldots + A_N^2) + 2S \end{align*}

So, S can easily be computed in \mathcal{O}(N): we only need to know the square of the sum of the elements, and the sum of squares of the elements.


Now, let’s look at what happens when we need to perform operations.
Suppose we had K = 1, i.e, only one operation needed to be performed. What would the optimal move be?
From the way we rewrote the cost above, it’s obvious that we should choose the smallest A_i to increase by 1.

In fact, this greedy choice holds even for K \gt 1: for each move, choose the smallest element of A and increase it by one!

Proof

This will be a bit long to write out, though it’s really just several applications of exchange arguments.

Let A_1 \leq A_2 \leq \ldots \leq A_N.
Suppose we increase A_i exactly B_i times; so the final values are C_i = A_i + B_i.

Recall that the beauty is (half of) the square of the sum of C_i, minus the sum of squares of C_i.
Note that the sum of C_i is always a constant (and equals K plus the sum of A_i).
So, maximizing the beauty requires us to minimize the sum of the squares of C_i.

Let’s call a solution an assignment of B_i values, and an optimal solution a solution that maximizes beauty.

Claim 1: There exists an optimal solution such that C_i \leq C_j when i \leq j.
Proof: Trivial, if C_i \gt C_j then since A_i \leq A_j it’s always possible to shift operations from index i to index j and essentially swap C_i and C_j.

From now on, we’ll only consider optimal solutions with this structure.

Claim 2: In an optimal solution, if i \lt j and B_j \gt 0, then B_i \gt 0.
Proof: If B_i = 0, reduce B_j by 1 and increase B_i by 1.
Work out the algebra to see that this gives us a strictly better beauty, which contradicts us starting with an optimal solution.

This tells us that we’ll only operate on some prefix of indices.
Let m be the highest index operated on.

Claim 3: If 1 \leq i \lt j \leq m, then C_j - C_i \leq 1
Proof: If C_i and C_j differ by two (or more), reduce one operation from the larger one and give it to the lower one; this improves beauty, which is once again a contradiction.

Note that these properties (almost) uniquely characterize an optimal solution.
In particular, there’ll be an index m such that:

  • The suffix starting from A_{m+1} is unchanged, i.e, C_i = A_i for this suffix.
  • There’ll be some integer x such that every C_i upto m is either x or x+1.

Note that if m is fixed, x is also uniquely fixed since we must perform exactly K moves.
Further, m itself is essentially fixed: we need to (at least) bring all the elements up to A_m, requiring (A_m - A_1) + (A_m - A_2) + \ldots + (A_m - A_m) moves.
This is an increasing function of m and we want it to stay within K, so the breakpoint is essentially unique.

Finally, note that the greedy solution of “always increase the smallest element” achieves this exact structure as well; and hence is optimal.


Now, let’s see how we can use this information to solve the problem at hand.


Let’s sort A, so that A_1 \leq A_2 \leq A_3 \leq \ldots \leq A_N.
Now, look at the process of performing operations:

  • First, we increase only A_1 till it reaches A_2.
  • Then, we alternate increasing A_1 and A_2 till they reach A_3.
  • Then, we A_1, A_2, A_3 in turn till they all reach A_4.
    \vdots
  • We increase A_1, \ldots, A_{i-1} in turn till they all reach A_i.

This allows us to quickly simulate the process by combining operations.
Let’s iterate from 1 to N.
When at index i:

  • We’ve ensured that so far, A_1 = A_2 = \ldots = A_i.
    Note that they’ll all be equal to (the initial value of) A_i.
  • Next, we must perform operations till all these elements reach A_{i+1}.
    This takes exactly i \cdot (A_{i+1} - A_i) operations.
  • If K is at least this number, we know that all the operations can be performed, so we move to index i+1.
  • Otherwise, the process stops at this index.
    Since the operations are to be performed on the elements in turn, the exact values of all of A_1, \ldots, A_i can be figured out; and A_{i+1}, A_{i+2}, \ldots, A_N won’t change.

We do \mathcal{O}(1) work at each index; and \mathcal{O}(N) work at one index, so this is overall linear.
Since we required sorting, the complexity is \mathcal{O}(N\log N).

Alternately, binary search on the value of the minimum of the array for a solution in \mathcal{O}(N\log K).

TIME COMPLEXITY

\mathcal{O}(N\log N) per test case.

CODE:

Author's code (C++)
#include<bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>  
#include <ext/pb_ds/tree_policy.hpp>   
using namespace std;
using namespace __gnu_pbds;

#pragma GCC target ("avx2")
#pragma GCC optimization ("O3")
#pragma GCC optimization ("unroll-loops")
#pragma GCC target("popcnt")

template <typename T>
using ordered_set = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;

template <typename T>
using ordered_multiset = tree<T, null_type, less_equal<T>, rb_tree_tag, tree_order_statistics_node_update>;
using namespace std;
 
const double pi = acos(-1);
 
// DEBUG FUNCTIONS 
#ifndef ONLINE_JUDGE
 
template<typename T>
void __p(T a) {
    cout<<a;
}
template<typename T, typename F>
void __p(pair<T, F> a) {
    cout<<"{";
    __p(a.first);
    cout<<",";
    __p(a.second);
    cout<<"}";
}
template<typename T>
void __p(std::vector<T> a) {
    cout<<"{";
    for(auto it=a.begin(); it<a.end(); it++)
        __p(*it),cout<<",}"[it+1==a.end()];
}
template<typename T>
void __p(std::set<T> a) {
    cout<<"{";
    for(auto it=a.begin(); it!=a.end();){
        __p(*it); 
        cout<<",}"[++it==a.end()];
    }
 
}
template<typename T>
void __p(std::multiset<T> a) {
    cout<<"{";
    for(auto it=a.begin(); it!=a.end();){
        __p(*it); 
        cout<<",}"[++it==a.end()];
    }
}
template<typename T>
void __p(ordered_set<T> a) {
    cout<<"{";
    for(auto it=a.begin(); it!=a.end();){
        __p(*it); 
        cout<<",}"[++it==a.end()];
    }
 
}
template<typename T>
void __p(ordered_multiset<T> a) {
    cout<<"{";
    for(auto it=a.begin(); it!=a.end();){
        __p(*it); 
        cout<<",}"[++it==a.end()];
    }
}
template<typename T, typename F>
void __p(std::map<T,F> a) {
    cout<<"{\n";
    for(auto it=a.begin(); it!=a.end();++it)
    {
        __p(it->first);
        cout << ": ";
        __p(it->second);
        cout<<"\n";
    }
    cout << "}\n";
}
 
template<typename T, typename ...Arg>
void __p(T a1, Arg ...a) {
    __p(a1);
    __p(a...);
}
template<typename Arg1>
void __f(const char *name, Arg1 &&arg1) {
    cout<<name<<" : ";
    __p(arg1);
    cout<<endl;
}
template<typename Arg1, typename ... Args>
void __f(const char *names, Arg1 &&arg1, Args &&... args) {
    int bracket=0,i=0;
    for(;; i++)
        if(names[i]==','&&bracket==0)
            break;
        else if(names[i]=='(')
            bracket++;
        else if(names[i]==')')
            bracket--;
    const char *comma=names+i;
    cout.write(names,comma-names)<<" : ";
    __p(arg1);
    cout<<" | ";
    __f(comma+1,args...);
}
#define trace(...) cout<<"Line:"<<__LINE__<<" ", __f(#__VA_ARGS__, __VA_ARGS__)
#else
#define trace(...)
#define error(...)
#endif
 
// DEBUG FUNCTIONS END 

# define FOR(i, a, n) for(int i = a; i<n;i++)
# define FORd(i, a, n) for(int i = a; i >= n; i--)
#define ll long long
ll mod = 1000000007;
# define endl "\n"
# define int ll
# define printArr(arr, n) FOR(abcd,0,  n){cout<<arr[abcd]<<" ";}cout<<endl;
#define f first 
#define se second 
#define pb push_back
#define pob pop_back
#define sz(x) (int)x.size()
#define all(x) x.begin(), x.end()
typedef vector<long long> vi;
typedef pair<long long, long long> pii;
typedef vector<pair<long long, long long>> vpi;
typedef vector<vector<int>> vvi;
int gcdExtended(int a, int b, int* x, int* y)
{
    // Base Case
    if (a == 0)
    {
        *x = 0, *y = 1;
        return b;
    }
 
    int x1, y1; // To store results of recursive call
    int gcd = gcdExtended(b % a, a, &x1, &y1);
 
    // Update x and y using results of recursive
    // call
    *x = y1 - (b / a) * x1;
    *y = x1;
 
    return gcd;
}

 
// Function to find modulo inverse of a
ll modInverse(ll a, ll m)
{
    int x, y;
    int g = gcdExtended(a, m, &x, &y);
    if (g != 1)
        return 0;
    else
    {
        // m is added to handle negative x
        ll res = (x % m + m) % m;
        return res;
    }
}

ll nCr(int n, int r){
    if(r>n){
        return 0;
    }
    if(r>n-r){
        r = n-r;
    }
    ll ans = 1;
    for(int i = 1; i<=r ; i++){
        ans *= (n-i+1);
        ans%= mod;
        ans *= modInverse(i, mod);
        ans %= mod;
        
    }

    return ans;
}

ll binpow(ll a, ll b) {
    if (b == 0)
        return 1;
    long long res = binpow(a, b / 2);
    if (b % 2)
        return (res * res)%mod * a % mod;
    else
        return (res * res) %mod;
}

// const int Max = 2e5 +1;
// ll fact[Max];
// ll inv_fact[Max];

// void preSolveFact(ll n){
//     ll ans = 1;
//     fact[0] = 1;
//     for(int i = 1; i<=n; i++){
//         ans *=i;
//         ans %= mod;
//         fact[i] = ans;
//     }
//     inv_fact[n] = binpow(fact[n], mod-2);

//     for(int i = n-1; i>=0; i--){
//         inv_fact[i] = inv_fact[i+1] * (i+1) %mod;
//     }
// }
// ll nCr_pre(ll n, ll r){
//     if(n>=r && n>=0 && r>=0)
//     return fact[n] * inv_fact[r] %mod * inv_fact[n-r]%mod;
//     else return 0;
// }

signed main(){
    #ifdef LOCALFLAG 
        freopen("input.in", "r", stdin);
        freopen("output.in", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    int t = 1;
    cin>>t;
    while(t--){
        int n, k;
        cin>>n>>k;
        int sum = 0;
        vi arr(n);
        FOR(i, 0, n){
        	cin>>arr[i];
        }
        sort(all(arr));
        int l = arr[0], r = arr[0] + k;
        int ans = l;
        while(l <= r){
        	int mid = (l + r)/2;
        	int ops = 0;
        	for(int i = 0; i < n; i++){
        		ops += max((ll)0, mid - arr[i]);
        	}
        	if(ops <= k) {
        		l = mid + 1;
        		ans = mid;
        	}
        	else r = mid - 1;
        }
        FOR(i, 0, n){
        	if(arr[i] < ans) {
        		k -= ans - arr[i];
        		arr[i] = ans;
        	}
        }
        FOR(i, 0, k){
        	arr[i]++;
        }
        FOR(i, 0, n){
        	arr[i] %= mod;
        	sum += arr[i];
        }
        sum %= mod;
        ans = 0;
        FOR(i, 0, n){
        	ans += (sum - arr[i])*arr[i] % mod;
        	ans = (ans + mod) % mod;
        }
        ans %= mod;
        cout<<ans * modInverse(2, mod) % mod<<endl;
    }

}
Editorialist's code (Python)
mod = 10**9 + 7
for _ in range(int(input())):
    n, k = map(int, input().split())
    a = sorted(list(map(int, input().split())))
    a.append(10 ** 10)
    for i in range(1, n+1):
        req = i*(a[i] - a[i-1])
        if req <= k: k -= req
        else:
            add, rem = k//i, k%i
            for j in range(i):
                a[j] = a[i-1] + add
                if rem:
                    a[j] += 1
                    rem -= 1
            break
    a.pop()
    ans = sum(a)**2 - sum(x**2 for x in a)
    print((ans // 2) % mod)

Please don’t use macros in editorials.

2 Likes

I’ve addressed this before, here.

why do I need to take modulo of sum when biggest sum will be in range of long long int ?
why is this wrong?
https://www.codechef.com/viewsolution/97451882

((sum-v[i])*v[i]%mod)
Think about what happens here if, say, sum = 10^{14} and v[i] = 10^9.
You calculate something roughly of the order 10^{23} before taking mod, that overflows even long long.

1 Like

Alternately, binary search on the value of the minimum of the array for a solution in O(Nlog⁡K).
i do not understand this line. can you plz clarify how binary search can we used.

1) Ideally you would want to make all the values the same to get the maximum sum.
2) So we check if we can make our current array at least equal to some value let’s say x, now think about the range of this value.
3) Smallest element of array <= x <= smallest element of array + k

Why this code is failing one test case
https://www.codechef.com/viewsolution/97888815
please help