# CARR - Editorial

Setter: Yusuf Kharodawala
Tester: Teja Vardhan Reddy
Editorialist: Taranpreet Singh

Easy-Medium

# PREREQUISITES:

Dynamic Programming, Matrix Exponentiation.

# PROBLEM:

Count the number of N length arrays such that all elements lie in the range [1, M] and no three consecutive elements are equal.

# QUICK EXPLANATION

• We can count the number of i length sequences such that last two elements are different, denoted by f(i, 1), and the number of i length sequences such that last two elements are same, denoted by f(i, 2).
• In order to compute the number of i length arrays of the first type, we need the number of valid arrays of length i-1 and for each valid array, we have M-1 choices for the last element, giving f(i, 1) = (M-1)*(f(i-1, 1)+f(i-1, 2))
• It is easy to notice that f(i, 2) is just all i-1 length sequences with last two elements different, and last element is appended again, thus giving f(i, 2) = f(i-1, 1)
• We can write above equations as matrix and use matrix exponentiation to compute the number of valid sequences of length n, given by f(n, 1)+f(n, 2)

# EXPLANATION

Let us consider the base cases first. For N = 1, we simply have M different arrays and all are valid.

Now, suppose we are constructing the array and want to decide element A_p. There can be two cases.

• Case 1: p > 2 and A_{p-1} = A_{p-2}
• Case 2: p \leq 2 or A_{p-1} \neq A_{p-2}

Let f(n, 2) denote the number of arrays of length n such that last 2 elements are same, and f(n, 1) denote the number of arrays of length n such that last 2 elements do not match.

We can see, that the final number of arrays of length n is given by f(n, 1)+f(n, 2) So we need to compute f(n, 1) and f(n, 2)

Let’s try computing f(i, t) from f(i-1, t)

For f(i, 1), we just need all valid arrays of length i-1 and fix A_i to be different from the last element, which can be done in M-1 ways. Hence, we can write f(i, 1) = (M-1)*(f(i-1, 1)+f(i-1, 2))

For f(i, 2), we need number of i-1 length arrays such that they have last element not equal to second last element, which is given by f(i-1, 1). Now, for each of these i-1 length arrays, we only have one choice, to append the same element as the last element of these i-1 length arrays. This gives f(i, 2) = f(i-1, 1)

Hence, using these recurrences, we can compute f(n, 1)+f(n, 2) in O(N) time, which is enough for the first subtask, but not for second subtask.

In order to speed up, let us write the recurrence in matrix form.

\begin{bmatrix} a & b \\ c & d \end{bmatrix} * \begin{bmatrix} f(i-1, 1)\\ f(i-1, 2) \end{bmatrix} = \begin{bmatrix} f(i, 1)\\ f(i, 2) \end{bmatrix}

Expanding the recurrence and comparing with the recurrence found above, we can find the values of a, b, c and d

Now, Let P= \begin{bmatrix} a & b \\ c & d \end{bmatrix} and F_i = \begin{bmatrix} f(i, 1)\\ f(i, 2) \end{bmatrix} We can write F_{i+1} = P*F_i = P*(P*F_{i-1} = P^2*F_{i-1} and so on.

Generalizing above, we get F_{N} = P^{N-1}*F_1 which we can compute with matrix exponentiation fast enough, solving the problem.

Cannot figure out matrices

P = \begin{bmatrix} M-1 & M-1 \\ 1 & 0 \end{bmatrix} F_1 = \begin{bmatrix} M \\ 0 \end{bmatrix}

Tutorial to read on Matrix Exponentiation here, here or this.
Must try problem here.

# TIME COMPLEXITY

The time complexity is O(2^3*log(N)) per test case.

# SOLUTIONS:

Setter's Solution
import java.util.*;
import java.io.*;
class Solution{

final static long mod = (long)1e9 + 7;

public static long[] matExp(long[][] x, long[] init, long y, long p){
int n = x.length;
long[][] tem = new long[n][], cur = new long[n][];
long[] res = init.clone(), t = init.clone();
long sum;
for(int i = 0; i < n; i++){
tem[i] = x[i].clone();
cur[i] = x[i].clone();
}
while (y > 0) {
if((y & 1) == 1){
for(int i = 0; i < n; i++){
sum = 0;
for(int j = 0; j < n; j++){
sum += t[j] * cur[j][i];
sum %= p;
}
res[i] = sum;
}
t = res.clone();
}
y >>= 1;
for(int i = 0; i < n; i++){
for(int j = 0; j < n; j++){
sum = 0;
for(int k = 0; k < n; k++){
sum += tem[i][k] * tem[k][j];
sum %= p;
}
cur[i][j] = sum;
}
}
for(int i = 0; i < n; i++) tem[i] = cur[i].clone();
}
return res;
}

public static void main(String[] args){

PrintWriter w = new PrintWriter(System.out);

int t = s.nextInt();

while (t-- > 0){
long n = s.nextLong(), m = s.nextLong() % mod;
if(n==1) {
w.print(m + "\n");
continue;
}
long[] init = { (m * m - m) % mod, m };
m--;
if(m < 0) m += mod;
long[][] transform = { { m, 1 }, { m, 0 } };
long[] res = matExp(transform, init, n - 2, mod);
w.print((res[0] + res[1]) % mod + "\n");
}

w.close();

}

StringTokenizer st;

String next() {
while (st == null || !st.hasMoreElements()) {
try{ st = new StringTokenizer(br.readLine()); }
catch (IOException  e) { e.printStackTrace(); }
}
return st.nextToken();
}

int nextInt() { return Integer.parseInt(next()); }

long nextLong() { return Long.parseLong(next()); }

}
}

Tester's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val

using namespace std;
using namespace __gnu_pbds;

#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout)
#define primeDEN 727999983

// find_by_order()  // order_of_key
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;
#define int ll
int c[2][2];
int mult(int a[2][2],int b[2][2]){
int i,j,k;
rep(i,2){
rep(j,2){
c[i][j]=0;
}
}
rep(i,2){
rep(j,2){
rep(k,2){
c[i][k]+=a[i][j]*b[j][k];
c[i][k]%=mod;
}
}
}
rep(i,2){
rep(j,2){
a[i][j]=c[i][j];
}
}
}
int a[2][2],res[2][2];
main(){
std::ios::sync_with_stdio(false); cin.tie(NULL);
int t;
cin>>t;
while(t--){
int n,m;
cin>>n>>m;
int val=m;
val%=mod;
int val2=val*val;
val2%=mod;
m--;
m%=mod;
n--;
a[0][0]=m;
a[0][1]=m;
a[1][0]=1;
a[1][1]=0;

res[1][1]=1;
res[0][0]=1;
res[1][0]=0;
res[0][1]=0;
while(n){
if(n%2){
mult(res,a);
}
mult(a,a);
n/=2;
}
//cout<<res[1][0]<< " "<<res[1][1]<<endl;
val2*=res[1][0];
val2%=mod;
val2+=res[1][1]*val;
val2%=mod;
cout<<val2<<endl;
}

return 0;
}

Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class CARR{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
long n = nl(), m = nl();
int[] f = new int[]{(int)(m%MOD), 0};
int[][] p = new int[][]{
{(int)((m+MOD-1)%MOD), (int)((m+MOD-1)%MOD)},
{1, 0}
};
int[] o = pow(p, f, n-1);
pn(((long)o[0]+o[1])%MOD);
}
int[] mul(int[][] a, int[] v){
int m = a.length;
int n = v.length;
int[] w = new int[m];
for(int i = 0;i < m;i++){
long sum = 0;
for(int k = 0;k < n;k++){
sum += (long)a[i][k] * v[k];
if(sum >= BIG)sum -= BIG;
}
w[i] = (int)(sum % MOD);
}
return w;
}
public int[] pow(int[][] A, int[] v, long e){
for(int i = 0;i < v.length;i++){
if(v[i] >= MOD)v[i] %= MOD;
}
int[][] MUL = A;
for(;e > 0;e>>>=1) {
if((e&1)==1)v = mul(MUL, v);
MUL = mul(MUL, MUL);
}
return v;
}
int[][] mul(int[][] a, int[][] b){
int n = a.length;
int[][] c = new int[n][n];
for(int i = 0; i< n; i++){
long[] sum = new long[n];
for(int k = 0; k< n; k++){
for(int j = 0; j< n; j++){
sum[j] += (long)a[i][k]*b[k][j];
if(sum[j] >= BIG)sum[j] -= BIG;
}
}
for(int j = 0; j< n; j++){
c[i][j] = (int)(sum[j]%MOD);
}
}
return c;
}
int MOD = (int)1e9+7;
long BIG = 8*(long)MOD*MOD;
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
DecimalFormat df = new DecimalFormat("0.00000000000");
static boolean multipleTC = true;
void run() throws Exception{
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new CARR().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}

StringTokenizer st;
}

}

String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
}catch (IOException  e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}

String nextLine() throws Exception{
String str = "";
try{
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}


Feel free to share your approach. Suggestions are welcomed as always.

11 Likes

I have also used same approach but my recurrence relation is different,
f(i) = f(i-1)m - f(i-3)(m-1) for i>3
and my matrix is:
m 0 1-m
1 0 0
0 1 0

My Submission
I am getting tle on second subtask.
If Anyone could help me with this…

1 Like

@taran_adm Can you submit solutions that are a bit easy to understand?
They are written very badly and I’m facing difficulty in making out what’ happening

1 Like

I dont know why my apporach is wrong

Here’s my approach :

i used constructive way of solving the problem by generalizing a recurrence relation into a formula
i.e given ‘n’ and ‘m’

we know that number of sequences for n=1 is m
number of sequences for n=2 : = m * m
now with n=3 , number of sequences possible is m * m * m - m
‘- m’ bcoz there will be ‘m’ sequences in which all three places will have same elements

let the answer for a particular n be dp[n]

Now i know the answer for n=3 , i.e all sequences such that no three adjacent elements are same

So know
number of sequences with n=4 is given by
==> (number of sequences with n=3 with i.e ( dp[3] ) * ( m ) - m
here ‘-m’ has been done bcoz there will be ‘m’ occurances of sequence with length 3 such that elements at place 2 and 3 will have the same value

so now generalizing for any n
dp[i]=dp[i-1]*m -m

and this gives for a particular n>=2 :
dp[n]= m^n + m * ( m^(n-2) - 1 ) / (m-1 )

it gives correct answer for the test cases given in question
I dont know why this gave WA

SOMEBODY PLEASE FIND FLAW IN MY APPROACH

1 Like

3 Likes

I suppose reading this would help. These codes are mostly standard implementation of Matrix Exponentiaion

1 Like

1
3 1

I would like to appreciate the effort of Editorialist @taran_1407 for writing such a lucid editorial! Thanks!

3 Likes

The code is working for n <=3 then also I am getting WA. My solution. Did you find the corner cases where your code was failing??

whats wrong with my stated approach

the only case was with m=1 where ans should be zero instead of one

@taran_1407 Please explain the recurrence relation
I mean how we got that relation

I am not able to validate or reach that recurrence relation by myself even after reading the editorial

AND
Why have we gone uptill only last two elements , why not analyse last three elements and then find recurrence relation

Please explain why the given explanation in valid / correct

1 Like

Hi,
Can anyone find what’s wrong in my approach?
if n<=2 return power(m,n)
if m==1 return 0;
else below logic
power(m,n)—total number of arrays possible of n length and each number between 1 to m.
(n-2)*power(m,n-2)—arrays with atleast one set of 3 consecutive elements equal
ans=power(m,n)-(n-2)*power(m,n-2)
power function was logarithmic only

Thank you for this problem!
It’s the first problem I have faced about matrix exponentiation in the contest, and I will try to learn and practice few more problems.

Hi, can anyone tell me what’s wrong in this solution (just for converting subtask1):

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define MOD 1000000007
ll int dp(ll int i,ll int m,ll int dip[])
{
if(i==0)
return 1;

else if(i==1)
return m%MOD;
else if(i==2)
return (m*m)%MOD;
else
{
if(dip[i]!=0)
return dip[i]%MOD;
else
dip[i] =(dp((i-1),m,dip)*dp((i-2),m,dip)- dp((i-3),m,dip)*dp((i-2),m,dip))%MOD;
}

return dip[i]%MOD;


}

int main(){
int t;
cin>>t;

while(t--)
{


ll int dip[100000];
for(int i=0;i<=100000;i++)
{
dip[i]=0;
}
ll int n,m;
cin>>n>>m;
cout<<dp(n,m,dip)<<endl;
}

return 0;


}

1 Like

https://www.codechef.com/viewsolution/29206915

I just implemented the matrix exponentiation approach.
Why this got TLE’d for both the subtask?

Thank you!

Actually , f(n,1) is not only the no of sequence of length n having last two element different
but also any three element consecutive element in the sequence is not equal .
this type of sequence can be constructed by (n-1 length sequence of same type and filling the nth position in m-1 ways so that any three element consecutive element in the sequence is not equal also f(n,1) can be constructed from f(n,2) type in similar way)
hence f(n,1) = (m-1)*(f(n,1) + f(n,2))
similarly f(n,2) = f(n-1,1).

1 Like

I used the recurrence
dp[i] = (m-1)*(dp[i-1]+dp[i-2])
it is giving TLE in big case…
https://www.codechef.com/viewsolution/29215921

\begin{bmatrix} m^2 - m & m-1 \\ m-1 & 0 \end{bmatrix} * \begin{bmatrix} m-1 & 1 \\ m-1 & 0\end{bmatrix}^{n-2}