PWMUL - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4

Setter: Yahor Dubovik
Tester: Harris Leung
Editorialist: Trung Dang

DIFFICULTY:

3265

PREREQUISITES:

Primitive Roots

PROBLEM:

You are given a prime number P and two integers A, B. Your task is to find the minimum value of (A^T \cdot B) \bmod{P} over all nonnegative integers T.

EXPLANATION:

Let g be the primitive root of P, and k be the “power cycle” of A (i.e. k is the smallest positive integer such that A^c \bmod P = 1). We can easily prove that k is a divisor of P - 1, so finding this k is a pretty trivial task.

When k is pretty small (around \sqrt{P - 1}), we can directly find the answer: find all possible residues of A^t \bmod P (that is A^0 \bmod P, A^1 \bmod P, \dots, A^{k - 1} \bmod P), multiply each by B and take the smallest value as the answer. What if k is large? We have the following two claims:

  • The answer itself is pretty small. Intuitively, since there are k residues of A^t \bmod P, and these values are pretty uniformly distributed, the answer should be around \frac{P - 1}{k} \le \sqrt{P - 1}, although we don’t have a definitive proof for this. This leads to the idea of iterating over the answer C and checking if C \equiv A^t \cdot B \pmod{P} for some integer t.
  • For any integer C, there exists an integer t such that A^t \cdot B \equiv C \pmod{P} if and only if B^k \equiv C^k \pmod{P}, where k is the power cycle of A. Proving it is the necessary condition is simple, C^k \equiv (A^t \cdot B)^k \equiv A^{kt} \cdot B^k \equiv B^k \pmod{P}. To prove that it is sufficient, let D = C \cdot B^{-1} \bmod P, then we need to prove that if D^k \bmod P = 1 then A^t \bmod P = D. To see this, let a' and b' be non-negative integers such that g^{a'} \bmod P = A and g^{d'} \bmod P = D. Since k is a power cycle of A, we have \gcd(a', P - 1) = k, which means there exists an integer x such that a' \cdot x \bmod (P - 1) = k. Furthermore, since D^k \bmod P = 1, we have k \mid d', so let y = \frac{d'}{k}. It is then pretty easy to see that a' \cdot (x \cdot y) \equiv d' \pmod{P - 1}, or A^{xy} \equiv D \pmod{P}.

Therefore, our algorithm is simply loop over the answer C from 1 upwards, checking whether C^{k} \equiv B^k \pmod{P} or not. Of course, this procedure takes O(\log{k}) (which is not fast enough when paired with our \sqrt{P - 1} bound for the answer), so we can speed up by optimizing using linear sieve up to the maximum possible answer and only calculating X^k \bmod P for primes X. This will speed up our complexity of this part to O(\sqrt{P} + \log{K} \cdot \frac{\sqrt{P}}{\log{\sqrt{P}}}) \approx O(\sqrt{K}), where \frac{\sqrt{P}}{\log{\sqrt{P}}} comes from the distribution of primes under \sqrt{P - 1}.

TIME COMPLEXITY:

Time complexity is O(\sqrt{P}) per test case.

SOLUTION:

Setter's Solution
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
ll a,b,p;
ll mod;
ll mult(ll a, ll b) {
    ll q = (ll) ((long double) a * (long double) b / (long double) mod);
    long long r = a * b - q * mod;
    while (r < 0) r += mod;
    while (r >= mod) r -= mod;
    return r;
}
ll pw(ll a, ll b) {
    if (b == 0) return 1;
    if (b & 1) return mult(a, pw(a, b - 1));
    ll res = pw(a, b / 2);
    return mult(res, res);
}
int main(){
    ios::sync_with_stdio(false);cin.tie(0);
    //  in(p);in(a);in(b);
//    freopen("input.txt", "r", stdin);
    cin >> p >> a >> b;
    mod = p;
    ll cyc=p-1;
    ll z=p-1;
    for(ll i=2; i*i<=p-1 ;i++){
        if(z%i==0){
            while(z%i==0){
                z/=i;
                if(pw(a,cyc/i)==1) cyc/=i;
            }
        }
    }
    if(z!=1){
        ll i=z;
        if(z%i==0){
            while(z%i==0){
                z/=i;
                if(pw(a,cyc/i)==1) cyc/=i;
            }
        }
    }
    if(cyc<=5e7){
        ll res=b;
        ll ans=b;
        for(int i=0; i<cyc ;i++){
            res=mult(res, a);
            ans=min(ans,res);
        }
        cout << ans << '\n';
    }
    else{
        ll mg=pw(b,cyc);
        if (1 == mg) {
            cout << 1;
            return 0;
        }
        vector<ll> lp(2);
        vector<ll> pws(2);
        //p^(alpha * c)
        pws[1] = 1;
        for (ll sz = 1; ; sz *= 2) {
            pws.resize(2 * sz + 1);
            lp.resize(2 * sz + 1);
            //j -> 2 * t
            for (ll start = 2; start <= 2 * sz; start++) {
                if (lp[start] == 0 || lp[start] == start) {
                    lp[start] = start;
                    pws[start] = pw(start, cyc);
                    for (ll j = 2 * start; j <= 2 * sz; j += start) {
                        if (lp[j] == 0) {
                            lp[j] = start;
                            pws[j] = mult(pws[start], pws[j / start]);
                            assert(lp[j / start] != 0);
                        }
                    }
                }
            }
            for (ll x = sz + 1; x <= 2 * sz; x++) {
                if (pws[x] == mg) {
                    cout << x << '\n';
                    return 0;
                }
            }
        }
    }
}
Tester's Solution
#include<bits/stdc++.h>
using namespace std;
typedef __int128 ll;
ll a,b,p;
void in(ll& x){
    long long y;cin >> y;
    x=y;
}
void out(ll& x){
    long long y=x;
    cout << y;
}
ll pw(ll x,ll y){
    if(y==0) return 1;
    if(y%2) return x*pw(x,y-1)%p;
    ll res=pw(x,y/2);
    return res*res%p;
}
ll gay[2000001];
int sp[2000001];
int pc=0;
int ps[1000001];
int main(){
    ios::sync_with_stdio(false);cin.tie(0);
    in(p);in(a);in(b);
    ll cyc=p-1;
    ll z=p-1;
    for(ll i=2; i*i<=p-1 ;i++){
        if(z%i==0){
            while(z%i==0){
                z/=i;
                if(pw(a,cyc/i)==1) cyc/=i;
            }
        }
    }
    if(z!=1){
        int i=z;
        if(z%i==0){
            while(z%i==0){
                z/=i;
                if(pw(a,cyc/i)==1) cyc/=i;
            }
        }
    }
    if(cyc<=2e7){
        ll res=b;
        ll ans=b;
        for(int i=0; i<cyc ;i++){
            res=res*a%p;
            ans=min(ans,res);
        }
        out(ans);
    }
    else{
        ll mg=pw(b,cyc);
        ll iu=2e6;
        for(ll i=1; ;i++){
        	if(i==1){
        		gay[i]=1;
        		if(gay[i]==mg) return out(i),0;
        		continue;
			}
			if(sp[i]==0){
				sp[i]=i;
				gay[i]=pw(i,cyc);
        		if(gay[i]==mg) return out(i),0;
        		ps[++pc]=i;
			}
			else{
				if(gay[i]==mg) return out(i),0;
			}
			for(ll j=1; j<=pc && ps[j]<=sp[i] && ps[j]*i<=iu ;j++){
				sp[ps[j]*i]=ps[j];
				gay[ps[j]*i]=gay[ps[j]]*gay[i]%p;
			}
        }
    }
}
Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;

using i128 = __int128_t;

long long m;

template<typename T>
T fast_pow(T a, long long p) {
    T ans = 1;
    for (; p > 0; p /= 2, (a *= a) %= m) {
        if (p & 1) {
            (ans *= a) %= m;
        }
    }
    return ans;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    long long _a, _b; cin >> m >> _a >> _b;
    i128 a = _a, b = _b;
    long long cyc = m - 1, tmp = m - 1;
    for (int i = 2; 1LL * i * i <= tmp; i++) {
        while (tmp % i == 0) {
            tmp /= i;
        }
        while (cyc % i == 0 && fast_pow(a, cyc / i) == 1) {
            cyc /= i;
        }
    }
    if (tmp > 1) {
        while (cyc % tmp == 0 && fast_pow(a, cyc / tmp) == 1) {
            cyc /= tmp;
        }
    }
    i128 ans = m;
    if (cyc <= 2E7) {
        for (i128 cur = 1; cyc > 0; cyc--, (cur *= a) %= m) {
            ans = min(ans, cur * b % m);
        }
    } else {
        i128 tar = fast_pow(b, cyc);
        const int lim = 2E6;
        vector<i128> pre_pw(lim);
        vector<int> lp(lim), pr;
        for (int i = 1; i < lim ; i++) {
            if (i == 1) {
                pre_pw[i] = 1;
            } else {
                if (lp[i] == 0) {
                    lp[i] = i;
                    pre_pw[i] = fast_pow(i, cyc);
                    pr.push_back(i);
                }
                for (int j = 0; j < pr.size() && pr[j] <= lp[i] && i * pr[j] < lim; j++) {
                    lp[i * pr[j]] = pr[j];
                }
                pre_pw[i] = pre_pw[i / lp[i]] * pre_pw[lp[i]] % m;
            }
            if (pre_pw[i] == tar) {
                ans = i;
                break;
            }
        }
    }
    cout << (long long)ans;
}

@kuroni , problem link is broken?

1 Like

Ah sorry, thanks for catching that :slight_smile:

1 Like

This shouldn’t be the case when one is writing an official editorial for an official-rated contest.
Everything should be conclusive and provable in it.

Does anyone have a proof for that or some different solution?

To be exactly fair, I think the proof for the claim does exist, but it would be super lengthy to include such a proof here. You can try to search for results along the line of “distribution of exponentials modulo a prime” or “distribution of powers of primitive roots modulo a prime” to see.

Thanks. Will do.