SMPM Editorial

PROBLEM LINK:

So Many Permutations!

Author: kalash04
Editorialist: kalash04
Tester: valiant_vidit

PROBLEM:

Kalash was given a task, to find the total permutations P of a n digit number formed by using digits 1 to 9 such that sum of any 3 consecutive digits is not odd. Since Kalash is bad at combinatorics, he asked Vidit to help him. Figure out all the answers given by him. Since the number of permutations P can be large, output P modulo 1000000007 (10^9 + 7).

PREREQUISITES:

  • Math
  • Combinatorics

DIFFICULTY:

EASY

EXPLANATION:

Here, we will use matrix exponentiation since the value of n is quite high. You can learn more about matrix exponentiation here. First, we initialize base matrix to store the base case since in 1-9 digits there are 5 odd digits and 4 even digits.
We then use the matrix power function to get the new matrix(nm). The init array stores the permutation for the 4 possible cases.

  • 2 odd digits (5*5 = 25)
  • 1 odd digit, 1 even digit (5*4 = 20)
  • 1 even digit, 1 odd digit (4*5 = 20)
  • 2 even digits (4*4 = 16)
    We then use this to calculate the total permutations in the nested for loops.

SOLUTION:

Setter's Solution in Java
import java.util.*;

public class SMPM {

 public static void main(String[] args) {

   Scanner sc = new Scanner(System.in);

   int cases = sc.nextInt();

   while (cases-- > 0) {

     long n = sc.nextLong();

     long[][] base = { { 0, 0, 5, 0 }, { 4, 0, 0, 0 }, { 0, 5, 0, 0 }, { 0, 0, 0, 4 } };

     long[][] nm = Matrix.matrixPower(base, n - 2);

     long ans = 0;

     long mod = 1000000007;

     int[] init = { 25, 20, 20, 16 };

     for (int i = 0; i < 4; i++) {

       for (int j = 0; j < 4; j++) {

         ans += nm[i][j] * init[j] % mod;

         if (ans >= mod)

           ans -= mod;

       }

     }

     System.out.println(ans);

   }

   sc.close();

 }

}

class Matrix {

 static int N = 4; // size of the matrix

 static long mod = 1000000007;

 // compute pow(base, pow) O(N^3) * logN

 static long[][] matrixPower(long[][] base, long pow) {

   long[][] ans = new long[N][N];

   // generate identity matrix

   for (int i = 0; i < N; i++)

     ans[i][i] = 1;

   // binary exponentiation

   while (pow != 0) {

     if ((pow & 1) != 0)

       ans = multiplyMatrix(ans, base);

     base = multiplyMatrix(base, base);

     pow >>= 1;

   }

   return ans;

 }

 // compute m * m2 O(N^3)

 static long[][] multiplyMatrix(long[][] m, long[][] m2) {

   long[][] ans = new long[N][N];

   for (int i = 0; i < N; i++)

     for (int j = 0; j < N; j++) {

       ans[i][j] = 0;

       for (int k = 0; k < N; k++) {

         ans[i][j] += (m[i][k] * m2[k][j]) % mod;

       }

     }

   return ans;

 }

}

Great Question for Begginners. And amazing Edotrial.