SMPM Editorial

So Many Permutations!

Author: kalash04
Editorialist: kalash04
Tester: valiant_vidit

PROBLEM:

Kalash was given a task, to find the total permutations P of a n digit number formed by using digits 1 to 9 such that sum of any 3 consecutive digits is not odd. Since Kalash is bad at combinatorics, he asked Vidit to help him. Figure out all the answers given by him. Since the number of permutations P can be large, output P modulo 1000000007 (10^9 + 7).

PREREQUISITES:

• Math
• Combinatorics

EASY

EXPLANATION:

Here, we will use matrix exponentiation since the value of n is quite high. You can learn more about matrix exponentiation here. First, we initialize base matrix to store the base case since in 1-9 digits there are 5 odd digits and 4 even digits.
We then use the matrix power function to get the new matrix(nm). The init array stores the permutation for the 4 possible cases.

• 2 odd digits (5*5 = 25)
• 1 odd digit, 1 even digit (5*4 = 20)
• 1 even digit, 1 odd digit (4*5 = 20)
• 2 even digits (4*4 = 16)
We then use this to calculate the total permutations in the nested for loops.

SOLUTION:

Setter's Solution in Java
``````import java.util.*;

public class SMPM {

public static void main(String[] args) {

Scanner sc = new Scanner(System.in);

int cases = sc.nextInt();

while (cases-- > 0) {

long n = sc.nextLong();

long[][] base = { { 0, 0, 5, 0 }, { 4, 0, 0, 0 }, { 0, 5, 0, 0 }, { 0, 0, 0, 4 } };

long[][] nm = Matrix.matrixPower(base, n - 2);

long ans = 0;

long mod = 1000000007;

int[] init = { 25, 20, 20, 16 };

for (int i = 0; i < 4; i++) {

for (int j = 0; j < 4; j++) {

ans += nm[i][j] * init[j] % mod;

if (ans >= mod)

ans -= mod;

}

}

System.out.println(ans);

}

sc.close();

}

}

class Matrix {

static int N = 4; // size of the matrix

static long mod = 1000000007;

// compute pow(base, pow) O(N^3) * logN

static long[][] matrixPower(long[][] base, long pow) {

long[][] ans = new long[N][N];

// generate identity matrix

for (int i = 0; i < N; i++)

ans[i][i] = 1;

// binary exponentiation

while (pow != 0) {

if ((pow & 1) != 0)

ans = multiplyMatrix(ans, base);

base = multiplyMatrix(base, base);

pow >>= 1;

}

return ans;

}

// compute m * m2 O(N^3)

static long[][] multiplyMatrix(long[][] m, long[][] m2) {

long[][] ans = new long[N][N];

for (int i = 0; i < N; i++)

for (int j = 0; j < N; j++) {

ans[i][j] = 0;

for (int k = 0; k < N; k++) {

ans[i][j] += (m[i][k] * m2[k][j]) % mod;

}

}

return ans;

}

}
``````

Great Question for Begginners. And amazing Edotrial.