DISTNUMS - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Vibhu Garg
Testers: Satyam, Jatin Garg
Editorialist: Nishank Suresh

DIFFICULTY:

2015

PREREQUISITES:

Prime factorization, Fermat’s little theorem, sum of a geometric progression, binary exponentiation

PROBLEM:

Given an integer N, you can do the following operation exactly K times:

  • Pick a positive divisor d of the current value of N and set N \gets N\times d

Find the sum of all possible final values of N that can be obtained, modulo 10^9 + 7.

EXPLANATION:

Consider the prime factorization of N, say

N = p_1^{a_1}p_2^{a_2}\ldots p_r^{a_r}

Note that multiplying N by a factor of itself cannot increase (or decrease) the number of distinct primes in its factorization: it only increases the value of the a_i of the existing p_i.
In particular, performing the move once allows us to set a_i to any value in the range [a_i, 2a_i].

Once you observe this, it is also not hard to see that after k moves, the final value of a_i can be anything in the range [a_i, 2^k \cdot a_i].

Proof

This can be proved with induction.
For k = 1, we already know that the range is [a_i, 2a_i]. Now, consider some k \gt 1.

By the inductive hypothesis, after the first k-1 moves, the exponent can be anything in the range [a_i, 2^{k-1} \cdot a_i].
Consider any x \in [a_i, 2^{k} \cdot a_i].

  • If x \leq 2^{k-1} \cdot a_i, then we can reach x using the first k-1 moves and then not touch it on the k-th.
  • If x \gt 2^{k-1} \cdot a_i, then use the first k-1 moves to reach 2^{k-1} \cdot a_i, and the k-th to reach x.

This completes the proof.

So, we know exactly which set of numbers can be formed, in terms of their prime factorizations. Now, we need to compute their sum.

This can be done by modifying a well-knowing algorithm that computes the sum of divisors of N from its prime factors.

If you haven't heard of this

Suppose N = p_1^{a_1}p_2^{a_2}\ldots p_r^{a_r}. Then, if S denotes the sum of all of its divisors, we have

S = (1 + p_1 + p_1^2 + \ldots + p_1^{a_1}) (1 + p_2 + \ldots + p_2^{a_2})\ldots (1 + p_r + \ldots + p_r^{a_r})

It’s easy to see that this expression, when expanded out, gives us the sum of all divisors of N: each divisor is defined by choosing an exponent b_i for p_i such that 0 \leq b_i \leq a_i, and any such choice of b_i gives us a distinct factor.

Note that S is now the product of several geometric progressions, and each of those can be individually computed using the formula for the sum of a geometric progression.

Applying the above idea, we see that the answer to our problem is nothing but:

(p_1^{a_1} + p_1^{a_1 + 1} + \ldots + p_1^{2^k \cdot a_1})(p_2^{a_2} + \ldots + p_2^{2^k \cdot a_2})\ldots (p_r^{a_r} + \ldots + p_r^{2^k \cdot a_r})

Each expression above is, once again, a geometric progression: starting from p_i^{a_i} with ratio p_i and 2^k \cdot a_i - a_i + 1 terms. Knowing all this information, the value of each expression can be calculated using the sum of GP formula in \mathcal{O}(\log MOD).

This computation is done r times in total, where r is the number of distinct prime factors N has. An easy bound for r is \log_2 N, so if N has been prime-factorized, the remaining part is accomplished in \mathcal{O}(\log N \log{MOD}) time.

Finally, we need to actually prime factorize N. Although faster algorithms exist, the constraints allow a simple \mathcal{O}(\sqrt N) factorization to also pass. without much issue.

There is one final caveat: when computing the sum of a GP for a given prime, you might need to compute a number of the form a^b \pmod {10^9 + 7} where b is extremely large. In fact, b can be as large as 2^k \cdot 20, which for k = 10^5 doesn’t fit into any datatype C++ has.

However, there is a solution to this: Fermat’s little theorem. According to this, when the modulo is prime, the exponent can be computed modulo MOD-1.
So, when computing a^b, first compute b modulo MOD-1 (which can itself be done in \mathcal{O}(\log{MOD}) using binary exponentiation), then use that computed value to compute a^b \pmod{MOD}.

TIME COMPLEXITY

\mathcal{O}(\sqrt{N} + \log{N}\log{MOD}) per test case.

CODE:

Setter's code (C++)
#include <bits/stdc++.h>
#define ll long long int
#define mod 1000000007
using namespace std;

ll binpow(ll a, ll b, ll m){
    a %= m;
    ll res = 1, mult = a;

    while(b){
        if(b & 1ll){
            res = (res * mult) % m;
        }
        mult = (mult * mult) % m;
        b >>= 1;
    }

    return res;
}

int main(){

    #ifndef ONLINE_JUDGE                 
    freopen("input6.txt", "r", stdin);                                           
    freopen("output6.txt", "w", stdout);                        
    #endif 

    ll t;
    cin >> t;

    while(t--){
        ll n, k;
        cin >> n >> k;

        map <ll, ll> primePowers;

        while(n % 2 == 0){
            n /= 2;
            primePowers[2]++;
        }

        for(ll i = 3; i * i <= n; i += 2){
            while(n % i == 0){
                primePowers[i]++;
                n /= i;
            }
        }

        if(n > 1) primePowers[n]++;

        ll ans = 1;
        map <ll, ll> seriesSums;

        for(auto p : primePowers){
            ll pw = (binpow(2, k, mod - 1) * p.second + 1) % (mod - 1);
            ll num = (binpow(p.first, pw, mod) - binpow(p.first, p.second, mod) + mod);
            ll den = binpow(p.first - 1, mod - 2, mod);
            seriesSums[p.first] = (num * den) % mod;
        }

        for(auto p : seriesSums){
            ans = (ans * p.second) % mod;
        }

        cout << ans << endl;
    }

    return 0;


}

// 2
// 4
// 16
// 8


// 10
// 100
// 20
// 50

// 6
// 36
// 36
 
Tester (rivalq)'s code (C++)
// Jai Shree Ram  
  
#include<bits/stdc++.h>
using namespace std;

#define rep(i,a,n)     for(int i=a;i<n;i++)
#define ll             long long
#define int            long long
#define pb             push_back
#define all(v)         v.begin(),v.end()
#define endl           "\n"
#define x              first
#define y              second
#define gcd(a,b)       __gcd(a,b)
#define mem1(a)        memset(a,-1,sizeof(a))
#define mem0(a)        memset(a,0,sizeof(a))
#define sz(a)          (int)a.size()
#define pii            pair<int,int>
#define hell           1000000007
#define elasped_time   1.0 * clock() / CLOCKS_PER_SEC



template<typename T1,typename T2>istream& operator>>(istream& in,pair<T1,T2> &a){in>>a.x>>a.y;return in;}
template<typename T1,typename T2>ostream& operator<<(ostream& out,pair<T1,T2> a){out<<a.x<<" "<<a.y;return out;}
template<typename T,typename T1>T maxs(T &a,T1 b){if(b>a)a=b;return a;}
template<typename T,typename T1>T mins(T &a,T1 b){if(b<a)a=b;return a;}

// -------------------- Input Checker Start --------------------
 
long long readInt(long long l, long long r, char endd)
{
    long long x = 0;
    int cnt = 0, fi = -1;
    bool is_neg = false;
    while(true)
    {
        char g = getchar();
        if(g == '-')
        {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if('0' <= g && g <= '9')
        {
            x *= 10;
            x += g - '0';
            if(cnt == 0)
                fi = g - '0';
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);
            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if(g == endd)
        {
            if(is_neg)
                x = -x;
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(false);
            }
            return x;
        }
        else
        {
            assert(false);
        }
    }
}
 
string readString(int l, int r, char endd)
{
    string ret = "";
    int cnt = 0;
    while(true)
    {
        char g = getchar();
        assert(g != -1);
        if(g == endd)
            break;
        cnt++;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}
 
long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
void readEOF() { assert(getchar() == EOF); }
 
vector<int> readVectorInt(int n, long long l, long long r)
{
    vector<int> a(n);
    for(int i = 0; i < n - 1; i++)
        a[i] = readIntSp(l, r);
    a[n - 1] = readIntLn(l, r);
    return a;
}
 
// -------------------- Input Checker End --------------------

const int MOD = hell;
 
struct mod_int {
    int val;
 
    mod_int(long long v = 0) {
        if (v < 0)
            v = v % MOD + MOD;
 
        if (v >= MOD)
            v %= MOD;
 
        val = v;
    }
 
    static int mod_inv(int a, int m = MOD) {
        int g = m, r = a, x = 0, y = 1;
 
        while (r != 0) {
            int q = g / r;
            g %= r; swap(g, r);
            x -= q * y; swap(x, y);
        }
 
        return x < 0 ? x + m : x;
    }
 
    explicit operator int() const {
        return val;
    }
 
    mod_int& operator+=(const mod_int &other) {
        val += other.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
 
    mod_int& operator-=(const mod_int &other) {
        val -= other.val;
        if (val < 0) val += MOD;
        return *this;
    }
 
    static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
           #if !defined(_WIN32) || defined(_WIN64)
                return x % m;
           #endif
           unsigned x_high = x >> 32, x_low = (unsigned) x;
           unsigned quot, rem;
           asm("divl %4\n"
            : "=a" (quot), "=d" (rem)
            : "d" (x_high), "a" (x_low), "r" (m));
           return rem;
    }
 
    mod_int& operator*=(const mod_int &other) {
        val = fast_mod((uint64_t) val * other.val);
        return *this;
    }
 
    mod_int& operator/=(const mod_int &other) {
        return *this *= other.inv();
    }
 
    friend mod_int operator+(const mod_int &a, const mod_int &b) { return mod_int(a) += b; }
    friend mod_int operator-(const mod_int &a, const mod_int &b) { return mod_int(a) -= b; }
    friend mod_int operator*(const mod_int &a, const mod_int &b) { return mod_int(a) *= b; }
    friend mod_int operator/(const mod_int &a, const mod_int &b) { return mod_int(a) /= b; }
 
    mod_int& operator++() {
        val = val == MOD - 1 ? 0 : val + 1;
        return *this;
    }
 
    mod_int& operator--() {
        val = val == 0 ? MOD - 1 : val - 1;
        return *this;
    }
 
    mod_int operator++(int32_t) { mod_int before = *this; ++*this; return before; }
    mod_int operator--(int32_t) { mod_int before = *this; --*this; return before; }
 
    mod_int operator-() const {
        return val == 0 ? 0 : MOD - val;
    }
 
    bool operator==(const mod_int &other) const { return val == other.val; }
    bool operator!=(const mod_int &other) const { return val != other.val; }
 
    mod_int inv() const {
        return mod_inv(val);
    }
 
    mod_int pow(long long p) const {
        assert(p >= 0);
        mod_int a = *this, result = 1;
 
        while (p > 0) {
            if (p & 1)
                result *= a;
 
            a *= a;
            p >>= 1;
        }
 
        return result;
    }
 
    friend ostream& operator<<(ostream &stream, const mod_int &m) {
        return stream << m.val;
    }
    friend istream& operator >> (istream &stream, mod_int &m) {
        return stream>>m.val;   
    }
};

#define SIEVE

const int N = 1e7 + 5;

int lp[N+1];
int pr[N];int t=0;

void sieve(){
    for (int i=2; i<N; ++i) {
            if (lp[i] == 0) {
                lp[i] = i;
                pr[t++]=i;
            }
        for (int j=0; j<t && pr[j]<=lp[i] && i*pr[j]<N; ++j)
            lp[i * pr[j]] = pr[j];
    }
}

int expo(int x,int y,int p){
    int a=1;
    x%=p;
    while(y){
        if(y&1)a=(a*x)%p;
        x=(x*x)%p;
        y/=2;
    }
    return a;
}


//(1 + p + p^2 .... p^(n - 1)) = (p^n - 1)/(p - 1)
//p^(2^k) % mod

int solve(){
 		int n = readIntSp(1,1e7);
 		int k = readIntLn(1,1e5);
 		vector<pii> primes;
 		mod_int ans = n;
 		while(n > 1){
 			int t = lp[n];
 			int cnt = 0;
 			while(t == lp[n]){
 				n /= t;
 				cnt++;
 			}
 			primes.push_back({t,cnt});
 			int pw = (expo(2,k,hell - 1) - 1)*cnt % (hell - 1);
 		        pw = (pw + 1)%(hell - 1);
 			mod_int val = (mod_int(t).pow(pw) - 1)/(mod_int(t) - 1);
 			ans *= val;
 		}
 		cout << ans << endl;

 return 0;
}
signed main(){
    ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    //freopen("input.txt", "r", stdin);
    //freopen("output.txt", "w", stdout);
    #ifdef SIEVE
    sieve();
    #endif
    #ifdef NCR
    init();
    #endif
    int t = readIntLn(1,1000);
    while(t--){
        solve();
    }
    return 0;
}
Tester (satyam_343)'s code (C++)
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
#define ll long long  
const ll INF_MUL=1e13;
const ll INF_ADD=1e18;  
#define pb push_back               
#define mp make_pair        
#define nline "\n"                         
#define f first                                          
#define s second                                             
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()   
#define vl vector<ll>         
#define vvl vector<vector<ll>>    
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);  
#endif     
void _print(int x){cerr<<x;}
void _print(ll x){cerr<<x;}  
void _print(char x){cerr<<x;} 
void _print(string x){cerr<<x;}     
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count()); 
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=1e9+7;       
const ll MAX=500500;  
ll binpow(ll a,ll b,ll MOD){
    ll ans=1;
    a%=MOD;
    while(b){
        if(b&1)
            ans=(ans*a)%MOD;
        b/=2;
        a=(a*a)%MOD;
    }
    return ans;
}
ll inverse(ll a,ll MOD){
    return binpow(a,MOD-2,MOD);
}
ll gt(ll n,ll freq,ll k){
    debug(mp(n,mp(freq,k)));
    ll pw=(binpow(2,k,MOD-1)*freq)%(MOD-1);
    ll now=(binpow(n,pw+1,MOD)-binpow(n,freq,MOD)+MOD)*inverse(n-1,MOD);
    now%=MOD;
    return now;
}
void solve(){               
    ll n,k; cin>>n>>k;
    ll ans=1;
    for(ll i=2;i*i<=n;i++){
        if(n%i){  
            continue;
        }
        ll freq=0;       
        while((n%i)==0){
            n/=i;
            freq++; 
        }
        ans=(ans*gt(i,freq,k))%MOD;
    }
    if(n!=1){  
        ans=(ans*gt(n,1,k))%MOD;
    }
    cout<<ans<<nline;
    return;
}                 
int main()                                                                             
{                     
    ios_base::sync_with_stdio(false);                           
    cin.tie(NULL);                          
    #ifndef ONLINE_JUDGE                 
    freopen("input.txt", "r", stdin);                                           
    freopen("output.txt", "w", stdout);  
    freopen("error.txt", "w", stderr);                        
    #endif    
    ll test_cases=1;                   
    cin>>test_cases;
    while(test_cases--){   
        solve();     
    }
    cout<<fixed<<setprecision(10);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
} 
Editorialist's code (Python)
mod = int(10**9 + 7)

def solve(p, a, k):
    # compute p^a + p^{a+1} + ... + p^{2^k a}
    first = pow(p, a, mod)
    ratio = p
    terms = (pow(2, k, mod-1)*a - a + 1)%(mod - 1)
    
    res = (first * (pow(ratio, terms, mod) - 1)) % mod
    res *= pow(ratio - 1, mod-2, mod)
    return res%mod

for _ in range(int(input())):
    ans = 1
    n, k = map(int, input().split())
    for i in range(2, n+1):
        if i*i > n:
            break
        if n%i != 0:
            continue
        ct = 0
        while n%i == 0:
            n //= i
            ct += 1
        ans *= solve(i, ct, k)
        ans %= mod
    if n > 1:
        ans *= solve(n, 1, k)
        ans %= mod
    print(ans)
5 Likes

How can anyone solve this? I think I ll never be able to :frowning:

3 Likes

In many tasks involving divisors, prime numbers (and more generally, prime factorization) are generally some of the first thing you should look at.
This is some intuition you can gain by doing some number theory, essentially prime numbers are the building blocks of anything to do with divisors.

Once you look at the prime factorization and write out exactly how an operation affects it, the rest of the solution follows pretty naturally I think. The only hard part is the final step where you need to compute the sum of a GP (which is taught in school or can be found online) and knowledge of Fermat’s little theorem (which is one of the first things you’ll learn when you start reading about number theory).

1 Like