PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: raysh07
Tester: tabr
Editorialist: iceknight1093
DIFFICULTY:
Simple
PREREQUISITES:
None
PROBLEM:
You’re given an array A.
Compute f(A) as follows:
- Replace each 0 in A with any positive integer of your choice.
- Then, compute the number of pairs (i, j) such that i \lt j and A_i = A_j, i.e, the number of pairs of equal elements in A.
- f(A) is the maximum possible value of this count.
EXPLANATION:
Let’s ignore the zeros for now, and focus only on the existing non-zero elements.
We want the number of pairs among them that are equal.
Let \text{freq}[x] denote the number of times x appears in the array.
Then, the number of pairs of indices that both contain x is exactly
So, among non-zero indices, the number of equal pairs is
Now, we have to think about what to do with the zeros.
It’s not hard to see that since we want to maximize the number of equal pairs, it’s optimal to set all the zeros to the same value, say y.
If there are k zeros, and we set them all to y, the number of additional new pairs we create is exactly
The first term comes from pairs within the new copies of y, while the second comes from pairs that involve one new copy and one existing copy.
Since k is a constant, and our aim is to maximize this quantity, clearly we should choose whichever y has the maximum \text{freq}[y].
So, we can now compute f(A) for a fixed array A in linear time:
- First, compute the \text{freq} array of frequencies of non-zero elements.
- Then, find the maximum element of this array, let it be M.
- If there are k zeros in the array, f(A) will equal
TIME COMPLEXITY:
\mathcal{O}(N) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
void Solve()
{
int n; cin >> n;
vector <int> a(n);
for (auto &x : a) cin >> x;
vector <int> f(n + 1, 0);
int mx = 0;
int z = 0;
int ans = 0;
for (auto x : a) if (!x){
z++;
}
for (auto x : a) if (x){
ans += f[x]++;
mx = max(mx, f[x]);
}
ans += mx * (z);
ans += z * (z - 1) / 2;
cout << ans << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
n = int(input())
a = list(map(int, input().split()))
freq = [0]*(n+1)
for x in a: freq[x] += 1
k, m = a.count(0), max(freq[1:])
ans = k*(k-1)//2 + k*m
for x in freq[1:]: ans += x*(x-1)//2
print(ans)