algorithm to find inverse modulo m

i have two questions how to find inverse of a number modulo m and n! modulo m

8 Likes

I hope I understood well, I asked same question (inverse modulo) in OLYMPIC tutorial.

alt text

Note: P is prime number

n! is simply

n! = ( n%MOD * (n-1)%MOD ) % MOD

11 Likes

Firstly, the inverse of an element a in the residue classes modulo m exists if and only if:
gcd(a,m) = 1 i.e. they are relatively prime

For finding the inverse, use the extended euclidean algorithm (Extended Euclidean algorithm - Wikipedia)

It finds the solution(x,y) to the following equation:

ax + by = gcd(a,b)

Taking b = m, the equation becomes:

ax + my = gcd(a,m)

since gcd(a,m) = 1

ax + my = 1

If we use the modulo m operation on both sides:

ax(mod m) + my(mod m) = 1(mod m)

ax(mod m) = 1(mod m)

=> x is the inverse of a modulo m

Given below is a recursive implementation of the extended euclidean algorithm:

void EE(int a, int b, int& x, int& y)
{
    if(a%b == 0)
    {
        x=0;
        y=1;
        return;
    }
    EE(b,a%b,x,y);
    int temp = x;
    x = y;
    y = temp - y*(a/b);
}

Using this function and the explanation above, the inverse function can be implemented as follows:

int inverse(int a, int m)
{
    int x,y;
    EE(a,m,x,y);
    if(x<0) x += m;
    return x;
}
18 Likes

if we have to calculate modulo inverse of a series of numbers with respect to a prime no.
We can use sieve to find a factor of composite numbers less than n. So for composite numbers inverse(i) = (inverse(i/factor(i)) * inverse(factor(i))) % m, and we can use either Extended Euclidean Algorithm or Fermat’s Theorem to find inverse for prime numbers. But we can still do better.

a * (m / a) + m % a = m

(a * (m / a) + m % a) mod m = m mod m, or

(a * (m / a) + m % a) mod m = 0, or

(- (m % a)) mod m = (a * (m / a)) mod m
.
Dividing both sides by (a * (m % a)), we get

– inverse(a) mod m = ((m/a) * inverse(m % a)) mod m

inverse(a) mod m = (- (m/a) * inverse(m % a)) mod m

Here’s a sample C++ code:

vector<int> inverseArray(int n, int m) {
vector<int> modInverse(n + 1,0);
modInverse[1] = 1;
for(int i = 2; i <= n; i++) {
    modInverse[i] = (-(m/i) * modInverse[m % i]) % m + m;
}
return modInverse;

}

The time complexity of the above code is O(n).

1 Like

Suppose we need to calculate nCr, in these cases, n > P. how to handle these cases?