PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Setter: Yusuf Kharodawala
Tester: Teja Vardhan Reddy
Editorialist: Taranpreet Singh
DIFFICULTY:
Easy-Medium
PREREQUISITES:
Dynamic Programming, Matrix Exponentiation.
PROBLEM:
Count the number of N length arrays such that all elements lie in the range [1, M] and no three consecutive elements are equal.
QUICK EXPLANATION
- We can count the number of i length sequences such that last two elements are different, denoted by f(i, 1), and the number of i length sequences such that last two elements are same, denoted by f(i, 2).
- In order to compute the number of i length arrays of the first type, we need the number of valid arrays of length i-1 and for each valid array, we have M-1 choices for the last element, giving f(i, 1) = (M-1)*(f(i-1, 1)+f(i-1, 2))
- It is easy to notice that f(i, 2) is just all i-1 length sequences with last two elements different, and last element is appended again, thus giving f(i, 2) = f(i-1, 1)
- We can write above equations as matrix and use matrix exponentiation to compute the number of valid sequences of length n, given by f(n, 1)+f(n, 2)
EXPLANATION
Let us consider the base cases first. For N = 1, we simply have M different arrays and all are valid.
Now, suppose we are constructing the array and want to decide element A_p. There can be two cases.
- Case 1: p > 2 and A_{p-1} = A_{p-2}
- Case 2: p \leq 2 or A_{p-1} \neq A_{p-2}
Let f(n, 2) denote the number of arrays of length n such that last 2 elements are same, and f(n, 1) denote the number of arrays of length n such that last 2 elements do not match.
We can see, that the final number of arrays of length n is given by f(n, 1)+f(n, 2) So we need to compute f(n, 1) and f(n, 2)
Let’s try computing f(i, t) from f(i-1, t)
For f(i, 1), we just need all valid arrays of length i-1 and fix A_i to be different from the last element, which can be done in M-1 ways. Hence, we can write f(i, 1) = (M-1)*(f(i-1, 1)+f(i-1, 2))
For f(i, 2), we need number of i-1 length arrays such that they have last element not equal to second last element, which is given by f(i-1, 1). Now, for each of these i-1 length arrays, we only have one choice, to append the same element as the last element of these i-1 length arrays. This gives f(i, 2) = f(i-1, 1)
Hence, using these recurrences, we can compute f(n, 1)+f(n, 2) in O(N) time, which is enough for the first subtask, but not for second subtask.
In order to speed up, let us write the recurrence in matrix form.
\begin{bmatrix} a & b \\ c & d \end{bmatrix} * \begin{bmatrix} f(i-1, 1)\\ f(i-1, 2) \end{bmatrix} = \begin{bmatrix} f(i, 1)\\ f(i, 2) \end{bmatrix}
Expanding the recurrence and comparing with the recurrence found above, we can find the values of a, b, c and d
Now, Let P= \begin{bmatrix} a & b \\ c & d \end{bmatrix} and F_i = \begin{bmatrix} f(i, 1)\\ f(i, 2) \end{bmatrix} We can write F_{i+1} = P*F_i = P*(P*F_{i-1} = P^2*F_{i-1} and so on.
Generalizing above, we get F_{N} = P^{N-1}*F_1 which we can compute with matrix exponentiation fast enough, solving the problem.
Cannot figure out matrices
P = \begin{bmatrix} M-1 & M-1 \\ 1 & 0 \end{bmatrix} F_1 = \begin{bmatrix} M \\ 0 \end{bmatrix}
Tutorial to read on Matrix Exponentiation here, here or this.
Must try problem here.
TIME COMPLEXITY
The time complexity is O(2^3*log(N)) per test case.
SOLUTIONS:
Setter's Solution
import java.util.*;
import java.io.*;
class Solution{
final static long mod = (long)1e9 + 7;
public static long[] matExp(long[][] x, long[] init, long y, long p){
int n = x.length;
long[][] tem = new long[n][], cur = new long[n][];
long[] res = init.clone(), t = init.clone();
long sum;
for(int i = 0; i < n; i++){
tem[i] = x[i].clone();
cur[i] = x[i].clone();
}
while (y > 0) {
if((y & 1) == 1){
for(int i = 0; i < n; i++){
sum = 0;
for(int j = 0; j < n; j++){
sum += t[j] * cur[j][i];
sum %= p;
}
res[i] = sum;
}
t = res.clone();
}
y >>= 1;
for(int i = 0; i < n; i++){
for(int j = 0; j < n; j++){
sum = 0;
for(int k = 0; k < n; k++){
sum += tem[i][k] * tem[k][j];
sum %= p;
}
cur[i][j] = sum;
}
}
for(int i = 0; i < n; i++) tem[i] = cur[i].clone();
}
return res;
}
public static void main(String[] args){
FastReader s = new FastReader(System.in);
PrintWriter w = new PrintWriter(System.out);
int t = s.nextInt();
while (t-- > 0){
long n = s.nextLong(), m = s.nextLong() % mod;
if(n==1) {
w.print(m + "\n");
continue;
}
long[] init = { (m * m - m) % mod, m };
m--;
if(m < 0) m += mod;
long[][] transform = { { m, 1 }, { m, 0 } };
long[] res = matExp(transform, init, n - 2, mod);
w.print((res[0] + res[1]) % mod + "\n");
}
w.close();
}
static class FastReader {
BufferedReader br;
StringTokenizer st;
public FastReader(InputStream i) { br = new BufferedReader(new InputStreamReader(i)); }
String next() {
while (st == null || !st.hasMoreElements()) {
try{ st = new StringTokenizer(br.readLine()); }
catch (IOException e) { e.printStackTrace(); }
}
return st.nextToken();
}
int nextInt() { return Integer.parseInt(next()); }
long nextLong() { return Long.parseLong(next()); }
}
}
Tester's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill - cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x) cout<<fixed<<val; // prints x digits after decimal in val
using namespace std;
using namespace __gnu_pbds;
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout)
#define primeDEN 727999983
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
// find_by_order() // order_of_key
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;
#define int ll
int c[2][2];
int mult(int a[2][2],int b[2][2]){
int i,j,k;
rep(i,2){
rep(j,2){
c[i][j]=0;
}
}
rep(i,2){
rep(j,2){
rep(k,2){
c[i][k]+=a[i][j]*b[j][k];
c[i][k]%=mod;
}
}
}
rep(i,2){
rep(j,2){
a[i][j]=c[i][j];
}
}
}
int a[2][2],res[2][2];
main(){
std::ios::sync_with_stdio(false); cin.tie(NULL);
int t;
cin>>t;
while(t--){
int n,m;
cin>>n>>m;
int val=m;
val%=mod;
int val2=val*val;
val2%=mod;
m--;
m%=mod;
n--;
a[0][0]=m;
a[0][1]=m;
a[1][0]=1;
a[1][1]=0;
res[1][1]=1;
res[0][0]=1;
res[1][0]=0;
res[0][1]=0;
while(n){
if(n%2){
mult(res,a);
}
mult(a,a);
n/=2;
}
//cout<<res[1][0]<< " "<<res[1][1]<<endl;
val2*=res[1][0];
val2%=mod;
val2+=res[1][1]*val;
val2%=mod;
cout<<val2<<endl;
}
return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class CARR{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
long n = nl(), m = nl();
int[] f = new int[]{(int)(m%MOD), 0};
int[][] p = new int[][]{
{(int)((m+MOD-1)%MOD), (int)((m+MOD-1)%MOD)},
{1, 0}
};
int[] o = pow(p, f, n-1);
pn(((long)o[0]+o[1])%MOD);
}
int[] mul(int[][] a, int[] v){
int m = a.length;
int n = v.length;
int[] w = new int[m];
for(int i = 0;i < m;i++){
long sum = 0;
for(int k = 0;k < n;k++){
sum += (long)a[i][k] * v[k];
if(sum >= BIG)sum -= BIG;
}
w[i] = (int)(sum % MOD);
}
return w;
}
public int[] pow(int[][] A, int[] v, long e){
for(int i = 0;i < v.length;i++){
if(v[i] >= MOD)v[i] %= MOD;
}
int[][] MUL = A;
for(;e > 0;e>>>=1) {
if((e&1)==1)v = mul(MUL, v);
MUL = mul(MUL, MUL);
}
return v;
}
int[][] mul(int[][] a, int[][] b){
int n = a.length;
int[][] c = new int[n][n];
for(int i = 0; i< n; i++){
long[] sum = new long[n];
for(int k = 0; k< n; k++){
for(int j = 0; j< n; j++){
sum[j] += (long)a[i][k]*b[k][j];
if(sum[j] >= BIG)sum[j] -= BIG;
}
}
for(int j = 0; j< n; j++){
c[i][j] = (int)(sum[j]%MOD);
}
}
return c;
}
int MOD = (int)1e9+7;
long BIG = 8*(long)MOD*MOD;
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
DecimalFormat df = new DecimalFormat("0.00000000000");
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new CARR().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Feel free to share your approach. Suggestions are welcomed as always.