PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: Vibhu Garg
Testers: Satyam, Jatin Garg
Editorialist: Nishank Suresh
DIFFICULTY:
2015
PREREQUISITES:
Prime factorization, Fermat’s little theorem, sum of a geometric progression, binary exponentiation
PROBLEM:
Given an integer N, you can do the following operation exactly K times:
- Pick a positive divisor d of the current value of N and set N \gets N\times d
Find the sum of all possible final values of N that can be obtained, modulo 10^9 + 7.
EXPLANATION:
Consider the prime factorization of N, say
Note that multiplying N by a factor of itself cannot increase (or decrease) the number of distinct primes in its factorization: it only increases the value of the a_i of the existing p_i.
In particular, performing the move once allows us to set a_i to any value in the range [a_i, 2a_i].
Once you observe this, it is also not hard to see that after k moves, the final value of a_i can be anything in the range [a_i, 2^k \cdot a_i].
Proof
This can be proved with induction.
For k = 1, we already know that the range is [a_i, 2a_i]. Now, consider some k \gt 1.
By the inductive hypothesis, after the first k-1 moves, the exponent can be anything in the range [a_i, 2^{k-1} \cdot a_i].
Consider any x \in [a_i, 2^{k} \cdot a_i].
- If x \leq 2^{k-1} \cdot a_i, then we can reach x using the first k-1 moves and then not touch it on the k-th.
- If x \gt 2^{k-1} \cdot a_i, then use the first k-1 moves to reach 2^{k-1} \cdot a_i, and the k-th to reach x.
This completes the proof.
So, we know exactly which set of numbers can be formed, in terms of their prime factorizations. Now, we need to compute their sum.
This can be done by modifying a well-knowing algorithm that computes the sum of divisors of N from its prime factors.
If you haven't heard of this
Suppose N = p_1^{a_1}p_2^{a_2}\ldots p_r^{a_r}. Then, if S denotes the sum of all of its divisors, we have
It’s easy to see that this expression, when expanded out, gives us the sum of all divisors of N: each divisor is defined by choosing an exponent b_i for p_i such that 0 \leq b_i \leq a_i, and any such choice of b_i gives us a distinct factor.
Note that S is now the product of several geometric progressions, and each of those can be individually computed using the formula for the sum of a geometric progression.
Applying the above idea, we see that the answer to our problem is nothing but:
Each expression above is, once again, a geometric progression: starting from p_i^{a_i} with ratio p_i and 2^k \cdot a_i - a_i + 1 terms. Knowing all this information, the value of each expression can be calculated using the sum of GP formula in \mathcal{O}(\log MOD).
This computation is done r times in total, where r is the number of distinct prime factors N has. An easy bound for r is \log_2 N, so if N has been prime-factorized, the remaining part is accomplished in \mathcal{O}(\log N \log{MOD}) time.
Finally, we need to actually prime factorize N. Although faster algorithms exist, the constraints allow a simple \mathcal{O}(\sqrt N) factorization to also pass. without much issue.
There is one final caveat: when computing the sum of a GP for a given prime, you might need to compute a number of the form a^b \pmod {10^9 + 7} where b is extremely large. In fact, b can be as large as 2^k \cdot 20, which for k = 10^5 doesn’t fit into any datatype C++ has.
However, there is a solution to this: Fermat’s little theorem. According to this, when the modulo is prime, the exponent can be computed modulo MOD-1.
So, when computing a^b, first compute b modulo MOD-1 (which can itself be done in \mathcal{O}(\log{MOD}) using binary exponentiation), then use that computed value to compute a^b \pmod{MOD}.
TIME COMPLEXITY
\mathcal{O}(\sqrt{N} + \log{N}\log{MOD}) per test case.
CODE:
Setter's code (C++)
#include <bits/stdc++.h>
#define ll long long int
#define mod 1000000007
using namespace std;
ll binpow(ll a, ll b, ll m){
a %= m;
ll res = 1, mult = a;
while(b){
if(b & 1ll){
res = (res * mult) % m;
}
mult = (mult * mult) % m;
b >>= 1;
}
return res;
}
int main(){
#ifndef ONLINE_JUDGE
freopen("input6.txt", "r", stdin);
freopen("output6.txt", "w", stdout);
#endif
ll t;
cin >> t;
while(t--){
ll n, k;
cin >> n >> k;
map <ll, ll> primePowers;
while(n % 2 == 0){
n /= 2;
primePowers[2]++;
}
for(ll i = 3; i * i <= n; i += 2){
while(n % i == 0){
primePowers[i]++;
n /= i;
}
}
if(n > 1) primePowers[n]++;
ll ans = 1;
map <ll, ll> seriesSums;
for(auto p : primePowers){
ll pw = (binpow(2, k, mod - 1) * p.second + 1) % (mod - 1);
ll num = (binpow(p.first, pw, mod) - binpow(p.first, p.second, mod) + mod);
ll den = binpow(p.first - 1, mod - 2, mod);
seriesSums[p.first] = (num * den) % mod;
}
for(auto p : seriesSums){
ans = (ans * p.second) % mod;
}
cout << ans << endl;
}
return 0;
}
// 2
// 4
// 16
// 8
// 10
// 100
// 20
// 50
// 6
// 36
// 36
Tester (rivalq)'s code (C++)
// Jai Shree Ram
#include<bits/stdc++.h>
using namespace std;
#define rep(i,a,n) for(int i=a;i<n;i++)
#define ll long long
#define int long long
#define pb push_back
#define all(v) v.begin(),v.end()
#define endl "\n"
#define x first
#define y second
#define gcd(a,b) __gcd(a,b)
#define mem1(a) memset(a,-1,sizeof(a))
#define mem0(a) memset(a,0,sizeof(a))
#define sz(a) (int)a.size()
#define pii pair<int,int>
#define hell 1000000007
#define elasped_time 1.0 * clock() / CLOCKS_PER_SEC
template<typename T1,typename T2>istream& operator>>(istream& in,pair<T1,T2> &a){in>>a.x>>a.y;return in;}
template<typename T1,typename T2>ostream& operator<<(ostream& out,pair<T1,T2> a){out<<a.x<<" "<<a.y;return out;}
template<typename T,typename T1>T maxs(T &a,T1 b){if(b>a)a=b;return a;}
template<typename T,typename T1>T mins(T &a,T1 b){if(b<a)a=b;return a;}
// -------------------- Input Checker Start --------------------
long long readInt(long long l, long long r, char endd)
{
long long x = 0;
int cnt = 0, fi = -1;
bool is_neg = false;
while(true)
{
char g = getchar();
if(g == '-')
{
assert(fi == -1);
is_neg = true;
continue;
}
if('0' <= g && g <= '9')
{
x *= 10;
x += g - '0';
if(cnt == 0)
fi = g - '0';
cnt++;
assert(fi != 0 || cnt == 1);
assert(fi != 0 || is_neg == false);
assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
}
else if(g == endd)
{
if(is_neg)
x = -x;
if(!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(false);
}
return x;
}
else
{
assert(false);
}
}
}
string readString(int l, int r, char endd)
{
string ret = "";
int cnt = 0;
while(true)
{
char g = getchar();
assert(g != -1);
if(g == endd)
break;
cnt++;
ret += g;
}
assert(l <= cnt && cnt <= r);
return ret;
}
long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
void readEOF() { assert(getchar() == EOF); }
vector<int> readVectorInt(int n, long long l, long long r)
{
vector<int> a(n);
for(int i = 0; i < n - 1; i++)
a[i] = readIntSp(l, r);
a[n - 1] = readIntLn(l, r);
return a;
}
// -------------------- Input Checker End --------------------
const int MOD = hell;
struct mod_int {
int val;
mod_int(long long v = 0) {
if (v < 0)
v = v % MOD + MOD;
if (v >= MOD)
v %= MOD;
val = v;
}
static int mod_inv(int a, int m = MOD) {
int g = m, r = a, x = 0, y = 1;
while (r != 0) {
int q = g / r;
g %= r; swap(g, r);
x -= q * y; swap(x, y);
}
return x < 0 ? x + m : x;
}
explicit operator int() const {
return val;
}
mod_int& operator+=(const mod_int &other) {
val += other.val;
if (val >= MOD) val -= MOD;
return *this;
}
mod_int& operator-=(const mod_int &other) {
val -= other.val;
if (val < 0) val += MOD;
return *this;
}
static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
#if !defined(_WIN32) || defined(_WIN64)
return x % m;
#endif
unsigned x_high = x >> 32, x_low = (unsigned) x;
unsigned quot, rem;
asm("divl %4\n"
: "=a" (quot), "=d" (rem)
: "d" (x_high), "a" (x_low), "r" (m));
return rem;
}
mod_int& operator*=(const mod_int &other) {
val = fast_mod((uint64_t) val * other.val);
return *this;
}
mod_int& operator/=(const mod_int &other) {
return *this *= other.inv();
}
friend mod_int operator+(const mod_int &a, const mod_int &b) { return mod_int(a) += b; }
friend mod_int operator-(const mod_int &a, const mod_int &b) { return mod_int(a) -= b; }
friend mod_int operator*(const mod_int &a, const mod_int &b) { return mod_int(a) *= b; }
friend mod_int operator/(const mod_int &a, const mod_int &b) { return mod_int(a) /= b; }
mod_int& operator++() {
val = val == MOD - 1 ? 0 : val + 1;
return *this;
}
mod_int& operator--() {
val = val == 0 ? MOD - 1 : val - 1;
return *this;
}
mod_int operator++(int32_t) { mod_int before = *this; ++*this; return before; }
mod_int operator--(int32_t) { mod_int before = *this; --*this; return before; }
mod_int operator-() const {
return val == 0 ? 0 : MOD - val;
}
bool operator==(const mod_int &other) const { return val == other.val; }
bool operator!=(const mod_int &other) const { return val != other.val; }
mod_int inv() const {
return mod_inv(val);
}
mod_int pow(long long p) const {
assert(p >= 0);
mod_int a = *this, result = 1;
while (p > 0) {
if (p & 1)
result *= a;
a *= a;
p >>= 1;
}
return result;
}
friend ostream& operator<<(ostream &stream, const mod_int &m) {
return stream << m.val;
}
friend istream& operator >> (istream &stream, mod_int &m) {
return stream>>m.val;
}
};
#define SIEVE
const int N = 1e7 + 5;
int lp[N+1];
int pr[N];int t=0;
void sieve(){
for (int i=2; i<N; ++i) {
if (lp[i] == 0) {
lp[i] = i;
pr[t++]=i;
}
for (int j=0; j<t && pr[j]<=lp[i] && i*pr[j]<N; ++j)
lp[i * pr[j]] = pr[j];
}
}
int expo(int x,int y,int p){
int a=1;
x%=p;
while(y){
if(y&1)a=(a*x)%p;
x=(x*x)%p;
y/=2;
}
return a;
}
//(1 + p + p^2 .... p^(n - 1)) = (p^n - 1)/(p - 1)
//p^(2^k) % mod
int solve(){
int n = readIntSp(1,1e7);
int k = readIntLn(1,1e5);
vector<pii> primes;
mod_int ans = n;
while(n > 1){
int t = lp[n];
int cnt = 0;
while(t == lp[n]){
n /= t;
cnt++;
}
primes.push_back({t,cnt});
int pw = (expo(2,k,hell - 1) - 1)*cnt % (hell - 1);
pw = (pw + 1)%(hell - 1);
mod_int val = (mod_int(t).pow(pw) - 1)/(mod_int(t) - 1);
ans *= val;
}
cout << ans << endl;
return 0;
}
signed main(){
ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
//freopen("input.txt", "r", stdin);
//freopen("output.txt", "w", stdout);
#ifdef SIEVE
sieve();
#endif
#ifdef NCR
init();
#endif
int t = readIntLn(1,1000);
while(t--){
solve();
}
return 0;
}
Tester (satyam_343)'s code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
#define ll long long
const ll INF_MUL=1e13;
const ll INF_ADD=1e18;
#define pb push_back
#define mp make_pair
#define nline "\n"
#define f first
#define s second
#define pll pair<ll,ll>
#define all(x) x.begin(),x.end()
#define vl vector<ll>
#define vvl vector<vector<ll>>
#define vvvl vector<vector<vector<ll>>>
#ifndef ONLINE_JUDGE
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);
#endif
void _print(int x){cerr<<x;}
void _print(ll x){cerr<<x;}
void _print(char x){cerr<<x;}
void _print(string x){cerr<<x;}
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";}
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=1e9+7;
const ll MAX=500500;
ll binpow(ll a,ll b,ll MOD){
ll ans=1;
a%=MOD;
while(b){
if(b&1)
ans=(ans*a)%MOD;
b/=2;
a=(a*a)%MOD;
}
return ans;
}
ll inverse(ll a,ll MOD){
return binpow(a,MOD-2,MOD);
}
ll gt(ll n,ll freq,ll k){
debug(mp(n,mp(freq,k)));
ll pw=(binpow(2,k,MOD-1)*freq)%(MOD-1);
ll now=(binpow(n,pw+1,MOD)-binpow(n,freq,MOD)+MOD)*inverse(n-1,MOD);
now%=MOD;
return now;
}
void solve(){
ll n,k; cin>>n>>k;
ll ans=1;
for(ll i=2;i*i<=n;i++){
if(n%i){
continue;
}
ll freq=0;
while((n%i)==0){
n/=i;
freq++;
}
ans=(ans*gt(i,freq,k))%MOD;
}
if(n!=1){
ans=(ans*gt(n,1,k))%MOD;
}
cout<<ans<<nline;
return;
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
freopen("error.txt", "w", stderr);
#endif
ll test_cases=1;
cin>>test_cases;
while(test_cases--){
solve();
}
cout<<fixed<<setprecision(10);
cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n";
}
Editorialist's code (Python)
mod = int(10**9 + 7)
def solve(p, a, k):
# compute p^a + p^{a+1} + ... + p^{2^k a}
first = pow(p, a, mod)
ratio = p
terms = (pow(2, k, mod-1)*a - a + 1)%(mod - 1)
res = (first * (pow(ratio, terms, mod) - 1)) % mod
res *= pow(ratio - 1, mod-2, mod)
return res%mod
for _ in range(int(input())):
ans = 1
n, k = map(int, input().split())
for i in range(2, n+1):
if i*i > n:
break
if n%i != 0:
continue
ct = 0
while n%i == 0:
n //= i
ct += 1
ans *= solve(i, ct, k)
ans %= mod
if n > 1:
ans *= solve(n, 1, k)
ans %= mod
print(ans)