PPDIV - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2

Setter: Богдан Пастущак
Tester: Felipe Mota
Editorialist: Taranpreet Singh

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Number Theory and Inclusion-Exclusion.

PROBLEM:

Let f(x) denote the sum of all perfect powers which divide x. Find \sum_{i = 1}^N f(i) for a given N modulo 10^9+7

Let g(i) = \sum_{j = 1}^i f(i), so we need to find g(N) for a given N.

QUICK EXPLANATION

  • Writing perfect powers as x^y, let’s group all perfect powers on the basis of the value of y. The maximum value of y cannot exceed 60 as 10^{18} < 2^{60}
  • A perfect power p contributes to g(N) exactly N/p times, once for each multiple of p up to N.
  • For a fixed power y, we can consider all possible values of x such that x^y \leq N and find their contribution. This has time complexity O(N^{1/y}*y) which can work for y \geq 3.
  • For y = 2, we need to find the intervals of form [L, R] of values of x, such that the contribution of each value is the same. i.e. N/R^2 must be the same as N/L^2 There can be at most N^{1/3} such intervals, giving time complexity O(N^{1/3}).
  • To avoid double-counting of perfect powers like 64, we need to apply inclusion-exclusion.

EXPLANATION

Firstly, let us see how much a perfect power p = x^y contributes to g(N) = \sum_{i = 1}^N f(i)

For each multiple of p, p would be added to the final sum and there are N/p multiples of p, so it contributes p*\lfloor \frac{N}{p} \rfloor to the final sum.

Now, one naive solution would be to precompute all perfect powers in advance, but the number of prime powers is approximate \sqrt N which isn’t feasible.

Let’s consider all perfect powers x^y and group them on the basis of y. There would be perfect powers like 2^6 = 64 which are considered for y = 2, 3, 6, we’ll handle that using Inclusion-Exclusion later.

Let us fix each value of y and iterate over all possible values of x such that x^y \leq N. There can be at most N^{1/y} such values of x.

For y \geq 3, this would work fine but for y = 2, this would lead to nearly O(\sqrt N) time complexity which is not feasible for N =10^18. We’ll optimize this later.

Hence, we have found the contribution of each perfect power x^y, grouped by y. Let f_y denote the sum of contributions of perfect y-th powers. Here, the contribution of 64 is included in f_2, f_3 as well as f_6.

So, we need to exclude contribution of all f_y from f_x such that y is a multiple of x and y > x. We can easily achieve this via the following pseudo-code.

for(int  i = 60; i >= 2; i--){
    //At this point, all j > i do not have duplicates
    for(int j = 2*i; j<= 60; j += i){
       f[i] -= f[j];
    }
}

Hence, from computing contribution of y-th power, we have computed the final answer. We need to handle 1 separately.

The only thing left now is to optimize calculation of f_2 as O(N^{1/2}*2) is too slow.

Let’s start from st = 2, find the largest end such that N/st^2 is same as N/end^2. The idea is, that the contribution of all perfect squares in the interval [st, en] shall contribute N/st^2 times, so we can combine the update. The perfect squares having the same N/st^2 are st^2, (st+1)^2 \ldots (en-1)^2, en^2, so the contribution of all values become N/st^2 * \big[ st^2 + (st+1)^2 \ldots (en-1)^2 +en^2 \big] which can be written as N/st^2*\big[ sumOfSquares(en)-sumOfSquares(st-1)\big] where sumOfSquares(n) gives the sum of squares of first n natural numbers.

Now, for a fixed st, we need to find the endpoint of the interval. The binary search solution shall work, but add an additional log(N) factor which won’t pass for the final subtask.

The final observation is, that Suppose K = N/st^2. Then en is given by \sqrt {N/K} since en = \sqrt{N/K} is the largest value of en such that N/en^2 = K holds. This gives us easy way to compute endpoint of interval. We can move to next interval by setting st = en+1 and repeat, till we reach N.

Refer to the implementations below in case anything is not clear.

Exercise: Prove that N/x^2 cannot take more than N^{1/3} different values for fixed N.

TIME COMPLEXITY

The overall time complexity is O(N^{1/3}) per test case.

SOLUTIONS:

Setter's Solution
/* Statement:
 * For positive integer n define f(n) as sum of all 
 * divisors of n, which are perfect powers.
 * Calculate sum f(1) + f(2) + .. + f(n) modulo 10^9 + 7
 * 
 * Solution:
 * Define F(n) = f(1) + f(2) + ... + f(n).
 * Let D(n, i) be the sum of all divisors of n, which are i-th perfect power.
 * We can calculate D(1, 2) + D(2, 2) + ... + D(n, 2) as follows:
 * Let's calculate for each number x how many times x^2 
 * will be added to D(n, 2). Obviously, it is [n / (x ^ 2)].
 * So D(n, 2) = 1 * [n / 1] + 4 * [n / 4] + 9 * [n / 9] + ...
 * Let's fix some l, and find for it maximum possible r, such that
 * [n / (l ^ 2)] = [n / (r ^ 2)] = k
 * It can be shown that r = [sqrt(n / k)]
 * So we can calculate D(n, 2) with complexity proportional to number of
 * segments with equal value [n / (x ^ 2)], and this number is O(n ^ 1/3).
 * 
 * Analogically, we can calculate D(n, i) for i > 2.
 * Also, we can just bruteforce them.
 * 
 * After that, we note that we can calculate some numbers more than once.
 * For example 64 will be included three times in those calculations
 * (as a perfect square, cube and sixth power).
 * So lets do inclusion-exclusion to avoid such situations.
 *  
 * Complexity: O(n ^ 1/3) 
 */ 
 
 #include <bits/stdc++.h>
using namespace std;

const int mod = 1e9 + 7;
const int inv6 = (mod + 1) / 6;// = 1/6

inline void add(int& x, int y)
{
	x += y;
	if (x >= mod) x -= mod;
}

inline void sub(int& x, int y)
{
	x -= y;
	if (x < 0) x += mod;
}

inline int mult(int x, int y)
{
	return x * (long long) y % mod;
}

inline int sumSquares(int r)// 1^2 + 2^2 + ... + r^2 = r * (r + 1) * (2r + 1) / 6
{
	return mult(r, mult(r + 1, mult(2 * r + 1, inv6)));
}

inline int sum(int l, int r)
{
	int res = sumSquares(r);
	sub(res, sumSquares(l - 1));
	return res;
}

inline long long power(int x, int k, long long n)
{
	__int128 res = 1;
	while(k--) res *= x;
	if (res > n) res = n + 1;
	return (long long) res;
}

int solve(long long n, int k)//Complexity: O(n^(1/k) * k)
{
	int ans = 0;
	for(int i = 2; ; ++i)
	{
		long long d = power(i, k, n);
		if (d > n) break;
		add(ans, d * (n / d) % mod);
	}

	return ans;
}

inline int get_sqrt(long long x)
{
	int r = sqrt(x);
	//sqrt function can give a small error
	while(r * (long long) r > x) r--;
	while((r + 1) * (long long)(r + 1) <= x) r++;
	return r;
}

int solve2(long long n)//Complexity: O(n^1/3) (operations sqrt)
//correct solution
{
	int ans = 0;
	int l = 2;
	while(l * (long long) l <= n)
	{
		long long k = n / (l * (long long) l);
		int r = get_sqrt(n / k);//the heaviest place in program
		add(ans, mult(sum(l, r), k % mod));
		l = r + 1;
	}

	return ans;
}

int solve2BinSearch(long long n)//Complexity: O(n^1/3 * log n) (operations /)
//should give TLE on last subtask
{
	int ans = 0;
	int l = 2;
	while(l * (long long) l <= n)
	{
		long long k = n / (l * (long long) l);
		int L = l, R = (int)1e9 + 1, M;
		while(R - L > 1)
		{
			M = (L + R) >> 1;
			if (k == n / (M * (long long) M))//the heaviest place
				L = M;
			else
				R = M;
		}
	
		add(ans, mult(sum(l, L), k % mod));
		l = L + 1;
	}

	return ans;
}

const int D = 60;// 2^D > maximum possible n
int d[D];

int main()
{
	int tc;
	cin >> tc;
	while(tc--)
	{
		long long n;
		cin >> n;
	
		d[2] = solve2(n);
		for(int i = 3; i < D; ++i)
			d[i] = solve(n, i);
	
		//do inclusion-exclusion
		for(int i = D - 1; i >= 2; --i)
			for(int j = i + i; j < D; j += i)
				sub(d[i], d[j]);
	
		int ans = n % mod;
		for(int i = 2; i < D; ++i)
			add(ans, d[i]);
		cout << ans << endl;
	}

	cerr << "Time elapsed: " << clock() / (double) CLOCKS_PER_SEC << endl;
	return 0;
}
Tester's Solution
#include <bits/stdc++.h>
#include <unordered_map>
using namespace std;
template<typename T = int> vector<T> create(size_t n){ return vector<T>(n); }
template<typename T, typename... Args> auto create(size_t n, Args... args){ return vector<decltype(create<T>(args...))>(n, create<T>(args...)); }
template<typename T = int, T mod = 1'000'000'007, typename U = long long>
struct umod{
	T val;
	umod(): val(0){}
	umod(U x){ x %= mod; if(x < 0) x += mod; val = x;}
	umod& operator += (umod oth){ val += oth.val; if(val >= mod) val -= mod; return *this; }
	umod& operator -= (umod oth){ val -= oth.val; if(val < 0) val += mod; return *this; }
	umod& operator *= (umod oth){ val = ((U)val) * oth.val % mod; return *this; }
	umod& operator /= (umod oth){ return *this *= oth.inverse(); }
	umod& operator ^= (U oth){ return *this = pwr(*this, oth); }
	umod operator + (umod oth) const { return umod(*this) += oth; }
	umod operator - (umod oth) const { return umod(*this) -= oth; }
	umod operator * (umod oth) const { return umod(*this) *= oth; }
	umod operator / (umod oth) const { return umod(*this) /= oth; }
	umod operator ^ (long long oth) const { return umod(*this) ^= oth; }
	bool operator < (umod oth) const { return val < oth.val; }
	bool operator > (umod oth) const { return val > oth.val; }
	bool operator <= (umod oth) const { return val <= oth.val; }
	bool operator >= (umod oth) const { return val >= oth.val; }
	bool operator == (umod oth) const { return val == oth.val; }
	bool operator != (umod oth) const { return val != oth.val; }
	umod pwr(umod a, U b) const { umod r = 1; for(; b; a *= a, b >>= 1) if(b&1) r *= a; return r; }
	umod inverse() const {
	    U a = val, b = mod, u = 1, v = 0;
	    while(b){
	        U t = a/b;
	        a -= t * b; swap(a, b);
	        u -= t * v; swap(u, v);
	    }
	    if(u < 0)
	        u += mod;
	    return u;
	}
};
bool is_perfect(int x){
	if(x == 1) return true;
	for(int i = 2; i <= x; i++){
		if(x % i == 0){
			int c = x, cn = 0;
			while(c % i == 0) c /= i, cn++;
			if(c == 1 && cn >= 2) return true;
		}
	}
	return false;
}
int fdx(int i){
	int ans = 0;
	for(int j = 1; j <= i; j++){
		if((i % j) == 0){
			if(is_perfect(j)){
				ans += j;
			}
		}
	}
	return ans;
}
int solve(int n){
	int ans = 0;
	for(int i = 1; i <= n; i++){
		ans += fdx(i);
	}
	return ans;
}
using U = umod<>;
bool is_prime(int x){
	if(x <= 1) return false;
	for(int i = 2; i * i <= x; i++) if(x % i == 0) return false;
	return true;
}
int main(){
	ios::sync_with_stdio(false);
	cin.tie(0);
	vector<int> p;
	for(int i = 2; i <= 70; i++) if(is_prime(i)) p.push_back(i);
	int t; cin >> t;
	const int LIM = 2000001;
	vector<int> mn(LIM, 1<<30);
	for(int i = 2; i < LIM; i++){
		int sq = sqrt(i);
		while(sq * sq < i) sq++;
		if(sq * sq == i) mn[i] = 2;
	}
	for(int c : p){
		if(c == 2) continue;
		for(int i = 2; ; i++){
			int g = LIM, r = 1;
			for(int j = 0; j < c; j++) g /= i, r *= i;
			if(g == 0) break;
			mn[r] = min(mn[r], c);
		}
	}
	for(int _ = 1; _ <= t; _++){
		long long n = _; cin >> n;
		U ans = 0, i6 = U(1) / 6;
		auto sqr_sum = [&](U l){
			return (l * (l + 1) * (l * 2 + 1)) * i6;
		};
		auto sqr_sum_rng = [&](U l, U r){
			return sqr_sum(r) - sqr_sum(l - 1);
		};
		for(int pr : p){
			if(pr == 2){
				for(long long i = 1, j; i * i <= n; i = j + 1){
					long long v = n / (i * i);
					j = sqrt(n / v);
					while(j * j <= (n / v)) j++;
					j--;
					if(i <= j){
						ans += sqr_sum_rng(i, j) * v;
					}
				}
			} else {
				for(int i = 2; ; i++){
					if(mn[i] < pr) continue;
					long long tn = n, r = 1;
					for(int j = 0; j < pr; j++) tn /= i, r *= i;
					if(tn == 0) break;
					ans += U(r) * tn;
				}
			}
		}
		cout << ans.val << '\n';
		// cout << solve(n) << endl;
	}
	return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class PPDIV{
	//SOLUTION BEGIN
	long MOD = (long)1e9+7;
	void pre() throws Exception{}
	long pow(long a, long p, long mx){
	    long o = 1;
	    while(p-- > 0){
	        o *= a;
	        if(o > mx)o = mx+1;
	    }
	    return o;
	}
	long sqrt(long n){
	    long x = (long)Math.sqrt(n);
	    while(x*x > n)x--;
	    while((x+1)*(x+1) <= n)x++;
	    return x;
	}
	long solve2(long n){
	    if(n < 4)return 0;
	    long curBase = 2, ans = 0, sqrtN = sqrt(n);
	    while(curBase <= n/curBase){
	        long V = n/(curBase*curBase);
	        long lo = sqrt(n/V);
	        ans += (n/(curBase*curBase)%MOD * (sumOfSquares(lo)+MOD-sumOfSquares(curBase-1)))%MOD;
	        if(ans >= MOD)ans -= MOD;
	        curBase = lo+1;
	    }
	    return ans;
	}
	long solve2BinarySearch(long n){
	    if(n < 4)return 0;
	    long curBase = 2, ans = 0, sqrtN = sqrt(n);
	    while(curBase <= n/curBase){
	        long lo = curBase, hi = sqrtN;
	        while(lo+1 < hi){
	            long mid = lo+(hi-lo)/2;
	            if(n/(curBase*curBase) == n/(mid*mid))lo = mid;
	            else hi = mid;
	        }
	        if(n/(curBase*curBase) == n/(hi*hi))lo = hi;
	        ans += (n/(curBase*curBase)%MOD * (sumOfSquares(lo)+MOD-sumOfSquares(curBase-1)))%MOD;
	        if(ans >= MOD)ans -= MOD;
	        curBase = lo+1;
	    }
	    return ans;
	}
	//Time complexity O(N^(1/k)*k)
	long solve(int power, long n){
	    long ans = 0;
	    for(int base = 2; ; base++){
	        long p = pow(base, power, n);
	        if(p > n)break;
	        ans += (n-n%p)%MOD;
	        if(ans >= MOD)ans -= MOD;
	    }
	    return ans;
	}
	void solve(int TC) throws Exception{
	    long n = nl();
	    int B = 60;
	    long[] sumOfPowers = new long[B];
	    sumOfPowers[2] = solve2(n);
	    for(int i = 3; i< B; i++)
	        sumOfPowers[i] = solve(i, n);
	    for(int i = B-1; i>= 2; i--){
	        for(int j = i+i; j< B; j+= i){
	            sumOfPowers[i] += MOD-sumOfPowers[j];
	            if(sumOfPowers[i] >= MOD)sumOfPowers[i] -= MOD;
	        }
	    }
	    long sum = n%MOD;
	    for(int i = 2; i< B; i++)sum = (sum+sumOfPowers[i])%MOD;
	    pn(sum);
	}
	long inv6 = inv(6);
	long sumOfSquares(long a, long b){
	    return (sumOfSquares(b)+MOD-sumOfSquares(a-1))%MOD;
	}
	long sumOfSquares(long n){
	    n%= MOD;
	    return (((n*(n+1))%MOD*(2*n+1))%MOD*inv6)%MOD;
	}
	long inv(long a){
	    long o = 1;
	    for(long p = MOD-2; p > 0; p>>=1){
	        if((p&1)==1)o = (o*a)%MOD;
	        a = (a*a)%MOD;
	    }
	    return o;
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	DecimalFormat df = new DecimalFormat("0.00000000000");
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    pre();
	    int T = (multipleTC)?ni():1;
	    for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new PPDIV().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

7 Likes

Why My solution is giving WA in last subtask , unable to find any mistake?

Not sure about python, but in C++ sqrt we give precision error
Test: 10^18 - 1
Correct answer: 127687580
Answer with sqrt instead of sqrtl: 480812650

3 Likes

Yes my answer is 480812650 , how to remove this error?

In C++ write sqrtl or do something like this:
inline int get_sqrt(long long x)
{
int r = sqrt(x);
//sqrt function can give a small error
while(r * (long long) r > x) r–;
while((r + 1) * (long long)(r + 1) <= x) r++;
return r;
}

Probably in python there is some high-precision sqrt, you may google it

3 Likes

use sqrtl instead of sqrt

Takes me days to solve. Then @taran_1407 uploads editorial saying, eh its easy. :laughing:

3 Likes

Can you confirm my time complexity is right?

Bad habits :stuck_out_tongue:

PS: I called it easy-med.

3 Likes

I don’t understand what is going on in your code, but for me it seems to be something bigger than O(n^1/3), isn’t it?

In sqrt() function use sqrt(1.0L*x) instead of sqrt(x)…

I dont know how by the same logic in c++ gives AC and the same logic in python gives WA :frowning: .
Python solution link: Click Here
C++ solution link: Click Here

Can anyone justify this ? :thinking: :thinking: :thinking: :thinking: :thinking:

Test your code on this test case: PPDIV - Editorial

can mobius inversion be used for this problem

Yeah, c++ gave right answer, python gave wrong answer :slight_smile: .

If you prefer video solution :

6 Likes

I was saying the same exactly on the comment section. This should have been fixed in the contest. Perfectly waisted my time. My code works on each test case. I have done random tests cases of all the powers from 2 to 18. But was only giving WA on TC <= 15 i.e Subtask 5.

The main error in python was
>>> a = sqrt(10 ** 18 - 1)
>>> a
1000000000.0

If we will see the people scoring 60. Most of them actually solved it correctly but just because of the precision error. Most of them are unable to get it AC completely. I myself submitted it more than 25 times but didn’t get more than 60. Also there is no reference answer to match your output for such high values for which brute code will take hours to solve.

2 Likes

yep, wasted so much time on this. Didn’t thought that issue is in the precision of sqrt function.