MODE_PROBLEM - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: wuhudsm
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Frequency arrays

PROBLEM:

For an array B, let f(B) denote the frequency of the mode of B.
Given an array A, process the following Q queries:

  • Given i and x, set A_i := x
  • Then, compute the minimum possible value of \sum_{i=1}^N f(A[1\ldots i]) across all rearrangements of A.

EXPLANATION:

We first figure out how to compute the minimum value of \sum_{i=1}^N f(A[1\ldots i]) across all rearrangements of A.

It’s not hard to see that the following greedy strategy works:

  • First, place one copy of every distinct element that appears in A.
  • Then, place one copy of every distinct element with frequency \geq 2 that appears in A.
  • Then, place one copy of every distinct element with frequency \geq 3 that appears in A.
    \vdots
Proof

Our claim is essentially the following:
For any element x and any integer K, in an optimal solution, the K'th occurrence of x in A should appear only after every element has appeared (K-1) times (or all of its occurrences are earlier, for those elements that appear \lt K-1 times).

Suppose this isn’t the case: for some x and K, let A_i be the K'th appearance of x, while there also exists some element y that has \lt K-1 occurrences before index i and an occurrence after index i.
Let j\gt i be the first time y appears after index i.

Consider what happens when we swap A_i and A_j:

  • For prefixes \lt i and \geq j, the set of elements remains the same, and the mode doesn’t change.
  • For prefixes between i and j, they lose one occurrence of x and gain one occurrence of y.
    But, since x already appeared more times than y here before the swap, this change certainly cannot increase the frequency of the mode: at best, it remains the same, and it might even decrease.

So, performing this swap doesn’t worsen the answer, and may make it better.
We can repeatedly make such swaps till we reach a state satisfying the initial claim.


With this in mind, it’s not too hard to compute the actual value of \sum_{i=1}^N f(A[1\ldots i]) either.
Note that we:

  • Add 1 to the total for every element that appears in A.
  • Add 2 to the total for every element that appears at least twice in A.
  • Add 3 to the total for every element that appears at least thrice in A.
    \vdots
  • Add N to the total for every element that appears at least N times in A.

This is easily computed in \mathcal{O}(N) time if the frequency array of A is known.


Next, we need to process updates.
To do that, let’s look at the value we’re computing from a different angle.
Consider some integer x, that appears f_x times in A.

  • The first time it appears, the frequency of the mode will be 1.
  • The second time it appears, the frequency of the mode will be 2.
    \vdots
  • The f_x'th time it appears, the frequency of the mode will be f_x.

So, looking only at the prefixes of A that end with x, the sum of frequencies of modes is

1 + 2 + 3 + \ldots + f_x = \frac{f_x \cdot (f_x + 1)}{2}

Obviously, every prefix has to end with some element - so the overall answer is just this summed across all x that appear in the array!
That is, the answer is simply

\sum_{x=1}^N \frac{f_x \cdot (f_x + 1)}{2}

This allows us to process updates in a fairly simple manner.
When A_i is updated, only the frequencies of two elements change: the old value of A_i, and the new value of A_i.
Recomputing the answer from these two new frequencies is easily done in \mathcal{O}(1) time, leading to a solution that’s \mathcal{O}(N+Q) overall.

TIME COMPLEXITY:

\mathcal{O}(N + Q) per testcase.

CODE:

Author's code (C++)
#include <map>
#include <set>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <bitset>
using namespace std;
typedef double db; 
typedef long long ll;
typedef unsigned long long ull;
const int N=400010;
const int LOGN=28;
const ll  TMD=0;
const ll  INF=2147483647;
int T,n,q;
ll  ans;
int a[N],occ[N];

int main()
{
	T=1;
	//scanf("%d",&T);
	while(T--)
	{
		scanf("%d%d",&n,&q);
		for(int i=1;i<=n;i++) scanf("%d",&a[i]);
  		for(int i=1;i<=n;i++) occ[i]=0;
  		for(int i=1;i<=n;i++) occ[a[i]]++;
  		for(int i=1;i<=n;i++) ans+=(ll)occ[i]*(occ[i]+1)/2;
  		for(int i=1;i<=q;i++)
  		{
			int p,x;
			scanf("%d%d",&p,&x);
			ans-=(ll)occ[a[p]]*(occ[a[p]]+1)/2;
			ans-=(ll)occ[x]*(occ[x]+1)/2;
			occ[a[p]]--;
			occ[x]++;
			ans+=(ll)occ[a[p]]*(occ[a[p]]+1)/2;
			ans+=(ll)occ[x]*(occ[x]+1)/2;
			a[p]=x;
			printf("%lld\n",ans);
    	}
	}
	
	//fclose(stdin);
	return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

#define IGNORE_CR

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
#ifdef IGNORE_CR
            if (c == '\r') {
                continue;
            }
#endif
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            assert(!isspace(buffer[pos]));
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    input_checker in;
    int n = in.readInt(1, 4e5);
    in.readSpace();
    int q = in.readInt(1, 4e5);
    in.readEoln();
    auto a = in.readInts(n, 1, n);
    in.readEoln();
    for (int i = 0; i < n; i++) {
        a[i]--;
    }
    vector<int> cnt(n);
    for (int x : a) {
        cnt[x]++;
    }
    map<int, int> mp;
    for (int x : cnt) {
        mp[x]++;
    }
    while (q--) {
        int p = in.readInt(1, n);
        in.readSpace();
        int x = in.readInt(1, n);
        in.readEoln();
        p--;
        x--;
        mp[cnt[a[p]]]--;
        if (mp[cnt[a[p]]] == 0) {
            mp.erase(cnt[a[p]]);
        }
        cnt[a[p]]--;
        mp[cnt[a[p]]]++;
        a[p] = x;
        mp[cnt[a[p]]]--;
        if (mp[cnt[a[p]]] == 0) {
            mp.erase(cnt[a[p]]);
        }
        cnt[a[p]]++;
        mp[cnt[a[p]]]++;
        long long ans = 0;
        for (auto t : mp) {
            ans += (t.first * 1LL * (t.first + 1) / 2) * t.second;
        }
        cout << ans << '\n';
    }
    in.readEof();
    return 0;
}
Editorialist's code (Python)
n, q = map(int, input().split())
a = list(map(int, input().split()))
freq = [0]*(n+1)
ans = 0
for x in a:
    freq[x] += 1
    ans += freq[x]

for _ in range(q):
    pos, val = map(int, input().split())
    pos -= 1
    ans -= freq[a[pos]]
    freq[a[pos]] -= 1
    a[pos] = val
    freq[val] += 1
    ans += freq[val]
    
    print(ans)
1 Like

https://www.codechef.com/viewsolution/1064711034

Can @iceknight1093 or someone please tell me why my solution didn’t pass?

It passes when used unordered_map but many correct submissions have used same map.

OMG i overcomplicated this easy problem :frowning:

The input and output are somewhat large here, so you need fast io methods.
Your template has a fastio macro but for some reason you’ve never called it.
Apart from that, endl is extremely slow, use \n instead.
Your code with these two fixes.

1 Like

I am just new to programming is it really necessary to include fastio , i just copied this template from some random solution but I never learned how to use it . Shouldn’t it be mentioned in the question to use fastio just like the question in the same contest
Practice Coding Problem which suggest using fastio.

And one more thing if tle is due to large inputs then why unorderedmap without fastio got accepted.

@iceknight1093 can you please help with the above query?