DFNC - Editorial

PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Setter: Kasra Mazaheri

Tester: Arshia

Editorialist: Taranpreet Singh

DIFFICULTY:

Medium-Hard

PREREQUISITES:

Linear Recurrences, Matrix Exponentiation

PROBLEM:

Given a set S of size K, an integer X and a sequence A of length M which defines an infinite sequence W such that W_{i} = A_{(i-1)\%M+1}.

Now, defining a function F(X) as

  • F(0) = X
  • if i \in S, F(i) = 0
  • otherwise F(i) = \big( \displaystyle\sum_{j = 1}^{i} F(i-j)*W_j \big) \% (10^9+7)

Find F(N)

QUICK EXPLANATION

  • Since the sequence W is cyclic, we can group the F(i) which shall always be multiplied by the same W_i. Hence, this way, F(i) is only dependent upon the previous M values of F(i)
  • Let’s define sequence T(i) = \displaystyle\sum_{j = i\%M}^{i} F(i) which is the required grouping. Now, For this sequence, we can use Matrix Exponentiation for computing T(i). F(i) can be computed for any i using previous M values of T(i).
  • For handling F(i) where i \in S, we can sort values in S and use matrix Exponentiation to compute till T(s_i-1) and then adjust for F(i) = 0.
  • As a speedup, it is better to pre-compute binary powers of transition matrix to reduce complexity.

EXPLANATION

First of all, Assume Set S is empty. So we are given initial element, the sequence A and we have to find F(N)

we can see that the sequence W repeats itself. We have F(i) = \big( \displaystyle\sum_{j = 1}^{i} F(i-j)*W_j \big). But, W_j doesn’t take more than M different values for different values of j. Specifically, due to the way sequence W is defined, we have W_j = W_{j-M} = \ldots W_{j-x*M} = A_{j-x*M} As long as j-x*M > 0 and x is maximum.

So, We can group all j such that F(j) gets multiplied by same A_p. We can now rewrite the summation as

F(i) = \sum_{p = 1}^{M} A_p*\sum_{x = 0}^{\lfloor(i-p)/M\rfloor} F(p+x*M)

Now, let’s define T(i) = \sum_{x = 0}^{\lfloor i/M \rfloor} F(i-x*M) + T(i-M)

Using this definition, we can write F(i) = \sum_{p = 1}^{M} A_p*T(i-p)

So, we have a linear recurrence depending upon the last M terms. We can easily apply Matrix Exponentiation here to obtain T(N) in O(M^3*log(N)) time. Also, by our definition of T(N), it is easy to see that $F(N) = T(N)-T(N-M) if N-M \geq 0, otherwise T(N) = F(N)

The linear recurrence for T(N) looks like, assuming M = 4

$
\begin{bmatrix}
W_{1} & W_{2} & W_3 & W_{4}+1 \
1 & 0 & 0 & 0\
0 & 1 & 0 & 0\
0 & 0 & 1 & 0\
\end{bmatrix}

\begin{bmatrix}
T(N-1)\
T(N-2)\
T(N-3)\
T(N-4)\
\end{bmatrix}
=
\begin{bmatrix}
T(N)\
T(N-1)\
T(N-2)\
T(N-3)\
\end{bmatrix}
$

Now, returning to our original problem where set S can be non-empty, let’s sort this set. Now, We have S(0) = T(0) = X. Now, Assuming we have calculated till T(P), We can calculate T(S_i-1) using above matrix exponentiation where S_i is the smallest value greater than P present in set S.

We now have T(S_i-1) calculated. Now, since F(S_i) = 0, we have T(S_i) = T(S_i-M) So, our transition matrix here become
$
\begin{bmatrix}
0 & 0 & 0 &1 \
1 & 0 & 0 & 0\
0 & 1 & 0 & 0\
0 & 0 & 1 & 0\
\end{bmatrix}
$

So, we can multiply it with our current answer matrix and set P = S_i

Lastly, we need to multiply our current answer matrix by N-P power of the first matrix to calculate T(N)

We can repeat same process to calculate T(N-M) and the difference of two is the required answer.

As a speedup, since we have to use matrix exponentiation multiple times over the same transition matrix, we can precompute binary powers of the Transition matrix, reducing time complexity by a factor of M here.

TIME COMPLEXITY

The time complexity is O(K*log(K) + M^3*log(N)+K*M^2*log(N))

SOLUTIONS:

Setter's Solution
// In The Name Of The Queen
#include<bits/stdc++.h>
using namespace std;
const int N = 202, LG = 61, Mod = 1e9 + 7;
struct Matrix
{
	int n, m, A[N][N];
	inline Matrix(int _n = 0, int _m = 0) : n(_n), m(_m) {memset(A, 0, sizeof(A));}
	inline Matrix operator * (Matrix &X)
	{
	    Matrix R(n, X.m);
	    for (int i = 0; i < n; i ++)
	        for (int k = 0; k < m; k ++)
	            for (int j = 0; j < X.m; j ++)
	                R[i][j] = (R[i][j] + 1LL * A[i][k] * X[k][j]) % Mod;
	    return (R);
	}
	inline Matrix operator ^ (long long Pw)
	{
	    Matrix R(n, n), T = * this;
	    for (int i = 0; i < n; i ++)
	        R[i][i] = 1;
	    for (; Pw; Pw >>= 1, T = T * T)
	        if (Pw & 1)
	            R = R * T;
	    return (R);
	}
	inline int * operator [] (int i)
	{
	    return (A[i]);
	}
};
int m, k, X, W[N];
long long n, S[N];
int main()
{
	scanf("%d%d%d%lld", &X, &k, &m, &n);
	for (int i = 1; i <= k; i ++)
	    scanf("%lld", &S[i]);
	for (int i = 1; i <= m; i ++)
	    scanf("%d", &W[i]);
	sort(S + 1, S + k + 1);
	if (n == 0)
	    return !printf("%lld\n", X);
	if (S[k] == n)
	    return !printf("0\n");
	Matrix A(1, m), M[LG];
	for (int i = 0; i < LG; i ++)
	    M[i] = Matrix(m, m);
	A[0][m - 1] = X;
	for (int i = 0; i < m - 1; i ++)
	    M[0][i + 1][i] = 1;
	for (int i = 0; i < m; i ++)
	    M[0][i][m - 1] = W[m - i];
	M[0][0][m - 1] ++;
	for (int i = 1; i < LG; i ++)
	    M[i] = M[i - 1] * M[i - 1];
	for (int i = 1; i <= k; i ++)
	{
	    for (int b = 0; b < LG; b ++)
	        if ((S[i] - S[i - 1] - 1) >> b & 1LL)
	            A = A * M[b];
	    int temp = A[0][0];
	    for (int j = 1; j < m; j ++)
	        A[0][j - 1] = A[0][j];
	    A[0][m - 1] = temp;
	}
	for (int b = 0; b < LG; b ++)
	    if ((n - S[k] - 1) >> b & 1LL)
	        A = A * M[b];
	int Fn = 0;
	for (int i = 0; i < m; i ++)
	    Fn = (Fn + 1LL * A[0][i] * W[m - i]) % Mod;
	return !printf("%d\n", Fn);
}

[details = “Tester’s Solution”]
include
include
include
include
include
include
include
include
include
include
include
include
include
include
include
include
include
include
include
include
include
include
include
include
include
include
include
include
include
include
include
include
if __cplusplus >= 201103L
include
include
include
include <condition_variable>
include <forward_list>
include
include <initializer_list>
include
include
include
include
include <scoped_allocator>
include <system_error>
include
include
include
include <type_traits>
include <unordered_map>
include <unordered_set>
#endif

int gcd(int a, int b) {return b == 0 ? a : gcd(b, a % b);}

using namespace :: std;


//=======================================================================//
#include <iostream>
#include <algorithm>
#include <string>
#include <assert.h>
long long readInt(long long l,long long r,char endd){
	long long x=0;
	int cnt=0;
	int fi=-1;
	bool is_neg=false;
	while(true){
	    char g=getchar();
	    if(g=='-'){
	        assert(fi==-1);
	        is_neg=true;
	        continue;
	    }
	    if('0'<=g && g<='9'){
	        x*=10;
	        x+=g-'0';
	        if(cnt==0){
	            fi=g-'0';
	        }
	        cnt++;
	        assert(fi!=0 || cnt==1);
	        assert(fi!=0 || is_neg==false);

	        assert(!(cnt>19 || ( cnt==19 && fi>1) ));
	    } else if(g==endd){
	        assert(cnt>0);
	        if(is_neg){
	            x= -x;
	        }
	        assert(l<=x && x<=r);
	        return x;
	    } else {
	        assert(false);
	    }
	}
}
string readString(int l,int r,char endd){
	string ret="";
	int cnt=0;
	while(true){
	    char g=getchar();
	    assert(g!=-1);
	    if(g==endd){
	        break;
	    }
	    cnt++;
	    ret+=g;
	}
	assert(l<=cnt && cnt<=r);
	return ret;
}
long long readIntSp(long long l,long long r){
	return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
	return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
	return readString(l,r,'\n');
}
string readStringSp(int l,int r){
	return readString(l,r,' ');
}
//=======================================================================//



#define ll long long
#define pb push_back
#define ld long double
#define mp make_pair
#define F first
#define S second
#define pii pair<ll,ll>

using namespace :: std;

const ll maxn=202;
const ll mod=1e9+7;
const ll inf=1e18+9;

ll a[maxn];
class M{
	public:
		int n,m;
		vector<vector<int> > a;
		M(int n=0,int m=0){
			this->n=n;
			this->m=m;
			vector<int> f;
			f.resize(m);
			fill(f.begin(),f.end(),0);
			a.resize(n);
			fill(a.begin(),a.end(),f);
		}
		M zarb(const M &b){
			if(this->m!=b.n){
				cout<<"RIDI";
				exit(0);
			}
			M ans(this->n,b.m);
			for(int i=0;i<this->n;i++){
				for(int k=0;k<this->m;k++){// a.m=b.n
					for(int j=0;j<b.m;j++){
						ans.a[i][j]=(ans.a[i][j]+(ll)this->a[i][k]*b.a[k][j])%mod;
					}
				}
			}
			return ans;
		}
};

M pre[61];

M jam(M a,const M &b){
	for(ll i=0;i<a.n;i++){
		for(ll j=0;j<a.m;j++){
			a.a[i][j]+=b.a[i][j];
			if(a.a[i][j]>=mod)a.a[i][j]-=mod;
		}
	}
	return a;
}
M tavan(M a,ll n){
	if(a.n!=a.m){
		cout<<"RIDI";
		exit(0);
	}
	M ans(a.n,a.n);
	for(ll i=0;i<a.n;i++)ans.a[i][i]=1;
	while(n){
		if(n&1){
			ans=ans.zarb(a);
		}
		n>>=1;
		a=a.zarb(a);
	}
	return ans;
}
M zarbfast(M a,ll x){
	for(ll i=0;i<61;i++){
		if((x>>i)&1){
			a=a.zarb(pre[i]);
		}
	}
	return a;
}

int main(){
	ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
	ll x,k,m,n;
	x=readIntSp(0,mod-1);
	k=readIntSp(0,200);
	m=readIntSp(0,200);
	n=readIntLn(0,(ll)1e18);
	if(n==0){
		cout<<x;
		return 0;
	}
	vector<ll> vec;
	for(ll i=0;i<k;i++){
		ll x;
		if(i<k-1){
	        x=readIntSp(1,n);
		}
		else{
	        x=readIntLn(1,n);
		}
		if(x==n){
	        cout<<0<<endl;
	        exit(0);
		}
		vec.pb(x);
	}
	sort(vec.begin(),vec.end());
	for(ll i=0;i<m;i++){
		if(i<m-1){
	        a[i]=readIntSp(0,mod-1);
		}
		else{
	        a[i]=readIntLn(0,mod-1);
		}
	}
	M base(m,m);
	M base2(m,m);
	for(ll i=0;i<m;i++){
		base.a[i][m-1]+=a[m-i-1];
		base.a[(i+1)%m][i]++;
		base2.a[(i+1)%m][i]++;
	}


	M avalie(1,m);
	avalie.a[0][m-1]=x;

	pre[0]=base;
	for(ll i=1;i<61;i++){
		pre[i]=pre[i-1].zarb(pre[i-1]);
	}
	while((ll)vec.size() && vec.back()>=n){
	    vec.pop_back();
	}
	ll NOWW=0;
	for(ll i=0;i<(ll)vec.size();i++){
		avalie=zarbfast(avalie,vec[i]-NOWW-1);
		avalie=avalie.zarb(base2);
		NOWW=vec[i];
	}

	avalie=zarbfast(avalie, n-1-NOWW);
	ll ans=0;
	for(ll i=0;i<m;i++)
	    ans=(ans+(ll)avalie.a[0][i]*a[m-1-i])%mod;
	cout<<ans<<endl;
}

[/details]
[details = “Editorialist’s Solution”]
import java.util.;
import java.io.
;
import java.text.*;
class DFNC{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
int x = ni();
int k = ni(), m = ni();
long n = nl();
long s = new long[k];
int w = new int[m];
for(int i = 0; i< k; i++)s[i] = nl();
for(int i = 0; i< m; i++)w[i] = (int)(nl()%mod);
Arrays.sort(s);
if(n == 0 || (k>0 && s[k-1] == n)){
if(n == 0)pn(x);
else pn(0);
return;
}

	    int[][] M = new int[m][m], M0 = new int[m][m];
	    for(int i = 1; i< m; i++){
	        M[i][i-1] = 1;
	        M0[i][i-1] = 1;
	    }
	    for(int i = 0; i< m; i++)M[0][i] = w[i];
	    M[0][m-1]++;
	    M0[0][m-1]++;
	    
	    int[][][] A = generateP2(M, 60);
	    long ans = 0;
	    
	    int[] v = new int[m];
	    v[0] = x;
	    
	    long pre = 0;
	    for(int i = 0; i< k; i++){
	        if(s[i]-pre-1 > 0)v = pow(A, v, s[i]-pre-1);
	        if(s[i]-pre > 0)v = mul(M0, v);
	        pre = s[i];
	    }
	    v = pow(A, v, n-pre-1);
	    for(int i = 0; i< m; i++){
	        ans += (long)v[i]*w[i];
	        if(ans >= BIG)ans -= BIG;
	    }
	    ans %= mod;
	    pn(ans);
	}
	//Shamelessly copied template
	///////// begin
	public static final int mod = 1000000007;
	public static final long m2 = (long)mod*mod;
	public static final long BIG = 8L*m2;

	// A^e*v
	public static int[] pow(int[][] A, int[] v, long e)
	{
		for(int i = 0;i < v.length;i++){
			if(v[i] >= mod)v[i] %= mod;
		}
		int[][] MUL = A;
		for(;e > 0;e>>>=1) {
			if((e&1)==1)v = mul(MUL, v);
			MUL = p2(MUL);
		}
		return v;
	}

	// int matrix*int vector
	public static int[] mul(int[][] A, int[] v)
	{
		int m = A.length;
		int n = v.length;
		int[] w = new int[m];
		for(int i = 0;i < m;i++){
			long sum = 0;
			for(int k = 0;k < n;k++){
				sum += (long)A[i][k] * v[k];
				if(sum >= BIG)sum -= BIG;
			}
			w[i] = (int)(sum % mod);
		}
		return w;
	}

	// int matrix^2 (be careful about negative value)
	public static int[][] p2(int[][] A)
	{
		int n = A.length;
		int[][] C = new int[n][n];
		for(int i = 0;i < n;i++){
			long[] sum = new long[n];
			for(int k = 0;k < n;k++){
				for(int j = 0;j < n;j++){
					sum[j] += (long)A[i][k] * A[k][j];
					if(sum[j] >= BIG)sum[j] -= BIG;
				}
			}
			for(int j = 0;j < n;j++){
				C[i][j] = (int)(sum[j] % mod);
			}
		}
		return C;
	}

	//////////// end

	// ret[n]=A^(2^n)
	public static int[][][] generateP2(int[][] A, int n)
	{
		int[][][] ret = new int[n+1][][];
		ret[0] = A;
		for(int i = 1;i <= n;i++)ret[i] = p2(ret[i-1]);
		return ret;
	}

	// A[0]^e*v
	// A[n]=A[0]^(2^n)
	public static int[] pow(int[][][] A, int[] v, long e)
	{
		for(int i = 0;e > 0;e>>>=1,i++) {
			if((e&1)==1)v = mul(A[i], v);
		}
		return v;
	}

	public static int[][] mul(int[][]... a)
	{
		int[][] base = a[0];
		for(int i = 1;i < a.length;i++){
			base = mul(base, a[i]);
		}
		return base;
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	DecimalFormat df = new DecimalFormat("0.00000000000");
	static boolean multipleTC = false;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new DFNC().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

[/details]

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile: