Guide to modular arithmetic (plus tricks) [CodeChef edition] [There is no other edition]

Oh no, it’s another one of those problems, you have to calculate the answer modulo some huge annoying number like 10^9 + 7. But it’s okay! Modular arithmetic is really quite simple, once you get to know it (and, of course, practice it).

Before I begin, some notation/clarifications:

  • \% means modulo, and a \% m is equivalent to a \mod m.
  • m is the modulus (such as 10^9 + 7), and for this post, m is always assumed to be prime. This will be important later.
  • The given code is in C++. Most other languages, however, should have similar syntax. Be careful with datatypes like int, even though I’ll use them here, it’s highly recommended to use a larger datatype like long long to avoid overflow. However, if you know what you’re doing, you can speed up operations by working with int (see the “Fast multiplication” details)
  • Modulo is a much heavier operation than addition, subtraction, or even multiplication. Try to avoid using it when it’s not necessary. This is what the “fast addition” and “fast subtraction” is for.
  • I generally use a as the number on the left side of the operation, and b as the number on the right (as in a + b).

Now let’s go!

Addition

Modular addition is quite simple. All you have to do is add the two numbers normally, then take their sum mod m. The cool thing about these operations like addition (also true for any arithmetic operation) is that it doesn’t matter how much you mod in between. That means ((a\%m)+(b\%m))\%m = (a + b)\%m.

Code
long long x = (a + b) % m;
Faster addition

If the two numbers you’re adding, a and b, are both between 0 and m - 1, then a + b is guaranteed to be less than 2m. So if the number is less than m, you don’t have to do anything, and if it’s at least as large as m, then you only have to subtract m once. So the following code is equivalent to the previous one:

long long x = a + b;
if (x >= m) x -= m;

Subtraction

Modular subtraction is very similar to addition. The only thing you have to be careful about is that in some languages (like C++ and Java), -a \% m might return a negative number when you actually want it to be positive. However, it’s guaranteed that the negative number is at least -(m - 1). So you can do modular subtraction by adding m again if necessary. It’s best to wrap this in a function as it’s often tedious.

Code
long long x = (a - b) % m;
if (x < 0) x += m;
Faster subtraction

If the two numbers you’re subtracting, a and b, are both between 0 and m - 1, then a - b is guaranteed to be between -m and m - 1. So if the number is less than 0, you only have to add m once, and otherwise, you don’t have to do anything.

long long x = a - b;
if (x < 0) x += m;

Multiplication

Modular multiplication is quite the same, you just multiply the two numbers together, then take the numbers mod m. With multiplication, you have to be very careful with overflow - make sure the two numbers you multiply are already less than m and are in long long datatypes, otherwise, the product might overflow before you can mod it with m.

Code
long long x = (a * b) % m;
Faster multiplication

Unfortunately, unlike addition and subtraction, there’s no way to do faster multiplication without requiring a modulo operation. This is because a \cdot b is only guaranteed to be less than m^2, which is not very useful.

Note by @anon49376339 (mentioned in this comment): there are other ways to speed up operations, though, at least on 32-bit compilers (I think this applies to CodeChef and most online judges). Doing arithmetic purely in int datatypes will be considerably faster, and you can make multiplication work by forcing a cast to long long like so:

Code
int a, b;
// stuff
int x = (1LL * a * b) % mod;

Division

Heh. Division’s not so simple. Read on, I’ll get back to it once it makes sense to.

(Fast) exponentiation

This is where it gets fun. The goal of this is to compute a to the power of b mod m, that is, a^b\%m, efficiently (in better than O(b) time). Clearly the naive algorithm would just be to multiply some x = 1 by a, b times. Let’s do better.

This may seem out of nowhere, but… trust me. Let’s write b in binary. So if b = 13, we’d write 1101. Then, start with some x = 1 which will end up with our final answer. We’ll, instead of directly computing a^{13}, split the process up into the powers of 2 that sum to b in binary. That means we’ll do it as a^1 \cdot a^4 \cdot a^8. Just one thing left, how do we get a^4 and a^8? We can compute a to each power of 2 efficiently by using the fact that a^{2k} = a^k \cdot a^k. That means we get a^2 as a \cdot a, a^4 as a^2 \cdot a^2, and so on. The rest is just implementation:

Code

How does this code work? Instead of explicitly checking if the i-th bit is set to multiply by a^{2^i}, we just divide b by 2 and keep checking the smallest bit. In addition, we compute a^{2^i} iteratively as we go.

long long mpow(long long a, long long b) { /* pow(a, b) % m */
  long long x = 1;
  while (b > 0) {
    if (b & 1) { /* fast way of checking if b is odd */
      x = (x * a) % m;
    }
    a = (a * a) % m; /* compute the next power of 2 */
    b >>= 1; /* this is equivalent to b = floor(b / 2) */
  }
  return x;
}

It’s best to wrap this code in a function.

This is O(\log{b}).

Modular multiplicative inverse (a.k.a. division)

Okay, now let’s get back to division. In modulo land, division is slightly different from what you’d expect. Instead of directly computing \dfrac{a}{b}, we’ll compute \dfrac{1}{b} and obtain \dfrac{a}{b} as a \cdot \dfrac{1}{b}. How do we even compute \dfrac{1}{b} (which I’ll call b^{-1} from now on)? Well, b^{-1} is the x such that (x \cdot b)\%m = 1. Why does that make sense? Because that’s exactly how it’s defined in normal arithmetic, except we can’t have fractions or decimals here, we need integers. For any prime m, this inverse exists for all b > 0, and furthermore is unique.

Well, we’ll want to do this in better time than the O(m) algorithm of trying all possible x. Let me pull out some random theorem: for a prime m, (b^{m - 1})\%m = 1.

Where the hell did that come from?

This is Fermat’s little theorem. A good proof of it is here, I’ll summarize it as well.

Proof

Let’s write out the sequence b, 2b, 3b, 4b, ..., (m - 1)b. Now we’ll take all of these elements mod m.

Something interesting (otherwise known as a “lemma”): they’re all unique. If they weren’t, and we had some i \cdot b = j \cdot b, then it must also be true that i\%m = j\%m, but that can’t be true because we only wrote out b multiplied by the numbers from 1 to m - 1.

Another interesting thing: none of them are 0. If an element i \cdot b reduces to 0 mod m, then either i or b must be divisible by m, which is impossible since m is prime and i, b < m.

So these numbers actually form the sequence of numbers from 1 to m - 1, but may be shuffled in some order. This means if we multiply all of the numbers of each sequence together, they’ll multiply to the same thing. So:

b \cdot 2b \cdot 3b \cdot 4b \cdot ... \cdot (m - 1)b = 1 \cdot 2 \cdot 3 \cdot 4 \cdot ... \cdot (m - 1) [all mod m]

Now on the left side, we’ll group all the b's together into b^{m - 1} (since there are exactly m - 1 b's)

b^{m - 1} \cdot 2 \cdot 3 \cdot 4 \cdot ... \cdot (m - 1) = 1 \cdot 2 \cdot 3 \cdot 4 \cdot ... \cdot (m - 1) [all mod m]

Notice that on both sides, we actually just have m - 1 factorial, or (m - 1)!:

b^{m - 1} \cdot (m - 1)! = (m - 1)! [all mod m]

Cancel out the factorials, and we have:

b^{m - 1} = 1 [all mod m]

and the proof is complete.

Cool, why is that useful? We can divide both sides by b to get: (b^{m - 2})\%m = b^{-1}. Wait, it’s that simple? We just need b^{m - 2}? Yep! And because you now know modular exponentiation, this is a piece of cake! Note: this also means modulo inverse is O(\log{m}).

Code

Refer to the code for modular exponentiation to know what mpow is.

/* divide a by b, modulo-style */
long long b_inverse = mpow(b, m - 2);
long long x = (a * b_inverse) % m;

Note: there are also other ways to do this, even possibly for non-prime m, some are here.

Some tricks

This is a (possibly) cool section that might be useful for even intermediate-level programmers. Feel free to comment with more if you know of some I haven’t thought of.

Computing all factorials up to n in O(n)
This should be fairly obvious after all you’ve been through.
1! = 1, n! = n \cdot (n - 1)!.

Code
long long fact[n + 1];
fact[0] = fact[1] = 1;
for (int i = 2; i <= n; i++) {
  fact[i] = (fact[i - 1] * i) % m;
}

Computing all inverse factorials up to n in O(n)
Maybe less obvious. Let’s use that ((n - 1)!)^{-1} = n \cdot (n!)^{-1}. Then, compute all factorials to get n!, and you only need the inverse of that, then you can work backward to get all factorials from n - 1 to 1. This is technically O(n + \log{m}) but the log is usually negligible.

Code
long long fact[n + 1];
fact[0] = fact[1] = 1;
for (int i = 2; i <= n; i++) {
  fact[i] = (fact[i - 1] * i) % m;
}

long long ifact[n + 1];
ifact[n] = mpow(fact[n], m - 2);
for (int i = n - 1; i >= 0; i--) {
  ifact[i] = ((i + 1) * ifact[i + 1]) % m;
}

Computing all i^{-1} up to n^{-1} in O(n)
Notice, now, how these are building off of each other? Let’s compute all factorials and inverse factorials, then just apply n^{-1} = \dfrac{(n - 1)!}{n!}. No need for any weird and annoying formulas! Also technically O(n + \log{m}).

Code
long long fact[n + 1];
fact[0] = fact[1] = 1;
for (int i = 2; i <= n; i++) {
  fact[i] = (fact[i - 1] * i) % m;
}

long long ifact[n + 1];
ifact[n] = mpow(fact[n], m - 2);
for (int i = n - 1; i >= 0; i--) {
  ifact[i] = ((i + 1) * ifact[i + 1]) % m;
}

long long inverse[n + 1];
for (int i = 1; i <= n; i++) {
  inverse[i] = (fact[i - 1] * ifact[i]) % m;
}

Computing all k^{-i} up to k^{-n} in O(n) (for a fixed k)
The simplest way to do this is to just take k^{-1}, then use that k^{-(i - 1)} \cdot k^{-1} = k^{-i} and multiply by k^{-1} a bunch of times, meaning that you only need one inverse. This is useful for, say, hashing. Also technically O(n + \log{m}).

Code
long long k_inverse[n + 1];
k_inverse[0] = 1;
long long k_inv = mpow(k, m - 2), cur = 1;
for (int i = 1; i <= n; i++) {
  cur = (cur * k_inv) % m;
  k_inverse[i] = cur;
}
96 Likes

I’ll let you guys steer these tutorials now (totally not because I’m running out of ideas) - which topics do you want to see coming up (from general to specific, easy to advanced, whatever you want)?

7 Likes

Can you try doing a tutorial for prefix function(KMP algorithm)? I find that a bit tough to understand.

5 Likes

Sure, but I’ll have to understand it myself first so it may take some time

3 Likes

anything (PYTH 3.6) would do, i only know the basics as of now
recommend me a link or something, mybe?
:neutral_face: :flushed: :cry:

Well… I think the best way to learn a language for this type of programming is through experience and finding out specific things - never be afraid to look up stuff like “how to sort an array in python” or “hashmap in python” or whatever you need. I looked for some sort of book online (via google, which you can also try) but didn’t find anything suited to Python. However, I think you can take any CP book and adapt its code to Python for the most part.

An interesting library with a lot of things implemented is pyrival. There are probably more like this out there, I’m not sure.

3 Likes

It’s a great initiative :clap:. It would be great if you can post some article on heavy light decomposition.

Sounds fun! I’ll tackle it after KMP because it seems even harder to explain intuitively.

4 Likes

May be the this is first time someone is starting writing articles edition series on discuss. Good luck! :v:

1 Like

@galencolin @everule1 @saurabhshadow can anyone please help me out with this question it requires fast prime factorization
problem link: SPOJ.com - Problem DIVSUM2
thanks in advance

Why here???

Just sieve and try all primes, O(\frac{n}{\log{n}}) works according to comments.

1 Like

my bad I will keep in mind from now onwards .
sieve? but n can be large as 10^16 and primes till 10^8 I don’t think that we can create an array of size 10^8
PS: I’m a beginner correct me if I’m wrong

Use bitset, all you’re interested in is the boolean value of “is it prime or not”

1 Like

Thanks Master but I code in python. something related to that ?

Segmented sieve then

Although, the best solution to your problem is to not use Python, it will make your life harder if you don’t know what you’re doing

3 Likes

@galencolin, how do we find out the modular inverse for a composite m? Fermat’s little theorem can’t be applied in this case.

1 Like

It may not exist. For example, there’s no x s.t. 1 \equiv 2 \cdot x \mod 6. This is true for any b, m where \gcd(b, m) > 1. Otherwise you can use the extended euclidean algorithm (method 2 in the link)

2 Likes

@galencolin, i get that

if (x.b)\%m = 1, then x is the modular inverse of b.

Now, Would b be the modular inverse of x too? If it is wrong, an example would be very helpful :slightly_smiling_face:.

unfortunately, i only know python :slightly_frowning_face:

1 Like

Yes, that is correct

1 Like