COUNTING - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: wuhudsm
Testers: iceknight1093, tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

The inclusion-exclusion principle

PROBLEM:

Given integers A, B, L, R, count the number of integers x such that L \leq x \leq R and either \gcd(x, A) = 1 or \gcd(x, B) = 1.

EXPLANATION:

First, let’s solve a simpler problem: let’s find the number of x in [L, R] that are coprime to just A.
Rather, we’ll compute the number of x that are not coprime to A.

This is a classic application of the inclusion-exclusion principle, utilizing the fact that A \leq 10^9 means A has very few prime factors (in fact, it’ll have \leq 9 distinct prime factors).

How?

Let p_1, p_2, \ldots, p_k be the distinct primes dividing A (which can be computed in \mathcal{O}(\sqrt{A}) using basic square-root factorization).
Clearly, \gcd(x, A) \gt 1 if and only if at least one of the p_i divide x.

So, let’s count the number of integers in [L, R] that are a multiple of one of the p_i.
This is fairly simple: if you fix p_i, it has \displaystyle \left\lfloor\frac{R}{p_i} \right\rfloor - \left\lfloor\frac{L-1}{p_i} \right\rfloor multiples in this range, so add this to the answer.

However, notice that if something is divisible by both p_1 and p_2, we’ve counted it twice. In fact, this applies to any integer that’s a multiple of p_i and p_j for i \neq j.
So, for each 1 \leq i \lt j \leq k, subtract \displaystyle \left\lfloor\frac{R}{p_i\cdot p_j} \right\rfloor - \left\lfloor\frac{L-1}{p_i\cdot p_j} \right\rfloor from the answer.

But now you’ll notice that if something is a product of \geq 3 of the p_i, we’ve added it thrice and subtracted it thrice, so it isn’t counted anymore!
So, for each product of three primes, add the count of its multiples in the range to the answer.

It’s not too hard to see that this alternating sequence of additions and subtractions will continue till you’ve reached the product of all N primes.
The correctness of this is formalized by the inclusion-exclusion principle, which leads to a solution that is extremely straightforward to state:

Fix a non-empty subset S of the primes. Let M be the product of the elements of S.

  • If |S| is odd, add \left\lfloor\frac{R}{M} \right\rfloor - \left\lfloor\frac{L-1}{M} \right\rfloor to the answer.
  • Otherwise, subtract \left\lfloor\frac{R}{M} \right\rfloor - \left\lfloor\frac{L-1}{M} \right\rfloor from the answer.

This gives us a solution in \mathcal{O}(2^k), and k \leq 9 here so this is extremely fast.

Now let’s use the above algorithm to solve the original problem.

Let c_A be the number of x \in [L, R] such that \gcd(x, A) = 1.
Let c_B be the number of x \in [L, R] such that \gcd(x, B) = 1.
Let c_{AB} be the number of x \in [L, R] such that \gcd(x, A) = 1 and \gcd(x, B) = 1.

The final answer is clearly c_A + c_B - c_{AB}.

Computing c_A and c_B is easy; it’s a direct application of the algorithm discussed above.

As for c_{AB}, note that \gcd(x, A) = 1 and \gcd(x, B) = 1 if and only if \gcd(x, AB) = 1.
So, we can apply the initial algorithm to AB and compute this too.

However, AB can be as large as 10^{18}, so directly prime factorizing it in \mathcal{O}(\sqrt{AB}) might be too slow.
Instead note that we only need to know the set of its prime factors.
This is easy: we computed the set of prime factors of A and B earlier, so simply take their union!

AB has \leq 9+9 = 18 distinct prime factors, and \mathcal{O}(2^{k}) is easily fast enough when k \leq 18.

TIME COMPLEXITY:

\mathcal{O}(\sqrt{A} + \sqrt{B} + 2^k) per testcase, where k \leq 18.

CODE:

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
	string buffer;
	int pos;

	const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
	const string number = "0123456789";
	const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
	const string lower = "abcdefghijklmnopqrstuvwxyz";

	input_checker() {
		pos = 0;
		while (true) {
			int c = cin.get();
			if (c == -1) {
				break;
			}
			buffer.push_back((char) c);
		}
	}

	int nextDelimiter() {
		int now = pos;
		while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
			now++;
		}
		return now;
	}

	string readOne() {
		assert(pos < (int) buffer.size());
		int nxt = nextDelimiter();
		string res;
		while (pos < nxt) {
			res += buffer[pos];
			pos++;
		}
		return res;
	}

	string readString(int minl, int maxl, const string& pattern = "") {
		assert(minl <= maxl);
		string res = readOne();
		assert(minl <= (int) res.size());
		assert((int) res.size() <= maxl);
		for (int i = 0; i < (int) res.size(); i++) {
			assert(pattern.empty() || pattern.find(res[i]) != string::npos);
		}
		return res;
	}

	int readInt(int minv, int maxv) {
		assert(minv <= maxv);
		int res = stoi(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	long long readLong(long long minv, long long maxv) {
		assert(minv <= maxv);
		long long res = stoll(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	void readSpace() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == ' ');
		pos++;
	}

	void readEoln() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == '\n');
		pos++;
	}

	void readEof() {
		assert((int) buffer.size() == pos);
	}
};

template <typename T>
vector<T> factor(T n) {
	n = abs(n);
	vector<T> res;
	for (T i = 2; i * i <= n; i++) {
		if (n % i == 0) {
			res.emplace_back(i);
			while (n % i == 0) {
				n /= i;
			}
		}
	}
	if (n > 1) {
		res.emplace_back(n);
	}
	return res;
}

int main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	input_checker in;
	long long a = in.readInt(1, 1e9);
	in.readSpace();
	long long b = in.readInt(1, 1e9);
	in.readSpace();
	auto x = factor(a), y = factor(b);
	auto z = x;
	z.insert(z.end(), y.begin(), y.end());
	sort(z.begin(), z.end());
	z.resize(unique(z.begin(), z.end()) - z.begin());
	auto Calc = [&](long long n, vector<long long> w) {
		long long res = 0;
		int sz = (int) w.size();
		for (int mask = 0; mask < (1 << sz); mask++) {
			long long c = 1;
			for (int i = 0; i < sz; i++) {
				if (mask & (1 << i)) {
					c *= w[i];
				}
			}
			if (__builtin_parity(mask)) {
				res -= n / c;
			} else {
				res += n / c;
			}
		}
		debug(n, w, res);
		return res;
	};
	auto Solve = [&](long long n) {
		return Calc(n, x) + Calc(n, y) - Calc(n, z);
	};
	long long l = in.readLong(1, 1e18);
	in.readSpace();
	long long r = in.readLong(1, 1e18);
	in.readEoln();
	in.readEof();
	assert(l <= r);
	cout << Solve(r) - Solve(l - 1) << '\n';
	return 0;
}
Editorialist's code (Python)
def prime_factor(x):
	i = 2
	primes = []
	while i*i <= x:
		if x%i == 0:
			primes.append(i)
			while x%i == 0: x //= i
		i += 1
	if x > 1: primes.append(x)
	return primes

def calc(l, r, primes):
	sz = len(primes)
	ans = 0
	for mask in range(1, 2**sz):
		num = 1
		for i in range(sz):
			if mask & (2 ** i): num *= primes[i]
		parity = bin(mask)[2:].count('1') % 2
		ct = r//num - (l-1)//num
		if parity == 1: ans += ct
		else: ans -= ct
	return ans

a, b, l, r = map(int, input().split())
ans = calc(l, r, prime_factor(a)) + calc(l, r, prime_factor(b)) - calc(l, r, list(set(prime_factor(a) + prime_factor(b))))
print(r-l+1-ans)
3 Likes

i am getting little difference in the answer can any one help?
thankyou

#pragma GCC target ("avx2")
#pragma GCC optimization ("unroll-loops")
#pragma GCC optimize("O2")
#include<bits/stdc++.h>
using namespace std;
#define fast ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
#define ordered_set tree<int, null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>
#define ll long long
#define pb push_back
#define mp make_pair
#define endl "\n"
#define int ll
#define vi vector<int>
#define vb vector<bool>
#define vvb vector<vb >
#define pii pair<int,int>
#define ss second
#define ff first
#define vpii vector<pii>
#define vvi vector<vi >
#define vs vector<string>
#define vvs vector<vs >
#define pqi priority_queue <int>
#define minpqi priority_queue <int, vector<int>, greater<int> >
#define all(x) x.begin(),x.end()
#define mii map<int,int>
#define for0(i,n) for(ll i=0;i<n;i++)
#define for1(i,n) for(ll i=1;i<=n;i++)
#define per(i,n) for(ll i=n-1;i>=0;i--)
#define per1(i,n) for(ll i=n;i>0;i--)
#define repeat(i,start,n) for(ll i=start;i<n;i++)
#define inp(arr,n) ll arr[n];rep(i,n){ cin>>arr[i];}
#define inp1(arr,n) ll arr[n+1];rep1(i,n){ cin>>arr[i];}
#define inp2d(arr,n,m) ll arr[n][m];rep(i,n)rep(j,m)cin>>arr[i][j];
#define inpg(adj,m) rep(i,m){int a,b;cin>>a>>b;adj[a].pb(b);adj[b].pb(a);}
#define print(a,n) for(ll i=0;i<n;i++){ cout<<a[i]<<" ";}
#define print1(a,n) for(ll i=1;i<=n;i++){ cout<<a[i]<<endl;}
#define mod 1000000007
#define maxx 1000000000000000000
#define PI 3.141592653589793238462643383279
#define mmax(a,b,c) max(a,max(b,c))
#define mmin(a,b,c) min(a,min(b,c))
#define init(arr,a) memset(arr,a,sizeof(arr))
#define lb lower_bound
#define ub upper_bound
#define er equal_range
#define maxe *max_element
#define mine *min_element
bool sortbysec(const pair<int, int> &a, const pair<int, int> &b) {
    if (a.ss == b.ss)
        return a.ff < b.ff;
    return (a.second < b.second);
}
string convert_to_bit(int a, int bit) {
    string s;
    while (a > 0) {
        s += (a % 2) + 48;
        a /= 2;
    }
    while (s.size() < bit)
        s += '0';
    reverse(s.begin(), s.end());
    return s;
}
int to_int(string s)
{
    int ans = 0;
    string temp;
    int i = 0;
    while (s[i] == '0')
    {
        i++;
    }
    if (i == s.size())
        return 0;
    temp.assign(s, i, s.size());
    int mul = 1;
    for (i = temp.size() - 1; i >= 0; i--)
    {
        int a = temp[i] - '0';
        ans += mul * a;
        mul *= 10;
    }
    return ans;

}
string to_string(int n)
{
    string ans = "";
    if (n == 0)
        return "0";
    while (n > 0)
    {
        int a = n % 10;
        n /= 10;
        char temp = a + '0';
        ans += temp;
    }
    reverse(all(ans));
    return ans;
}
int bin_to_dec(int n)
{
    int num = n;
    int dec_value = 0;
    int base = 1;
    int temp = num;
    while (temp) {
        int last_digit = temp % 10;
        temp = temp / 10;
        dec_value += last_digit * base;
        base = base * 2;
    }
    return dec_value;
}
int binstr_to_dec(string s)
{
    int ans = 0;
    for0(i, s.size())
    if (s[i] == '1')
        ans += pow(2, s.size() - i - 1);
    return ans;
}
int dx[] = {0, 1, 0, -1};
int dy[] = {1, 0, -1, 0};
void faltu()
{
    int a = 2;
    while (a > 0)a--;
}
int power(int x, int y)
{
    if (y == 0)return 1;
    int u = power(x, y / 2);
    u = (u * u) % mod;
    if (y % 2)u = (x * u) % mod;
    return u;

}
int inv(int x)
{
    return power(x % mod, mod - 2);
}
int ncr(int n, int r)
{
    if (n < r)return 0;
    if (n == r | r == 0)return 1;
    int numerator = 1;
    int denominator = 1;
    for1(i, r)denominator = (denominator * i) % mod;
    for0(i, r)numerator = (numerator * ((n - i) % mod)) % mod;
    return (numerator * inv(denominator)) % mod;
}
/*freopen("input.txt","r",stdin);
  freopen("output.txt","w",stdout);*/
/*#include <boost/multiprecision/cpp_int.hpp>
namespace mp = boost::multiprecision;
mp::cpp_int u = 1;*/ //this line in main,here u is the big integer
//##@@
//------------------------------------Code Starts here------------------------------------//


int solution(vector<int>&v, int l, int r) {

    int n = v.size();
    int mask = (1 << n);
    mask--;
    int res = 0;
    for (int i = 0; i < mask; i++) {

        vi temp;

        int prod = 1;
        int cnt = 0;
        for (int j = 0; j < n; j++) {
            if (i & (1ll << j)) {
                cnt++;
                temp.pb(v[j]);
                prod *= v[j];
            }
        }
        if (cnt == 0)continue;
        if (cnt % 2)res += ((r / (int)prod) - ((l - 1) / (int)prod));
        else res -= ((r / (int)prod) - ((l - 1) / (int)prod));



    }
    return res;
}
vi factorise(int x) {

    int z = x;
    vi v;
    for (int i = 2; i * i <= z; i++) {
        if (x % i == 0) {
            v.pb(i);
            while (x % i == 0)x /= i;
        }
    }

    if (x > 1)v.pb(x);

    return v;
}
vector<int>unite(vector<int>&v1, vector<int>&v2)
{
    set<int>s;
    for (auto it : v1)s.insert(it);

    for (auto it : v2)s.insert(it);

    vi v;
    for (auto it : s)v.pb(it);

    return v;
}
void solve()
{
    int a, b, l, r;
    cin >> a >> b >> l >> r;

    int n = r - l + 1;

    vi v = factorise(a);
    //v.pb(1);
    int ans1 = n - solution(v, l, r);

    vi v2 = factorise(b);
    // v2.pb(1);
    int ans2 = n - solution(v2, l, r);
    int lc = (a * b) / __gcd(a, b);
    vi v3 = factorise(lc);
    // v3.pb(1);

    int ans3 = n - solution(v3, l, r);

    cout << (ans1 + ans2 - ans3);
}
main()
{

    int t = 1;
    //cin >> t;
    while (t--)
    {
        solve();
    }
}