PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: raysh07
Tester: tabr
Editorialist: iceknight1093
DIFFICULTY:
Simple
PREREQUISITES:
None
PROBLEM:
You’re given an array A, initially filled with zeros.
Process N updates to it, each one giving you X and Y and asking you to set A_X := Y.
After each update, compute f(A) as follows:
- Replace each 0 in A with any positive integer of your choice.
- Then, compute the number of pairs (i, j) such that i \lt j and A_i = A_j, i.e, the number of pairs of equal elements in A.
- f(A) is the maximum possible value of this count.
EXPLANATION:
First, let’s see how we can compute f(A) quickly for a fixed array A.
Let’s ignore the zeros for now, and focus only on the existing non-zero elements.
We want the number of pairs among them that are equal.
Let \text{freq}[x] denote the number of times x appears in the array.
Then, the number of pairs of indices that both contain x is exactly
So, among non-zero indices, the number of equal pairs is
Now, we have to think about what to do with the zeros.
It’s not hard to see that since we want to maximize the number of equal pairs, it’s optimal to set all the zeros to the same value, say y.
If there are k zeros, and we set them all to y, the number of additional new pairs we create is exactly
The first term comes from pairs within the new copies of y, while the second comes from pairs that involve one new copy and one existing copy.
Since k is a constant, and our aim is to maximize this quantity, clearly we should choose whichever y has the maximum \text{freq}[y].
So, we can now compute f(A) for a fixed array A in linear time:
- First, compute the \text{freq} array of frequencies of non-zero elements.
- Then, find the maximum element of this array, let it be M.
- If there are k zeros in the array, f(A) will equal
Of course, running this in linear time after every update is going to be too slow.
However, observe that upon an update that sets A_X := Y, very little actually changes:
- k decreases by 1, since we have one fewer zero.
- \text{freq}[Y] increases by 1.
- So, only one term changes in the \sum_{x=1}^N \frac{\text{freq}[x]\cdot (\text{freq}[x]-1)}{2} expression.
Thus, instead of recomputing the whole thing, we can update the sum in \mathcal{O}(1) with the new value of the term corresponding to Y.
- So, only one term changes in the \sum_{x=1}^N \frac{\text{freq}[x]\cdot (\text{freq}[x]-1)}{2} expression.
- M, the maximum value of \text{freq}, either remains the same or becomes equal to \text{freq}[Y] (after \text{freq}[Y] has been increased by 1).
That is, we can simply set M = \max(M, \text{freq}[Y]).
So, in constant time, all three of k, M, and the summation can be updated.
This means the actual answer can also be updated in constant time, giving us a algorithm that runs in \mathcal{O}(N) time.
TIME COMPLEXITY:
\mathcal{O}(N) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
void Solve()
{
int n; cin >> n;
vector <int> f(n + 1, 0);
int ans = 0;
int mx = 0;
for (int i = 1; i <= n; i++){
int x, y; cin >> x >> y;
ans += f[y]++;
mx = max(mx, f[y]);
int v = mx * (n - i) + (n - i) * (n - i - 1) / 2;
cout << (ans + v) << " \n"[i == n];
}
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
void solve(int n, vector<int>& x, vector<int>& y) {
vector<int> cnt(n + 1);
int mx = 0;
long long sum = 0;
for (int i = 0; i < n; i++) {
sum += cnt[y[i]];
cnt[y[i]] += 1;
mx = max(mx, cnt[y[i]]);
long long ans = sum;
ans -= mx * 1LL * (mx - 1) / 2;
ans += (mx + n - i - 1) * 1LL * (mx + n - i - 2) / 2;
cout << ans << " \n"[i == n - 1];
}
}
////////////////////////////////////////
#define IGNORE_CR
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
#ifdef IGNORE_CR
if (c == '\r') {
continue;
}
#endif
buffer.push_back((char) c);
}
}
string readOne() {
assert(pos < (int) buffer.size());
string res;
while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
assert(!isspace(buffer[pos]));
res += buffer[pos];
pos++;
}
return res;
}
string readString(int min_len, int max_len, const string& pattern = "") {
assert(min_len <= max_len);
string res = readOne();
assert(min_len <= (int) res.size());
assert((int) res.size() <= max_len);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int min_val, int max_val) {
assert(min_val <= max_val);
int res = stoi(readOne());
assert(min_val <= res);
assert(res <= max_val);
return res;
}
long long readLong(long long min_val, long long max_val) {
assert(min_val <= max_val);
long long res = stoll(readOne());
assert(min_val <= res);
assert(res <= max_val);
return res;
}
vector<int> readInts(int size, int min_val, int max_val) {
assert(min_val <= max_val);
vector<int> res(size);
for (int i = 0; i < size; i++) {
res[i] = readInt(min_val, max_val);
if (i != size - 1) {
readSpace();
}
}
return res;
}
vector<long long> readLongs(int size, long long min_val, long long max_val) {
assert(min_val <= max_val);
vector<long long> res(size);
for (int i = 0; i < size; i++) {
res[i] = readLong(min_val, max_val);
if (i != size - 1) {
readSpace();
}
}
return res;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
};
int main() {
input_checker in;
int tt = in.readInt(1, 1e4);
in.readEoln();
int sn = 0;
while (tt--) {
int n = in.readInt(1, 2e5);
in.readEoln();
sn += n;
vector<int> x(n), y(n);
for (int i = 0; i < n; i++) {
x[i] = in.readInt(1, n);
in.readSpace();
y[i] = in.readInt(1, n);
in.readEoln();
}
assert((int) set<int>(x.begin(), x.end()).size() == n);
solve(n, x, y);
}
cerr << sn << endl;
assert(sn <= 2e5);
in.readEof();
return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
n = int(input())
freq = [0]*(n+1)
mxfreq, sm = 0, 0
for i in range(n):
x, y = map(int, input().split())
sm += freq[y]
freq[y] += 1
mxfreq = max(mxfreq, freq[y])
rem = n-1-i
print(sm + rem*mxfreq + rem*(rem-1)//2)