DEARRERR - Editorial

PROBLEM LINK:

Practice

Contest

Author: Md Shahid

Tester: Sandeep Singh

Editorialist: Md Shahid

DIFFICULTY:

Medium

PREREQUISITES:

Matrix Exponentiation

PROBLEM:

Given the following two equations:

A(n) = \begin{cases} 2A(n-1) +3A(n-2)+ B(n),& \text{if } n\geq 2\\ 2,& \text{if } n = 1 \\ 1, & \text{if } n = 0 \\ \end{cases}
\text{Where, } B(n)= \begin{cases} 2B(n-1) +3B(n-2),& \text{if } n\geq 2\\ 1,& \text{if } 0 \leq n \leq 1 \\ \end{cases}

You need to find out A(n), where 0 \leq n \leq 10^{18}.

EXPLANATION:

Matrix exponentiation is a technique that is used to solve recursive function in O(D^3logN) time, where D is the dimension of the matrix used. Let’s see how it works.
This can be done in following steps:

1. How to take the exponent of a square matrix in O(D^3logN) time

Suppose there is a square matrix M of dimension D.
M^2 = M x M
M^4 = M^2 x M^2
.
.
.
M^n = M^{n/2} x M^{n/2}
If you observe carefully then you will notice that you are performing logn operation to find the nth power of matrix M.

Pseudocode

Lets say matmul() is used to multiply two matrices and M1 = M.

function expo(M, n): 
    if n == 0 or n == 1 
          return;
    expo(M, n/2);
    matmul(M, M);
    if n % 2 == 1:
       matmul(M, M1);

2. How to convert recursive equation into matrix form

Let’s say you have a simple recursive function,

f(n) = f(n-1) +f(n-2)

Matrix form of it is as follows:

\begin{bmatrix} f(n) \\ f(n-1) \end{bmatrix} \quad = M* \begin{bmatrix} f(n-1) \\ f(n-2) \end{bmatrix} \quad

Now, we have to put M in such a way that it holds.

\begin{bmatrix} f(n) \\ f(n-1) \end{bmatrix} \quad = \begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix} \quad * \begin{bmatrix} f(n-1) \\ f(n-2) \end{bmatrix} \quad

If you multipy matrices in RHS and equate both side then you will get the same recursive equation again.

This is the time to observe the hidden pattern which results in O(D^3logN) solution.
As we know,
\begin{bmatrix} f(n) \\ f(n-1) \end{bmatrix} \quad = M * \begin{bmatrix} f(n-1) \\ f(n-2) \end{bmatrix} \quad ---------(1)

\begin{bmatrix} f(n+1) \\ f(n) \end{bmatrix} \quad = M * \begin{bmatrix} f(n) \\ f(n-1) \end{bmatrix} \quad ---------(2)

Then from (1) and (2)

\begin{bmatrix} f(n+1) \\ f(n) \end{bmatrix} \quad = M * (M * \begin{bmatrix} f(n-1) \\ f(n-2) \end{bmatrix} \quad)

or, \begin{bmatrix} f(n+1) \\ f(n) \end{bmatrix} \quad = M^2 * \begin{bmatrix} f(n-1) \\ f(n-2) \end{bmatrix} \quad

The general form of the above equation is as follows:

\begin{bmatrix} f(n) \\ f(n-1) \end{bmatrix} \quad = M^{n-1} * \begin{bmatrix} f(1) \\ f(0) \end{bmatrix} \quad

If you want to calculate f(N+1) then first calculate M^N then multiply with respective matrix and compare both sides.

Applying the second step we can write the general form in the following way:

\begin{bmatrix} A(n+1) \\ A(n) \\ B(n+2) \\ B(n+1) \end{bmatrix} \quad = M * \begin{bmatrix} A(n) \\ A(n-1) \\ B(n+1) \\ B(n) \end{bmatrix} \quad

Where M = \begin{bmatrix} 2&3&1&0 \\ 1&0&0&0 \\ 0&0&2&3 \\ 0&0&1&0 \end{bmatrix} \quad

or,

\begin{bmatrix} A(n+1) \\ A(n) \\ B(n+2) \\ B(n+1) \end{bmatrix} \quad = M^2 * \begin{bmatrix} A(n-1) \\ A(n-2) \\ B(n) \\ B(n-1) \end{bmatrix} \quad

or,

\begin{bmatrix} A(n+1) \\ A(n) \\ B(n+2) \\ B(n+1) \end{bmatrix} \quad = M^3 * \begin{bmatrix} A(n-2) \\ A(n-3) \\ B(n-1) \\ B(n-2) \end{bmatrix} \quad

.
.
.
or,

\begin{bmatrix} A(n+1) \\ A(n) \\ B(n+2) \\ B(n+1) \end{bmatrix} \quad = M^{n-1} * \begin{bmatrix} A(1) \\ A(0) \\ B(2) \\ B(1) \end{bmatrix} \quad

From above equation we can easily find the value of A(n+1) by calculating just M^{n-1}. The answer can be very large so compute it with modulo 1000000007.

Time Complexity : O(D^3logN)

AUTHOR’S AND EDITORIALIST’S SOLUTIONS:

Author's Solution
#include<bits/stdc++.h>
using namespace std;
#define ll unsigned long long int 
#define vi vector<int> 
#define vvi vector<vector<int>>
#define pii pair<int,int>
#define pll pair<ll, ll>
#define vl vector<ll> 
#define vvl vector<vector<ll>>
#define vpii vector<pii>
#define vpll vector<pll>
#define umap unordered_map
#define uset unordered_set
#define all(c) c.begin(), c.end()
#define maxarr(A) *max_element(A, A+n) 
#define maxvec(v) *max_element(all(v)) 
#define present(map,elem) map.find(elem)!=map.end()
#define lb(v,elem) (lower_bound(all(v),elem) - v.begin())
#define ub(v,elem) (upper_bound(all(v),elem) - v.begin())
#define pb push_back 
#define mp make_pair
#define For(i,a,b) for(int i=a; i<b; ++i) 
#define rep(i,a,b) for(ll i=a; i<b; ++i)
#define mod 1000000007 


ll modmult(ll a, ll b) { 
    ll res = 0;
    a = a % mod; 
    while (b > 0) { 
        if (b % 2 == 1) 
            res = (res + a) % mod; 
        a = (a * 2) % mod; 
        b /= 2; 
    } 
    return res % mod; 
} 

void matmul(ll F[4][4], ll M[4][4]) {
   ll res[4][4];
   rep(i,0,4) {
     rep(j,0,4) {
       res[i][j] = 0;
       rep(k,0,4) {
          res[i][j] += modmult(F[i][k],M[k][j]);
          res[i][j] %= mod;
       }
     }
   }
   rep(i,0,4) {
     rep(j,0,4) 
       F[i][j] = res[i][j]%mod;
   }
}

void expo(ll F[4][4], ll n) {
  if(n==0 || n==1)
     return;

  ll M[4][4] = {{2,3,1,0},
                {1,0,0,0},
                {0,0,2,3},
                {0,0,1,0}};
  expo(F,n/2);
  matmul(F,F);
  if(n%2==1)
     matmul(F,M);
}

ll recurse(ll n) {

  ll F[4][4] = {{2,3,1,0},
                {1,0,0,0},
                {0,0,2,3},
                {0,0,1,0}};
  expo(F,n-1);
  ll ans = ((2*F[0][0])%mod + (F[0][1])%mod + (5*F[0][2])%mod + (F[0][3])%mod)%mod;
  
  return ans;
}


int main() { 
  ios_base::sync_with_stdio(false); 
  cin.tie(NULL); 
  cout.tie(NULL);
  #ifdef Judge
    freopen("input0.in","r",stdin);
    freopen("output0.in","w",stdout);
  #endif
  int t;
  cin>>t;
  while(t--){
    ll n;
    cin>>n;
    if(n==0)
       cout<<1<<endl;
    else if(n==1)
       cout<<2<<endl;
    else {
       cout<<recurse(n)<<endl;
    }
  } 
  return 0;
} 

Tester's Solution
//PKMKB
#include <bits/stdc++.h>
#define ll long long int
#define max(a,b) (a>b?a:b)
#define min(a,b) (a<b?a:b)
#define scanarr(a,b,c) for( i=b;i<c;i++)cin>>a[i]
#define showarr(a,b,c) for( i=b;i<c;i++)cout<<a[i]<<' '
#define ln cout<<'\n'
#define FAST ios_base::sync_with_stdio(false);cin.tie(NULL);
#define mod 1000000007
#define MAX 100005
using namespace std;
////////////////////////////////////////////////////////////////CODE STARTS HERE////////////////////////////////////////////////////////////////
void multiply(ll a[4][4],ll b[4][4]){
    ll mul[4][4];
    int i,j,k;
    for( i=0;i<4;i++){
        for( j=0;j<4;j++){
            mul[i][j]=0;
            for( k=0;k<4;k++){
                mul[i][j]+=((a[i][k]%mod)*(b[k][j]%mod))%mod;
            }
        }
    }
    for(i=0;i<4;i++)
        for(j=0;j<4;j++)
            a[i][j]=mul[i][j];
}
 
ll power(ll F[4][4],ll n){
    ll M[4][4]={{2,3,1,0},{1,0,0,0},{0,0,2,3},{0,0,1,0}};
    if(n==1)
        return (2*F[0][0]+F[0][1]+5*F[0][2]+F[0][3])%mod;
    power(F,n/2);
    multiply(F,F);
    if(n&1)
        multiply(F,M);
    return ((2*F[0][0])%mod+F[0][1]+(5*F[0][2])%mod+F[0][3])%mod;    
 
}
ll findnthterm(ll n){
    if(n==0)
        return 1;
    if(n==1)
        return 2;
    ll F[4][4]={{2,3,1,0},{1,0,0,0},{0,0,2,3},{0,0,1,0}};
    return power(F,n-1);
}
void solve(){
    int i;
    ll n;
    cin>>n;
    cout<<findnthterm(n)<<endl;
   
}
int main(){
 
    int t;
    cin>>t;
    while(t--)
        solve();
} 
1 Like