PROBLEM LINK:
Author: Shivam Pradhan
Tester: Sankalp Gupta
DIFFICULTY:
MEDIUM
PREREQUISITES:
Math, Matrix Exponentiation
PROBLEM:
There are K villages numbered from 1 to K in the countryside. You live in a village numbered 1. All villages are interconnected. The cost of traveling from one village to another is 1
coin. You have exactly N coins with you. Your task is to find in how many ways you can spend exactly N coins so that you start from your home village 1and return to your home village in the end.As the answer can be large, print it modulo 10^9+7.
QUICK EXPLANATION:
We will find the O(n) DP solution to the problem and then we will optimize it using the Matrix Exponentiation Concept.
EXPLANATION:
Let’s find an O(n) solution first to the given problem.
We can see that there are 2 types of villages: home village and other villages. All other villages are identical.
Let’s denote A(n) as the answer to return to home village after n steps starting from home village.
Let’s denote B(n) as the answer to return to other village after n steps starting from home village.
Initially for n==0, A(0)=1 and B(0)=0 because in zero steps we can’t go from home village to other village.
We can clearly say that transition A(n)=(k-1)*B(n-1) and B(n)=(k-2)*B(n-1)+A(n-1) holds.
As there are (k-1) different other villages from where you can return to home village.
And we can return to other village from (k-2) other villages and one home village.
To optimize above O(n) solution we will use Matrix exponentiation.
Now representing the above set of equation in term of matrix
P = \begin{bmatrix}
0 & K-1\\
1 & K-2
\end{bmatrix}
A0 = \begin{bmatrix} 1\\ 0 \end{bmatrix}
R = \begin{bmatrix} A(n)\\ B(n) \end{bmatrix}
To get the answer for n steps
we can say (P^n)A0 = R.
First entry of R is our required answer.
Now all left is to calculate Matrix P raise to the power n in O(log n) time.
We can do so by using recursive relation
P^n = P^x* P^x (if n is even)
= P^x * P^x * P (if n is odd )
where x = n/2.
ALTERNATE EXPLANATION:
We can obtain a simple formula for the above solution : ( (k-1)^n + (k-1)*(-1)^n )/k.
SOLUTIONS:
Setter's Solution
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define mod 1000000007
void multiply(ll mat[2][2],ll temp[2][2]){
ll x=(mat[0][0]*temp[0][0] + mat[0][1]*temp[1][0])%mod;
ll y=(mat[0][0]*temp[0][1] + mat[0][1]*temp[1][1])%mod;
ll z=(mat[1][0]*temp[0][0] + mat[1][1]*temp[1][0])%mod;
ll w=(mat[1][0]*temp[0][1] + mat[1][1]*temp[1][1])%mod;
mat[0][0]=x;
mat[0][1]=y;
mat[1][0]=z;
mat[1][1]=w;
}
void power(ll mat[2][2],int n,int k){
ll a[2][2]={0,k-1,1,k-2};
if(n==0 || n==1)return;
power(mat,n/2,k);
multiply(mat,mat);
if(n&1)multiply(mat,a);
}
int main(){
int t; cin>>t;
while(t--){
int n,k; cin>>n>>k;
ll mat[2][2]={0,k-1,1,k-2};
power(mat,n,k);
cout<<mat[0][0]<<"\n";
}
}
Tester's Solution
#include<bits/stdc++.h>
#define ll long long int
#define ull unsigned long long int
#define vi vector<int>
#define vll vector<ll>
#define vvi vector<vi>
#define vvl vector<vll>
#define pb push_back
#define mp make_pair
#define all(v) v.begin(), v.end()
#define pii pair<int,int>
#define pll pair<ll,ll>
#define vpii vector<pii >
#define vpll vector<pll >
#define ff first
#define ss second
#define PI 3.14159265358979323846
#define fastio ios_base::sync_with_stdio(false) , cin.tie(NULL) ,cout.tie(NULL)
ll power(ll a,ll b){ ll f=1; while(b>0){ if(b&1) f*=a; a*=a; b>>=1;} return f; }
ll power(ll a,ll b,ll m){ ll f=1; while(b>0){ if(b&1) f=(f*a)%m; a=(a*a)%m; b>>=1;} return f;}
bool pp(int a,int b) {return a>b;}
using namespace std;
ll m=1e9+7;
ll t[2][2];
void mul(ll t[2][2],ll f[2][2]){
ll x = ((f[0][0] * t[0][0])%m + (f[0][1] * t[1][0])%m)%m;
ll y = ((f[0][0] * t[0][1])%m + (f[0][1] * t[1][1])%m)%m;
ll z = ((f[1][0] * t[0][0])%m + (f[1][1] * t[1][0])%m)%m;
ll w = ((f[1][0] * t[0][1])%m + (f[1][1] * t[1][1])%m)%m;
f[0][0] = x;
f[0][1] = y;
f[1][0] = z;
f[1][1] = w;
}
int cal(int n){
ll f[2][2] = {{1,0},{0,1}};
while(n>0){
if(n&1)
mul(t,f);
mul(t,t);
n>>=1;
}
return f[0][0];
}
void solve(){
ll n,k;
cin>>n>>k;
assert(n>0&&n<=1e9);
assert(k>1&&k<=1e9);
t[0][0] = 0;
t[0][1] = k-1;
t[1][0] = 1;
t[1][1] = k-2;
cout<<cal(n)<<"\n";
}
int main()
{
fastio;
int t;
cin>>t;
assert(t>0&&t<=1e5);
while(t--){
solve();
}
return 0;
}