EQPAIR - Editorial


Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: S. Manuj Nanthan
Preparer: Souradeep Paul
Testers: Satyam, Jatin Garg
Editorialist: Nishank Suresh




Basic math, frequency maps (or sorting).


Given an array A, count the number of pairs (i, j) such that 1 \leq i \lt j \leq N and \gcd(A_i, A_j) = \text{lcm}(A_i, A_j).


The observation to be made here is as follows:

\gcd(x, y) = \text{lcm}(x, y) \iff x = y

We have the following inequalities:

  • \gcd(x, y) \leq x and \gcd(x, y) \leq y
  • \text{lcm}(x, y) \geq x and \text{lcm}(x, y) \geq y

Putting them together,
\gcd(x, y) \leq x, y \leq \text{lcm}(x, y)

So, when \gcd(x, y) = \text{lcm}(x, y), x and y must both also be equal to this value.
Conversely, if x = y, of course we have \gcd(x, y) = \text{lcm}(x, y) = x = y.

So, the problem reduces to simply counting the number of pairs of equal elements in A.
Note that this still needs to be done faster than \mathcal{O}(N^2) to pass the time limit.

This can be done in several ways, though perhaps the easiest is as follows:

  • Build a frequency map of the elements of A (using map/unordered_map in C++, TreeMap/Hashmap in Java, or dict in python) and then iterate across its elements.
  • If an element x has a frequency of f_x, then it contributes f_x\cdot(f_x-1)/2 pairs to the answer, so sum this value across all x.

Building the frequency map can be done in \mathcal{O}(N) or \mathcal{O}(N\log N). Iterating across it then takes \mathcal{O}(N).

Note that the answer can be as large as around N^2, which won’t fit in a 32-bit integer for large N. Make sure to use a 64-bit integer datatype.


\mathcal{O}(N) or \mathcal{O}(N\log N) per test case.


Editorialist's code (C++)
#include <iostream>
#include <map>
using namespace std;

int main() {
	int t; cin >> t;
	while (t--) {
	    int n; cin >> n;
	    map<int, int> mp;
	    for (int i = 0; i < n; ++i) {
	        int x; cin >> x;
	        mp[x] += 1;
	    long long int ans = 0;
	    for (auto [x, y] : mp) {
	        ans += 1LL*y*(y-1)/2;
	    cout << ans << '\n';
Editorialist's code (Python)
for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    d = dict()
    for x in a:
        if x not in d:
            d[x] = 1
            d[x] += 1
    ans = 0
    for x in d.values():
        ans += x*(x-1)//2
1 Like