ALLEQ - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iceknight1093
Tester: mexomerf
Editorialist: iceknight1093

DIFFICULTY:

easy-medium

PREREQUISITES:

Math

PROBLEM:

An integer X is called A-good if \gcd(A_i, X) = \gcd(A_j, X) for all 1 \leq i \lt j \leq N.
f(A) denotes the number of A-good integers between 1 and M.
Given N and M, find the sum of f(A) across all arrays containing N integers between 1 and M.

EXPLANATION:

Let B_i = \gcd(A_I, X).
Observe that no matter what the array A is, B_i will be a factor of X.
So, if all the elements of B are equal (which is what X being A-good means), this singular element of B must itself be a factor of X.

Let’s fix X and a factor of X, say d, and try to count the number of arrays such that taking the GCD of every element with X results in d.


First, observe that we can now solve for each A_i independently: that is, if there are P choices for what A_1 can be, then every other index will have those exact same P possibilities too.
So, we can try to find P for A_1, and then the count in this case is just P^N.
So, we focus on that.

Since X is a multiple of d, we have X = k_1\cdot d for some positive integer k_1.
A_1 must also be a multiple of d, say A_1 = k_2\cdot d.

Note that \gcd(X, A_1) = \gcd(k_1d, k_2d) = d\cdot\gcd(k_1, k_2), which can equal d if and only if \gcd(k_1, k_2) = 1.

k_1 is fixed, so we really just want to count the number of integers k_2 such that k_2\leq \frac{M}{d} and \gcd(k_1, k_2) = 1.

To do this, we use the inclusion-exclusion principle.
Let the distinct primes dividing k_1 be p_1, p_2, \ldots, p_m.
Note that a value of k_2 is valid if and only if it contains none of these primes as a factor.

So, do the following:

  • Start with all integers from 1 to \frac{M}{d}
  • Subtract out multiples of any one p_i.
  • Add in multiples of each pair of the p_i, which would’ve been subtracted out twice.
  • Subtract out multiples of each triple of the p_i, which were both added and subtracted thrice.
    \vdots

Essentially, subtract out multiples of an odd number of the p_i, and add multiples of even numbers of them.

The end result of this is the value P, so add P^N to the answer.

To quickly prime factorize a number, store the prime divisors of all numbers from 1 to M using a sieve.


Let’s analyze the complexity of the above algorithm.
For each pair of (X, d) where X divides d, if \frac X d has k distinct prime factors, we do 2^k work by essentially iterating through each mask of these factors.

At a surface level, this seems somewhat hard to analyze beyond a cursory “there aren’t too many prime factors” bound.
However, we can do better!

Note that since we iterate through products of distinct prime factors of \frac X d, every element we iterate through is a factor of \frac X d.
This means, for a fixed d, we’re essentially iterating through all values of \frac X d (of which there are \frac M d), and then across (at most) all factors of each such value.

That is, our complexity is bounded by

\sum_{d=1}^M \sum_{y=1}^{\frac M d} \tau(y)

where \tau(y) denotes the number of divisors of y.

The harmonic lemma tells us that the inner summation is \displaystyle\mathcal{O}\left(\frac{M}{d}\log{\frac{M}{d}}\right).

Applying the harmonic lemma again, and bounding \log{\frac M d} above by just \log M, we see that the sum of this across all d from 1 to M is bounded above by \mathcal{O}(M\log^2 M) which is good enough for our purposes.

TIME COMPLEXITY:

\mathcal{O}(M\log^2 M + M\log M \log N) per testcase.

CODE:

Author's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

/**
 * Integers modulo p, where p is a prime
 * Source: Aeren (modified from tourist?)
 *         Modmul for 64-bit mod from kactl:ModMulLL
 * Works with p < 7.2e18 with x87 80-bit long double, and p < 2^52 ~ 4.5e12 with 64-bit
 */
template<typename T>
struct Z_p{
	using Type = typename decay<decltype(T::value)>::type;
	static vector<Type> MOD_INV;
	constexpr Z_p(): value(){ }
	template<typename U> Z_p(const U &x){ value = normalize(x); }
	template<typename U> static Type normalize(const U &x){
		Type v;
		if(-mod() <= x && x < mod()) v = static_cast<Type>(x);
		else v = static_cast<Type>(x % mod());
		if(v < 0) v += mod();
		return v;
	}
	const Type& operator()() const{ return value; }
	template<typename U> explicit operator U() const{ return static_cast<U>(value); }
	constexpr static Type mod(){ return T::value; }
	Z_p &operator+=(const Z_p &otr){ if((value += otr.value) >= mod()) value -= mod(); return *this; }
	Z_p &operator-=(const Z_p &otr){ if((value -= otr.value) < 0) value += mod(); return *this; }
	template<typename U> Z_p &operator+=(const U &otr){ return *this += Z_p(otr); }
	template<typename U> Z_p &operator-=(const U &otr){ return *this -= Z_p(otr); }
	Z_p &operator++(){ return *this += 1; }
	Z_p &operator--(){ return *this -= 1; }
	Z_p operator++(int){ Z_p result(*this); *this += 1; return result; }
	Z_p operator--(int){ Z_p result(*this); *this -= 1; return result; }
	Z_p operator-() const{ return Z_p(-value); }
	template<typename U = T>
	typename enable_if<is_same<typename Z_p<U>::Type, int>::value, Z_p>::type &operator*=(const Z_p& rhs){
		#ifdef _WIN32
		uint64_t x = static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value);
		uint32_t xh = static_cast<uint32_t>(x >> 32), xl = static_cast<uint32_t>(x), d, m;
		asm(
			"divl %4; \n\t"
			: "=a" (d), "=d" (m)
			: "d" (xh), "a" (xl), "r" (mod())
		);
		value = m;
		#else
		value = normalize(static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value));
		#endif
		return *this;
	}
	template<typename U = T>
	typename enable_if<is_same<typename Z_p<U>::Type, int64_t>::value, Z_p>::type &operator*=(const Z_p &rhs){
		uint64_t ret = static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value) - static_cast<uint64_t>(mod()) * static_cast<uint64_t>(1.L / static_cast<uint64_t>(mod()) * static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value));
		value = normalize(static_cast<int64_t>(ret + static_cast<uint64_t>(mod()) * (ret < 0) - static_cast<uint64_t>(mod()) * (ret >= static_cast<uint64_t>(mod()))));
		return *this;
	}
	template<typename U = T>
	typename enable_if<!is_integral<typename Z_p<U>::Type>::value, Z_p>::type &operator*=(const Z_p &rhs){
		value = normalize(value * rhs.value);
		return *this;
	}
	template<typename U>
	Z_p &operator^=(U e){
		if(e < 0) *this = 1 / *this, e = -e;
		Z_p res = 1;
		for(; e; *this *= *this, e >>= 1) if(e & 1) res *= *this;
		return *this = res;
	}
	template<typename U>
	Z_p operator^(U e) const{
		return Z_p(*this) ^= e;
	}
	Z_p &operator/=(const Z_p &otr){
		Type a = otr.value, m = mod(), u = 0, v = 1;
		if(a < (int)MOD_INV.size()) return *this *= MOD_INV[a];
		while(a){
			Type t = m / a;
			m -= t * a; swap(a, m);
			u -= t * v; swap(u, v);
		}
		assert(m == 1);
		return *this *= u;
	}
	template<typename U> friend const Z_p<U> &abs(const Z_p<U> &v){ return v; }
	Type value;
};
template<typename T> bool operator==(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value == rhs.value; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(const Z_p<T>& lhs, U rhs){ return lhs == Z_p<T>(rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) == rhs; }
template<typename T> bool operator!=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(const Z_p<T> &lhs, U rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(U lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T> bool operator<(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value < rhs.value; }
template<typename T> bool operator>(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value > rhs.value; }
template<typename T> bool operator<=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value <= rhs.value; }
template<typename T> bool operator>=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value >= rhs.value; }
template<typename T> Z_p<T> operator+(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(const Z_p<T> &lhs, U rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T> Z_p<T> operator-(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T> Z_p<T> operator*(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T> Z_p<T> operator/(const Z_p<T> &lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(const Z_p<T>& lhs, U rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(U lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T> istream &operator>>(istream &in, Z_p<T> &number){
	typename common_type<typename Z_p<T>::Type, int64_t>::type x;
	in >> x;
	number.value = Z_p<T>::normalize(x);
	return in;
}
template<typename T> ostream &operator<<(ostream &out, const Z_p<T> &number){ return out << number(); }

/*
using ModType = int;
struct VarMod{ static ModType value; };
ModType VarMod::value;
ModType &mod = VarMod::value;
using Zp = Z_p<VarMod>;
*/

// constexpr int mod = 1e9 + 7; // 1000000007
constexpr int mod = (119 << 23) + 1; // 998244353
// constexpr int mod = 1e9 + 9; // 1000000009
using Zp = Z_p<integral_constant<decay<decltype(mod)>::type, mod>>;

template<typename T> vector<typename Z_p<T>::Type> Z_p<T>::MOD_INV;
template<typename T = integral_constant<decay<decltype(mod)>::type, mod>>
void precalc_inverse(int SZ){
	auto &inv = Z_p<T>::MOD_INV;
	if(inv.empty()) inv.assign(2, 1);
	for(; inv.size() <= SZ; ) inv.push_back((mod - 1LL * mod / (int)inv.size() * inv[mod % (int)inv.size()]) % mod);
}

template<typename T>
vector<T> precalc_power(T base, int SZ){
	vector<T> res(SZ + 1, 1);
	for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * base;
	return res;
}

template<typename T>
vector<T> precalc_factorial(int SZ){
	vector<T> res(SZ + 1, 1); res[0] = 1;
	for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * i;
	return res;
}

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    const int maxM = 5e5 + 10;
    vector<basic_string<int>> primes(maxM);
    for (int i = 2; i < maxM; ++i) if (primes[i].empty()) {
        for (int j = i; j < maxM; j += i)
            primes[j] += i;
    }

    int t; cin >> t;
    while (t--) {
        int n, m; cin >> n >> m;
        Zp ans = 0;

        for (int x = 1; x <= m; ++x) for (int y = x; y <= m; y += x) {
            int k = primes[y/x].size();
            int ct = 0;
            for (int mask = 0; mask < 1 << k; ++mask) {
                int prod = 1, mul = 1;
                for (int i = 0; i < k; ++i) {
                    if ((mask >> i) & 1) {
                        prod *= primes[y/x][i];
                        mul *= -1;
                    }
                }
                
                // Multiples of prod are bad
                ct += mul * ((m / x) / prod);
            }

            ans += Zp(ct) ^ n;
        }
        cout << ans << '\n';
    }
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define md 998244353
#define int long long
#define N 500001
vector<int> dp[N];
int modex(int a, int b){
    if(b == 0){
        return 1;
    }
    int temp = modex(a, b / 2);
    temp *= temp;
    temp %= md;
    if(b % 2){
        temp *= a;
        temp %= md;
    }
    return temp;
}
void cal(vector<int> &v, int cnt, int k, int num, int &res, int temp){
    if(cnt == v.size()){
        int tempp = (1 - 2 * (num % 2)) * (k / temp);
        //cout<<tempp<<"\n";
        res += tempp;
        return;
    }
    cal(v, cnt + 1, k, num, res, temp);
    cal(v, cnt + 1, k, num + 1, res, temp * v[cnt]);
}
int32_t main() {
    for(int i = 2; i < N; i++){
        if(dp[i].size() == 0){
            int x = i;
            while(x < N){
                dp[x].push_back(i);
                x += i;
            }
        }
    }
	int t;
	cin>>t;
	while(t--){
	    int n, m;
	    cin>>n>>m;
	    int ans = 0;
	    for(int i = 1; i <= m; i++){
	        int x = i;
	        int cnt = 1;
	        while(x <= m){
	            int res = 0;
	            //cout<<cnt<<":\n";
	            cal(dp[cnt], 0, m / i, 0, res, 1);
	            //cout<<res<<"\n";
	            ans += modex(res, n);
	            ans %= md;
	            x += i;
	            cnt++;
	        }
	    }
	    cout<<ans<<"\n";
	}
}

I got the part till bitmask, can anyone pls explain how to find out coprimes without using bitmasks over all values?