PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: hellolad
Tester: kingmessi
Editorialist: iceknight1093
DIFFICULTY:
Easy
PREREQUISITES:
Sorting, prefix sums, binary search
PROBLEM:
You’re given an N\times N grid A.
For a fixed integer K, you will set A_{i, j} = \min(A_{i, j}, K) for every element of the grid.
The score of a row, with respect to K, is the number of other rows that have a larger sum than it after the operation.
Find the minimum value of K that maximizes the score of the first row.
EXPLANATION:
Note that the order of elements within each row doesn’t matter at all, so let’s sort every row for simplicity.
Now, to start with, we look at just the first two rows (recall that they’re sorted now).
Suppose we fix a value of K.
Then, to decide whether the second row has a larger sum quickly, we can do the following:
- Compute the largest index i_1 such that H_{1, i_1} \lt K.
Similarly compute the largest i_2 such that H_{2, i_2} \lt K. - The sum of the first row is now H_{1, 1} + H_{1, 2} + \ldots + H_{1, i_1} + K\cdot (N - i_1), because everything after i_1 will be set to K and there are N - i_1 such elements.
- Similarly, the sum of the second row is H_{2, 1} + H_{2, 2} + \ldots + H_{2, i_2} + K\cdot (N - i_2), which is to be compared to the above.
Observe that in both cases, we have some prefix sum of the row and then a multiple of K.
That is, if we denote P_{i, j} = P_{i, 1} + P_{i, 2} + \ldots + P_{i, j} to be the j-th prefix sum of the i-th row, then the quantities we’re comparing are
P_{1, i_1} + K\cdot (N - i_1) and P_{2, i_2} + K\cdot (N - i_2)
Now, while this is easy to compute and compare for a fixed K, that’s not too useful to solve the actual problem: we definitely can’t try every K.
Let’s do a bit of rearrangement.
Simple algebra gives us
This then leads to three cases depending on whether i_1 - i_2 is positive, negative, or zero.
In particular,
- i_1 - i_2 \gt 0 gives a lower bound on K.
- i_1 - i_2 \lt 0 gives an upper bound on K.
- i_1 = i_2 gives an expression that’s either true for every K, or false for every K.
In any case, note that if i_1 and i_2 are fixed, we’ll obtain some bound on K for which the inequality holds true.
Further, because of how they’re obtained, there aren’t too many pairs of (i_1, i_2) possible at all: in fact, there are only O(N) distinct pairs!
To see why, consider K and K+1. When will they have different (i_1, i_2) pairs?
By definition of i_1 and i_2, this can happen only when either one of the rows contains an element equal to K itself, which will cause i_1 and/or i_2 to increment.
There are 2N elements in total across both rows, so the pairs can change at most 2N times too.
It’s thus quite easy to find all \mathcal{O}(N) pairs of (i_1, i_2), as well as which ranges of K they correspond to.
Note that for each range of K for which (i_1, i_2) is fixed, solving the inequality P_{1, i_1} - P_{2, i_2} \lt K\cdot (i_1 - i_2) as above will give us a further range of K that it’s valid for.
Intersecting these two ranges will give us (at most) one range of K for which the second row has larger sum than the first.
This way, looking at just the first and second rows, we obtain a collection of at most 2N disjoint intervals, each describing a range where the second row is larger than the first.
One way to think about this is that each interval tells us to “add 1 to the answer” for the range of K that it represents.
If this is repeated for each row, we’ll obtain a collection of \mathcal{O}(N^2) intervals that in total describe the answers for all K.
Let these intervals be [l_1, r_1], [l_2, r_2], \ldots
We now want to find the smallest K which gives the maximum answer.
The answer for a certain K is, in terms of the above intervals, simply the count of them that contain this K.
So, it’s easy to see that the smallest K that maximizes the answer must be one of the left endpoints of the intervals, i.e. one of the l_i. This is because, if an interval doesn’t start at K, then K-1 won’t have a lower answer; so K won’t be the minimum.
This means we can just compute the answer for every l_i and take the best among them.
Computing the answer for a fixed l_i is easy enough: we want to find the number of intervals that contain l_i; this is just the total number of intervals, minus the count of them that start after l_i or end before l_i.
Finding the number of intervals that start/end after/before a point is a simple application of binary search.
Since we test \mathcal{O}(N^2) points, the overall complexity is \mathcal{O}(N^2 \log N).
TIME COMPLEXITY:
\mathcal{O}(N^2\log N) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define IOS ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
// template<class T> using ordered_set = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
template<class T> using ordered_multiset = tree<T, null_type, less_equal<T>, rb_tree_tag, tree_order_statistics_node_update>;
#define int long long
int32_t main(){
// #ifndef ONLINE_JUDGE
// freopen("1.in", "r", stdin);
// freopen("1.out", "w", stdout);
// #endif
IOS
int t;
cin>>t;
while(t--){
int n;
cin>>n;
vector<vector<int>> h(n,vector<int>(n));
vector<vector<int>> pref(n,vector<int>(n+1));
for(int i=0;i<n;++i){
for(int j=0;j<n;++j){
cin>>h[i][j];
}
sort(h[i].begin(),h[i].end());
for(int j=0;j<n;++j){
pref[i][j+1]=pref[i][j]+h[i][j];
}
}
vector<array<int,2>> p;
for(int i=1;i<n;++i){
set<int> s;
s.insert(0);
s.insert(1e9);
for(auto x:h[0]){
s.insert(x);
}
for(auto x:h[i]){
s.insert(x);
}
vector<int> a;
for(auto x:s){
a.push_back(x);
}
for(int j=0;j<a.size()-1;++j){
int x=lower_bound(h[0].begin(),h[0].end(),a[j]+1)-h[0].begin();
int y=lower_bound(h[i].begin(),h[i].end(),a[j]+1)-h[i].begin();
int c1=pref[0][x];
int c2=pref[i][y];
int k1=n-x,k2=n-y;
if(k1==k2){
if(c1<c2){
p.push_back({a[j]+1,a[j+1]});
}
}
else if(k1>k2){
int cur=(c2-c1+(k1-k2-1))/(k1-k2)-1;
cur=min(cur,a[j+1]);
if(cur>a[j]){
p.push_back({a[j]+1,cur});
}
}
else{
int cur=(c2-c1)/(k1-k2)+1;
cur=max(cur,a[j]+1);
if(cur<=a[j+1]){
p.push_back({cur,a[j+1]});
}
}
}
}
sort(p.begin(),p.end());
int ans=1;
int mx=0;
ordered_multiset<int> ss;
for(int i=0;i<p.size();++i){
int x=ss.size()-ss.order_of_key(p[i][0])+1;
if(x>mx){
mx=x;
ans=p[i][0];
}
ss.insert(p[i][1]);
}
cout<<ans<<'\n';
}
return 0;
}
Tester's code 1 (C++)
#include <bits/stdc++.h>
using namespace std;
#define IOS ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
// template<class T> using ordered_set = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
template<class T> using ordered_multiset = tree<T, null_type, less_equal<T>, rb_tree_tag, tree_order_statistics_node_update>;
#define int long long
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}
int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && !isspace(buffer[now])) {
now++;
}
return now;
}
string readOne() {
assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
return res;
}
string readString(int minl, int maxl, const string &pattern = "") {
assert(minl <= maxl);
string res = readOne();
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int minv, int maxv) {
assert(minv <= maxv);
int res = stoi(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
long long res = stoll(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
auto readInts(int n, int minv, int maxv) {
assert(n >= 0);
vector<int> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readInt(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
auto readLongs(int n, long long minv, long long maxv) {
assert(n >= 0);
vector<long long> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readLong(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
}inp;
int32_t main(){
// #ifndef ONLINE_JUDGE
// freopen("1.in", "r", stdin);
// freopen("1.out", "w", stdout);
// #endif
IOS
int t;
// cin>>t;
t = inp.readInt(1,10'000);
inp.readEoln();
int smn = 0;
while(t--){
int n;
// cin>>n;
n = inp.readInt(2,1'000);
smn += n*n;
inp.readEoln();
vector<vector<int>> h(n,vector<int>(n));
vector<vector<int>> pref(n,vector<int>(n+1));
for(int i=0;i<n;++i){
for(int j=0;j<n;++j){
// cin>>h[i][j];
h[i][j] = inp.readInt(1,1000'000'000);
if(j == n-1)inp.readEoln();
else inp.readSpace();
}
sort(h[i].begin(),h[i].end());
for(int j=0;j<n;++j){
pref[i][j+1]=pref[i][j]+h[i][j];
}
}
vector<array<int,2>> p;
for(int i=1;i<n;++i){
set<int> s;
s.insert(0);
s.insert(1e9);
for(auto x:h[0]){
s.insert(x);
}
for(auto x:h[i]){
s.insert(x);
}
vector<int> a;
for(auto x:s){
a.push_back(x);
}
for(int j=0;j<a.size()-1;++j){
int x=lower_bound(h[0].begin(),h[0].end(),a[j]+1)-h[0].begin();
int y=lower_bound(h[i].begin(),h[i].end(),a[j]+1)-h[i].begin();
int c1=pref[0][x];
int c2=pref[i][y];
int k1=n-x,k2=n-y;
if(k1==k2){
if(c1<c2){
p.push_back({a[j]+1,a[j+1]});
}
}
else if(k1>k2){
int cur=(c2-c1+(k1-k2-1))/(k1-k2)-1;
cur=min(cur,a[j+1]);
if(cur>a[j]){
p.push_back({a[j]+1,cur});
}
}
else{
int cur=(c2-c1)/(k1-k2)+1;
cur=max(cur,a[j]+1);
if(cur<=a[j+1]){
p.push_back({cur,a[j+1]});
}
}
}
}
sort(p.begin(),p.end());
int ans=1;
int mx=0;
ordered_multiset<int> ss;
for(int i=0;i<p.size();++i){
int x=ss.size()-ss.order_of_key(p[i][0])+1;
if(x>mx){
mx=x;
ans=p[i][0];
}
ss.insert(p[i][1]);
}
cout<<ans<<'\n';
}
inp.readEof();
assert(smn <= 1000'000);
return 0;
}
Tester's code 2 (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e9
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
void Solve()
{
int n; cin >> n;
vector<vector<int>> a(n, vector<int>(n));
for (int i = 0; i < n; i++){
for (int j = 0; j < n; j++){
cin >> a[i][j];
}
}
// find the intervals for which sum(i) > sum(1)
vector <pair<int, int>> ans;
for (int i = 1; i < n; i++){
int s1 = 0;
int s2 = 0;
vector <pair<int, int>> c;
for (int j = 0; j < n; j++){
s1 += a[0][j];
s2 += a[i][j];
c.push_back({a[0][j], 0});
c.push_back({a[i][j], 1});
}
int c1 = 0, c2 = 0;
sort(c.begin(), c.end());
reverse(c.begin(), c.end());
if (s1 < s2){
ans.push_back({c[0].first + 1, INF});
}
for (int j = 0; j + 1 < c.size(); j++){
// <= this
// > next
if (c[j].second == 0){
c1 += 1;
s1 -= c[j].first;
} else {
c2 += 1;
s2 -= c[j].first;
}
// check if lowest works
// check if highest works
// binary search
int lo = c[j + 1].first + 1;
int hi = c[j].first;
if (lo > hi) continue;
auto works = [&](int x){
assert(lo <= x && x <= hi);
int sum1 = s1 + c1 * x;
int sum2 = s2 + c2 * x;
return sum1 < sum2;
};
// if (hi == 4){
// cout << "WTF\n";
// }
if (c2 < c1){
// some prefix of [lo, hi] will work
// find largest that works
int l = lo, r = hi;
while (l != r){
int mid = (l + r + 1) / 2;
if (works(mid)){
l = mid;
} else {
r = mid - 1;
}
}
if (works(lo))
ans.push_back({lo, l});
} else {
// cout << hi << " " << works(hi) << "\n";
int l = lo, r = hi;
while (l != r){
int mid = (l + r) / 2;
if (works(mid)){
r = mid;
} else {
l = mid + 1;
}
}
if (works(hi)){
ans.push_back({l, hi});
}
}
}
}
// look for max intersection
map <int, int> mp;
for (auto z : ans){
int x = z.first;
int y = z.second;
// cout << x << " " << y << '\n';
mp[x] += 1;
mp[y + 1] -= 1;
}
int mx = 0;
int got = 1;
int curr = 0;
for (auto z : mp){
int x = z.first;
int y = z.second;
curr += y;
if (curr > mx){
mx = curr;
got = x;
}
}
cout << got << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}