PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: wuhudsm
Tester: tabr
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Frequency arrays
PROBLEM:
For an array B, let f(B) denote the frequency of the mode of B.
Given an array A, process the following Q queries:
- Given i and x, set A_i := x
- Then, compute the minimum possible value of \sum_{i=1}^N f(A[1\ldots i]) across all rearrangements of A.
EXPLANATION:
We first figure out how to compute the minimum value of \sum_{i=1}^N f(A[1\ldots i]) across all rearrangements of A.
It’s not hard to see that the following greedy strategy works:
- First, place one copy of every distinct element that appears in A.
- Then, place one copy of every distinct element with frequency \geq 2 that appears in A.
- Then, place one copy of every distinct element with frequency \geq 3 that appears in A.
\vdots
Proof
Our claim is essentially the following:
For any element x and any integer K, in an optimal solution, the K'th occurrence of x in A should appear only after every element has appeared (K-1) times (or all of its occurrences are earlier, for those elements that appear \lt K-1 times).
Suppose this isn’t the case: for some x and K, let A_i be the K'th appearance of x, while there also exists some element y that has \lt K-1 occurrences before index i and an occurrence after index i.
Let j\gt i be the first time y appears after index i.
Consider what happens when we swap A_i and A_j:
- For prefixes \lt i and \geq j, the set of elements remains the same, and the mode doesn’t change.
- For prefixes between i and j, they lose one occurrence of x and gain one occurrence of y.
But, since x already appeared more times than y here before the swap, this change certainly cannot increase the frequency of the mode: at best, it remains the same, and it might even decrease.
So, performing this swap doesn’t worsen the answer, and may make it better.
We can repeatedly make such swaps till we reach a state satisfying the initial claim.
With this in mind, it’s not too hard to compute the actual value of \sum_{i=1}^N f(A[1\ldots i]) either.
Note that we:
- Add 1 to the total for every element that appears in A.
- Add 2 to the total for every element that appears at least twice in A.
- Add 3 to the total for every element that appears at least thrice in A.
\vdots - Add N to the total for every element that appears at least N times in A.
This is easily computed in \mathcal{O}(N) time if the frequency array of A is known.
Next, we need to process updates.
To do that, let’s look at the value we’re computing from a different angle.
Consider some integer x, that appears f_x times in A.
- The first time it appears, the frequency of the mode will be 1.
- The second time it appears, the frequency of the mode will be 2.
\vdots - The f_x'th time it appears, the frequency of the mode will be f_x.
So, looking only at the prefixes of A that end with x, the sum of frequencies of modes is
Obviously, every prefix has to end with some element - so the overall answer is just this summed across all x that appear in the array!
That is, the answer is simply
This allows us to process updates in a fairly simple manner.
When A_i is updated, only the frequencies of two elements change: the old value of A_i, and the new value of A_i.
Recomputing the answer from these two new frequencies is easily done in \mathcal{O}(1) time, leading to a solution that’s \mathcal{O}(N+Q) overall.
TIME COMPLEXITY:
\mathcal{O}(N + Q) per testcase.
CODE:
Author's code (C++)
#include <map>
#include <set>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <bitset>
using namespace std;
typedef double db;
typedef long long ll;
typedef unsigned long long ull;
const int N=400010;
const int LOGN=28;
const ll TMD=0;
const ll INF=2147483647;
int T,n,q;
ll ans;
int a[N],occ[N];
int main()
{
T=1;
//scanf("%d",&T);
while(T--)
{
scanf("%d%d",&n,&q);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
for(int i=1;i<=n;i++) occ[i]=0;
for(int i=1;i<=n;i++) occ[a[i]]++;
for(int i=1;i<=n;i++) ans+=(ll)occ[i]*(occ[i]+1)/2;
for(int i=1;i<=q;i++)
{
int p,x;
scanf("%d%d",&p,&x);
ans-=(ll)occ[a[p]]*(occ[a[p]]+1)/2;
ans-=(ll)occ[x]*(occ[x]+1)/2;
occ[a[p]]--;
occ[x]++;
ans+=(ll)occ[a[p]]*(occ[a[p]]+1)/2;
ans+=(ll)occ[x]*(occ[x]+1)/2;
a[p]=x;
printf("%lld\n",ans);
}
}
//fclose(stdin);
return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif
#define IGNORE_CR
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
#ifdef IGNORE_CR
if (c == '\r') {
continue;
}
#endif
buffer.push_back((char) c);
}
}
string readOne() {
assert(pos < (int) buffer.size());
string res;
while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
assert(!isspace(buffer[pos]));
res += buffer[pos];
pos++;
}
return res;
}
string readString(int min_len, int max_len, const string& pattern = "") {
assert(min_len <= max_len);
string res = readOne();
assert(min_len <= (int) res.size());
assert((int) res.size() <= max_len);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int min_val, int max_val) {
assert(min_val <= max_val);
int res = stoi(readOne());
assert(min_val <= res);
assert(res <= max_val);
return res;
}
long long readLong(long long min_val, long long max_val) {
assert(min_val <= max_val);
long long res = stoll(readOne());
assert(min_val <= res);
assert(res <= max_val);
return res;
}
vector<int> readInts(int size, int min_val, int max_val) {
assert(min_val <= max_val);
vector<int> res(size);
for (int i = 0; i < size; i++) {
res[i] = readInt(min_val, max_val);
if (i != size - 1) {
readSpace();
}
}
return res;
}
vector<long long> readLongs(int size, long long min_val, long long max_val) {
assert(min_val <= max_val);
vector<long long> res(size);
for (int i = 0; i < size; i++) {
res[i] = readLong(min_val, max_val);
if (i != size - 1) {
readSpace();
}
}
return res;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
};
int main() {
input_checker in;
int n = in.readInt(1, 4e5);
in.readSpace();
int q = in.readInt(1, 4e5);
in.readEoln();
auto a = in.readInts(n, 1, n);
in.readEoln();
for (int i = 0; i < n; i++) {
a[i]--;
}
vector<int> cnt(n);
for (int x : a) {
cnt[x]++;
}
map<int, int> mp;
for (int x : cnt) {
mp[x]++;
}
while (q--) {
int p = in.readInt(1, n);
in.readSpace();
int x = in.readInt(1, n);
in.readEoln();
p--;
x--;
mp[cnt[a[p]]]--;
if (mp[cnt[a[p]]] == 0) {
mp.erase(cnt[a[p]]);
}
cnt[a[p]]--;
mp[cnt[a[p]]]++;
a[p] = x;
mp[cnt[a[p]]]--;
if (mp[cnt[a[p]]] == 0) {
mp.erase(cnt[a[p]]);
}
cnt[a[p]]++;
mp[cnt[a[p]]]++;
long long ans = 0;
for (auto t : mp) {
ans += (t.first * 1LL * (t.first + 1) / 2) * t.second;
}
cout << ans << '\n';
}
in.readEof();
return 0;
}
Editorialist's code (Python)
n, q = map(int, input().split())
a = list(map(int, input().split()))
freq = [0]*(n+1)
ans = 0
for x in a:
freq[x] += 1
ans += freq[x]
for _ in range(q):
pos, val = map(int, input().split())
pos -= 1
ans -= freq[a[pos]]
freq[a[pos]] -= 1
a[pos] = val
freq[val] += 1
ans += freq[val]
print(ans)