MEDMAX - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3

Author: Saarang Srinivasan
Tester: Danny Mittal
Editorialist: Nishank Suresh

DIFFICULTY:

Easy

PREREQUISITES:

Binary search

PROBLEM:

You are given a grid of N\times N integers, which you are free to rearrange however you wish.
After a rearrangement, let m_i be the median of the i-th row.
The cost of a rearrangement is defined to be m_1 + m_2 + \dots + m_N.
Maximize \min(m_1, m_2, \dots, m_N) subject to the cost of the arrangement being at most k, or say that this is impossible.

QUICK EXPLANATION:

Sort the given N^2 integers, let them be v_1\leq v_2\leq \dots \leq v_{N^2}.
The minimum possible cost is v_{N/2 + 1} + v_{N/2 + 1 + (N/2 + 1)} + v_{N/2 + 1 + 2(N/2 + 1)} + \dots + v_{N(N/2 + 1)}.
If this is larger than k, no arrangement is possible so output -1.
Otherwise, note that the cost is a monotonic function so binary search on the value of \min(m_1, m_2, \dots, m_N).

EXPLANATION:

(All division mentioned in the editorial is integer division, i.e, a/b means \lfloor\frac{a}{b}\rfloor)

Note that the initial arrangement of integers in the matrix doesn’t matter whatsoever, so let’s just assume we have N^2 integers in the
Let the medians in a given arrangement be m_1, m_2, \dots, m_N. w.l.o.g let m_1\leq m_2\leq \dots m_N.
Note that in the i-th row, there are N/2 elements \leq m_i (i.e to its left) and N - N/2 - 1 elements which are \geq m_i (i.e to its right).

So, the absolute smallest possible value of m_1 is v_{N/2 + 1}, since we need N/2 elements to its left.
We need another N/2 elements to the left of m_2, so m_2 must be at least v_{N/2 + 1 + N/2 + 1}.
Continuing this process, it’s easily seen that m_i \geq v_{i(N/2 + 1)}.

Further, choosing these as our m_i is also possible.

How ?

Arrange the matrix as follows:
The i-th row contains v_{(i-1)(N/2 + 1) + 1}, v_{(i-1)(N/2 + 1) + 2}, \dots, v_{i(N/2 + 1)}, along with the largest remaining N - (N/2 + 1) elements.

So the minimum cost of any possible arrangement is

v_{N/2 + 1} + v_{N/2 + 1 + (N/2 + 1)} + v_{N/2 + 1 + 2(N/2 + 1)} + \dots + v_{N(N/2 + 1)}

If this is larger than k, we know that no arrangement is possible, so output -1.

Now, suppose an arrangement is possible. We would like to maximize \min(m_1, m_2, \dots, m_N).
For a given integer x, define cost(x) to be the minimum cost of having m_i \geq x for each 1\leq i\leq N, and cost(x) = \infty if it is impossible to have every median \geq x. (In practice, k+1 can be treated as infinity.)

cost is clearly a non-decreasing function (why?), and we are looking for the largest value of x such that cost(x) \leq k. Such a problem is exactly what binary search solves, and so binary search is what we will use.

All we need to know now is how to calculate cost(x). This can be done by the following greedy algorithm:
Suppose we have v_1\leq v_2\leq \dots v_{N^2}. Maintain a variable L, initialized to 0, which denotes the number of unused elements so far.
Initially, set cost(x) = 0.
Iterate i from 1 to N^2.

  • If v_i \geq x and L \geq N/2, we can choose v_i to be a median, so we do cost(x) += v_i and L -= N/2
  • Otherwise, i isn’t chosen and so is free to be used as a lower element later, so increment L by 1.
  • Once the N-th element has been chosen, stop the loop.

This takes care of elements to the left - we also check if we have enough elements on the right.
Once again, this part can be checked greedily - when choosing an index i, make sure that there are at least N - N/2 - 1 free indices to its right.

If we are unable to choose N elements via this process, it’s impossible to make every median \geq x so we set cost(x) = \inf.

We computed cost(x) in \mathcal{O}(N^2), so along with the binary search, our problem is solved in \mathcal{O}(N^2\log{N}).

SUBTASKS:

SUBTASK 1:

The first subtask has k = 10^{14}. Note that m_i \leq 10^9 so the total sum never exceeds 10^{12}. In other words, we can simply ignore k and maximize the sum of the medians we choose.

To do this, the ideas from above tell us that choosing greedily from the back works - skip the last N/2 - 1 elements and choose the next one; again skip N/2 - 1 elements and choose the next one, and so on till we have chosen N elements. Finally, print their sum.

SUBTASK 2:

This subtask permits a solution in \mathcal{O}(N^4), which is almost the same as the solution for the final subtask except the binary search is replaced by a brute-force iteration checking every value of cost(x) for x among our N^2 values.

TIME COMPLEXITY:

\mathcal{O}(N^2\log{N}).

CODE:

Setter (C++)
#include <bits/stdc++.h>
using namespace std;
 
void solve() {
    int n;
    long long k;
    cin >> n >> k ;
    vector<int> v;
    for(int _ = 0; _ < n * n; _++) {
        int x;
        cin >> x;
        v.push_back(x);
    }
    sort(v.begin(), v.end());
 
    int left = n >> 1;
    int right = n * n - n - n * ((n >> 1) - (n + 1) % 2);
    int small = n >> 1;
    int l = left, r = right, ans = -1;
    while(l <= r) {
        int m = (l + r) >> 1;
        int x = 0, y = m;
        vector<bool> vis(n * n);
        long long sum = 0;
        for(int i = 0; i < n; i++) {
            for(int _ = 0; _ < small; _++) {
                if(vis[x])
                    x = y;
                if(x == y) 
                    x++, y++;
                else
                    x++;
            }
            if(x == y) 
                sum += v[y], x++, y++;
            else 
                sum += v[y], vis[y] = 1, y++;
        }
 
        if(sum <= k) 
            l = m + 1, ans = v[m];
        else
            r = m - 1;
    }
    cout << ans << endl;
}
 
signed main() {
    std::ios::sync_with_stdio(0);;
    std::cin.tie(0);
    int t;
    cin >> t;
    while(t--) {
        solve();
    }
    return 0;
}
Tester (Kotlin)
import java.io.BufferedInputStream

const val BILLION = 1000000000L
const val K_LIMIT = 100000000000000L

fun main(omkar: Array<String>) {
    val jin = FastScanner()
    var nSum = 0
    repeat(jin.nextInt(100)) {
        val n = jin.nextInt(1000, false)
        nSum += n
        if (nSum > 1000) {
            throw InvalidInputException("constraint on sum n violated")
        }
        val k = jin.nextLong(K_LIMIT)
        val elements = Array(n * n) { jin.nextLong(BILLION, (it + 1) % n == 0) }
        elements.sort()
        val marked = BooleanArray(n * n)
        var sum = 0L
        var j = n / 2
        repeat(n) {
            marked[j] = true
            sum += elements[j]
            j += (n / 2) + 1
        }
        if (sum > k) {
            println(-1)
        } else {
            var j1 = n / 2
            var j2 = j1
            while (j1 < n * (n / 2)) {
                while (j2 < n * n && marked[j2]) {
                    j2++
                }
                if (j2 == n * n || sum - elements[j1] + elements[j2] > k) {
                    break
                }
                sum -= elements[j1]
                sum += elements[j2]
                marked[j2] = true
                j1++
            }
            println(elements[j1])
        }
    }
    jin.endOfInput()
}

class InvalidInputException(message: String): Exception(message)

class FastScanner {
    private val BS = 1 shl 16
    private val NC = 0.toChar()
    private val buf = ByteArray(BS)
    private var bId = 0
    private var size = 0
    private var c = NC
    private var `in`: BufferedInputStream? = null
    private val validation: Boolean

    constructor(validation: Boolean) {
        this.validation = validation
        `in` = BufferedInputStream(System.`in`, BS)
    }

    constructor() : this(true)

    private val char: Char
        private get() {
            while (bId == size) {
                size = try {
                    `in`!!.read(buf)
                } catch (e: Exception) {
                    return NC
                }
                if (size == -1) return NC
                bId = 0
            }
            return buf[bId++].toChar()
        }

    fun validationFail(message: String) {
        if (validation) {
            throw InvalidInputException(message)
        }
    }

    fun endOfInput() {
        if (char != NC) {
            validationFail("excessive input")
        }
        if (validation) {
            System.err.println("input validated")
        }
    }

    fun nextInt(from: Int, to: Int, endsLine: Boolean = true) = nextLong(from.toLong(), to.toLong(), endsLine).toInt()

    fun nextInt(to: Int, endsLine: Boolean = true) = nextInt(1, to, endsLine)

    fun nextLong(endsLine: Boolean): Long {
        var neg = false
        c = char
        if (c !in '0'..'9' && c != '-' && c != ' ' && c != '\n') {
            validationFail("found character other than digit, negative sign, space, and newline, character code = ${c.toInt()}")
        }
        if (c == '-') {
            neg = true
            c = char
        }
        var res = 0L
        while (c in '0'..'9') {
            res = (res shl 3) + (res shl 1) + (c - '0').toLong()
            c = char
        }
        if (endsLine) {
            if (c != '\n') {
                validationFail("found character other than newline, character code = ${c.toInt()}")
            }
        } else {
            if (c != ' ') {
                validationFail("found character other than space, character code = ${c.toInt()}")
            }
        }
        return if (neg) -res else res
    }

    fun nextLong(from: Long, to: Long, endsLine: Boolean = true): Long {
        val res = nextLong(endsLine)
        if (res !in from..to) {
            validationFail("$res not in range $from..$to")
        }
        return res
    }

    fun nextLong(to: Long, endsLine: Boolean = true) = nextLong(1L, to, endsLine)
}
Editorialist (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,mmx,avx,avx2")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(0); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        ll k; cin >> k;
        vector<int> v(n*n);
        for (int &x : v)
            cin >> x;
        sort(begin(v), end(v));

        ll mincost = 0;
        for (int i = n/2, j = 0; j < n; i += n/2 + 1, ++j) {
            mincost += v[i];
        }
        if (mincost > k) {
            cout << -1 << '\n';
            continue;
        }

        int lo = n/2, hi = n*n - 1;
        int N = n*n;
        while (lo < hi) {
            int mid = (lo + hi + 1)/2;
            ll cost = 0;
            int less = mid, rlim = N, taken = 0;
            for (int i = mid; taken < n and i < N; ++i) {
                if (i >= rlim) break;
                if (less >= n/2) {
                    taken += 1;
                    cost += v[i];
                    less -= n/2;
                    rlim -= n - n/2 - 1;
                    if (rlim <= i) {
                        cost = k + 1;
                        break;
                    }
                }
                else ++less;
            }
            if (taken < n) cost = k+1;
            if (cost <= k) lo = mid;
            else hi = mid-1;
        }
        cout << v[lo] << '\n';
    }
}
4 Likes

This can be also solved greedily.
Here is my approach: Solution: 53242718 | CodeChef

My approach :

suggestions are welcome.(regarding code update/ better approach/ clean code)

My submission might have used a similar approach with prefix sum and I explained it in this comment on codeforces.

i am getting wa on testcase 6 can anyone tell me where i am getting wrong?
mysolution

How do we pick x in cost(x) for this solution?

Best solution :- Here by @nityam_prabhat

The choice of x is given by binary search.
At each step of the binary search, we calculate cost(mid) and compare this with k. Depending on whether it is greater than k or not, update the limits of the binary search.

2 Likes

This code is easier to read and implement. But wanted to know how did you think this way of finding the cost of rearrangement in the function fun.

This is not my code but i can explain :-:

i think , you know what is purpose of binary search here.( if you are not clear then watch official video editorial by codechef). Also notice the initial values of hi and lo there.

The key part is finding cost(mid).

Lets see matrix filling:(1-base indexing )
lets assume arr[mid} is the desired answer .
if all rows are sorted then median would occur at (n/2+ 1)th column of each row .
We first need to fill first row upto (n/2)th column with values from (mid-1) values from sorted Linear array.If there are values left over from (mid-1) integers , They can be used to filled second row upto (n/2)th column if they are plenty enough and so on…
In this way we can fill (mid-1)/(n/2) pre-median columns.

Let mx=(mid-1)/(n/2)…
if( mid%(n/2)==0) then reduce mx by 1.( because in this case number of filled rows dont increase by 1 , value of mx is increased)

Meanwhile we can fill median column i.e.(n/2+1)th with arr[mid] , arr[mid+1], arr[mid+2]…and so on.

if (mid-1) integers are consumed then we have to use untaken values after mid to fill pre-median columns. …

meanwhile median columns will be filled with values ((mx+1)(n/2+1))th , ((mx+1)(n/2+1))th…
((n)*(n/2+1))th values.

Take the sum of median column as cost(mid)…

Rest You can do it by yourself…:blush:

1 Like

I have solved this problem using a vector container for storing answer and queue. ans contains those elements which are sure to be a part of those elements which will make the cost of rearranging and queue q contains those elements which will be selected for putting in ans. I am getting parially correct answer. Can anyone please help me out with the flaws in my code?
https://www.codechef.com/viewsolution/53290644

Finally understood it and got an AC.
Thanks for making it crystal clear…:v:

We can do it in N (logN^2).

Fix the minimum median and find the min cost that can be achieved with it. After fixing the minimum median we can find the min cost in O(N). There are N^2 possible minimum medians. We can do binary search on them.

The testcases seems to be weak. Instead of binary search. I just looped through all N^2 medians. But still got AC
Solution: 54279655 | CodeChef

The performance here is primarily limited by:

  • reading the input: O(N^2) with a pretty high constant factor
  • sorting the array: O(N^2 log N^2), which is the same as O(N^2 log N) from the editorial if we remove constant factor

The contribution of the actual solution is negligible in comparison. I have two high performance submissions with O(N^2) time complexity of my own code (the parts other than input/sort):

And yes, the codechef testcases are definitely not stressing the worst possible input.