PERFSQRS - Editorial

PROBLEM LINK:

Practice
Contest

Setter: Shahjalal Shohag
Tester: Rahul Dugar
Editorialist: Ishmeet Singh Saggu

DIFFICULTY:

Medium-Hard

PREREQUISITES:

Generating Functions, Number Theoretic Transformation, Polynomial Inverses and Mobius Function

PROBLEM:

For a sequence of integers A_1,A_2,…,A_N, let G(A) denote a graph with N nodes (numbered 1 through N) such that for each pair of distinct nodes i and j, there is an undirected edge between these nodes if and only if A_i⋅A_j is a perfect square .

You are given a graph with N nodes (numbered 1 through N) and an integer K. Find the number of integer sequences A_1, A_2,…,A_N such that 1≤A_i≤K for each valid i and the given graph is a subgraph of G(A). Since this number might be very large, compute it modulo 998244353.

EXPLANATION:

Let F(w) denote the number obtained after dividing w with the maximum number which is a square and divides w.

Let’s solve the problem when there is exactly one connected component(of size s) in the graph. The product of two numbers X and Y will be a perfect square iff F(X) = F(Y). So the necessary and sufficient condition of the given graph being a subgraph of G(A) is F(A_1) = F(A_2) = ... = F(A_s).

So we have to find the number of arrays of length s such that 1 \leq A_i \leq K and all F(A_i) are equal. Let’s fix F(A_i) to X. So each A_i can take any value Z \le K s.t F(Z) = X. There are exactly \sqrt \frac{K}{X} such numbers. It is so because all those numbers are the product of X and some perfect square.

So the number of ways for a fixed F(A_i) = X is (\sqrt \frac{K}{X})^s.

And the total number of ways P(s)= \sum (\sqrt \frac{K}{X})^s, where X is a square free number \le K.

As K \le 10^{14}, \sqrt \frac{K}{X} will generate ~10^5 unique numbers under 10^{14}.
Let T = vector of tuples (l_i, r_i, c_i, v_i) s.t. there are c_i square-free numbers in the range (l_i, r_i) and all of them(let them be d_j) yields the same value v_i = \sqrt{\frac{K}{d_j}}.

We can find the ranges using binary search. Then we have to find the number of square-free numbers in a range. Square-free numbers under r is \sum_{i = 1}^{r}{\mu(i) ( \frac{r}{i^2})}, where \mu(i) = Möbius function. Again \frac{r}{i^2} will generate very few unique numbers. We can brute force them under \sqrt{K} (as the denominator is i^2) and find the sum of möbius function in a range which can be precomputed. If we brute force this, then total operations will be ~10^7.

There is a different way to do this which you may find simpler, you can refer to the commented code of the tester for this first part.

The total number of ways P(s) = \sum{c_i v_i^s}.
This is the solution for a single connected component of size s.

Now we will discuss the full solution.

Let S = multiset of the sizes of the connected components of the given graph and P(z) = solution for a connected component with size z.

So our final answer to the problem = \prod{P(z)}: z \in S

First, we compute the vector V using the aforementioned idea. Let the size of V be SZV and we have already noticed that SZV \le 10^5.

Then we will find the solution for each connected component separately. Notice that the maximum size of a connected component(let it be M) will be \le 10^5. So how to compute it for each z from 1 to M?

let F_i(x) be a polynomial with variable x and F_i(x) = \frac{c_i}{1 - v_i x} = c_i + c_i v_i x + c_i v_i ^2 x^2 + c_i v_i^3 x^3 + \ldots

so P(z) = coefficient of x^z in \sum{F_i(x)}.

We can compute the sum using divide and conquer and with the help of usual polynomial stuffs.
Basically, we can perform the D&C like the following: Each F_i(x) is of the form \frac{A(x)}{B(x)}. So we can merge the two childs, one being \frac{C(x)}{D(x)} and the other being \frac{E(x)}{F(x)} using this: \frac{A(x)}{B(x)} = \frac{C(x)}{D(x)} + \frac{E(x)}{F(x)} , so A(x) = C(x) F(x) + E(x) D(x), B(x) = D(x) F(x).

In the end, we will perform a single polynomial inverse to find the exact value of \sum{F_i(x)}.
The complexity is O(M*log(M)*log(SZV)) because we don’t perform the inverses in the beginning. So the max coefficients of the polynomials are the same as the vector sizes in a merge sort tree.

So as we have computed the value of \sum{F_i(x)} efficiently, we can compute each P(z) easily, thus completing the full solution.

Check the setter’s solution for more clarity.

SOLUTIONS:

Setter's Solution
#include<bits/stdc++.h>
using namespace std;
 
#include<bits/stdc++.h>
using namespace std;
 
const int N = 1e5 + 9, mod = 998244353;
 
struct base {
    double x, y;
    base() { x = y = 0; }
    base(double x, double y): x(x), y(y) { }
};
inline base operator + (base a, base b) { return base(a.x + b.x, a.y + b.y); }
inline base operator - (base a, base b) { return base(a.x - b.x, a.y - b.y); }
inline base operator * (base a, base b) { return base(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); }
inline base conj(base a) { return base(a.x, -a.y); }
int lim = 1;
vector<base> roots = {{0, 0}, {1, 0}};
vector<int> rev = {0, 1};
const double PI = acosl(- 1.0);
void ensure_base(int p) {
    if(p <= lim) return;
    rev.resize(1 << p);
    for(int i = 0; i < (1 << p); i++) rev[i] = (rev[i >> 1] >> 1) + ((i & 1)  <<  (p - 1));
    roots.resize(1 << p);
    while(lim < p) {
        double angle = 2 * PI / (1 << (lim + 1));
        for(int i = 1 << (lim - 1); i < (1 << lim); i++) {
            roots[i << 1] = roots[i];
            double angle_i = angle * (2 * i + 1 - (1 << lim));
            roots[(i << 1) + 1] = base(cos(angle_i), sin(angle_i));
        }
        lim++;
    }
}
void fft(vector<base> &a, int n = -1) {
    if(n == -1) n = a.size();
    assert((n & (n - 1)) == 0);
    int zeros = __builtin_ctz(n);
    ensure_base(zeros);
    int shift = lim - zeros;
    for(int i = 0; i < n; i++) if(i < (rev[i] >> shift)) swap(a[i], a[rev[i] >> shift]);
    for(int k = 1; k < n; k <<= 1) {
        for(int i = 0; i < n; i += 2 * k) {
            for(int j = 0; j < k; j++) {
                base z = a[i + j + k] * roots[j + k];
                a[i + j + k] = a[i + j] - z;
                a[i + j] = a[i + j] + z;
            }
        }
    }
}
//eq = 0: 4 FFTs in total
//eq = 1: 3 FFTs in total
vector<int> multiply(vector<int> &a, vector<int> &b, int eq = 0) {
    int need = a.size() + b.size() - 1;
    int p = 0;
    while((1 << p) < need) p++;
    ensure_base(p);
    int sz = 1 << p;
    vector<base> A, B;
    if(sz > (int)A.size()) A.resize(sz);
    for(int i = 0; i < (int)a.size(); i++) {
        int x = (a[i] % mod + mod) % mod;
        A[i] = base(x & ((1 << 15) - 1), x >> 15);
    }
    fill(A.begin() + a.size(), A.begin() + sz, base{0, 0});
    fft(A, sz);
    if(sz > (int)B.size()) B.resize(sz);
    if(eq) copy(A.begin(), A.begin() + sz, B.begin());
    else {
        for(int i = 0; i < (int)b.size(); i++) {
            int x = (b[i] % mod + mod) % mod;
            B[i] = base(x & ((1 << 15) - 1), x >> 15);
        }
        fill(B.begin() + b.size(), B.begin() + sz, base{0, 0});
        fft(B, sz);
    }
    double ratio = 0.25 / sz;
    base r2(0,  - 1), r3(ratio, 0), r4(0,  - ratio), r5(0, 1);
    for(int i = 0; i <= (sz >> 1); i++) {
        int j = (sz - i) & (sz - 1);
        base a1 = (A[i] + conj(A[j])), a2 = (A[i] - conj(A[j])) * r2;
        base b1 = (B[i] + conj(B[j])) * r3, b2 = (B[i] - conj(B[j])) * r4;
        if(i != j) {
            base c1 = (A[j] + conj(A[i])), c2 = (A[j] - conj(A[i])) * r2;
            base d1 = (B[j] + conj(B[i])) * r3, d2 = (B[j] - conj(B[i])) * r4;
            A[i] = c1 * d1 + c2 * d2 * r5;
            B[i] = c1 * d2 + c2 * d1;
        }
        A[j] = a1 * b1 + a2 * b2 * r5;
        B[j] = a1 * b2 + a2 * b1;
    }
    fft(A, sz); fft(B, sz);
    vector<int> res(need);
    for(int i = 0; i < need; i++) {
        long long aa = A[i].x + 0.5;
        long long bb = B[i].x + 0.5;
        long long cc = A[i].y + 0.5;
        res[i] = (aa + ((bb % mod) << 15) + ((cc % mod) << 30))%mod;
    }
    return res;
}
template <int32_t MOD>
struct modint {
    int32_t value;
    modint() = default;
    modint(int32_t value_) : value(value_) {}
    inline modint<MOD> operator + (modint<MOD> other) const { int32_t c = this->value + other.value; return modint<MOD>(c >= MOD ? c - MOD : c); }
    inline modint<MOD> operator - (modint<MOD> other) const { int32_t c = this->value - other.value; return modint<MOD>(c <    0 ? c + MOD : c); }
    inline modint<MOD> operator * (modint<MOD> other) const { int32_t c = (int64_t)this->value * other.value % MOD; return modint<MOD>(c < 0 ? c + MOD : c); }
    inline modint<MOD> & operator += (modint<MOD> other) { this->value += other.value; if (this->value >= MOD) this->value -= MOD; return *this; }
    inline modint<MOD> & operator -= (modint<MOD> other) { this->value -= other.value; if (this->value <    0) this->value += MOD; return *this; }
    inline modint<MOD> & operator *= (modint<MOD> other) { this->value = (int64_t)this->value * other.value % MOD; if (this->value < 0) this->value += MOD; return *this; }
    inline modint<MOD> operator - () const { return modint<MOD>(this->value ? MOD - this->value : 0); }
    modint<MOD> pow(uint64_t k) const {
        modint<MOD> x = *this, y = 1;
        for (; k; k >>= 1) {
            if (k & 1) y *= x;
            x *= x;
        }
        return y;
    }
    modint<MOD> inv() const { return pow(MOD - 2); }  // MOD must be a prime
    inline modint<MOD> operator /  (modint<MOD> other) const { return *this *  other.inv(); }
    inline modint<MOD> operator /= (modint<MOD> other)       { return *this *= other.inv(); }
    inline bool operator == (modint<MOD> other) const { return value == other.value; }
    inline bool operator != (modint<MOD> other) const { return value != other.value; }
    inline bool operator < (modint<MOD> other) const { return value < other.value; }
    inline bool operator > (modint<MOD> other) const { return value > other.value; }
};
template <int32_t MOD> modint<MOD> operator * (int64_t value, modint<MOD> n) { return modint<MOD>(value) * n; }
template <int32_t MOD> modint<MOD> operator * (int32_t value, modint<MOD> n) { return modint<MOD>(value % MOD) * n; }
template <int32_t MOD> ostream & operator << (ostream & out, modint<MOD> n) { return out << n.value; }
 
using mint = modint<mod>;
struct poly {
    vector<mint> a;
    inline void normalize() {
        while((int)a.size() && a.back() == 0) a.pop_back();
    }
    template<class...Args> poly(Args...args): a(args...) { }
    poly(const initializer_list<mint> &x): a(x.begin(), x.end()) { }
    int size() const { return (int)a.size(); }
    inline mint coef(const int i) const { return (i < a.size() && i >= 0) ? a[i]: mint(0); }
  mint operator[](const int i) const { return (i < a.size() && i >= 0) ? a[i]: mint(0); } //Beware!! p[i] = k won't change the value of p.a[i]
  bool is_zero() const {return a.empty();}
    poly operator + (const poly &x) const {
        int n = max(size(), x.size());
        vector<mint> ans(n);
        for(int i = 0; i < n; i++) ans[i] = coef(i) + x.coef(i);
        while ((int)ans.size() && ans.back() == 0) ans.pop_back();
        return ans;
    }
    poly operator - (const poly &x) const {
        int n = max(size(), x.size());
        vector<mint> ans(n);
        for(int i = 0; i < n; i++) ans[i] = coef(i) - x.coef(i);
        while ((int)ans.size() && ans.back() == 0) ans.pop_back();
        return ans;
    }
    poly operator * (const poly& b) const {
        if(is_zero() || b.is_zero()) return {};
        vector<int> A, B;
      for(auto x: a) A.push_back(x.value);
      for(auto x: b.a) B.push_back(x.value);
      auto res = multiply(A, B, (A == B));
      vector<mint> ans;
      for(auto x: res) ans.push_back(mint(x));
      while ((int)ans.size() && ans.back() == 0) ans.pop_back();
      return ans;
    }
    poly operator * (const mint& x) const {
        int n = size();
        vector<mint> ans(n);
        for(int i = 0; i < n; i++) ans[i] = a[i] * x;
        return ans;
    }
    poly operator / (const mint &x) const{ return (*this) * x.inv(); }
    poly& operator += (const poly &x) { return *this = (*this) + x; }
    poly& operator -= (const poly &x) { return *this = (*this) - x; }
    poly& operator *= (const poly &x) { return *this = (*this) * x; }
    poly& operator *= (const mint &x) { return *this = (*this) * x; }
    poly& operator /= (const mint &x) { return *this = (*this) / x; }
    poly mod_xk(int k) const { return {a.begin(), a.begin() + min(k, size())}; } //modulo by x^k
    poly mul_xk(int k) const { // multiply by x^k
    poly ans(*this);
    ans.a.insert(ans.a.begin(), k, 0);
    return ans;
  }
  poly div_xk(int k) const { // divide by x^k
    return vector<mint>(a.begin() + min(k, (int)a.size()), a.end());
  }
  poly substr(int l, int r) const { // return mod_xk(r).div_xk(l)
    l = min(l, size());
    r = min(r, size());
    return vector<mint>(a.begin() + l, a.begin() + r);
  }
  poly reverse_it(int n, bool rev = 0) const { // reverses and leaves only n terms
    poly ans(*this);
    if(rev) { // if rev = 1 then tail goes to head
      ans.a.resize(max(n, (int)ans.a.size()));
    }
    reverse(ans.a.begin(), ans.a.end());
    return ans.mod_xk(n);
  }
  poly differantiate() const {
      int n = size(); vector<mint> ans(n);
      for(int i = 1; i < size(); i++) ans[i - 1] = coef(i) * i;
      return ans;
  }
  poly integrate() const {
      int n = size(); vector<mint> ans(n);
      for(int i = 0; i < size(); i++) ans[i + 1] = coef(i) / (i + 1);
      return ans;
  }
  poly inverse(int n) const {  // 1 / p(x) % x^n, O(nlogn)
      assert(!is_zero()); assert(a[0] != 0);
      poly ans{mint(1) / a[0]};
      for(int i = 1; i < n; i *= 2) {
          ans = (ans * mint(2) - ans * ans * mod_xk(2 * i)).mod_xk(2 * i);
      }
      return ans.mod_xk(n);
  }
};
 
int a[N], c[N];
///up / down = sum_{i=l}^{r}{c[i]/(1-a[i]*x)}
void yo(int l, int r, poly &up, poly &down) {
  if (l == r) {
    up = poly({c[l]}); down = poly({1, -a[l]});
    return;
  }
  poly A, B, C, D;
  int mid = l + r >> 1;
  yo(l, mid, A, B);
  yo(mid + 1, r, C, D);
  up = A * D + B * C;
  down = B * D;
}
const int M = 1e7 + 9;
int mob[M];
void mobius() {
    for(int i = 1; i < M; i++) mob[i] = 3;
    mob[1] = 1;
    for (int i = 2; i < M; i++) {
        if (mob[i] == 3) {
            mob[i] = -1;
            for (int j = 2 * i; j < M; j += i) mob[j] = mob[j] == 3 ? -1 : mob[j] * (-1);
            if(i <= (M - 1) / i) {
                for (int j = i * i; j < M; j += i * i) mob[j] = 0;
            }
        }
    }
}
long long SQRT(long long n) {
  long long s = sqrt(n);
  while (s * s <= n) s++;
  while (s * s > n) s--;
  return s;
}
int pref[N];
long long g(long long n) {
  long long ans = 0;
  long long last;
    for (long long i = 1; i * i <= n; i = last + 1) {
        long long p = n / (n / i / i);
        last = SQRT(p);
        ans += n / i / i * (pref[last] - pref[i - 1]);
    }
    return ans;
}
long long cnt[M];
int32_t main() {
  ios_base::sync_with_stdio(0);
  cin.tie(0);
  long long n, k, m; cin >> n >> k >> m;
  assert(1 <= n && n <= 1e10);
  assert(1 <= k && k <= 1e14);
  assert(0 <= m && m <= 1e5);
  vector<long long> f(N, 0);
  long long one = n, last = 0;
  while (m--) {
    long long l, r; cin >> l >> r;
    assert(1 <= l && l < r && r <= n && l > last);
    last = r;
    f[r - l + 1]++;
    one -= r - l + 1;
  }
  f[1] += one;
  
  mobius();
  for (int i = 1; i < M; i++) {
      pref[i] = pref[i - 1] + mob[i];
  }
  long long i = 1;
  vector<array<long long, 2>> cand;
  while (i <= k) {
    long long l = i, r = k, cur;
    while (l <= r) {
      long long mid = (l + r) / 2;
      if (SQRT(k / mid) == SQRT(k / i)) {
        cur = mid;
        l = mid + 1;
      }
      else r = mid - 1;
    }
    long long p = g(cur) - g(i - 1);
    if (p) cand.push_back({SQRT(k / i), p});
    i = cur + 1;
  }
 
  int sz = cand.size();
  assert(1 <= sz && sz < N);
  for (int i = 0; i < sz; i++) {
    a[i] = cand[i][0] % mod;
    c[i] = cand[i][1] % mod;
  }
  poly up, down;
  yo(0, sz - 1, up, down);
  poly p = up * (down.inverse(N));
  p.a.resize(N);
 
  mint ans = 1;
  for (int i = 1; i < N; i++) {
    if (f[i]) {
      ans *= p[i].pow(f[i]);
    }
  }
  cout << ans << '\n';
  return 0;
}  
Tester's Solution
#pragma GCC optimize("Ofast")
#pragma GCC target("avx2")
#include <bits/stdc++.h>
using namespace std;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/rope>
using namespace __gnu_pbds;
using namespace __gnu_cxx;
#ifndef rd
#define trace(...)
#define endl '\n'
#endif
#define pb push_back
#define fi first
#define se second
//#define int long long
typedef long long ll;
typedef long double f80;
#define double long double
#define pii pair<int,int>
#define pll pair<ll,ll>
#define sz(x) ((long long)x.size())
#define fr(a,b,c) for(int a=b; a<=c; a++)
#define rep(a,b,c) for(int a=b; a<c; a++)
#define trav(a,x) for(auto &a:x)
#define all(con) con.begin(),con.end()
const ll infl=0x3f3f3f3f3f3f3f3fLL;
const int infi=0x3f3f3f3f;
const int mod=998244353;
//const int mod=1000000007;
typedef vector<int> vi;
typedef vector<ll> vl;

typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> oset;
auto clk=clock();
mt19937_64 rang(chrono::high_resolution_clock::now().time_since_epoch().count());
int rng(int lim) {
	uniform_int_distribution<int> uid(0,lim-1);
	return uid(rang);
}
int powm(int a, int b) {
	int res=1;
	while(b) {
		if(b&1)
			res=(res*a)%mod;
		a=(a*a)%mod;
		b>>=1;
	}
	return res;
}

long long readInt(long long l, long long r, char endd) {
	long long x=0;
	int cnt=0;
	int fi=-1;
	bool is_neg=false;
	while(true) {
		char g=getchar();
		if(g=='-') {
			assert(fi==-1);
			is_neg=true;
			continue;
		}
		if('0'<=g&&g<='9') {
			x*=10;
			x+=g-'0';
			if(cnt==0) {
				fi=g-'0';
			}
			cnt++;
			assert(fi!=0 || cnt==1);
			assert(fi!=0 || is_neg==false);

			assert(!(cnt>19 || ( cnt==19 && fi>1) ));
		} else if(g==endd) {
			if(is_neg) {
				x=-x;
			}
			assert(l<=x&&x<=r);
			return x;
		} else {
			assert(false);
		}
	}
}
string readString(int l, int r, char endd) {
	string ret="";
	int cnt=0;
	while(true) {
		char g=getchar();
		assert(g!=-1);
		if(g==endd) {
			break;
		}
		cnt++;
		ret+=g;
	}
	assert(l<=cnt&&cnt<=r);
	return ret;
}
long long readIntSp(long long l, long long r) {
	return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
	return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
	return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
	return readString(l,r,' ');
}

#define forn(i, x, y) for(int i = x; i <= y; i++)

const int MAXN = 18;
const int maxn = 1 << MAXN;
//const int mod = 998244353;
const int root = 3;

int A[maxn], B[maxn];
int W[maxn],iW[maxn], I[maxn];
int nn;
const int threshold = 200;

int fact[maxn], ifact[maxn];

namespace modulo{
    const int MOD = 998244353;
    int add(const int &a,const int &b){
        int val = a + b;
        if(val >= MOD) val -= MOD;
        return val;
    }
    int sub(const int &a,const int &b){
        int val = a - b;
        if(val < 0) val += MOD;
        return val;
    }
    int mul(const int &a, const int &b){
        return 1ll * a * b % MOD;
    }
}
using namespace modulo;

void ensureINV(int n){
    if(n <= nn) return;
    if(!nn){
        I[1] = 1;
        nn = 1;
    }
    forn(i, nn + 1, n)
        I[i] = (mod - mul((mod / i), I[mod % i]));
    nn = n;
}
void print(vi &a, int n = 0){
    if(!n) n = a.size();
    a.resize(n);
    forn(i, 0, n - 1)
        cout << a[i] << " ";
    cout << endl;
}
int pwr(int a,int b){
    int ans = 1;
    while(b){
        if(b & 1)
            ans = mul(ans, a);
        a = mul(a, a);
        b >>= 1;
    }
    return ans;
}
void precompute(){
    W[0] = iW[0] = 1;
    int g = pwr(root,(mod - 1) / maxn), ig = pwr(g, mod - 2);
    forn(i, 1, maxn / 2 - 1){
        W[i] = mul(W[i - 1], g);
        iW[i] = mul(iW[i - 1], ig);
    }
}
int rev(int i, int n){
    int irev = 0;
    n >>= 1;
    while(n){
        n >>= 1;
        irev = (irev << 1) | (i & 1);
        i >>= 1;
    }
    return irev;
}
void go(int a[], int n){
    forn(i, 0, n - 1){
        int r = rev(i, n);
        if(i < r)
            swap(a[i], a[r]);
    }
}

void fft(int a[], int n, bool inv = 0){
    go(a, n);
    int len, i, j, *p, *q, u, v, ind, add;
    for(len = 2; len <= n; len <<= 1){
        for(i = 0; i < n; i += len){
            ind = 0, add = maxn / len;
            p = &a[i], q = &a[i + len / 2];
            forn(j, 0, len / 2 - 1){
                v = mul((*q), (inv ? iW[ind] : W[ind]));
                (*q) = sub((*p), v);
                (*p) = ::add((*p), v);
                ind += add;
                p++, q++;
            }
        }
    }
    if(inv) {
        int p = pwr(n, mod - 2);
        forn(i, 0, n - 1)
            a[i] = mul(a[i], p);
    }
}
vi brute(const vi &a, const vi &b){
    vi c(a.size() + b.size() - 1, 0);
    for(int i = 0; i < a.size(); i++){
        for(int j = 0; j < b.size(); j++){
            c[i + j] = add(c[i + j], mul(a[i], b[j]));
        }
    }
    return c;
}
vi mul(vi a, vi b){ // n = total size
    if(min(a.size(),b.size())<= threshold)
        return brute(a, b);
    int nn=sz(a)+sz(b)-1;
    int n=1;
    while(n<nn)
    	n<<=1;
    a.resize(n);
    copy(all(a), A);
    fft(A, n);
    if(a == b)
        copy(A, A + n, B);
    else{
        b.resize(n);
        copy(all(b), B);
        fft(B, n);
    }
    forn(i, 0, n - 1)
        A[i] = mul(A[i], B[i]);
    fft(A, n, 1);
    vi c(A, A + nn);
    return c;
}
vi inv(vi a, int m){ // get m terms
    assert(a[0] != 0);
    int tot = 1;
    while(tot < m)
        tot <<= 1;
    swap(tot, m);
    a.resize(m, 0);
    vi ia(m, 0);
    ia[0] = pwr(a[0], mod - 2);
    for(int sz = 2; sz <= m; sz <<= 1){
        copy(ia.begin(), ia.begin() + sz / 2, A);
        copy(a.begin(), a.begin() + sz, B);
        fill(A + sz / 2, A + (sz << 1), 0);
        fill(B + sz, B + (sz << 1), 0);
        fft(A, sz << 1);
        fft(B, sz << 1);
        forn(j, 0, (sz << 1) - 1)
            A[j] = add(A[j], sub(A[j], mul(mul(A[j], A[j]), B[j])));
        fft(A, sz << 1, 1);
        copy(A, A + sz, ia.begin());
    }
    ia.resize(tot);
    return ia;
}

ll cntr[10000005];
pair<vi, vi> gogo[200005];
void solve() {
	precompute();
	ll n=readIntSp(1,10'000'000'000LL),k=readIntSp(1,100'000'000'000'000LL),m=readIntLn(0,100000);

//	So, let us see, we're trying to find the number of square-free numbers which can be multiplied by i*i without exceeding K and on multiplying by (i+1)*(i+1) they exceed k+1.
//	Let us denote these by array cntr[]
//	Now, for a single component of size x. we can see that the answer will be sum of cntr[i]*(i^(x))

//  Now, array cntr stores the number of numbers which can be multiplied by (i*i) without exceeding
	for(ll i=1; i<=10000000; i++)
		cntr[i]=k/(i*i);

//  Now, we remove the numbers which are not square free among these counts
//  Currently, let us say the set which cntr represents consists of some numbers which contain squares as factors
//  We can see that these numbers would also be counted in the sets for cntr[2*i](if the square factor is 4), cntr[3*i](if the square factor is 9) etc.
//  So, we iterate on this array in reverse order and remove all the multiples counts.
//  Thus, after this operation we have count of square free numbers which can be multiplied by i*i without exceeding k.
	for(int i=10000000; i>0; i--)
		for(int j=2*i; j<=10000000; j+=i)
			cntr[i]-=cntr[j];

//  Now, We transform it to the cnt we require that is the count of squarefree numbers for which multiplying by i*i doesn't exceed k but (i+1)*(i+1) exceeds k.
//  This can be done by simply subtracting cnt[i] by cnt[i+1] in forward order.
	for(int i=1; i<=10000000; i++)
		cntr[i]-=cntr[i+1];

	vector<pii> disto;
	fr(i,1,10000000)
		if(cntr[i])
			disto.pb({i,cntr[i]%mod});
	for(int i=sz(disto); i<2*sz(disto); i++) {
		gogo[i].fi={disto[i-sz(disto)].se};
		gogo[i].se={1,mod-disto[i-sz(disto)].fi};
	}
	for(int i=sz(disto)-1; i>0; i--) {
		vi aa=mul(gogo[i<<1].fi,gogo[i<<1|1].se);
		vi bb=mul(gogo[i<<1].se,gogo[i<<1|1].fi);
		for(int j=0; j<sz(aa); j++)
			aa[j]=add(aa[j],bb[j]);
		gogo[i].fi=aa;
		gogo[i].se=mul(gogo[i<<1].se,gogo[i<<1|1].se);
	}
	vi ans=mul(gogo[1].fi,inv(gogo[1].se,100001));
	ans.resize(100001);
	int answer=1;
	ll tot=n;
	ll last=0;
	while(m--) {
		ll l=readIntSp(last+1,n),r=readIntLn(l+1,n);
		assert(r-l+1<=100000);
		last=r;
		tot-=r-l+1;
		answer=mul(answer,ans[r-l+1]);
	}
	answer=mul(answer,pwr(k%mod,tot%(mod-1)));
	cout<<answer<<endl;
}

signed main() {
	ios_base::sync_with_stdio(0),cin.tie(0);
	srand(chrono::high_resolution_clock::now().time_since_epoch().count());
	cout<<fixed<<setprecision(8);
	int t=1;
//	cin>>t;
//	t=readIntLn(1,100);
	fr(i,1,t)
		solve();
	assert(getchar()==EOF);
#ifdef rd
	cerr<<endl<<endl<<endl<<"Time Elapsed: "<<((double)(clock()-clk))/CLOCKS_PER_SEC<<endl;
#endif
	return 0;
}

VIDEO EDITORIAL:

Feel free to share your approach. In case of any doubt or anything is unclear please ask it in the comment section. Any suggestions are welcomed. :smile:

2 Likes