PROBLEM LINK:
Author: kalash04
Editorialist: kalash04
Tester: valiant_vidit
PROBLEM:
Kalash was given a task, to find the total permutations P of a n digit number formed by using digits 1 to 9 such that sum of any 3 consecutive digits is not odd. Since Kalash is bad at combinatorics, he asked Vidit to help him. Figure out all the answers given by him. Since the number of permutations P can be large, output P modulo 1000000007 (10^9 + 7).
PREREQUISITES:
- Math
- Combinatorics
DIFFICULTY:
EASY
EXPLANATION:
Here, we will use matrix exponentiation since the value of n is quite high. You can learn more about matrix exponentiation here. First, we initialize base matrix to store the base case since in 1-9 digits there are 5 odd digits and 4 even digits.
We then use the matrix power function to get the new matrix(nm). The init array stores the permutation for the 4 possible cases.
- 2 odd digits (5*5 = 25)
- 1 odd digit, 1 even digit (5*4 = 20)
- 1 even digit, 1 odd digit (4*5 = 20)
- 2 even digits (4*4 = 16)
We then use this to calculate the total permutations in the nested for loops.
SOLUTION:
Setter's Solution in Java
import java.util.*;
public class SMPM {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int cases = sc.nextInt();
while (cases-- > 0) {
long n = sc.nextLong();
long[][] base = { { 0, 0, 5, 0 }, { 4, 0, 0, 0 }, { 0, 5, 0, 0 }, { 0, 0, 0, 4 } };
long[][] nm = Matrix.matrixPower(base, n - 2);
long ans = 0;
long mod = 1000000007;
int[] init = { 25, 20, 20, 16 };
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
ans += nm[i][j] * init[j] % mod;
if (ans >= mod)
ans -= mod;
}
}
System.out.println(ans);
}
sc.close();
}
}
class Matrix {
static int N = 4; // size of the matrix
static long mod = 1000000007;
// compute pow(base, pow) O(N^3) * logN
static long[][] matrixPower(long[][] base, long pow) {
long[][] ans = new long[N][N];
// generate identity matrix
for (int i = 0; i < N; i++)
ans[i][i] = 1;
// binary exponentiation
while (pow != 0) {
if ((pow & 1) != 0)
ans = multiplyMatrix(ans, base);
base = multiplyMatrix(base, base);
pow >>= 1;
}
return ans;
}
// compute m * m2 O(N^3)
static long[][] multiplyMatrix(long[][] m, long[][] m2) {
long[][] ans = new long[N][N];
for (int i = 0; i < N; i++)
for (int j = 0; j < N; j++) {
ans[i][j] = 0;
for (int k = 0; k < N; k++) {
ans[i][j] += (m[i][k] * m2[k][j]) % mod;
}
}
return ans;
}
}