MODE - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Authors: krypto_ray
Testers: iceknight1093, tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Math, binary search

PROBLEM:

You’re given an array A. In one move, you can replace any element of it with any other integer.

For each i from 1 to N, find the minimum number of moves needed to make the number of modes exactly i.

EXPLANATION:

First, note that making the number of modes \gt N/2 but \lt N is impossible. So, for all such integers, the answer is -1.

Now, let’s see how to solve this for a fixed i.
Let’s also fix x, the value of the mode itself.

This means that:

  • Exactly i integers appear exactly x times in the array. In particular, this means i\cdot x \leq N must hold.
  • Every other integer appears at most x-1 times in the array.

This requires us to work with frequencies, so for each integer k, let \text{freq}[k] denote the number of times k appears in the array.
The frequency table can be precomputed with a map.

Now, for this fixed i and x, we need to do the following:

  • First, bring every element with frequency \gt x, down to frequency x.
    If \text{freq}[k] \gt x, this takes exactly \text{freq}[k] - x moves.
    To find the sum of all of these, it suffices to know the sum of all frequencies that are \gt x, as well as the number of such frequencies. This can be precomputed from the sorted frequency table, and is just a couple of suffix sums.
    Let the number of moves required to do this be s_1.
  • Now that everything has frequency \leq x, let’s look at the number of things with frequency exactly x. Suppose there are c such numbers. c can be computed in \mathcal{O}(1) or \mathcal{O}(\log N) (once again, it’s a suffix of the sorted frequency array).
  • If c \geq i, then we need to make c-i of them have frequency x-1, which takes another c-i moves. So, in this case the number of moves needed is s_1 + c - i.
  • If c \lt i, then we need to bring some elements with frequency \lt x, up to frequency x.
    In particular, we need x - \text{freq}[k] moves for each k with smaller frequency that we pick.

For the last case, it’s clearly optimal to choose those k with the largest possible frequency (while also being \lt x).
So, we want the sum of the largest c-i frequencies whose values don’t exceed x.

Notice that this represents a subarray of the sorted frequency array, and the right endpoint of this subarray can be computed in \mathcal{O}(\log N) with binary search (it’s the largest frequency that’s \lt x).
Let this sum be s_2.

Notice that in this case, we need exactly \max(s_1, s_2) moves: when removing ‘excess’ frequency, we can use the move to increase a lower frequency, and hence no moves are wasted.

So, a fixed i and x can be processed in \mathcal{O}(\log N) or even \mathcal{O}(1) (with two pointers).

As noted earlier, only pairs of i and x such that i\cdot x \leq N matter, and there are \mathcal{O}(N\log N) such pairs (because \frac{N}{1} + \frac{N}{2} + \frac{N}{3} + \ldots + \frac{N}{N} is \mathcal{O}(N\log N)).

This solves the problem.

TIME COMPLEXITY

\mathcal{O}(N\log N) per test case.

CODE:

Setter's code (C++)
#include<bits/stdc++.h>
using namespace std;
using ll=long long;
const ll inf=1e16;

#ifdef ANI
#include "D:/DUSTBIN/local_inc.h"
#else
#define dbg(...) 0
#endif

vector<ll> solution(vector<ll> arr) {
    ll n=arr.size();
    vector<ll>freq(n+1);
    vector<pair<ll,ll>>v;
    map<ll,ll>mp;
    for(ll i=0;i<n;i++){
        mp[arr[i]]++;
    }
    for(auto el:mp){
        ll x=el.first, y=el.second;
        freq[y]++;
    }
    vector<ll>ans(n+1,1e9);
    ans[n]=0;
    v.push_back({0,1e9});
    for(ll i=1;i<=n;i++){
        if(freq[i]>0){
            v.push_back({i,freq[i]});
            ans[n]+=(i-1)*freq[i];
        } 
    }
    ll suff=0,cnt=0;
    
    for(ll i=n;i>=2;i--){
        ll op=suff-cnt*i,curr=freq[i]+cnt,x=1,redu_ele=op;
        pair<ll,ll>p={i,0};
        ll idx=lower_bound(v.begin(),v.end(),p)-v.begin();
        idx--;
        for(ll j=1;j<=n/i;j++){
            if(j<=curr){
                ans[j]=min(ans[j],op+(curr-j));
            }
            else{
                if(v[idx].second>=x){
                    x++;
                }
                else{
                    idx--;
                    x=2;
                }
                ll value=max(i-v[idx].first-redu_ele,0ll);
                redu_ele-=min(redu_ele,max(i-v[idx].first-value,0ll));
                op+=(value);
                ans[j]=min(ans[j],op);
            }
        }
        suff+=(freq[i]*i);
        cnt+=freq[i];
    }
    for(auto &x:ans) {
        if(x==1e9) x=-1;
    }
    return vector<ll>(ans.begin()+1,ans.end());
}


void solve() {
    ll t;
    cin>>t;
    assert(t<=100000);
    ll tot=0;
    while(t--) {
        ll n;
        cin>>n;
        tot+=n;
        vector<ll> a(n);
        ll nax=1e9;
        for(ll i=0;i<n;i++) {
            cin>>a[i];
            assert(a[i]<=nax);
        }
        vector<ll> res=solution(a);
        for(ll i=0;i<n;i++) {
            cout<<res[i]<<" \n"[i==n-1];
        }
    }
    assert(tot<=200000);
}

int main() {
	solve();	
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    input_checker in;
    int tt = in.readInt(1, 1e5);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 2e5);
        in.readEoln();
        sn += n;
        vector<int> a = in.readInts(n, 1, 1e9);
        in.readEoln();
        map<int, int> cnt;
        for (int i = 0; i < n; i++) {
            cnt[a[i]]++;
        }
        vector<int> b;
        for (auto p : cnt) {
            b.emplace_back(p.second);
        }
        sort(b.rbegin(), b.rend());
        while ((int) b.size() < n + 1) {
            b.emplace_back(0);
        }
        vector<int> pref(n + 1);
        for (int i = 0; i < n; i++) {
            pref[i + 1] = pref[i] + b[i];
        }
        for (int i = 1; i <= n; i++) {
            int ans = 1e9;
            for (int j = pref[i] / i; j <= pref[i] / i + 1; j++) {
                if (j == 1 && i != n) {
                    continue;
                }
                if (i * j > n) {
                    continue;
                }
                int x = 0, y = 0;
                {
                    int low = -1, high = i - 1;
                    while (high - low > 1) {
                        int mid = (high + low) >> 1;
                        if (b[mid] > j) {
                            low = mid;
                        } else {
                            high = mid;
                        }
                    }
                    x = pref[high] - j * high;
                    y = j * (i - high) - (pref[i] - pref[high]);
                }
                if (b[i] == j) {
                    x += (int) (b.rend() - lower_bound(b.rbegin(), b.rend(), j)) - i;
                }
                ans = min(ans, max(x, y));
            }
            if (ans > 1e8) {
                ans = -1;
            }
            cout << ans << " \n"[i == n];
        }
    }
    assert(sn <= 2e5);
    in.readEof();
    return 0;
}
Editorialist's code (Python)
from bisect import bisect_left

for _ in range(int(input())):
	n = int(input())
	a = list(map(int, input().split()))
	
	freq = {}
	for x in a:
		if x not in freq: freq[x] = 0
		freq[x] += 1
	
	freqct = [0]*(n+1)
	for x in freq.values(): freqct[x] += 1
	
	freqs = [0]
	for x in range(1, n+1):
		if freqct[x] > 0: freqs.append(x)
	
	suffsum, suffct = [0]*(n+1), [0]*(n+1)
	for i in reversed(range(1, n+1)):
		suffct[i], suffsum[i] = freqct[i], freqct[i] * i
		if i < n:
			suffct[i] += suffct[i+1]
			suffsum[i] += suffsum[i+1]
	
	ans = [-1]*(n+1)
	
	for mxfreq in range(1, n+1):
		mxct = 1

		ptr = bisect_left(freqs, mxfreq) - 1
		rem = freqct[freqs[ptr]]
		if ptr == 0: rem = n+100
		smallsum = 0

		while mxct*mxfreq <= n:
		
			# Everything >= mxfreq should be brought down -> S1 moves
			# If there are x >= ct such things, (x - ct) more moves
			# If there are x < ct such things, take the largest (ct - x) of them and bring them up to x -> S2 moves
			# ans = max(S1, S2)
			
			moresum, morect = suffsum[mxfreq], suffct[mxfreq]
			moves = moresum - mxfreq*morect
			if morect >= mxct:
				moves += morect - mxct
			else:
				smallsum += mxfreq - freqs[ptr]
				rem -= 1
				if rem == 0:
					ptr -= 1
					rem = freqct[freqs[ptr]]
					if ptr == 0: rem = n+100
				moves = max(moves, smallsum)
			
			if 2*mxct <= n or mxct == n:
				if ans[mxct] == -1: ans[mxct] = moves
				else: ans[mxct] = min(ans[mxct], moves)
			mxct += 1

	print(*ans[1:])
1 Like