Oh no, it’s another one of those problems, you have to calculate the answer modulo some huge annoying number like 10^9 + 7. But it’s okay! Modular arithmetic is really quite simple, once you get to know it (and, of course, practice it).
Before I begin, some notation/clarifications:
- \% means modulo, and a \% m is equivalent to a \mod m.
- m is the modulus (such as 10^9 + 7), and for this post, m is always assumed to be prime. This will be important later.
- The given code is in C++. Most other languages, however, should have similar syntax. Be careful with datatypes like
int
, even though I’ll use them here, it’s highly recommended to use a larger datatype likelong long
to avoid overflow. However, if you know what you’re doing, you can speed up operations by working withint
(see the “Fast multiplication” details) - Modulo is a much heavier operation than addition, subtraction, or even multiplication. Try to avoid using it when it’s not necessary. This is what the “fast addition” and “fast subtraction” is for.
- I generally use a as the number on the left side of the operation, and b as the number on the right (as in a + b).
Now let’s go!
Addition
Modular addition is quite simple. All you have to do is add the two numbers normally, then take their sum mod m. The cool thing about these operations like addition (also true for any arithmetic operation) is that it doesn’t matter how much you mod in between. That means ((a\%m)+(b\%m))\%m = (a + b)\%m.
Code
long long x = (a + b) % m;
Faster addition
If the two numbers you’re adding, a and b, are both between 0 and m - 1, then a + b is guaranteed to be less than 2m. So if the number is less than m, you don’t have to do anything, and if it’s at least as large as m, then you only have to subtract m once. So the following code is equivalent to the previous one:
long long x = a + b;
if (x >= m) x -= m;
Subtraction
Modular subtraction is very similar to addition. The only thing you have to be careful about is that in some languages (like C++ and Java), -a \% m might return a negative number when you actually want it to be positive. However, it’s guaranteed that the negative number is at least -(m - 1). So you can do modular subtraction by adding m again if necessary. It’s best to wrap this in a function as it’s often tedious.
Code
long long x = (a - b) % m;
if (x < 0) x += m;
Faster subtraction
If the two numbers you’re subtracting, a and b, are both between 0 and m - 1, then a - b is guaranteed to be between -m and m - 1. So if the number is less than 0, you only have to add m once, and otherwise, you don’t have to do anything.
long long x = a - b;
if (x < 0) x += m;
Multiplication
Modular multiplication is quite the same, you just multiply the two numbers together, then take the numbers mod m. With multiplication, you have to be very careful with overflow - make sure the two numbers you multiply are already less than m and are in long long
datatypes, otherwise, the product might overflow before you can mod it with m.
Code
long long x = (a * b) % m;
Faster multiplication
Unfortunately, unlike addition and subtraction, there’s no way to do faster multiplication without requiring a modulo operation. This is because a \cdot b is only guaranteed to be less than m^2, which is not very useful.
Note by @anon49376339 (mentioned in this comment): there are other ways to speed up operations, though, at least on 32-bit compilers (I think this applies to CodeChef and most online judges). Doing arithmetic purely in int
datatypes will be considerably faster, and you can make multiplication work by forcing a cast to long long
like so:
Code
int a, b;
// stuff
int x = (1LL * a * b) % mod;
Division
Heh. Division’s not so simple. Read on, I’ll get back to it once it makes sense to.
(Fast) exponentiation
This is where it gets fun. The goal of this is to compute a to the power of b mod m, that is, a^b\%m, efficiently (in better than O(b) time). Clearly the naive algorithm would just be to multiply some x = 1 by a, b times. Let’s do better.
This may seem out of nowhere, but… trust me. Let’s write b in binary. So if b = 13, we’d write 1101. Then, start with some x = 1 which will end up with our final answer. We’ll, instead of directly computing a^{13}, split the process up into the powers of 2 that sum to b in binary. That means we’ll do it as a^1 \cdot a^4 \cdot a^8. Just one thing left, how do we get a^4 and a^8? We can compute a to each power of 2 efficiently by using the fact that a^{2k} = a^k \cdot a^k. That means we get a^2 as a \cdot a, a^4 as a^2 \cdot a^2, and so on. The rest is just implementation:
Code
How does this code work? Instead of explicitly checking if the i-th bit is set to multiply by a^{2^i}, we just divide b by 2 and keep checking the smallest bit. In addition, we compute a^{2^i} iteratively as we go.
long long mpow(long long a, long long b) { /* pow(a, b) % m */
long long x = 1;
while (b > 0) {
if (b & 1) { /* fast way of checking if b is odd */
x = (x * a) % m;
}
a = (a * a) % m; /* compute the next power of 2 */
b >>= 1; /* this is equivalent to b = floor(b / 2) */
}
return x;
}
It’s best to wrap this code in a function.
This is O(\log{b}).
Modular multiplicative inverse (a.k.a. division)
Okay, now let’s get back to division. In modulo land, division is slightly different from what you’d expect. Instead of directly computing \dfrac{a}{b}, we’ll compute \dfrac{1}{b} and obtain \dfrac{a}{b} as a \cdot \dfrac{1}{b}. How do we even compute \dfrac{1}{b} (which I’ll call b^{-1} from now on)? Well, b^{-1} is the x such that (x \cdot b)\%m = 1. Why does that make sense? Because that’s exactly how it’s defined in normal arithmetic, except we can’t have fractions or decimals here, we need integers. For any prime m, this inverse exists for all b > 0, and furthermore is unique.
Well, we’ll want to do this in better time than the O(m) algorithm of trying all possible x. Let me pull out some random theorem: for a prime m, (b^{m - 1})\%m = 1.
Where the hell did that come from?
This is Fermat’s little theorem. A good proof of it is here, I’ll summarize it as well.
Proof
Let’s write out the sequence b, 2b, 3b, 4b, ..., (m - 1)b. Now we’ll take all of these elements mod m.
Something interesting (otherwise known as a “lemma”): they’re all unique. If they weren’t, and we had some i \cdot b = j \cdot b, then it must also be true that i\%m = j\%m, but that can’t be true because we only wrote out b multiplied by the numbers from 1 to m - 1.
Another interesting thing: none of them are 0. If an element i \cdot b reduces to 0 mod m, then either i or b must be divisible by m, which is impossible since m is prime and i, b < m.
So these numbers actually form the sequence of numbers from 1 to m - 1, but may be shuffled in some order. This means if we multiply all of the numbers of each sequence together, they’ll multiply to the same thing. So:
b \cdot 2b \cdot 3b \cdot 4b \cdot ... \cdot (m - 1)b = 1 \cdot 2 \cdot 3 \cdot 4 \cdot ... \cdot (m - 1) [all mod m]
Now on the left side, we’ll group all the b's together into b^{m - 1} (since there are exactly m - 1 b's)
b^{m - 1} \cdot 2 \cdot 3 \cdot 4 \cdot ... \cdot (m - 1) = 1 \cdot 2 \cdot 3 \cdot 4 \cdot ... \cdot (m - 1) [all mod m]
Notice that on both sides, we actually just have m - 1 factorial, or (m - 1)!:
b^{m - 1} \cdot (m - 1)! = (m - 1)! [all mod m]
Cancel out the factorials, and we have:
b^{m - 1} = 1 [all mod m]
and the proof is complete.
Cool, why is that useful? We can divide both sides by b to get: (b^{m - 2})\%m = b^{-1}. Wait, it’s that simple? We just need b^{m - 2}? Yep! And because you now know modular exponentiation, this is a piece of cake! Note: this also means modulo inverse is O(\log{m}).
Code
Refer to the code for modular exponentiation to know what mpow
is.
/* divide a by b, modulo-style */
long long b_inverse = mpow(b, m - 2);
long long x = (a * b_inverse) % m;
Note: there are also other ways to do this, even possibly for non-prime m, some are here.
Some tricks
This is a (possibly) cool section that might be useful for even intermediate-level programmers. Feel free to comment with more if you know of some I haven’t thought of.
Computing all factorials up to n in O(n)
This should be fairly obvious after all you’ve been through.
1! = 1, n! = n \cdot (n - 1)!.
Code
long long fact[n + 1];
fact[0] = fact[1] = 1;
for (int i = 2; i <= n; i++) {
fact[i] = (fact[i - 1] * i) % m;
}
Computing all inverse factorials up to n in O(n)
Maybe less obvious. Let’s use that ((n - 1)!)^{-1} = n \cdot (n!)^{-1}. Then, compute all factorials to get n!, and you only need the inverse of that, then you can work backward to get all factorials from n - 1 to 1. This is technically O(n + \log{m}) but the log is usually negligible.
Code
long long fact[n + 1];
fact[0] = fact[1] = 1;
for (int i = 2; i <= n; i++) {
fact[i] = (fact[i - 1] * i) % m;
}
long long ifact[n + 1];
ifact[n] = mpow(fact[n], m - 2);
for (int i = n - 1; i >= 0; i--) {
ifact[i] = ((i + 1) * ifact[i + 1]) % m;
}
Computing all i^{-1} up to n^{-1} in O(n)
Notice, now, how these are building off of each other? Let’s compute all factorials and inverse factorials, then just apply n^{-1} = \dfrac{(n - 1)!}{n!}. No need for any weird and annoying formulas! Also technically O(n + \log{m}).
Code
long long fact[n + 1];
fact[0] = fact[1] = 1;
for (int i = 2; i <= n; i++) {
fact[i] = (fact[i - 1] * i) % m;
}
long long ifact[n + 1];
ifact[n] = mpow(fact[n], m - 2);
for (int i = n - 1; i >= 0; i--) {
ifact[i] = ((i + 1) * ifact[i + 1]) % m;
}
long long inverse[n + 1];
for (int i = 1; i <= n; i++) {
inverse[i] = (fact[i - 1] * ifact[i]) % m;
}
Computing all k^{-i} up to k^{-n} in O(n) (for a fixed k)
The simplest way to do this is to just take k^{-1}, then use that k^{-(i - 1)} \cdot k^{-1} = k^{-i} and multiply by k^{-1} a bunch of times, meaning that you only need one inverse. This is useful for, say, hashing. Also technically O(n + \log{m}).
Code
long long k_inverse[n + 1];
k_inverse[0] = 1;
long long k_inv = mpow(k, m - 2), cur = 1;
for (int i = 1; i <= n; i++) {
cur = (cur * k_inv) % m;
k_inverse[i] = cur;
}