PROBLEM LINK:
Setter: Kasra Mazaheri
Tester: Arshia
Editorialist: Taranpreet Singh
DIFFICULTY:
Medium-Hard
PREREQUISITES:
Linear Recurrences, Matrix Exponentiation
PROBLEM:
Given a set S of size K, an integer X and a sequence A of length M which defines an infinite sequence W such that W_{i} = A_{(i-1)\%M+1}.
Now, defining a function F(X) as
- F(0) = X
- if i \in S, F(i) = 0
- otherwise F(i) = \big( \displaystyle\sum_{j = 1}^{i} F(i-j)*W_j \big) \% (10^9+7)
Find F(N)
QUICK EXPLANATION
- Since the sequence W is cyclic, we can group the F(i) which shall always be multiplied by the same W_i. Hence, this way, F(i) is only dependent upon the previous M values of F(i)
- Let’s define sequence T(i) = \displaystyle\sum_{j = i\%M}^{i} F(i) which is the required grouping. Now, For this sequence, we can use Matrix Exponentiation for computing T(i). F(i) can be computed for any i using previous M values of T(i).
- For handling F(i) where i \in S, we can sort values in S and use matrix Exponentiation to compute till T(s_i-1) and then adjust for F(i) = 0.
- As a speedup, it is better to pre-compute binary powers of transition matrix to reduce complexity.
EXPLANATION
First of all, Assume Set S is empty. So we are given initial element, the sequence A and we have to find F(N)
we can see that the sequence W repeats itself. We have F(i) = \big( \displaystyle\sum_{j = 1}^{i} F(i-j)*W_j \big). But, W_j doesn’t take more than M different values for different values of j. Specifically, due to the way sequence W is defined, we have W_j = W_{j-M} = \ldots W_{j-x*M} = A_{j-x*M} As long as j-x*M > 0 and x is maximum.
So, We can group all j such that F(j) gets multiplied by same A_p. We can now rewrite the summation as
F(i) = \sum_{p = 1}^{M} A_p*\sum_{x = 0}^{\lfloor(i-p)/M\rfloor} F(p+x*M)
Now, let’s define T(i) = \sum_{x = 0}^{\lfloor i/M \rfloor} F(i-x*M) + T(i-M)
Using this definition, we can write F(i) = \sum_{p = 1}^{M} A_p*T(i-p)
So, we have a linear recurrence depending upon the last M terms. We can easily apply Matrix Exponentiation here to obtain T(N) in O(M^3*log(N)) time. Also, by our definition of T(N), it is easy to see that $F(N) = T(N)-T(N-M) if N-M \geq 0, otherwise T(N) = F(N)
The linear recurrence for T(N) looks like, assuming M = 4
\begin{bmatrix} W_{1} & W_{2} & W_3 & W_{4}+1 \\ 1 & 0 & 0 & 0\\ 0 & 1 & 0 & 0\\ 0 & 0 & 1 & 0\\ \end{bmatrix} \begin{bmatrix} T(N-1)\\ T(N-2)\\ T(N-3)\\ T(N-4)\\ \end{bmatrix} = \begin{bmatrix} T(N)\\ T(N-1)\\ T(N-2)\\ T(N-3)\\ \end{bmatrix}
Now, returning to our original problem where set S can be non-empty, let’s sort this set. Now, We have S(0) = T(0) = X. Now, Assuming we have calculated till T(P), We can calculate T(S_i-1) using above matrix exponentiation where S_i is the smallest value greater than P present in set S.
We now have T(S_i-1) calculated. Now, since F(S_i) = 0, we have T(S_i) = T(S_i-M) So, our transition matrix here become
\begin{bmatrix}
0 & 0 & 0 &1 \\
1 & 0 & 0 & 0\\
0 & 1 & 0 & 0\\
0 & 0 & 1 & 0\\
\end{bmatrix}
So, we can multiply it with our current answer matrix and set P = S_i
Lastly, we need to multiply our current answer matrix by N-P power of the first matrix to calculate T(N)
We can repeat same process to calculate T(N-M) and the difference of two is the required answer.
As a speedup, since we have to use matrix exponentiation multiple times over the same transition matrix, we can precompute binary powers of the Transition matrix, reducing time complexity by a factor of M here.
TIME COMPLEXITY
The time complexity is O(K*log(K) + M^3*log(N)+K*M^2*log(N))
SOLUTIONS:
Setter's Solution
// In The Name Of The Queen
#include<bits/stdc++.h>
using namespace std;
const int N = 202, LG = 61, Mod = 1e9 + 7;
struct Matrix
{
int n, m, A[N][N];
inline Matrix(int _n = 0, int _m = 0) : n(_n), m(_m) {memset(A, 0, sizeof(A));}
inline Matrix operator * (Matrix &X)
{
Matrix R(n, X.m);
for (int i = 0; i < n; i ++)
for (int k = 0; k < m; k ++)
for (int j = 0; j < X.m; j ++)
R[i][j] = (R[i][j] + 1LL * A[i][k] * X[k][j]) % Mod;
return (R);
}
inline Matrix operator ^ (long long Pw)
{
Matrix R(n, n), T = * this;
for (int i = 0; i < n; i ++)
R[i][i] = 1;
for (; Pw; Pw >>= 1, T = T * T)
if (Pw & 1)
R = R * T;
return (R);
}
inline int * operator [] (int i)
{
return (A[i]);
}
};
int m, k, X, W[N];
long long n, S[N];
int main()
{
scanf("%d%d%d%lld", &X, &k, &m, &n);
for (int i = 1; i <= k; i ++)
scanf("%lld", &S[i]);
for (int i = 1; i <= m; i ++)
scanf("%d", &W[i]);
sort(S + 1, S + k + 1);
if (n == 0)
return !printf("%lld\n", X);
if (S[k] == n)
return !printf("0\n");
Matrix A(1, m), M[LG];
for (int i = 0; i < LG; i ++)
M[i] = Matrix(m, m);
A[0][m - 1] = X;
for (int i = 0; i < m - 1; i ++)
M[0][i + 1][i] = 1;
for (int i = 0; i < m; i ++)
M[0][i][m - 1] = W[m - i];
M[0][0][m - 1] ++;
for (int i = 1; i < LG; i ++)
M[i] = M[i - 1] * M[i - 1];
for (int i = 1; i <= k; i ++)
{
for (int b = 0; b < LG; b ++)
if ((S[i] - S[i - 1] - 1) >> b & 1LL)
A = A * M[b];
int temp = A[0][0];
for (int j = 1; j < m; j ++)
A[0][j - 1] = A[0][j];
A[0][m - 1] = temp;
}
for (int b = 0; b < LG; b ++)
if ((n - S[k] - 1) >> b & 1LL)
A = A * M[b];
int Fn = 0;
for (int i = 0; i < m; i ++)
Fn = (Fn + 1LL * A[0][i] * W[m - i]) % Mod;
return !printf("%d\n", Fn);
}
Tester's Solution
#include <algorithm>
#include <bitset>
#include <complex>
#include <deque>
#include <exception>
#include <fstream>
#include <functional>
#include <iomanip>
#include <ios>
#include <iosfwd>
#include <iostream>
#include <istream>
#include <iterator>
#include <limits>
#include <list>
#include <locale>
#include <map>
#include <memory>
#include <new>
#include <numeric>
#include <ostream>
#include <queue>
#include <set>
#include <sstream>
#include <stack>
#include <stdexcept>
#include <streambuf>
#include <string>
#include <typeinfo>
#include <utility>
#include <valarray>
#include <vector>
#if __cplusplus >= 201103L
#include <array>
#include <atomic>
#include <chrono>
#include <condition_variable>
#include <forward_list>
#include <future>
#include <initializer_list>
#include <mutex>
#include <random>
#include <ratio>
#include <regex>
#include <scoped_allocator>
#include <system_error>
#include <thread>
#include <tuple>
#include <typeindex>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#endif
int gcd(int a, int b) {return b == 0 ? a : gcd(b, a % b);}
using namespace :: std;
//=======================================================================//
#include <iostream>
#include <algorithm>
#include <string>
#include <assert.h>
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
assert(cnt>0);
if(is_neg){
x= -x;
}
assert(l<=x && x<=r);
return x;
} else {
assert(false);
}
}
}
string readString(int l,int r,char endd){
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
return readString(l,r,'\n');
}
string readStringSp(int l,int r){
return readString(l,r,' ');
}
//=======================================================================//
#define ll long long
#define pb push_back
#define ld long double
#define mp make_pair
#define F first
#define S second
#define pii pair<ll,ll>
using namespace :: std;
const ll maxn=202;
const ll mod=1e9+7;
const ll inf=1e18+9;
ll a[maxn];
class M{
public:
int n,m;
vector<vector<int> > a;
M(int n=0,int m=0){
this->n=n;
this->m=m;
vector<int> f;
f.resize(m);
fill(f.begin(),f.end(),0);
a.resize(n);
fill(a.begin(),a.end(),f);
}
M zarb(const M &b){
if(this->m!=b.n){
cout<<"RIDI";
exit(0);
}
M ans(this->n,b.m);
for(int i=0;i<this->n;i++){
for(int k=0;k<this->m;k++){// a.m=b.n
for(int j=0;j<b.m;j++){
ans.a[i][j]=(ans.a[i][j]+(ll)this->a[i][k]*b.a[k][j])%mod;
}
}
}
return ans;
}
};
M pre[61];
M jam(M a,const M &b){
for(ll i=0;i<a.n;i++){
for(ll j=0;j<a.m;j++){
a.a[i][j]+=b.a[i][j];
if(a.a[i][j]>=mod)a.a[i][j]-=mod;
}
}
return a;
}
M tavan(M a,ll n){
if(a.n!=a.m){
cout<<"RIDI";
exit(0);
}
M ans(a.n,a.n);
for(ll i=0;i<a.n;i++)ans.a[i][i]=1;
while(n){
if(n&1){
ans=ans.zarb(a);
}
n>>=1;
a=a.zarb(a);
}
return ans;
}
M zarbfast(M a,ll x){
for(ll i=0;i<61;i++){
if((x>>i)&1){
a=a.zarb(pre[i]);
}
}
return a;
}
int main(){
ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
ll x,k,m,n;
x=readIntSp(0,mod-1);
k=readIntSp(0,200);
m=readIntSp(0,200);
n=readIntLn(0,(ll)1e18);
if(n==0){
cout<<x;
return 0;
}
vector<ll> vec;
for(ll i=0;i<k;i++){
ll x;
if(i<k-1){
x=readIntSp(1,n);
}
else{
x=readIntLn(1,n);
}
if(x==n){
cout<<0<<endl;
exit(0);
}
vec.pb(x);
}
sort(vec.begin(),vec.end());
for(ll i=0;i<m;i++){
if(i<m-1){
a[i]=readIntSp(0,mod-1);
}
else{
a[i]=readIntLn(0,mod-1);
}
}
M base(m,m);
M base2(m,m);
for(ll i=0;i<m;i++){
base.a[i][m-1]+=a[m-i-1];
base.a[(i+1)%m][i]++;
base2.a[(i+1)%m][i]++;
}
M avalie(1,m);
avalie.a[0][m-1]=x;
pre[0]=base;
for(ll i=1;i<61;i++){
pre[i]=pre[i-1].zarb(pre[i-1]);
}
while((ll)vec.size() && vec.back()>=n){
vec.pop_back();
}
ll NOWW=0;
for(ll i=0;i<(ll)vec.size();i++){
avalie=zarbfast(avalie,vec[i]-NOWW-1);
avalie=avalie.zarb(base2);
NOWW=vec[i];
}
avalie=zarbfast(avalie, n-1-NOWW);
ll ans=0;
for(ll i=0;i<m;i++)
ans=(ans+(ll)avalie.a[0][i]*a[m-1-i])%mod;
cout<<ans<<endl;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class DFNC{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
int x = ni();
int k = ni(), m = ni();
long n = nl();
long[] s = new long[k];
int[] w = new int[m];
for(int i = 0; i< k; i++)s[i] = nl();
for(int i = 0; i< m; i++)w[i] = (int)(nl()%mod);
Arrays.sort(s);
if(n == 0 || (k>0 && s[k-1] == n)){
if(n == 0)pn(x);
else pn(0);
return;
}
int[][] M = new int[m][m], M0 = new int[m][m];
for(int i = 1; i< m; i++){
M[i][i-1] = 1;
M0[i][i-1] = 1;
}
for(int i = 0; i< m; i++)M[0][i] = w[i];
M[0][m-1]++;
M0[0][m-1]++;
int[][][] A = generateP2(M, 60);
long ans = 0;
int[] v = new int[m];
v[0] = x;
long pre = 0;
for(int i = 0; i< k; i++){
if(s[i]-pre-1 > 0)v = pow(A, v, s[i]-pre-1);
if(s[i]-pre > 0)v = mul(M0, v);
pre = s[i];
}
v = pow(A, v, n-pre-1);
for(int i = 0; i< m; i++){
ans += (long)v[i]*w[i];
if(ans >= BIG)ans -= BIG;
}
ans %= mod;
pn(ans);
}
//Shamelessly copied template
///////// begin
public static final int mod = 1000000007;
public static final long m2 = (long)mod*mod;
public static final long BIG = 8L*m2;
// A^e*v
public static int[] pow(int[][] A, int[] v, long e)
{
for(int i = 0;i < v.length;i++){
if(v[i] >= mod)v[i] %= mod;
}
int[][] MUL = A;
for(;e > 0;e>>>=1) {
if((e&1)==1)v = mul(MUL, v);
MUL = p2(MUL);
}
return v;
}
// int matrix*int vector
public static int[] mul(int[][] A, int[] v)
{
int m = A.length;
int n = v.length;
int[] w = new int[m];
for(int i = 0;i < m;i++){
long sum = 0;
for(int k = 0;k < n;k++){
sum += (long)A[i][k] * v[k];
if(sum >= BIG)sum -= BIG;
}
w[i] = (int)(sum % mod);
}
return w;
}
// int matrix^2 (be careful about negative value)
public static int[][] p2(int[][] A)
{
int n = A.length;
int[][] C = new int[n][n];
for(int i = 0;i < n;i++){
long[] sum = new long[n];
for(int k = 0;k < n;k++){
for(int j = 0;j < n;j++){
sum[j] += (long)A[i][k] * A[k][j];
if(sum[j] >= BIG)sum[j] -= BIG;
}
}
for(int j = 0;j < n;j++){
C[i][j] = (int)(sum[j] % mod);
}
}
return C;
}
//////////// end
// ret[n]=A^(2^n)
public static int[][][] generateP2(int[][] A, int n)
{
int[][][] ret = new int[n+1][][];
ret[0] = A;
for(int i = 1;i <= n;i++)ret[i] = p2(ret[i-1]);
return ret;
}
// A[0]^e*v
// A[n]=A[0]^(2^n)
public static int[] pow(int[][][] A, int[] v, long e)
{
for(int i = 0;e > 0;e>>>=1,i++) {
if((e&1)==1)v = mul(A[i], v);
}
return v;
}
public static int[][] mul(int[][]... a)
{
int[][] base = a[0];
for(int i = 1;i < a.length;i++){
base = mul(base, a[i]);
}
return base;
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
DecimalFormat df = new DecimalFormat("0.00000000000");
static boolean multipleTC = false;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new DFNC().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Feel free to share your approach. Suggestions are welcomed as always.