# MAXEQEASY - Editorial

Author: raysh07
Tester: tabr
Editorialist: iceknight1093

Simple

None

# PROBLEM:

You’re given an array A.

Compute f(A) as follows:

• Replace each 0 in A with any positive integer of your choice.
• Then, compute the number of pairs (i, j) such that i \lt j and A_i = A_j, i.e, the number of pairs of equal elements in A.
• f(A) is the maximum possible value of this count.

# EXPLANATION:

Let’s ignore the zeros for now, and focus only on the existing non-zero elements.
We want the number of pairs among them that are equal.
Let \text{freq}[x] denote the number of times x appears in the array.
Then, the number of pairs of indices that both contain x is exactly

\frac{\text{freq}[x]\cdot (\text{freq}[x]-1)}{2}

So, among non-zero indices, the number of equal pairs is

\sum_{x=1}^N \frac{\text{freq}[x]\cdot (\text{freq}[x]-1)}{2}

Now, we have to think about what to do with the zeros.
It’s not hard to see that since we want to maximize the number of equal pairs, it’s optimal to set all the zeros to the same value, say y.

If there are k zeros, and we set them all to y, the number of additional new pairs we create is exactly

\frac{k\cdot (k-1)}{2} + k\cdot\text{freq}[y]

The first term comes from pairs within the new copies of y, while the second comes from pairs that involve one new copy and one existing copy.

Since k is a constant, and our aim is to maximize this quantity, clearly we should choose whichever y has the maximum \text{freq}[y].

So, we can now compute f(A) for a fixed array A in linear time:

• First, compute the \text{freq} array of frequencies of non-zero elements.
• Then, find the maximum element of this array, let it be M.
• If there are k zeros in the array, f(A) will equal
\frac{k\cdot (k-1)}{2} + k\cdot M + \sum_{x=1}^N \frac{\text{freq}[x]\cdot (\text{freq}[x]-1)}{2}

# TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

# CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

void Solve()
{
int n; cin >> n;

vector <int> a(n);
for (auto &x : a) cin >> x;

vector <int> f(n + 1, 0);
int mx = 0;
int z = 0;
int ans = 0;
for (auto x : a) if (!x){
z++;
}
for (auto x : a) if (x){
ans += f[x]++;
mx = max(mx, f[x]);
}

ans += mx * (z);
ans += z * (z - 1) / 2;

cout << ans << "\n";
}

int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in",  "r", stdin);
// freopen("out", "w", stdout);

cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}

Editorialist's code (Python)
for _ in range(int(input())):
n = int(input())
a = list(map(int, input().split()))
freq = [0]*(n+1)
for x in a: freq[x] += 1

k, m = a.count(0), max(freq[1:])
ans = k*(k-1)//2 + k*m
for x in freq[1:]: ans += x*(x-1)//2
print(ans)