PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: S. Manuj Nanthan
Preparer: Souradeep Paul
Testers: Satyam, Jatin Garg
Editorialist: Nishank Suresh
DIFFICULTY:
1248
PREREQUISITES:
Basic math, frequency maps (or sorting).
PROBLEM:
Given an array A, count the number of pairs (i, j) such that 1 \leq i \lt j \leq N and \gcd(A_i, A_j) = \text{lcm}(A_i, A_j).
EXPLANATION:
The observation to be made here is as follows:
Proof
We have the following inequalities:
- \gcd(x, y) \leq x and \gcd(x, y) \leq y
- \text{lcm}(x, y) \geq x and \text{lcm}(x, y) \geq y
Putting them together,
\gcd(x, y) \leq x, y \leq \text{lcm}(x, y)
So, when \gcd(x, y) = \text{lcm}(x, y), x and y must both also be equal to this value.
Conversely, if x = y, of course we have \gcd(x, y) = \text{lcm}(x, y) = x = y.
So, the problem reduces to simply counting the number of pairs of equal elements in A.
Note that this still needs to be done faster than \mathcal{O}(N^2) to pass the time limit.
This can be done in several ways, though perhaps the easiest is as follows:
- Build a frequency map of the elements of A (using
map/unordered_map
in C++,TreeMap/Hashmap
in Java, ordict
in python) and then iterate across its elements. - If an element x has a frequency of f_x, then it contributes f_x\cdot(f_x-1)/2 pairs to the answer, so sum this value across all x.
Building the frequency map can be done in \mathcal{O}(N) or \mathcal{O}(N\log N). Iterating across it then takes \mathcal{O}(N).
Note that the answer can be as large as around N^2, which won’t fit in a 32-bit integer for large N. Make sure to use a 64-bit integer datatype.
TIME COMPLEXITY
\mathcal{O}(N) or \mathcal{O}(N\log N) per test case.
CODE:
Editorialist's code (C++)
#include <iostream>
#include <map>
using namespace std;
int main() {
int t; cin >> t;
while (t--) {
int n; cin >> n;
map<int, int> mp;
for (int i = 0; i < n; ++i) {
int x; cin >> x;
mp[x] += 1;
}
long long int ans = 0;
for (auto [x, y] : mp) {
ans += 1LL*y*(y-1)/2;
}
cout << ans << '\n';
}
}
Editorialist's code (Python)
for _ in range(int(input())):
n = int(input())
a = list(map(int, input().split()))
d = dict()
for x in a:
if x not in d:
d[x] = 1
else:
d[x] += 1
ans = 0
for x in d.values():
ans += x*(x-1)//2
print(ans)