CARR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2

Setter: Yusuf Kharodawala
Tester: Teja Vardhan Reddy
Editorialist: Taranpreet Singh

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Dynamic Programming, Matrix Exponentiation.

PROBLEM:

Count the number of N length arrays such that all elements lie in the range [1, M] and no three consecutive elements are equal.

QUICK EXPLANATION

  • We can count the number of i length sequences such that last two elements are different, denoted by f(i, 1), and the number of i length sequences such that last two elements are same, denoted by f(i, 2).
  • In order to compute the number of i length arrays of the first type, we need the number of valid arrays of length i-1 and for each valid array, we have M-1 choices for the last element, giving f(i, 1) = (M-1)*(f(i-1, 1)+f(i-1, 2))
  • It is easy to notice that f(i, 2) is just all i-1 length sequences with last two elements different, and last element is appended again, thus giving f(i, 2) = f(i-1, 1)
  • We can write above equations as matrix and use matrix exponentiation to compute the number of valid sequences of length n, given by f(n, 1)+f(n, 2)

EXPLANATION

Let us consider the base cases first. For N = 1, we simply have M different arrays and all are valid.

Now, suppose we are constructing the array and want to decide element A_p. There can be two cases.

  • Case 1: p > 2 and A_{p-1} = A_{p-2}
  • Case 2: p \leq 2 or A_{p-1} \neq A_{p-2}

Let f(n, 2) denote the number of arrays of length n such that last 2 elements are same, and f(n, 1) denote the number of arrays of length n such that last 2 elements do not match.

We can see, that the final number of arrays of length n is given by f(n, 1)+f(n, 2) So we need to compute f(n, 1) and f(n, 2)

Let’s try computing f(i, t) from f(i-1, t)

For f(i, 1), we just need all valid arrays of length i-1 and fix A_i to be different from the last element, which can be done in M-1 ways. Hence, we can write f(i, 1) = (M-1)*(f(i-1, 1)+f(i-1, 2))

For f(i, 2), we need number of i-1 length arrays such that they have last element not equal to second last element, which is given by f(i-1, 1). Now, for each of these i-1 length arrays, we only have one choice, to append the same element as the last element of these i-1 length arrays. This gives f(i, 2) = f(i-1, 1)

Hence, using these recurrences, we can compute f(n, 1)+f(n, 2) in O(N) time, which is enough for the first subtask, but not for second subtask.

In order to speed up, let us write the recurrence in matrix form.

\begin{bmatrix} a & b \\ c & d \end{bmatrix} * \begin{bmatrix} f(i-1, 1)\\ f(i-1, 2) \end{bmatrix} = \begin{bmatrix} f(i, 1)\\ f(i, 2) \end{bmatrix}

Expanding the recurrence and comparing with the recurrence found above, we can find the values of a, b, c and d

Now, Let P= \begin{bmatrix} a & b \\ c & d \end{bmatrix} and F_i = \begin{bmatrix} f(i, 1)\\ f(i, 2) \end{bmatrix} We can write F_{i+1} = P*F_i = P*(P*F_{i-1} = P^2*F_{i-1} and so on.

Generalizing above, we get F_{N} = P^{N-1}*F_1 which we can compute with matrix exponentiation fast enough, solving the problem.

Cannot figure out matrices

P = \begin{bmatrix} M-1 & M-1 \\ 1 & 0 \end{bmatrix} F_1 = \begin{bmatrix} M \\ 0 \end{bmatrix}

Tutorial to read on Matrix Exponentiation here, here or this.
Must try problem here.

TIME COMPLEXITY

The time complexity is O(2^3*log(N)) per test case.

SOLUTIONS:

Setter's Solution
import java.util.*;
import java.io.*;
class Solution{

	final static long mod = (long)1e9 + 7;

	public static long[] matExp(long[][] x, long[] init, long y, long p){
	    int n = x.length;
	    long[][] tem = new long[n][], cur = new long[n][];
	    long[] res = init.clone(), t = init.clone();
	    long sum;
	    for(int i = 0; i < n; i++){
	        tem[i] = x[i].clone();
	        cur[i] = x[i].clone();
	    }
	    while (y > 0) {
	        if((y & 1) == 1){
	            for(int i = 0; i < n; i++){
	                sum = 0;
	                for(int j = 0; j < n; j++){
	                    sum += t[j] * cur[j][i];
	                    sum %= p;
	                }
	                res[i] = sum;
	            }
	            t = res.clone();
	        }
	        y >>= 1;
	        for(int i = 0; i < n; i++){
	            for(int j = 0; j < n; j++){
	                sum = 0;
	                for(int k = 0; k < n; k++){
	                    sum += tem[i][k] * tem[k][j];
	                    sum %= p;
	                }
	                cur[i][j] = sum;
	            }
	        }
	        for(int i = 0; i < n; i++) tem[i] = cur[i].clone();
	    }
	    return res;
	}

	public static void main(String[] args){	
	
		FastReader s = new FastReader(System.in);
		PrintWriter w = new PrintWriter(System.out);

		int t = s.nextInt();

		while (t-- > 0){
			long n = s.nextLong(), m = s.nextLong() % mod;
			if(n==1) {
				w.print(m + "\n");
				continue;
			}
			long[] init = { (m * m - m) % mod, m };
			m--;
			if(m < 0) m += mod;
			long[][] transform = { { m, 1 }, { m, 0 } };
			long[] res = matExp(transform, init, n - 2, mod);
			w.print((res[0] + res[1]) % mod + "\n");
		}

		w.close();
	
	}

	static class FastReader { 
	    BufferedReader br; 
	    StringTokenizer st; 
  
	    public FastReader(InputStream i) { br = new BufferedReader(new InputStreamReader(i)); } 
  
	    String next() { 
	        while (st == null || !st.hasMoreElements()) { 
	            try{ st = new StringTokenizer(br.readLine()); } 
	            catch (IOException  e) { e.printStackTrace(); } 
	        } 
	        return st.nextToken(); 
	    } 
  
	    int nextInt() { return Integer.parseInt(next()); }

		long nextLong() { return Long.parseLong(next()); }

	} 
}
Tester's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp> 
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val

using namespace std;
using namespace __gnu_pbds;

#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout) 
#define primeDEN 727999983
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

// find_by_order()  // order_of_key
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;
#define int ll
int c[2][2];
int mult(int a[2][2],int b[2][2]){
	int i,j,k;
	rep(i,2){
		rep(j,2){
			c[i][j]=0;
		}
	}
	rep(i,2){
		rep(j,2){
			rep(k,2){
				c[i][k]+=a[i][j]*b[j][k];
				c[i][k]%=mod;
			}
		}
	}
	rep(i,2){
		rep(j,2){
			a[i][j]=c[i][j];	
		}
	}
}
int a[2][2],res[2][2];
main(){
	std::ios::sync_with_stdio(false); cin.tie(NULL);
	int t;
	cin>>t;
	while(t--){
		int n,m;
		cin>>n>>m;
		int val=m;
		val%=mod;
		int val2=val*val;
		val2%=mod;
		m--;
		m%=mod;
		n--;
		a[0][0]=m;
		a[0][1]=m;
		a[1][0]=1;
		a[1][1]=0;

		res[1][1]=1;
		res[0][0]=1;
		res[1][0]=0;
		res[0][1]=0;
		while(n){
			if(n%2){
				mult(res,a);
			}
			mult(a,a);
			n/=2;
		}
		//cout<<res[1][0]<< " "<<res[1][1]<<endl;
		val2*=res[1][0];
		val2%=mod;
		val2+=res[1][1]*val;
		val2%=mod;
		cout<<val2<<endl;
	}
	
	return 0;   
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class CARR{
	//SOLUTION BEGIN
	void pre() throws Exception{}
	void solve(int TC) throws Exception{
	    long n = nl(), m = nl();
	    int[] f = new int[]{(int)(m%MOD), 0};
	    int[][] p = new int[][]{
	            {(int)((m+MOD-1)%MOD), (int)((m+MOD-1)%MOD)},
	            {1, 0}
	    };
	    int[] o = pow(p, f, n-1);
	    pn(((long)o[0]+o[1])%MOD);
	}
	int[] mul(int[][] a, int[] v){
	    int m = a.length;
	    int n = v.length;
	    int[] w = new int[m];
	    for(int i = 0;i < m;i++){
	        long sum = 0;
	        for(int k = 0;k < n;k++){
	            sum += (long)a[i][k] * v[k];
	            if(sum >= BIG)sum -= BIG;
	        }
	        w[i] = (int)(sum % MOD);
	    }
	    return w;
	}
	public int[] pow(int[][] A, int[] v, long e){
	    for(int i = 0;i < v.length;i++){
	        if(v[i] >= MOD)v[i] %= MOD;
	    }
	    int[][] MUL = A;
	    for(;e > 0;e>>>=1) {
	        if((e&1)==1)v = mul(MUL, v);
	        MUL = mul(MUL, MUL);
	    }
	    return v;
	}
	int[][] mul(int[][] a, int[][] b){
	    int n = a.length;
	    int[][] c = new int[n][n];
	    for(int i = 0; i< n; i++){
	        long[] sum = new long[n];
	        for(int k = 0; k< n; k++){
	            for(int j = 0; j< n; j++){
	                sum[j] += (long)a[i][k]*b[k][j];
	                if(sum[j] >= BIG)sum[j] -= BIG;
	            }
	        }
	        for(int j = 0; j< n; j++){
	            c[i][j] = (int)(sum[j]%MOD);
	        }
	    }
	    return c;
	}
	int MOD = (int)1e9+7;
	long BIG = 8*(long)MOD*MOD;
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	DecimalFormat df = new DecimalFormat("0.00000000000");
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new CARR().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

15 Likes

I have also used same approach but my recurrence relation is different,
f(i) = f(i-1)m - f(i-3)(m-1) for i>3
and my matrix is:
m 0 1-m
1 0 0
0 1 0

My Submission
I am getting tle on second subtask.
If Anyone could help me with this…
Thanks in advance.

1 Like

@taran_adm Can you submit solutions that are a bit easy to understand?
They are written very badly and I’m facing difficulty in making out what’ happening

1 Like

I dont know why my apporach is wrong

Please find my mistake

Here’s my approach :

i used constructive way of solving the problem by generalizing a recurrence relation into a formula
i.e given ‘n’ and ‘m’

we know that number of sequences for n=1 is m
number of sequences for n=2 : = m * m
now with n=3 , number of sequences possible is m * m * m - m
‘- m’ bcoz there will be ‘m’ sequences in which all three places will have same elements

let the answer for a particular n be dp[n]

Now i know the answer for n=3 , i.e all sequences such that no three adjacent elements are same

So know
number of sequences with n=4 is given by
==> (number of sequences with n=3 with i.e ( dp[3] ) * ( m ) - m
here ‘-m’ has been done bcoz there will be ‘m’ occurances of sequence with length 3 such that elements at place 2 and 3 will have the same value

so now generalizing for any n
dp[i]=dp[i-1]*m -m

and this gives for a particular n>=2 :
dp[n]= m^n + m * ( m^(n-2) - 1 ) / (m-1 )

it gives correct answer for the test cases given in question
I dont know why this gave WA

my soln link CodeChef: Practical coding for everyone

SOMEBODY PLEASE FIND FLAW IN MY APPROACH
@taran_1407 Please help me here
Thanks in advance

1 Like

@ashishpathak28 Can you please explain your recurrence relation?

Thanks in advance!

3 Likes

I suppose reading this would help. These codes are mostly standard implementation of Matrix Exponentiaion

1 Like

your code gives wrong answer for this case
1
3 1

I would like to appreciate the effort of Editorialist @taran_1407 for writing such a lucid editorial! Thanks!

3 Likes

The code is working for n <=3 then also I am getting WA. My solution. Did you find the corner cases where your code was failing??

whats wrong with my stated approach :roll_eyes::roll_eyes:

the only case was with m=1 where ans should be zero instead of one

@taran_1407 Please explain the recurrence relation
I mean how we got that relation

I am not able to validate or reach that recurrence relation by myself even after reading the editorial

AND
Why have we gone uptill only last two elements , why not analyse last three elements and then find recurrence relation

Please explain why the given explanation in valid / correct

Please help me , am a noob in DP :disappointed::disappointed::pensive:

1 Like

Hi,
Can anyone find what’s wrong in my approach?
if n<=2 return power(m,n)
if m==1 return 0;
else below logic
power(m,n)—total number of arrays possible of n length and each number between 1 to m.
(n-2)*power(m,n-2)—arrays with atleast one set of 3 consecutive elements equal
ans=power(m,n)-(n-2)*power(m,n-2)
power function was logarithmic only

Thank you for this problem!
It’s the first problem I have faced about matrix exponentiation in the contest, and I will try to learn and practice few more problems.

Hi, can anyone tell me what’s wrong in this solution (just for converting subtask1):

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define MOD 1000000007
ll int dp(ll int i,ll int m,ll int dip[])
{
if(i==0)
return 1;

else if(i==1)
return m%MOD;
else if(i==2)
return (m*m)%MOD;
else
{
if(dip[i]!=0)
return dip[i]%MOD;
else
dip[i] =(dp((i-1),m,dip)*dp((i-2),m,dip)- dp((i-3),m,dip)*dp((i-2),m,dip))%MOD;
}

return dip[i]%MOD;

}

int main(){
int t;
cin>>t;

while(t--)
{

ll int dip[100000];
for(int i=0;i<=100000;i++)
{
dip[i]=0;
}
ll int n,m;
cin>>n>>m;
cout<<dp(n,m,dip)<<endl;
}

return 0;

}

1 Like

https://www.codechef.com/viewsolution/29206915

I just implemented the matrix exponentiation approach.
Why this got TLE’d for both the subtask?

Thank you!

Actually , f(n,1) is not only the no of sequence of length n having last two element different
but also any three element consecutive element in the sequence is not equal .
this type of sequence can be constructed by (n-1 length sequence of same type and filling the nth position in m-1 ways so that any three element consecutive element in the sequence is not equal also f(n,1) can be constructed from f(n,2) type in similar way)
hence f(n,1) = (m-1)*(f(n,1) + f(n,2))
similarly f(n,2) = f(n-1,1).

1 Like

I used the recurrence
dp[i] = (m-1)*(dp[i-1]+dp[i-2])
it is giving TLE in big case…
https://www.codechef.com/viewsolution/29215921

\begin{bmatrix} m^2 - m & m-1 \\ m-1 & 0 \end{bmatrix} * \begin{bmatrix} m-1 & 1 \\ m-1 & 0\end{bmatrix}^{n-2}
This is your answer. This works because …
\begin{bmatrix} dp[i] & dp[i-1] \\ dp[i-1] & dp[i-2] \end{bmatrix} * \begin{bmatrix} m-1 & 1 \\ m-1 & 0\end{bmatrix} = \begin{bmatrix} dp[i] * (m-1) + dp[i-1] *(m-1) & dp[i]*1 + dp[i-1] * 0 \\ dp[i-1]*(m-1) +dp[i-2] * (m-1) & dp[i-1]\end{bmatrix} = \begin{bmatrix} dp[i+1] & dp[i] \\ dp[i] & dp[i-1]\end{bmatrix}
this forms a recursive. Matrix exponentiation can be done in log n time using the same method as normal modular powers.

1 Like