PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: wuhudsm
Testers: iceknight1093, tabr
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
The inclusion-exclusion principle
PROBLEM:
Given integers A, B, L, R, count the number of integers x such that L \leq x \leq R and either \gcd(x, A) = 1 or \gcd(x, B) = 1.
EXPLANATION:
First, let’s solve a simpler problem: let’s find the number of x in [L, R] that are coprime to just A.
Rather, we’ll compute the number of x that are not coprime to A.
This is a classic application of the inclusion-exclusion principle, utilizing the fact that A \leq 10^9 means A has very few prime factors (in fact, it’ll have \leq 9 distinct prime factors).
How?
Let p_1, p_2, \ldots, p_k be the distinct primes dividing A (which can be computed in \mathcal{O}(\sqrt{A}) using basic square-root factorization).
Clearly, \gcd(x, A) \gt 1 if and only if at least one of the p_i divide x.
So, let’s count the number of integers in [L, R] that are a multiple of one of the p_i.
This is fairly simple: if you fix p_i, it has \displaystyle \left\lfloor\frac{R}{p_i} \right\rfloor - \left\lfloor\frac{L-1}{p_i} \right\rfloor multiples in this range, so add this to the answer.
However, notice that if something is divisible by both p_1 and p_2, we’ve counted it twice. In fact, this applies to any integer that’s a multiple of p_i and p_j for i \neq j.
So, for each 1 \leq i \lt j \leq k, subtract \displaystyle \left\lfloor\frac{R}{p_i\cdot p_j} \right\rfloor - \left\lfloor\frac{L-1}{p_i\cdot p_j} \right\rfloor from the answer.
But now you’ll notice that if something is a product of \geq 3 of the p_i, we’ve added it thrice and subtracted it thrice, so it isn’t counted anymore!
So, for each product of three primes, add the count of its multiples in the range to the answer.
It’s not too hard to see that this alternating sequence of additions and subtractions will continue till you’ve reached the product of all N primes.
The correctness of this is formalized by the inclusion-exclusion principle, which leads to a solution that is extremely straightforward to state:
Fix a non-empty subset S of the primes. Let M be the product of the elements of S.
- If |S| is odd, add \left\lfloor\frac{R}{M} \right\rfloor - \left\lfloor\frac{L-1}{M} \right\rfloor to the answer.
- Otherwise, subtract \left\lfloor\frac{R}{M} \right\rfloor - \left\lfloor\frac{L-1}{M} \right\rfloor from the answer.
This gives us a solution in \mathcal{O}(2^k), and k \leq 9 here so this is extremely fast.
Now let’s use the above algorithm to solve the original problem.
Let c_A be the number of x \in [L, R] such that \gcd(x, A) = 1.
Let c_B be the number of x \in [L, R] such that \gcd(x, B) = 1.
Let c_{AB} be the number of x \in [L, R] such that \gcd(x, A) = 1 and \gcd(x, B) = 1.
The final answer is clearly c_A + c_B - c_{AB}.
Computing c_A and c_B is easy; it’s a direct application of the algorithm discussed above.
As for c_{AB}, note that \gcd(x, A) = 1 and \gcd(x, B) = 1 if and only if \gcd(x, AB) = 1.
So, we can apply the initial algorithm to AB and compute this too.
However, AB can be as large as 10^{18}, so directly prime factorizing it in \mathcal{O}(\sqrt{AB}) might be too slow.
Instead note that we only need to know the set of its prime factors.
This is easy: we computed the set of prime factors of A and B earlier, so simply take their union!
AB has \leq 9+9 = 18 distinct prime factors, and \mathcal{O}(2^{k}) is easily fast enough when k \leq 18.
TIME COMPLEXITY:
\mathcal{O}(\sqrt{A} + \sqrt{B} + 2^k) per testcase, where k \leq 18.
CODE:
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}
int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
now++;
}
return now;
}
string readOne() {
assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
return res;
}
string readString(int minl, int maxl, const string& pattern = "") {
assert(minl <= maxl);
string res = readOne();
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int minv, int maxv) {
assert(minv <= maxv);
int res = stoi(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
long long res = stoll(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
};
template <typename T>
vector<T> factor(T n) {
n = abs(n);
vector<T> res;
for (T i = 2; i * i <= n; i++) {
if (n % i == 0) {
res.emplace_back(i);
while (n % i == 0) {
n /= i;
}
}
}
if (n > 1) {
res.emplace_back(n);
}
return res;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
input_checker in;
long long a = in.readInt(1, 1e9);
in.readSpace();
long long b = in.readInt(1, 1e9);
in.readSpace();
auto x = factor(a), y = factor(b);
auto z = x;
z.insert(z.end(), y.begin(), y.end());
sort(z.begin(), z.end());
z.resize(unique(z.begin(), z.end()) - z.begin());
auto Calc = [&](long long n, vector<long long> w) {
long long res = 0;
int sz = (int) w.size();
for (int mask = 0; mask < (1 << sz); mask++) {
long long c = 1;
for (int i = 0; i < sz; i++) {
if (mask & (1 << i)) {
c *= w[i];
}
}
if (__builtin_parity(mask)) {
res -= n / c;
} else {
res += n / c;
}
}
debug(n, w, res);
return res;
};
auto Solve = [&](long long n) {
return Calc(n, x) + Calc(n, y) - Calc(n, z);
};
long long l = in.readLong(1, 1e18);
in.readSpace();
long long r = in.readLong(1, 1e18);
in.readEoln();
in.readEof();
assert(l <= r);
cout << Solve(r) - Solve(l - 1) << '\n';
return 0;
}
Editorialist's code (Python)
def prime_factor(x):
i = 2
primes = []
while i*i <= x:
if x%i == 0:
primes.append(i)
while x%i == 0: x //= i
i += 1
if x > 1: primes.append(x)
return primes
def calc(l, r, primes):
sz = len(primes)
ans = 0
for mask in range(1, 2**sz):
num = 1
for i in range(sz):
if mask & (2 ** i): num *= primes[i]
parity = bin(mask)[2:].count('1') % 2
ct = r//num - (l-1)//num
if parity == 1: ans += ct
else: ans -= ct
return ans
a, b, l, r = map(int, input().split())
ans = calc(l, r, prime_factor(a)) + calc(l, r, prime_factor(b)) - calc(l, r, list(set(prime_factor(a) + prime_factor(b))))
print(r-l+1-ans)