Please proofread the atricle for tutorial section on sieve methods.

Hi All,

Purpose of this article is to introduce beginners with simple concepts of number theory. I had written this article few years back. I thought of putting it in codechef tutorial section, but for now I feel that article has some issues and needs some formatting corrections. I have done my best to remove the formatting corrections. Please help me in this formatting correction, so that we can create a good article for beginners to learn. Any help/ feedback is appreciated.

Please also suggest problems to solve on the topic. The article is in community wiki, so please feel free to edit it.

Structure of the article.

First I will explain what does sieve mean. Then I will give you some examples with corresponding java codes and finally some exercises :slight_smile:

According to sieve , “sieve” means A utensil of wire mesh or closely perforated metal, used for straining, sifting, ricing, or puréeing. Similar to this definition sieve is a method of doing stuff, where you keep rejecting the things that are useless and reducing the space that you are currently looking at.

So much of thing in air, Let us now take examples.

Finding primes upto N

You have to print all primes upto N.

Method1

For all the numbers i from 1 to N, check if i is prime or not. If it is a prime, then print it.

Subproblem:
Checking whether a number K is prime.

Solution:

  • For all numbers i from 2 to K-1, check if K is divisible by i (as every number is divisible by 1 and itself). If yes, then not a prime else the number is a prime.
    Complexity of this solution : O(K)

  • Note that we do not need to check upto K-1, instead we can very well check upto sqrt(K).
    Proof:
    Let us say a number K = a * b. Note that atleast one of a and b <= sqrt(K) otherwise product of them would exceed K. So check just upto sqrt(K).

  • Either use some probabilistic method for finding whether a number is prime or not.
    More on this later. For now see link

Method 2:

Now here comes the idea of sieve. So initially assume that all numbers are prime. Then you try to sieve/refine your search range by not looking at the entire numbers but at the reduced space. eg. When you find all the numbers which are divisible by 2, You do not look into those again, as they are already not prime. So those numbers are sieved out. Now try for 3,4, upto n.

In other terms, You first try all the numbers which are divisible are 2 (except 2 itself),

Note that all those wont be primes. So you can remove those out of your consideration now. Now try the same for 3,4,… N. Finally you will end up with only prime numbers.

For understanding the actual thing going on, see the code.

So the code basically sets all the numbers upto N to be prime. Then for every number that is still prime, we set all of its multiples upto N to be non-prime.

import java.io.*;
import java.util.*;
import java.math.*;

public class Main {
    static boolean[] isPrime;

    public static void sieveOptimized(int N) {
        isPrime = new boolean[N + 1];
        
        for (int i = 2; i <= N; i++)
            isPrime[i] = true;
        for (int i = 2; i * i <= N; i++) {
                if (isPrime[i]) {
                // For further optimization, You can do instead of j += i, j += (2 * i).
                // Proof is left to reader :)
                for (int j = i * i; j <= N; j += i) 
                    isPrime[j] = false;
            }
        }
    }
    

    public static void sieve(int N) {
        isPrime = new boolean[N + 1];
        
        for (int i = 2; i <= N; i++)
            isPrime[i] = true;
        for (int i = 2; i <= N; i++) {
            if (isPrime[i]) {
                for (int j = i + i; j <= N; j += i) 
                    isPrime[j] = false;
            }
        }
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int limit = sc.nextInt();

        // call the sieve method
        sieveOptimized(limit);

        for (int i = 1; i <= limit; i++) {
            if (isPrime[i]) 
                System.out.printf("%d ",i);
        }
    }
}

Now comes the complexity part: Complexity of this code is O(n * logn).

Proof: (This proof comes a lot of times in various algorithms, So pay attention).

For all numbers i going from 2 to n, you need to check all the multiples of i. Note that number of multiples of i upto n are n / i. Hence Expression for the complexity will be written as n / 2 + n / 3 + … + 1. Take n common out of expression. Now it becomes n * (1 / 2 + … + 1/n).

Now as the expression is definitely greater than n. So adding an n to the expression won’t have any effect on the complexity, So add n to the expression and now it becomes n * (1 + 1 / 2 + … + 1/ n). The expression (1 + 1 / 2 + 1 / 3 + … 1 / n) is harmonic sum and it’s bounded by ln(n). Hence overall complexity is O(n * logn)

Proof of harmonic sum:
A simple Proof: Let us integrate 1 / x from 1 to n. (Note that we are doing integration, which means sum of area under the curve 1/x, which is greater than (1 + 1 / 2 + … + 1 / n). Value of the integral can be found easily. In fact integration of 1/x dx is ln(x).

Finding Sum of divisors of numbers upto N.

Now you have to find sum of divisors of all numbers upto N. Here we are not just considering proper divisors(numbers other 1 and itself), we are considering all the divisors. Here you can do something like this.

Here let us say divisorSum[i] denotes sum of divisors of i. Intially value of divisorSum[i] is equal to zero. Then for all the numbers i from 1 to n, We check all the multiples of i (let us say j) and add i to divisorSum[j].

In other words, Start from 1 and for all the numbers which are multiples of 1, increment their sumDiviors by 1.

Now do the same for 2,3, … N. Note that for a number i, you are doing this adding operation upto N/i times. So the complexity calculation is same as before.

Look the beautifully commented code.

import java.io.*;
import java.util.*;
import java.math.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int T = sc.nextInt();
        
        while (T-- > 0) {
            int n = sc.nextInt();
            int divisorSum[] = new int [n + 1];
            
            // For every number i, You know that 2*i,3*i,4*i   upto k*i such that k*i<=n, will have i 
            // as one of it's divisors, so add that to divisorSum[j]

            for (int i = 1; i <= n; i++) {
                for (int j = i; j <= n; j += i) {
                    divisorSum[j] += i;
                }
            }
            
            // Complexity of this code is O(n * logn)
            // Proof: Expression for the complexity can be written as n / 1 + n / 2 + ... + n / n
            // take n common
            // n * (1 + 1 / 2 + ..... + 1/n)
            // (1 + 1 / 2 + 1 / 3 + ... 1 / n) is harmonic sum and it's bounded by logn.
            // A simple Proof: Let us integrate 1 / x from 1 to n. 
            // (Note that we are doing integration, which means sum of area under the curve 1/x
            // which is greater than (1 + 1 / 2 + ... + 1 / n)
            // value of integration can be found easily
            // as integration of 1/x dx is ln(x)

            for (int i = 1; i <= n; i++) 
                System.out.printf("%d ", divisorSum[i]);
            
            System.out.printf("\n");
        }
    }
}

Finding No of divisors of numbers upto N.

This is also same as the previous example. Here instead of the storing sum in the array, store the number of divisors and for every multiple of i (say j), In the previous example, you were adding value i to divisorSum[j] , Here just increment the count of noOfDivisior[j] by one.

Code is very easy and hence omitted. Complexity is also same.

Sieve for finding euler phi function.

I will denote the euler phi function for a number n by phi(n). phi(n) is defined as follows.

It is count of how many number from 1 to n are coprime(having gcd value 1) to n.

For example phi(6) is 2 as 1,5 are coprime to 6.

Few properties of phi function :

  • phi(p) = p - 1. Where p is prime. All the numbers from 1 to p - 1 are coprime to p.
  • phi(a * b) = phi(a) * phi(b) where a and b are coprime.
  • phi(p^k) = p^k - p^(k - 1). Note that here ^ denotes power. Here all the numbers from 1 to p^k are coprime to p^k except all the multiples of p, which are exactly p^(k -1).

Method for finding:

  • Simple : For all numbers from 1 to n, check if it is coprime to n or not, If yes add that to your answer.

  • Let us say your number is n, which can be denoted as p1^k1 * p2^k2 … p_mk_m. Note that here p1, p2… pm are prime. Basically n is written in it’s prime representation.
    Then phi(n) would be [ p1^k1 - (p1^(k1-1) ) ] * … [p_m^k_m - (p_m^(k_m-1) )] . The expression for n can also be written as p1^k1 * p2^k2 * … * p_m^k_m * (1 - 1/p1) * (1 - 1/p2) … * (1 - 1/pm).
    which is equal to n * (1 - 1/p1) * (1 - 1/p2) … * (1 - 1/p_m).

See the code for details.

import java.io.*;
import java.util.*;
import java.math.*;

public class Main {
    
    public static boolean isPrime (int n) {
        if (n < 2)
            return false;
        for (int i = 2; i * i <= n; i++)
            if (n % i == 0)
                return false;
        return true;
    }
    
    public static int eulerPhiDirect (int n) {
        int result = n;
        for (int i = 2; i <= n; i++) {
            if (isPrime(i))
                result -= result / i;
        }
        
        return result;
    }
    
    public static int eulerPhi (int n) {
        int result = n;
        // think that it like this, initially all numbers have gcd with n to be equal to 1.
        // Hence value of result is n
        // according to formulla  n * (1 - 1/p1) * (1 - 1/p2) .... * (1 - 1/pm). We will be calculating value 
        // of the product upto i. that is n * (1 - 1/p1) * ... (1 - 1/p_i)
        // So let us take example of p1. value of result after one iteration will be n - n / p1, which is precisly
        // n * (1 - 1/p1). 
        // Similarily by induction hypthesis we can say finally the result will be as required.

        for (int i = 2; i * i <= n; i++) {
            if (n % i == 0) {
                result -= result / i;
                
                // By using while loop here, we are ensuing that all the numbers i will be prime.
                // because for every i, all it's multiples are gets removed.
                while (n % i == 0) {
                    n /= i;
                }
            }
        }
        
        if (n > 1) {
            result -= result / n;
        }
        
        return result;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int T = sc.nextInt();
        
        while (T-- > 0) {
            int n = sc.nextInt();

            // call the eulerPhiDirect or eulerPhi method. 
            // culerPhi is more faster as it does not take sqrt(n) for checking for prime

            int ans = eulerPhiDirect (n);
            System.out.println(ans);
        }
    }
}

Now Let us calculate value of sieve of all numbers from 1 to N.
Let us say eulerPhi[i] be the value of phi(i). Assign initially all the values of eulerPhi[i] to be i. Then for every prime p, for all multiples of p, we will multiply value of eulerPhi[i] by (1 - 1/p) as per the formula. multiplying eulerPhi[i] by (1 - 1/p) is exactly equal to eulerPhi[i] -= (eulerPhi[i] / p).

import java.io.*;
import java.util.*;
import java.math.*;

public class Main {
    
    public static boolean isPrime (int n) {
        if (n < 2)
            return false;
        for (int i = 2; i * i <= n; i++)
            if (n % i == 0)
                return false;
        return true;
    }
    
    private static int[] eulerPhi;
    
    public static void eulerSieve (int N) {
        eulerPhi = new int[N + 1];
        
        // set initial value of phi(i) = i
        for (int i = 1; i <= N; i++)
            eulerPhi[i] = i;
        
        // for every prime i, do as described in blog.
        // Note that we are using isPrime(i) function that takes sqrt(n) time. You are advised to write 
        // a seperate sieve of finding primes as described by me.
        // which will reduce the compleixty of this to n * log(n)
        // Proof of this is similar to previous ones. Left to reader.
        
        for (int i = 1; i <= N; i++) {
            if (isPrime(i))
                for (int j = i; j <= N; j += i)
                    eulerPhi[j] -= eulerPhi[j] / i;
        }
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int N = sc.nextInt();
        
        eulerSieve(N);
        
        for (int i = 1; i <= N; i++)
            System.out.printf("%d ",eulerPhi[i]);
    }
}

Sieve for finding inverse modulo m.

Inverse of a number a with respect to m is defined as b * a = 1 mod m. Then b is called inverse of a modulo m.

So first of all do you realize that it is not necessary to exist the inverse? Example for a = 4, b = 6.

Let us play around the expressions. a * b = 1 mod m can be written as a * b = 1 + m * k.
which can be written as a * b - m * k = 1. If the left side has a common factor of both a and m, means gcd(a,m) != 1, then note that right side won’t be divisible by that number, Hence no solution of

the equation when a and m has gcd non zero. Hence inverse will exist when a and m have non zero gcd.

Now solving a * b + m * (-k) = 1. write the same as a * b + m * k1 = 1.

So let us try to find solution of a * b + k * m = 1. This k is not equal to previous k, in fact it is -k. It is equal to k1.

So let us try to solve generic problem a * x + b * y = 1. where a and b can also be negative and gcd(a, b) = 1.

Let us try a simpler version of the same problem which is solving b * x1 + (a % b) * y = 1;

Now try to relate these equations.

a % b = a - [ a / b] * b. where [] denotes floor function. This is same as removing all multiples of b from a, which is exactly equal to a % b.

Now the equation turns into b * x1 + (a - [a/b] * b) * y1 = 1

which is a * y1 + b * (x1 - [a/b] * y1) = 1.

So this is recursive version of a * x + b * y = 1, where x is replaced with y1 and y is replaced with x1 - [a/b] * y1.

Things getting really complex.

(Note that this method is exactly similar to finding gcd of two numbers). Seeing the code will help you to understand more.)

Complexity of this code is same as gcd as it has exactly the same recurrence relation as of that. Time complexity of gcd(m, n) is log (max(m , n)). Hence we can consider time complexity of this method around O(logn).

import java.io.*;
import java.util.*;
import java.math.*;

class pair {
    public int x, y;
    
    pair (int x,int y) {
        this.x = x;
        this.y = y;
    }   
    
    boolean isEquals (pair p) {
        if (this.x == p.x && this.y == p.y) 
            return true;
        else 
            return false;
    }
}

public class Main {
    public static int gcd (int a, int b) {
        if (b == 0)
            return a;   
        return gcd (b, a % b);
    }
    
    public static pair solve (int a,int b) {
        if (b == 0) {
            // a * x + b * y = 1
            // here b = 0
            // hence a * x = 1
            // if a is not 1, then error else x = 1 and y = 0
            // Note that error wont be here, we will always find a which is not 1
            // as error case is already handle in solveThis function
            return new pair (1, 0);
        } else {
            // do the recursive call
            pair p = solve (b, a % b);
            int x1 = p.x;
            int y1 = p.y;
            
            int x = y1;
            int y = x1 - (a / b) * y1;
            
            return new pair (x, y);
        }
    }

    public static pair solveThis (int a, int b) {
        if (gcd (a, b) != 1)
            // (-1, -1) corresponds to error, that means no inverse exists in this case
            return new pair (-1, -1);
        else 
            return solve (a, b);
    }

    public static int modpow (long a, long b, long c) {
        long res = 1;
        while (b > 0) {
            if (b % 2 == 1) {
                res = (res * a) % c;
            }
            a = (a * a) % c;
            b >>= 1;
        }
        
        return (int) res;
    }

    public static int findInverseModuloPrime (int a, int p) {
        return modpow (a, p - 2, p);
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        
        int a = sc.nextInt();
        int m = sc.nextInt();
        
        pair p = solveThis (a, m);

        if (p.isEquals(new pair(-1, -1)))
            System.out.printf("Error, inverse does not exist");
        else 
            System.out.printf("%d %d\n", p.x, p.y);
    }
}

Another easier method

Now I will tell you another easier method . Generalized version of Fermat’s theorem says that a ^ phi(m) = 1 mod m where gcd(a, m) = 1. Fir finding inverse a ^ (phi(m) - 1) = 1 / a (mod m) = (inverse(a)) mod m.

Hence for finding inverse(a) mod m, You can just find a ^ (phi(m) - 1) by modular exponention method. In case of m being prime, As phi(m) = m - 1. So just find a ^ (m - 2) % m. This is what precisely computed by modpow function in the above function.

Complexity:

Now Sieve for finding inverse modulo m:
You have to find inverse(i) mod m for i ranging from 1 to n.
As complexity of modpow is log (n). and we are doing this for n numbers. Hence total complexity will be O(n * logn). Here I am going to describe a better method using sieve

Just use the identitiy inverse(i) = (-(m/i) * inverse(m % i) ) % m + m;

Proof:
It is good one. I do not want reveal it now. Try yourself. If you come up with it, post it in the comments, I would check whether it is correct or not?

import java.io.*;
import java.util.*;
import java.math.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
                
        int n = sc.nextInt(); // denotes the range upto which you want to find the value of modInverse[i]
        int m = sc.nextInt();

        int modInverse[] = new int[n + 1];
        
        modInverse[1] = 1; // this is you know 1 * 1 mod m = 1

        for (int i = 2; i <= n; i++) {
            modInverse[i] = (-(m/i) * modInverse[m % i]) % m + m;
        }
        
        for (int i = 2; i <= n; i++) 
            System.out.printf("%d ", modInverse[i]);
    }
}

Sample problems

Finally, Here are some few exercises for you to try

  1. GCDEX Spoj

    I will keep adding the list if I found some problems.

12 Likes

Helpful tutorial for revision of the all the concepts.

where can I get more of these kinds of tutorial. Is there a page on codechef which provide these awesome tutorials.
It was very helpful.
Can someone help me referring to more of these tutorials.

You can add the problem PRIME1 and give the idea of segmented sieve in the tutorial as its a very good and a unique concept which I saw last days many people had doubts in.

the complexity is n log log n instead of n log n source : MAXimal :: algo :: Решето Эратосфена (russian) for sieve for finding primes

1 Like

just search ‘tutorial’ in codechef discuss…you should be able to got all of them