PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Setter: Богдан Пастущак
Tester: Felipe Mota
Editorialist: Taranpreet Singh
DIFFICULTY:
Easy-Medium
PREREQUISITES:
Number Theory and Inclusion-Exclusion.
PROBLEM:
Let f(x) denote the sum of all perfect powers which divide x. Find \sum_{i = 1}^N f(i) for a given N modulo 10^9+7
Let g(i) = \sum_{j = 1}^i f(i), so we need to find g(N) for a given N.
QUICK EXPLANATION
- Writing perfect powers as x^y, let’s group all perfect powers on the basis of the value of y. The maximum value of y cannot exceed 60 as 10^{18} < 2^{60}
- A perfect power p contributes to g(N) exactly N/p times, once for each multiple of p up to N.
- For a fixed power y, we can consider all possible values of x such that x^y \leq N and find their contribution. This has time complexity O(N^{1/y}*y) which can work for y \geq 3.
- For y = 2, we need to find the intervals of form [L, R] of values of x, such that the contribution of each value is the same. i.e. N/R^2 must be the same as N/L^2 There can be at most N^{1/3} such intervals, giving time complexity O(N^{1/3}).
- To avoid double-counting of perfect powers like 64, we need to apply inclusion-exclusion.
EXPLANATION
Firstly, let us see how much a perfect power p = x^y contributes to g(N) = \sum_{i = 1}^N f(i)
For each multiple of p, p would be added to the final sum and there are N/p multiples of p, so it contributes p*\lfloor \frac{N}{p} \rfloor to the final sum.
Now, one naive solution would be to precompute all perfect powers in advance, but the number of prime powers is approximate \sqrt N which isn’t feasible.
Let’s consider all perfect powers x^y and group them on the basis of y. There would be perfect powers like 2^6 = 64 which are considered for y = 2, 3, 6, we’ll handle that using Inclusion-Exclusion later.
Let us fix each value of y and iterate over all possible values of x such that x^y \leq N. There can be at most N^{1/y} such values of x.
For y \geq 3, this would work fine but for y = 2, this would lead to nearly O(\sqrt N) time complexity which is not feasible for N =10^18. We’ll optimize this later.
Hence, we have found the contribution of each perfect power x^y, grouped by y. Let f_y denote the sum of contributions of perfect y-th powers. Here, the contribution of 64 is included in f_2, f_3 as well as f_6.
So, we need to exclude contribution of all f_y from f_x such that y is a multiple of x and y > x. We can easily achieve this via the following pseudo-code.
for(int i = 60; i >= 2; i--){
//At this point, all j > i do not have duplicates
for(int j = 2*i; j<= 60; j += i){
f[i] -= f[j];
}
}
Hence, from computing contribution of y-th power, we have computed the final answer. We need to handle 1 separately.
The only thing left now is to optimize calculation of f_2 as O(N^{1/2}*2) is too slow.
Let’s start from st = 2, find the largest end such that N/st^2 is same as N/end^2. The idea is, that the contribution of all perfect squares in the interval [st, en] shall contribute N/st^2 times, so we can combine the update. The perfect squares having the same N/st^2 are st^2, (st+1)^2 \ldots (en-1)^2, en^2, so the contribution of all values become N/st^2 * \big[ st^2 + (st+1)^2 \ldots (en-1)^2 +en^2 \big] which can be written as N/st^2*\big[ sumOfSquares(en)-sumOfSquares(st-1)\big] where sumOfSquares(n) gives the sum of squares of first n natural numbers.
Now, for a fixed st, we need to find the endpoint of the interval. The binary search solution shall work, but add an additional log(N) factor which won’t pass for the final subtask.
The final observation is, that Suppose K = N/st^2. Then en is given by \sqrt {N/K} since en = \sqrt{N/K} is the largest value of en such that N/en^2 = K holds. This gives us easy way to compute endpoint of interval. We can move to next interval by setting st = en+1 and repeat, till we reach N.
Refer to the implementations below in case anything is not clear.
Exercise: Prove that N/x^2 cannot take more than N^{1/3} different values for fixed N.
TIME COMPLEXITY
The overall time complexity is O(N^{1/3}) per test case.
SOLUTIONS:
Setter's Solution
/* Statement:
* For positive integer n define f(n) as sum of all
* divisors of n, which are perfect powers.
* Calculate sum f(1) + f(2) + .. + f(n) modulo 10^9 + 7
*
* Solution:
* Define F(n) = f(1) + f(2) + ... + f(n).
* Let D(n, i) be the sum of all divisors of n, which are i-th perfect power.
* We can calculate D(1, 2) + D(2, 2) + ... + D(n, 2) as follows:
* Let's calculate for each number x how many times x^2
* will be added to D(n, 2). Obviously, it is [n / (x ^ 2)].
* So D(n, 2) = 1 * [n / 1] + 4 * [n / 4] + 9 * [n / 9] + ...
* Let's fix some l, and find for it maximum possible r, such that
* [n / (l ^ 2)] = [n / (r ^ 2)] = k
* It can be shown that r = [sqrt(n / k)]
* So we can calculate D(n, 2) with complexity proportional to number of
* segments with equal value [n / (x ^ 2)], and this number is O(n ^ 1/3).
*
* Analogically, we can calculate D(n, i) for i > 2.
* Also, we can just bruteforce them.
*
* After that, we note that we can calculate some numbers more than once.
* For example 64 will be included three times in those calculations
* (as a perfect square, cube and sixth power).
* So lets do inclusion-exclusion to avoid such situations.
*
* Complexity: O(n ^ 1/3)
*/
#include <bits/stdc++.h>
using namespace std;
const int mod = 1e9 + 7;
const int inv6 = (mod + 1) / 6;// = 1/6
inline void add(int& x, int y)
{
x += y;
if (x >= mod) x -= mod;
}
inline void sub(int& x, int y)
{
x -= y;
if (x < 0) x += mod;
}
inline int mult(int x, int y)
{
return x * (long long) y % mod;
}
inline int sumSquares(int r)// 1^2 + 2^2 + ... + r^2 = r * (r + 1) * (2r + 1) / 6
{
return mult(r, mult(r + 1, mult(2 * r + 1, inv6)));
}
inline int sum(int l, int r)
{
int res = sumSquares(r);
sub(res, sumSquares(l - 1));
return res;
}
inline long long power(int x, int k, long long n)
{
__int128 res = 1;
while(k--) res *= x;
if (res > n) res = n + 1;
return (long long) res;
}
int solve(long long n, int k)//Complexity: O(n^(1/k) * k)
{
int ans = 0;
for(int i = 2; ; ++i)
{
long long d = power(i, k, n);
if (d > n) break;
add(ans, d * (n / d) % mod);
}
return ans;
}
inline int get_sqrt(long long x)
{
int r = sqrt(x);
//sqrt function can give a small error
while(r * (long long) r > x) r--;
while((r + 1) * (long long)(r + 1) <= x) r++;
return r;
}
int solve2(long long n)//Complexity: O(n^1/3) (operations sqrt)
//correct solution
{
int ans = 0;
int l = 2;
while(l * (long long) l <= n)
{
long long k = n / (l * (long long) l);
int r = get_sqrt(n / k);//the heaviest place in program
add(ans, mult(sum(l, r), k % mod));
l = r + 1;
}
return ans;
}
int solve2BinSearch(long long n)//Complexity: O(n^1/3 * log n) (operations /)
//should give TLE on last subtask
{
int ans = 0;
int l = 2;
while(l * (long long) l <= n)
{
long long k = n / (l * (long long) l);
int L = l, R = (int)1e9 + 1, M;
while(R - L > 1)
{
M = (L + R) >> 1;
if (k == n / (M * (long long) M))//the heaviest place
L = M;
else
R = M;
}
add(ans, mult(sum(l, L), k % mod));
l = L + 1;
}
return ans;
}
const int D = 60;// 2^D > maximum possible n
int d[D];
int main()
{
int tc;
cin >> tc;
while(tc--)
{
long long n;
cin >> n;
d[2] = solve2(n);
for(int i = 3; i < D; ++i)
d[i] = solve(n, i);
//do inclusion-exclusion
for(int i = D - 1; i >= 2; --i)
for(int j = i + i; j < D; j += i)
sub(d[i], d[j]);
int ans = n % mod;
for(int i = 2; i < D; ++i)
add(ans, d[i]);
cout << ans << endl;
}
cerr << "Time elapsed: " << clock() / (double) CLOCKS_PER_SEC << endl;
return 0;
}
Tester's Solution
#include <bits/stdc++.h>
#include <unordered_map>
using namespace std;
template<typename T = int> vector<T> create(size_t n){ return vector<T>(n); }
template<typename T, typename... Args> auto create(size_t n, Args... args){ return vector<decltype(create<T>(args...))>(n, create<T>(args...)); }
template<typename T = int, T mod = 1'000'000'007, typename U = long long>
struct umod{
T val;
umod(): val(0){}
umod(U x){ x %= mod; if(x < 0) x += mod; val = x;}
umod& operator += (umod oth){ val += oth.val; if(val >= mod) val -= mod; return *this; }
umod& operator -= (umod oth){ val -= oth.val; if(val < 0) val += mod; return *this; }
umod& operator *= (umod oth){ val = ((U)val) * oth.val % mod; return *this; }
umod& operator /= (umod oth){ return *this *= oth.inverse(); }
umod& operator ^= (U oth){ return *this = pwr(*this, oth); }
umod operator + (umod oth) const { return umod(*this) += oth; }
umod operator - (umod oth) const { return umod(*this) -= oth; }
umod operator * (umod oth) const { return umod(*this) *= oth; }
umod operator / (umod oth) const { return umod(*this) /= oth; }
umod operator ^ (long long oth) const { return umod(*this) ^= oth; }
bool operator < (umod oth) const { return val < oth.val; }
bool operator > (umod oth) const { return val > oth.val; }
bool operator <= (umod oth) const { return val <= oth.val; }
bool operator >= (umod oth) const { return val >= oth.val; }
bool operator == (umod oth) const { return val == oth.val; }
bool operator != (umod oth) const { return val != oth.val; }
umod pwr(umod a, U b) const { umod r = 1; for(; b; a *= a, b >>= 1) if(b&1) r *= a; return r; }
umod inverse() const {
U a = val, b = mod, u = 1, v = 0;
while(b){
U t = a/b;
a -= t * b; swap(a, b);
u -= t * v; swap(u, v);
}
if(u < 0)
u += mod;
return u;
}
};
bool is_perfect(int x){
if(x == 1) return true;
for(int i = 2; i <= x; i++){
if(x % i == 0){
int c = x, cn = 0;
while(c % i == 0) c /= i, cn++;
if(c == 1 && cn >= 2) return true;
}
}
return false;
}
int fdx(int i){
int ans = 0;
for(int j = 1; j <= i; j++){
if((i % j) == 0){
if(is_perfect(j)){
ans += j;
}
}
}
return ans;
}
int solve(int n){
int ans = 0;
for(int i = 1; i <= n; i++){
ans += fdx(i);
}
return ans;
}
using U = umod<>;
bool is_prime(int x){
if(x <= 1) return false;
for(int i = 2; i * i <= x; i++) if(x % i == 0) return false;
return true;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
vector<int> p;
for(int i = 2; i <= 70; i++) if(is_prime(i)) p.push_back(i);
int t; cin >> t;
const int LIM = 2000001;
vector<int> mn(LIM, 1<<30);
for(int i = 2; i < LIM; i++){
int sq = sqrt(i);
while(sq * sq < i) sq++;
if(sq * sq == i) mn[i] = 2;
}
for(int c : p){
if(c == 2) continue;
for(int i = 2; ; i++){
int g = LIM, r = 1;
for(int j = 0; j < c; j++) g /= i, r *= i;
if(g == 0) break;
mn[r] = min(mn[r], c);
}
}
for(int _ = 1; _ <= t; _++){
long long n = _; cin >> n;
U ans = 0, i6 = U(1) / 6;
auto sqr_sum = [&](U l){
return (l * (l + 1) * (l * 2 + 1)) * i6;
};
auto sqr_sum_rng = [&](U l, U r){
return sqr_sum(r) - sqr_sum(l - 1);
};
for(int pr : p){
if(pr == 2){
for(long long i = 1, j; i * i <= n; i = j + 1){
long long v = n / (i * i);
j = sqrt(n / v);
while(j * j <= (n / v)) j++;
j--;
if(i <= j){
ans += sqr_sum_rng(i, j) * v;
}
}
} else {
for(int i = 2; ; i++){
if(mn[i] < pr) continue;
long long tn = n, r = 1;
for(int j = 0; j < pr; j++) tn /= i, r *= i;
if(tn == 0) break;
ans += U(r) * tn;
}
}
}
cout << ans.val << '\n';
// cout << solve(n) << endl;
}
return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class PPDIV{
//SOLUTION BEGIN
long MOD = (long)1e9+7;
void pre() throws Exception{}
long pow(long a, long p, long mx){
long o = 1;
while(p-- > 0){
o *= a;
if(o > mx)o = mx+1;
}
return o;
}
long sqrt(long n){
long x = (long)Math.sqrt(n);
while(x*x > n)x--;
while((x+1)*(x+1) <= n)x++;
return x;
}
long solve2(long n){
if(n < 4)return 0;
long curBase = 2, ans = 0, sqrtN = sqrt(n);
while(curBase <= n/curBase){
long V = n/(curBase*curBase);
long lo = sqrt(n/V);
ans += (n/(curBase*curBase)%MOD * (sumOfSquares(lo)+MOD-sumOfSquares(curBase-1)))%MOD;
if(ans >= MOD)ans -= MOD;
curBase = lo+1;
}
return ans;
}
long solve2BinarySearch(long n){
if(n < 4)return 0;
long curBase = 2, ans = 0, sqrtN = sqrt(n);
while(curBase <= n/curBase){
long lo = curBase, hi = sqrtN;
while(lo+1 < hi){
long mid = lo+(hi-lo)/2;
if(n/(curBase*curBase) == n/(mid*mid))lo = mid;
else hi = mid;
}
if(n/(curBase*curBase) == n/(hi*hi))lo = hi;
ans += (n/(curBase*curBase)%MOD * (sumOfSquares(lo)+MOD-sumOfSquares(curBase-1)))%MOD;
if(ans >= MOD)ans -= MOD;
curBase = lo+1;
}
return ans;
}
//Time complexity O(N^(1/k)*k)
long solve(int power, long n){
long ans = 0;
for(int base = 2; ; base++){
long p = pow(base, power, n);
if(p > n)break;
ans += (n-n%p)%MOD;
if(ans >= MOD)ans -= MOD;
}
return ans;
}
void solve(int TC) throws Exception{
long n = nl();
int B = 60;
long[] sumOfPowers = new long[B];
sumOfPowers[2] = solve2(n);
for(int i = 3; i< B; i++)
sumOfPowers[i] = solve(i, n);
for(int i = B-1; i>= 2; i--){
for(int j = i+i; j< B; j+= i){
sumOfPowers[i] += MOD-sumOfPowers[j];
if(sumOfPowers[i] >= MOD)sumOfPowers[i] -= MOD;
}
}
long sum = n%MOD;
for(int i = 2; i< B; i++)sum = (sum+sumOfPowers[i])%MOD;
pn(sum);
}
long inv6 = inv(6);
long sumOfSquares(long a, long b){
return (sumOfSquares(b)+MOD-sumOfSquares(a-1))%MOD;
}
long sumOfSquares(long n){
n%= MOD;
return (((n*(n+1))%MOD*(2*n+1))%MOD*inv6)%MOD;
}
long inv(long a){
long o = 1;
for(long p = MOD-2; p > 0; p>>=1){
if((p&1)==1)o = (o*a)%MOD;
a = (a*a)%MOD;
}
return o;
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
DecimalFormat df = new DecimalFormat("0.00000000000");
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
pre();
int T = (multipleTC)?ni():1;
for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new PPDIV().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Feel free to share your approach. Suggestions are welcomed as always.