P6172 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: pols_agyi_pols
Tester: kingmessi
Editorialist: iceknight1093

DIFFICULTY:

Easy

PREREQUISITES:

Sets

PROBLEM:

Bob has a permutation P of length N.
He will obtain a score by performing the following operation till P is empty:

  • Add the median of P to the score.
  • Then, delete one element from P, of his choice.

Bob’s aim is to minimize his score.
Alice wants to maximize Bob’s score, and to do so, can delete exactly one subarray of length exactly K from the permutation.
Find Bob’s final score.

EXPLANATION:

First, let’s analyze what Bob would do if Alice weren’t present.

For convenience, let’s sort the array P, so that P_1 \lt P_2 \lt \cdots \lt P_N.
Let M denote the index of the median.
Then,

  • Bob will add P_M to his answer.
    Then, he must delete an element - and we need to see how M changes after this deletion.
  • Since the median differs slightly depending on whether N is even or odd, let’s look at the cases separately.
    • If N is even, M will either remain the same (if an element \gt P_M is deleted), or will increase by 1 (if an element \leq P_M is deleted).
    • If N is odd, M will either decrease by 1 (if an element \gt P_M is deleted), or remain the same (if an element \leq P_M is deleted).
  • N reduces by 1 after the deletion.

Bob’s aim is to minimize the sum of medians, so ideally when N is even he’ll try to keep M the same (since the alternative is to increase it, which just makes the median larger), and when N is odd he’ll try to decrease it since this leads to a smaller median.
This leads to a very simple strategy for Bob: just repeatedly delete the largest remaining element.
The optimality of this can be proven using a simple exchange argument.


Now, let’s take a look at what exactly Bob’s score will be.
Suppose N is even, so that M = \frac N 2 initially.
Then,

  • First, P_{\frac N 2} is added to the score. Then, the maximum is deleted - this doesn’t change M.
  • P_{\frac N 2} is added to the score again. The maximum is deleted - this time, M reduces by 1.
  • Now, P_{\frac N 2 - 1} is added to the score; and after deletion M doesn’t change.
  • Next, P_{\frac N 2 - 1} is again added to the score; but this time M decreases by 1 after deletion
    \vdots

In general, it can be seen that Bob’s final score is exactly

2\cdot \left(P_1 + P_2 + P_3 + \ldots + P_{\frac N 2}\right)

The case when N is odd is similar, just that we use \frac{N-1}{2} instead of \frac N 2, and there will be a single term of P_{\frac{N+1}{2}} added in as well.

So, quite simply, Bob’s score is twice the sum of the smallest half of elements (plus one more element, if the length is odd).


Let’s now relate this to Alice’s moves.
Alice can delete some subarray of length K; after which there are N-K elements remaining, so Bob’s score will be twice the sum of the smallest \frac{N-K}{2} elements.
We need to check every subarray of length K, and take the maximum of Bob’s answers among them - that’s the subarray Alice will choose to delete.

To do this, note that deleting the subarrays P[i:i+K) and P[i+1 : i+1+K) don’t actually differ by much at all.
In particular, the only elements that differ in their states at all are P_i and P_{i+K} — the former is deleted in the first case but not deleted in the second, while the latter is the opposite.

This allows us to use a sliding window approach to maintain the smallest \frac{N-K}{2} elements and their sum.
Let’s maintain two sets S_L and S_H, where S_L stores the smallest \frac{N-K}{2} of the remaining elements and S_H stores everything else.
Populate these sets initially assuming the first K elements are deleted, i.e. the subarray starting from index 1.

Now, let’s see what happens when moving one step to the right, to the subarray starting from index 2.

  • P_{K+1} was previously available to us, but is now not.
    It’s present in exactly one of S_L and S_H, delete it from that one.
  • P_1 was not available to us earlier, but now is.
    Insert it into either S_L or S_H, doesn’t really matter.
  • Finally, these insertions/deletions might have thrown things off, so we ‘rebalance’ the sets.
    That is,
    • First, make sure that S_L has size exactly \frac{N-K}{2}.
      To do this, either move elements from S_H to it if there are too few, or from it to S_H if there are too many.
    • Then, we must ensure that S_L contains the smallest \frac{N-K}{2} elements.
      This is equivalent to saying that every element of S_L is smaller than every element of S_H, so we can simply run this simple algorithm: while \max(S_L) \gt \min(S_H), swap the smallest element of S_H with the largest element of S_L.

Note that the rebalance step still involves only \mathcal{O}(1) inserts/deletes, each of which can be done in \mathcal{O}(\log N) time if using an appropriate data structure (std::set in C++ for instance).
So, in \mathcal{O}(\log N) time we’re able to move to the next window, while keeping the sets S_L and S_H updated.
Note that we care about specifically the sum of S_L so make sure to keep that updated when inserts/deletes happen too.

This way, we can go through every window and compute Bob’s optimal score for the corresponding deletion; and then print the maximum of them all as the answer.

TIME COMPLEXITY:

\mathcal{O}(N\log N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define ll long long

int main() {
	ll tt=1;
    cin>>tt;
    while(tt--){
        ll n,k;
        cin>>n>>k;
        ll a[n];
        set <ll> l,r;
        for(int i=0;i<n;i++){
            cin>>a[i];
        }
        for(int i=k;i<n;i++){
            r.insert(a[i]);
        }
        ll ans=0;
        ll sum=0;
        ll cnt=(n-k+1)/2;
        ll flag=(n-k)%2;
        while(l.size()<cnt){
            l.insert(*r.begin());
            sum+=*r.begin();
            r.erase(r.begin());
        }
        ans=2*sum-flag*(*l.rbegin());
        for(int i=k;i<n;i++){
            if(a[i-k]>(*l.rbegin())){
                r.insert(a[i-k]);
            }else{
                r.insert(*l.rbegin());
                sum-=(*l.rbegin());
                l.erase(l.find(*l.rbegin()));
                l.insert(a[i-k]);
                sum+=a[i-k];
            }
            if(l.count(a[i])){
                sum-=a[i];
                l.erase(l.find(a[i]));
                l.insert(*r.begin());
                sum+=(*r.begin());
                r.erase(r.begin());
            }else{
                r.erase(r.find(a[i]));
            }
            ans=max(ans,2*sum-flag*(*l.rbegin()));
        }
        cout<<ans<<"\n";
    }
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define ll long long

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && !isspace(buffer[now])) {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
}inp;

int main() {
    ll tt=1;
    // cin>>tt;
    tt = inp.readInt(1,100'000);
    inp.readEoln();
    while(tt--){
        ll n,k;
        // cin>>n>>k;
        n = inp.readInt(2,200'000);
        inp.readSpace();
        k = inp.readInt(1,n-1);
        inp.readEoln();

        ll a[n];
        set <ll> l,r;
        for(int i=0;i<n;i++){
            // cin>>a[i];
            a[i] = inp.readInt(1,n);
            if(i == n-1)inp.readEoln();
            else inp.readSpace();
            
        }
        set<int> s10;
        for(auto &x : a)s10.insert(x);
        assert(s10.size() == n);

        for(int i=k;i<n;i++){
            r.insert(a[i]);
        }
        ll ans=0;
        ll sum=0;
        ll cnt=(n-k+1)/2;
        ll flag=(n-k)%2;
        while(l.size()<cnt){
            l.insert(*r.begin());
            sum+=*r.begin();
            r.erase(r.begin());
        }
        ans=2*sum-flag*(*l.rbegin());
        for(int i=k;i<n;i++){
            if(a[i-k]>(*l.rbegin())){
                r.insert(a[i-k]);
            }else{
                r.insert(*l.rbegin());
                sum-=(*l.rbegin());
                l.erase(l.find(*l.rbegin()));
                l.insert(a[i-k]);
                sum+=a[i-k];
            }
            if(l.count(a[i])){
                sum-=a[i];
                l.erase(l.find(a[i]));
                l.insert(*r.begin());
                sum+=(*r.begin());
                r.erase(r.begin());
            }else{
                r.erase(r.find(a[i]));
            }
            ans=max(ans,2*sum-flag*(*l.rbegin()));
        }
        cout<<ans<<"\n";
    }
    inp.readEof();
}
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n, k; cin >> n >> k;
        vector<int> p(n);
        for (int &x : p) cin >> x;
        
        int half = (n - k) / 2;
        set<int> sL, sH;
        ll sm = 0;
        auto rebalance = [&] () {
            while (sL.size() > half) {
                int x = *sL.rbegin();
                sL.erase(x); sm -= x;
                sH.insert(x);
            }
            while (sL.size() < half) {
                int x = *sH.begin();
                sH.erase(x);
                sL.insert(x); sm += x;
            }

            if (sL.empty()) return;
            while (*sL.rbegin() > *sH.begin()) {
                int x = *sL.rbegin(), y = *sH.begin();
                sL.erase(x); sm -= x;
                sH.erase(y);
                sL.insert(y); sm += y;
                sH.insert(x);
            }
        };
        
        for (int i = k; i < n; ++i) {
            sL.insert(p[i]);
            sm += p[i];
        }
        rebalance();

        ll ans = 2*sm;
        if ((n-k)%2) ans += *sH.begin();

        for (int i = k; i < n; ++i) {
            sL.insert(p[i-k]);
            sm += p[i-k];

            if (sL.count(p[i])) {
                sm -= p[i];
                sL.erase(p[i]);
            }
            else sH.erase(p[i]);

            rebalance();
            ll cur = 2*sm;
            if ((n-k)%2) cur += *sH.begin();
            ans = max(ans, cur);
        }
        cout << ans << '\n';
    }
}
3 Likes

Good problem. I couldn’t solve it myself. But while searching for approaches, I did come across to figure out a way to find the median quickly, but didn’t remain at it for long to investigate it further.

Nice problem that taught me something new about calculating running medians. Also, I now see the relation between finding median of a running stream wherein we also keep two heaps. Nice trick overall and nice question with twist to remove number from stream as well.

2 Likes

Here is my code

include <bits/stdc++.h>
using namespace std;
define nl ‘\n’
define int long long int
define sz(s) (long long)(s.size())
define all(v) (v).begin(),(v).end()
define vi vector
define vi vector
define pii pair<int, int>
define fl(i, n) for (int i = 0; i < n; i++)

template
void println(T x){ cout<<x<<“\n”;}
template
void print(T x){ cout<<x<<" ";}

int mod = 1e9+7 ;
int power(int x, int y){
if(y==0)return 1;
int p = power(x, y/2);
if(y&1)return (((p*p)%mod)x)%mod;
return (p
p)%mod;
}
int modIn(int x){
return power(x,mod-2);
}

void solve()
{
int n,k;
cin>>n>>k;
int a[n];
fl(i,n)cin>>a[i];

multiset<int> s1, s2;
int sum1 = 0 ,sum2=0;
auto balance = [&](){
    while (s1.size() > s2.size()){
        int x = *s1.rbegin();
        s2.insert(x);
        sum2+=x;
        sum1-=x;
        s1.erase(s1.find(x));
    }
    while (s1.size() < s2.size()){
        int x = *s2.begin();
        s1.insert(x);
        sum1+=x;
        sum2-=x;
        s2.erase(s2.find(x));
    }
};
auto addNum = [&](int num){
    if (!s1.empty() and num <= *(s1.rbegin())){
        s1.insert(num);
        sum1+=num;
    }
    else if(!s1.empty() and num > *(s1.rbegin())){
        s2.insert(num);
        sum2+=num;
    }
    else if (s1.empty()){
        s1.insert(num);
        sum1+=num;
    }
    balance();
};

for(int i=0;i<n;i++){
    addNum(a[i]);
}
// cerr<<s1.size()<<" "<<s2.size()<<"\n";
int count = (n-k);
// remove the first k elements as subarray
for(int i=0;i<k;i++){
    if(s1.find(a[i])!=s1.end()){
        s1.erase(s1.find(a[i]));
        sum1-=a[i];
    }else{
        s2.erase(s2.find(a[i]));
        sum2-=a[i];
    }
}
balance();
int ans = 0;
// even so twice of sum of first half small elements
if(count%2==0){
    ans = max(ans,sum1*2);
}else{
    ans = max(ans,2*sum1 - (*s1.rbegin()));
}
int low = 0, high = k;
while(high<n){
    // add back the first element
    addNum(a[low]);
    // remove the kth element
     if(s1.find(a[high])!=s1.end()){
        s1.erase(s1.find(a[high]));
        sum1-=a[high];
    }else{
        s2.erase(s2.find(a[high]));
        sum2-=a[high];
    }
    balance();
    if(count%2==0){
        ans = max(ans,sum1*2);
    }else{
        ans = max(ans,2*sum1 - (*s1.rbegin()));
    }
    low++;
    high++;
}
println(ans);

}
signed main() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
clock_t z = clock();
int tests=1;
cin >> tests;
while (tests–) {
solve();
}
cerr << "Run Time : " << ((double)(clock() - z) / CLOCKS_PER_SEC);
return 0;
}

Solve sliding window median on cses or leetcode

1 Like

Wowww!!! I didn’t know about this problem. I guess if someone has done the cses problem before, then this problem would be a cakewalk.

@iceknight1093 Please look into it -

Most people cheated on Div2 C and Div2 D using un-natural segment tree solution of ChatGPT - please check such unnatural solutions and similar codes and flag them .

Cheater Segment Tree solutions -
https://www.codechef.com/viewsolution/1129724694
CodeChef: Practical coding for everyone .