DFNC - Editorial

PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Setter: Kasra Mazaheri

Tester: Arshia

Editorialist: Taranpreet Singh

DIFFICULTY:

Medium-Hard

PREREQUISITES:

Linear Recurrences, Matrix Exponentiation

PROBLEM:

Given a set S of size K, an integer X and a sequence A of length M which defines an infinite sequence W such that W_{i} = A_{(i-1)\%M+1}.

Now, defining a function F(X) as

  • F(0) = X
  • if i \in S, F(i) = 0
  • otherwise F(i) = \big( \displaystyle\sum_{j = 1}^{i} F(i-j)*W_j \big) \% (10^9+7)

Find F(N)

QUICK EXPLANATION

  • Since the sequence W is cyclic, we can group the F(i) which shall always be multiplied by the same W_i. Hence, this way, F(i) is only dependent upon the previous M values of F(i)
  • Let’s define sequence T(i) = \displaystyle\sum_{j = i\%M}^{i} F(i) which is the required grouping. Now, For this sequence, we can use Matrix Exponentiation for computing T(i). F(i) can be computed for any i using previous M values of T(i).
  • For handling F(i) where i \in S, we can sort values in S and use matrix Exponentiation to compute till T(s_i-1) and then adjust for F(i) = 0.
  • As a speedup, it is better to pre-compute binary powers of transition matrix to reduce complexity.

EXPLANATION

First of all, Assume Set S is empty. So we are given initial element, the sequence A and we have to find F(N)

we can see that the sequence W repeats itself. We have F(i) = \big( \displaystyle\sum_{j = 1}^{i} F(i-j)*W_j \big). But, W_j doesn’t take more than M different values for different values of j. Specifically, due to the way sequence W is defined, we have W_j = W_{j-M} = \ldots W_{j-x*M} = A_{j-x*M} As long as j-x*M > 0 and x is maximum.

So, We can group all j such that F(j) gets multiplied by same A_p. We can now rewrite the summation as

F(i) = \sum_{p = 1}^{M} A_p*\sum_{x = 0}^{\lfloor(i-p)/M\rfloor} F(p+x*M)

Now, let’s define T(i) = \sum_{x = 0}^{\lfloor i/M \rfloor} F(i-x*M) + T(i-M)

Using this definition, we can write F(i) = \sum_{p = 1}^{M} A_p*T(i-p)

So, we have a linear recurrence depending upon the last M terms. We can easily apply Matrix Exponentiation here to obtain T(N) in O(M^3*log(N)) time. Also, by our definition of T(N), it is easy to see that $F(N) = T(N)-T(N-M) if N-M \geq 0, otherwise T(N) = F(N)

The linear recurrence for T(N) looks like, assuming M = 4

\begin{bmatrix} W_{1} & W_{2} & W_3 & W_{4}+1 \\ 1 & 0 & 0 & 0\\ 0 & 1 & 0 & 0\\ 0 & 0 & 1 & 0\\ \end{bmatrix} \begin{bmatrix} T(N-1)\\ T(N-2)\\ T(N-3)\\ T(N-4)\\ \end{bmatrix} = \begin{bmatrix} T(N)\\ T(N-1)\\ T(N-2)\\ T(N-3)\\ \end{bmatrix}

Now, returning to our original problem where set S can be non-empty, let’s sort this set. Now, We have S(0) = T(0) = X. Now, Assuming we have calculated till T(P), We can calculate T(S_i-1) using above matrix exponentiation where S_i is the smallest value greater than P present in set S.

We now have T(S_i-1) calculated. Now, since F(S_i) = 0, we have T(S_i) = T(S_i-M) So, our transition matrix here become
\begin{bmatrix} 0 & 0 & 0 &1 \\ 1 & 0 & 0 & 0\\ 0 & 1 & 0 & 0\\ 0 & 0 & 1 & 0\\ \end{bmatrix}

So, we can multiply it with our current answer matrix and set P = S_i

Lastly, we need to multiply our current answer matrix by N-P power of the first matrix to calculate T(N)

We can repeat same process to calculate T(N-M) and the difference of two is the required answer.

As a speedup, since we have to use matrix exponentiation multiple times over the same transition matrix, we can precompute binary powers of the Transition matrix, reducing time complexity by a factor of M here.

TIME COMPLEXITY

The time complexity is O(K*log(K) + M^3*log(N)+K*M^2*log(N))

SOLUTIONS:

Setter's Solution
// In The Name Of The Queen
#include<bits/stdc++.h>
using namespace std;
const int N = 202, LG = 61, Mod = 1e9 + 7;
struct Matrix
{
	int n, m, A[N][N];
	inline Matrix(int _n = 0, int _m = 0) : n(_n), m(_m) {memset(A, 0, sizeof(A));}
	inline Matrix operator * (Matrix &X)
	{
	    Matrix R(n, X.m);
	    for (int i = 0; i < n; i ++)
	        for (int k = 0; k < m; k ++)
	            for (int j = 0; j < X.m; j ++)
	                R[i][j] = (R[i][j] + 1LL * A[i][k] * X[k][j]) % Mod;
	    return (R);
	}
	inline Matrix operator ^ (long long Pw)
	{
	    Matrix R(n, n), T = * this;
	    for (int i = 0; i < n; i ++)
	        R[i][i] = 1;
	    for (; Pw; Pw >>= 1, T = T * T)
	        if (Pw & 1)
	            R = R * T;
	    return (R);
	}
	inline int * operator [] (int i)
	{
	    return (A[i]);
	}
};
int m, k, X, W[N];
long long n, S[N];
int main()
{
	scanf("%d%d%d%lld", &X, &k, &m, &n);
	for (int i = 1; i <= k; i ++)
	    scanf("%lld", &S[i]);
	for (int i = 1; i <= m; i ++)
	    scanf("%d", &W[i]);
	sort(S + 1, S + k + 1);
	if (n == 0)
	    return !printf("%lld\n", X);
	if (S[k] == n)
	    return !printf("0\n");
	Matrix A(1, m), M[LG];
	for (int i = 0; i < LG; i ++)
	    M[i] = Matrix(m, m);
	A[0][m - 1] = X;
	for (int i = 0; i < m - 1; i ++)
	    M[0][i + 1][i] = 1;
	for (int i = 0; i < m; i ++)
	    M[0][i][m - 1] = W[m - i];
	M[0][0][m - 1] ++;
	for (int i = 1; i < LG; i ++)
	    M[i] = M[i - 1] * M[i - 1];
	for (int i = 1; i <= k; i ++)
	{
	    for (int b = 0; b < LG; b ++)
	        if ((S[i] - S[i - 1] - 1) >> b & 1LL)
	            A = A * M[b];
	    int temp = A[0][0];
	    for (int j = 1; j < m; j ++)
	        A[0][j - 1] = A[0][j];
	    A[0][m - 1] = temp;
	}
	for (int b = 0; b < LG; b ++)
	    if ((n - S[k] - 1) >> b & 1LL)
	        A = A * M[b];
	int Fn = 0;
	for (int i = 0; i < m; i ++)
	    Fn = (Fn + 1LL * A[0][i] * W[m - i]) % Mod;
	return !printf("%d\n", Fn);
}
Tester's Solution
#include <algorithm>
#include <bitset>
#include <complex>
#include <deque>
#include <exception>
#include <fstream>
#include <functional>
#include <iomanip>
#include <ios>
#include <iosfwd>
#include <iostream>
#include <istream>
#include <iterator>
#include <limits>
#include <list>
#include <locale>
#include <map>
#include <memory>
#include <new>
#include <numeric>
#include <ostream>
#include <queue>
#include <set>
#include <sstream>
#include <stack>
#include <stdexcept>
#include <streambuf>
#include <string>
#include <typeinfo>
#include <utility>
#include <valarray>
#include <vector>
#if __cplusplus >= 201103L
#include <array>
#include <atomic>
#include <chrono>
#include <condition_variable>
#include <forward_list>
#include <future>
#include <initializer_list>
#include <mutex>
#include <random>
#include <ratio>
#include <regex>
#include <scoped_allocator>
#include <system_error>
#include <thread>
#include <tuple>
#include <typeindex>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#endif

int gcd(int a, int b) {return b == 0 ? a : gcd(b, a % b);}

using namespace :: std;


//=======================================================================//
#include <iostream>
#include <algorithm>
#include <string>
#include <assert.h>
long long readInt(long long l,long long r,char endd){
	long long x=0;
	int cnt=0;
	int fi=-1;
	bool is_neg=false;
	while(true){
	    char g=getchar();
	    if(g=='-'){
	        assert(fi==-1);
	        is_neg=true;
	        continue;
	    }
	    if('0'<=g && g<='9'){
	        x*=10;
	        x+=g-'0';
	        if(cnt==0){
	            fi=g-'0';
	        }
	        cnt++;
	        assert(fi!=0 || cnt==1);
	        assert(fi!=0 || is_neg==false);

	        assert(!(cnt>19 || ( cnt==19 && fi>1) ));
	    } else if(g==endd){
	        assert(cnt>0);
	        if(is_neg){
	            x= -x;
	        }
	        assert(l<=x && x<=r);
	        return x;
	    } else {
	        assert(false);
	    }
	}
}
string readString(int l,int r,char endd){
	string ret="";
	int cnt=0;
	while(true){
	    char g=getchar();
	    assert(g!=-1);
	    if(g==endd){
	        break;
	    }
	    cnt++;
	    ret+=g;
	}
	assert(l<=cnt && cnt<=r);
	return ret;
}
long long readIntSp(long long l,long long r){
	return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
	return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
	return readString(l,r,'\n');
}
string readStringSp(int l,int r){
	return readString(l,r,' ');
}
//=======================================================================//



#define ll long long
#define pb push_back
#define ld long double
#define mp make_pair
#define F first
#define S second
#define pii pair<ll,ll>

using namespace :: std;

const ll maxn=202;
const ll mod=1e9+7;
const ll inf=1e18+9;

ll a[maxn];
class M{
	public:
		int n,m;
		vector<vector<int> > a;
		M(int n=0,int m=0){
			this->n=n;
			this->m=m;
			vector<int> f;
			f.resize(m);
			fill(f.begin(),f.end(),0);
			a.resize(n);
			fill(a.begin(),a.end(),f);
		}
		M zarb(const M &b){
			if(this->m!=b.n){
				cout<<"RIDI";
				exit(0);
			}
			M ans(this->n,b.m);
			for(int i=0;i<this->n;i++){
				for(int k=0;k<this->m;k++){// a.m=b.n
					for(int j=0;j<b.m;j++){
						ans.a[i][j]=(ans.a[i][j]+(ll)this->a[i][k]*b.a[k][j])%mod;
					}
				}
			}
			return ans;
		}
};

M pre[61];

M jam(M a,const M &b){
	for(ll i=0;i<a.n;i++){
		for(ll j=0;j<a.m;j++){
			a.a[i][j]+=b.a[i][j];
			if(a.a[i][j]>=mod)a.a[i][j]-=mod;
		}
	}
	return a;
}
M tavan(M a,ll n){
	if(a.n!=a.m){
		cout<<"RIDI";
		exit(0);
	}
	M ans(a.n,a.n);
	for(ll i=0;i<a.n;i++)ans.a[i][i]=1;
	while(n){
		if(n&1){
			ans=ans.zarb(a);
		}
		n>>=1;
		a=a.zarb(a);
	}
	return ans;
}
M zarbfast(M a,ll x){
	for(ll i=0;i<61;i++){
		if((x>>i)&1){
			a=a.zarb(pre[i]);
		}
	}
	return a;
}

int main(){
	ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
	ll x,k,m,n;
	x=readIntSp(0,mod-1);
	k=readIntSp(0,200);
	m=readIntSp(0,200);
	n=readIntLn(0,(ll)1e18);
	if(n==0){
		cout<<x;
		return 0;
	}
	vector<ll> vec;
	for(ll i=0;i<k;i++){
		ll x;
		if(i<k-1){
	        x=readIntSp(1,n);
		}
		else{
	        x=readIntLn(1,n);
		}
		if(x==n){
	        cout<<0<<endl;
	        exit(0);
		}
		vec.pb(x);
	}
	sort(vec.begin(),vec.end());
	for(ll i=0;i<m;i++){
		if(i<m-1){
	        a[i]=readIntSp(0,mod-1);
		}
		else{
	        a[i]=readIntLn(0,mod-1);
		}
	}
	M base(m,m);
	M base2(m,m);
	for(ll i=0;i<m;i++){
		base.a[i][m-1]+=a[m-i-1];
		base.a[(i+1)%m][i]++;
		base2.a[(i+1)%m][i]++;
	}


	M avalie(1,m);
	avalie.a[0][m-1]=x;

	pre[0]=base;
	for(ll i=1;i<61;i++){
		pre[i]=pre[i-1].zarb(pre[i-1]);
	}
	while((ll)vec.size() && vec.back()>=n){
	    vec.pop_back();
	}
	ll NOWW=0;
	for(ll i=0;i<(ll)vec.size();i++){
		avalie=zarbfast(avalie,vec[i]-NOWW-1);
		avalie=avalie.zarb(base2);
		NOWW=vec[i];
	}

	avalie=zarbfast(avalie, n-1-NOWW);
	ll ans=0;
	for(ll i=0;i<m;i++)
	    ans=(ans+(ll)avalie.a[0][i]*a[m-1-i])%mod;
	cout<<ans<<endl;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class DFNC{
	//SOLUTION BEGIN
	void pre() throws Exception{}
	void solve(int TC) throws Exception{
	    int x = ni();
	    int k = ni(), m = ni();
	    long n = nl();
	    long[] s = new long[k];
	    int[] w = new int[m];
	    for(int i = 0; i< k; i++)s[i] = nl();
	    for(int i = 0; i< m; i++)w[i] = (int)(nl()%mod);
	    Arrays.sort(s);
	    if(n == 0 || (k>0 && s[k-1] == n)){
	        if(n == 0)pn(x);
	        else pn(0);
	        return;
	    }
	    
	    int[][] M = new int[m][m], M0 = new int[m][m];
	    for(int i = 1; i< m; i++){
	        M[i][i-1] = 1;
	        M0[i][i-1] = 1;
	    }
	    for(int i = 0; i< m; i++)M[0][i] = w[i];
	    M[0][m-1]++;
	    M0[0][m-1]++;
	    
	    int[][][] A = generateP2(M, 60);
	    long ans = 0;
	    
	    int[] v = new int[m];
	    v[0] = x;
	    
	    long pre = 0;
	    for(int i = 0; i< k; i++){
	        if(s[i]-pre-1 > 0)v = pow(A, v, s[i]-pre-1);
	        if(s[i]-pre > 0)v = mul(M0, v);
	        pre = s[i];
	    }
	    v = pow(A, v, n-pre-1);
	    for(int i = 0; i< m; i++){
	        ans += (long)v[i]*w[i];
	        if(ans >= BIG)ans -= BIG;
	    }
	    ans %= mod;
	    pn(ans);
	}
	//Shamelessly copied template
	///////// begin
	public static final int mod = 1000000007;
	public static final long m2 = (long)mod*mod;
	public static final long BIG = 8L*m2;

	// A^e*v
	public static int[] pow(int[][] A, int[] v, long e)
	{
		for(int i = 0;i < v.length;i++){
			if(v[i] >= mod)v[i] %= mod;
		}
		int[][] MUL = A;
		for(;e > 0;e>>>=1) {
			if((e&1)==1)v = mul(MUL, v);
			MUL = p2(MUL);
		}
		return v;
	}

	// int matrix*int vector
	public static int[] mul(int[][] A, int[] v)
	{
		int m = A.length;
		int n = v.length;
		int[] w = new int[m];
		for(int i = 0;i < m;i++){
			long sum = 0;
			for(int k = 0;k < n;k++){
				sum += (long)A[i][k] * v[k];
				if(sum >= BIG)sum -= BIG;
			}
			w[i] = (int)(sum % mod);
		}
		return w;
	}

	// int matrix^2 (be careful about negative value)
	public static int[][] p2(int[][] A)
	{
		int n = A.length;
		int[][] C = new int[n][n];
		for(int i = 0;i < n;i++){
			long[] sum = new long[n];
			for(int k = 0;k < n;k++){
				for(int j = 0;j < n;j++){
					sum[j] += (long)A[i][k] * A[k][j];
					if(sum[j] >= BIG)sum[j] -= BIG;
				}
			}
			for(int j = 0;j < n;j++){
				C[i][j] = (int)(sum[j] % mod);
			}
		}
		return C;
	}

	//////////// end

	// ret[n]=A^(2^n)
	public static int[][][] generateP2(int[][] A, int n)
	{
		int[][][] ret = new int[n+1][][];
		ret[0] = A;
		for(int i = 1;i <= n;i++)ret[i] = p2(ret[i-1]);
		return ret;
	}

	// A[0]^e*v
	// A[n]=A[0]^(2^n)
	public static int[] pow(int[][][] A, int[] v, long e)
	{
		for(int i = 0;e > 0;e>>>=1,i++) {
			if((e&1)==1)v = mul(A[i], v);
		}
		return v;
	}

	public static int[][] mul(int[][]... a)
	{
		int[][] base = a[0];
		for(int i = 1;i < a.length;i++){
			base = mul(base, a[i]);
		}
		return base;
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	DecimalFormat df = new DecimalFormat("0.00000000000");
	static boolean multipleTC = false;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new DFNC().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile: