# EQPAIR - Editorial

Author: S. Manuj Nanthan
Testers: Satyam, Jatin Garg
Editorialist: Nishank Suresh

1248

# PREREQUISITES:

Basic math, frequency maps (or sorting).

# PROBLEM:

Given an array A, count the number of pairs (i, j) such that 1 \leq i \lt j \leq N and \gcd(A_i, A_j) = \text{lcm}(A_i, A_j).

# EXPLANATION:

The observation to be made here is as follows:

\gcd(x, y) = \text{lcm}(x, y) \iff x = y
Proof

We have the following inequalities:

• \gcd(x, y) \leq x and \gcd(x, y) \leq y
• \text{lcm}(x, y) \geq x and \text{lcm}(x, y) \geq y

Putting them together,
\gcd(x, y) \leq x, y \leq \text{lcm}(x, y)

So, when \gcd(x, y) = \text{lcm}(x, y), x and y must both also be equal to this value.
Conversely, if x = y, of course we have \gcd(x, y) = \text{lcm}(x, y) = x = y.

So, the problem reduces to simply counting the number of pairs of equal elements in A.
Note that this still needs to be done faster than \mathcal{O}(N^2) to pass the time limit.

This can be done in several ways, though perhaps the easiest is as follows:

• Build a frequency map of the elements of A (using map/unordered_map in C++, TreeMap/Hashmap in Java, or dict in python) and then iterate across its elements.
• If an element x has a frequency of f_x, then it contributes f_x\cdot(f_x-1)/2 pairs to the answer, so sum this value across all x.

Building the frequency map can be done in \mathcal{O}(N) or \mathcal{O}(N\log N). Iterating across it then takes \mathcal{O}(N).

Note that the answer can be as large as around N^2, which won’t fit in a 32-bit integer for large N. Make sure to use a 64-bit integer datatype.

# TIME COMPLEXITY

\mathcal{O}(N) or \mathcal{O}(N\log N) per test case.

# CODE:

Editorialist's code (C++)
#include <iostream>
#include <map>
using namespace std;

int main() {
int t; cin >> t;
while (t--) {
int n; cin >> n;
map<int, int> mp;
for (int i = 0; i < n; ++i) {
int x; cin >> x;
mp[x] += 1;
}
long long int ans = 0;
for (auto [x, y] : mp) {
ans += 1LL*y*(y-1)/2;
}
cout << ans << '\n';
}
}

Editorialist's code (Python)
for _ in range(int(input())):
n = int(input())
a = list(map(int, input().split()))
d = dict()
for x in a:
if x not in d:
d[x] = 1
else:
d[x] += 1
ans = 0
for x in d.values():
ans += x*(x-1)//2
print(ans)

1 Like