YACHEFNUM - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Yash Thakker
Testers: Satyam, Jatin Garg
Editorialist: Nishank Suresh

DIFFICULTY:

2820

PREREQUISITES:

Digit DP

PROBLEM:

The rating of a number is defined to be the number of trailing zeros in the product of its digits (and 1 if this product is 0).

Given two large numbers A and B, find the sum of the ratings of all numbers from A+1 to B.

EXPLANATION:

A and B are rather large numbers: they can have upto 1000 digits each, which definitely doesn’t allow for them to be read as integers in C++.

A lot of the time, when dealing with large numbers and you need to calculate some function for all numbers \leq n, the solution tends to be digit DP. You can read an introduction to the technique here. Some familiarity with the technique will be assumed below.

In our case, let’s define the function f(N) to be the sum of the ratings of all numbers from 1 to N. Then, our answer is f(B) - f(A), so we just need to be compute f(N) at a given argument N fast enough.

Note that the rating of a number depends only on two things:

  • The number of twos and fives present as factors of its digits.
  • Whether the number contains 0 or not. If it does, the rating is always going to be 1.

In particular, for a number that doesn’t contain a zero, its rating is simply the minimum of the number of factors of 2 and 5 present among its digits.
For now, let’s consider only numbers that don’t contain 0 — the ones that do can be taken care of separately later, and it isn’t hard to do so.

There’s a fairly simple function that comes to mind immediately:
(1-based indexing is used below)
Let g(i, x, y) denote the answer if we are going to place the i-th digit of the number, and there are x powers of 2 and y powers of 5 so far.
Transitions are simple: iterate over which digit d from 1 to 9 is going to be placed here while ensuring that you don’t exceed N (recall that we are ignoring 0), calculate the new values of x and y from d (call them x' and y' respectively), and then recurse to g(i+1, x', y') and add it to the answer.
The base case is of course g(L+1, x, y) = \min(x, y) where L is the length of the string N, and the final answer is g(1, 0, 0).

By memoizing the values of g(i, x, y) (with perhaps extra boolean states to denote whether the current inequality with N is tight or not, and whether the current prefix is non-zero), this solution runs in \mathcal{O}(L^3 \cdot 10).
However, L = 10^3 for us, so this is too slow.

To speed it up, we use a somewhat common DP optimization: reduce the number of states using existing information.
In particular, it is unnecessary to know both x and y — we only really need to know whether there are more 2-s or more 5-s, and how many more there are. This information can be kept with a boolean and an integer, so we’ve reduced the number of states by a factor of L.

However, making this change also necessitates a change in the definition of our g, since the parameters are different now.
It’s somewhat hard to explain properly with words, and I recommend looking at the setter’s code linked below for what to do. However, the gist is as follows:

  • g(state) now stores a pair: the answer for that state, and the number of ways to reach that state. Let’s denote these by ans(state) and ways(state) respectively.
  • If the current state has more 2-s than 5-s, and the current digit adds x more 2-s (where x \gt 0), nothing much needs to be done: the difference between the number of twos and fives grows by x.
  • A similar condition applies to the current state having more 5-s and the current digit being 5
  • Now, suppose the state has more 2-s and we add a 5. Then,
    • The difference decreases by 1 (unless it was already zero)
    • Let state_2 be the state we recurse into. Then, we need to further add ways(state_2) to ans(state) to account for this extra five.
  • Something similar is to be done for the case when currently there are more 5-s and the current digit contains a 2. In particular, \{2, 6\}, \{4\}, \{8\} all need to be treated as separate cases since they add 1/2/3 twos respectively.
  • ways(state) is the sum of ways(state_2) across all states that are recursed into.

This brings us down to something like \mathcal{O}(L^2 \cdot 8) states, and 10 transitions from each, which is fast enough for L = 1000.

Now, all that remains is to count the contribution of numbers that contain a zero.
Note that this is essentially the same as counting the numbers \leq N that contain a zero, which is not very hard to do.
For example, you can do the following:

  • For each i from 1 to L, suppose the first time you go below N is at position i. So, the first i-1 digits match those of N, and the last L-i digits are completely free.
  • Suppose you place digit d at this position.
    • If d = 0 or the first i-1 digits contain a zero, the remaining digits can be absolutely anything. So, add 10^{L-i} to the answer
    • Otherwise, you need to count the number of strings of length L-i that contain a zero. This can be precomputed with another dp.

Special care must be taken when placing d = 0 at i = 1, since the above discussion only applies when the number has a non-zero prefix.
However, placing 0 at i = 1 is the same as counting all numbers of length \leq N-1 that contain a zero, which can again be precomputed.

Note that this is not the only way to solve the problem: I noticed that many contestants (and even the setters) have many different solutions, some perhaps easier than the one outlined here.
If you would like to share your solution in the comments below, feel free to do so!

TIME COMPLEXITY

\mathcal{O}(L^2 \cdot 10) per test case, with a constant factor of around 8.

CODE:

Setter's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long 
#define pb(e) push_back(e)
#define sv(a) sort(a.begin(),a.end())
#define sa(a,n) sort(a,a+n)
#define mp(a,b) make_pair(a,b)
#define vf first
#define vs second
#define ar array
#define all(x) x.begin(),x.end()
const int inf = 0x3f3f3f3f;
const int mod = 1000000007; 
const double PI=3.14159265358979323846264338327950288419716939937510582097494459230;
bool remender(ll a , ll b){return a%b;}

//freopen("problemname.in", "r", stdin);
//freopen("problemname.out", "w", stdout);

vector<int> v;

pair<ll,ll> dp[1002][1002][2][2][2];
ll dp1[1002][2][2][2];

ll zeros(int i , int start , int tight , int came){
	if(i == v.size())return came;
	if(dp1[i][start][tight][came] != -1)return dp1[i][start][tight][came];
	int last = 9;
	if(tight == 1)last = v[i];
	ll res = 0;
	for(int j = 0; j <= last; j++){
		if(j == 0){
			if(start == 1){
				res += zeros(i + 1 , start , 0 , 0);
				res %= mod;
			}
			else {
				res += zeros(i + 1 , 0 , (j == last ? tight : 0) , 1);
				res %= mod;
			}
		}
		else {
			res += zeros(i + 1 , 0 , (j == last ? tight : 0) , came);
			res %= mod;
		}
	}
	return dp1[i][start][tight][came] = res;
}

pair<ll,ll> rec(int i , int cnt , int which , int start, int tight){
	if(i == (v.size())){
		return mp(1,0);
	}
	if(dp[i][cnt][which][start][tight].vf != -1)return dp[i][cnt][which][start][tight];
	int last = 9;
	if(tight == 1)last = v[i];
	if(last == 0){
		return mp(0,0);
	}
	ll res = 0, ways = 0;
	if(start == 1){
		pair<ll,ll> p = rec(i + 1 , cnt , which , start , 0);
		res += p.vs;
		ways += p.vf;
	}
	for(int j = 1; j <= last; j++){
		int c = cnt , add = 0 , wh = which;
		if(j % 2 == 0){
			if(c == 0){
				if(j == 2 || j == 6)c = 1;
				else if(j == 4)c = 2;
				else if(j == 8)c = 3;
				wh = 0;
			}
			else {
				if(which == 0){
					if(j == 2 || j == 6)c++;
					else if(j == 4)c+=2;
					else if(j == 8)c += 3;
				}
				else {
					add++;
					c--;
					if(j == 4 || j == 8){
						if(c == 0){
							wh = 0;
							if(j == 4)c = 1;
							else c = 2;
						}
						else {
							add++;
							c--;
							if(j == 8){
								if(c == 0){
									wh = 0;
									c = 1;
								}
								else {
									c--;
									add++;
								}
							}
						}
					}
				}
			}
		}
		if(j == 5){
			if(c == 0){
				c = 1;
				wh = 1;
			}
			else {
				if(which == 1)c++;
				else {
					add++;
					c--;
				}
			}
		}
		c = min(c , 1000);
		if(j == v[i]){
			pair<ll,ll> p = rec(i + 1 , c , wh , 0 , tight);
			res += add*p.vf+p.vs;
			res %= mod;
			ways += p.vf;
			ways %= mod;
		}
		else {
			pair<ll,ll> p = rec(i + 1 , c , wh , 0 , 0);
			res += add*p.vf + p.vs;
			res %= mod;
			ways += p.vf;
			ways %= mod;
		}
	}
	return dp[i][cnt][which][start][tight] = mp(ways , res);
}

void solve(){
	string s;
	getline(cin , s);
	int n = s.size();
	ll ans = 0;
	for(int i = 0; i < n; i++){
		if(s[i] == ' '){
			memset(dp,-1,sizeof dp);
			memset(dp1,-1,sizeof dp1);
			ans -= (rec(0,0,0,1,1).vs + zeros(0,1,1,0))%mod;
			v.clear();
			continue;
		}
		v.pb(s[i]-'0');
	}
	memset(dp,-1,sizeof dp);
	memset(dp1,-1,sizeof dp1);
	ans += (rec(0,0,0,1,1).vs + zeros(0,1,1,0))%mod;
	ans += mod;
	ans %= mod;
	cout << ans << '\n';
}

int main(){
ios_base::sync_with_stdio(false);
cin.tie(NULL);
	//int t;cin >> t;while(t--)
	solve();
	return 0;
}
Tester (rivalq)'s code (C++)
// Jai Shree Ram  
  
#include<bits/stdc++.h>
using namespace std;



#define rep(i,a,n)     for(int i=a;i<n;i++)
#define ll             long long
#define int            long long
#define pb             push_back
#define all(v)         v.begin(),v.end()
#define endl           "\n"
#define x              first
#define y              second
#define gcd(a,b)       __gcd(a,b)
#define mem1(a)        memset(a,-1,sizeof(a))
#define mem0(a)        memset(a,0,sizeof(a))
#define sz(a)          (int)a.size()
#define pii            pair<int,int>
#define hell           1000000007
#define elasped_time   1.0 * clock() / CLOCKS_PER_SEC



template<typename T1,typename T2>istream& operator>>(istream& in,pair<T1,T2> &a){in>>a.x>>a.y;return in;}
template<typename T1,typename T2>ostream& operator<<(ostream& out,pair<T1,T2> a){out<<a.x<<" "<<a.y;return out;}
template<typename T,typename T1>T maxs(T &a,T1 b){if(b>a)a=b;return a;}
template<typename T,typename T1>T mins(T &a,T1 b){if(b<a)a=b;return a;}


const int MOD = hell;
 
struct mod_int {
    int val;
 
    mod_int(long long v = 0) {
        if (v < 0)
            v = v % MOD + MOD;
 
        if (v >= MOD)
            v %= MOD;
 
        val = v;
    }
 
    static int mod_inv(int a, int m = MOD) {
        int g = m, r = a, x = 0, y = 1;
 
        while (r != 0) {
            int q = g / r;
            g %= r; swap(g, r);
            x -= q * y; swap(x, y);
        }
 
        return x < 0 ? x + m : x;
    }
 
    explicit operator int() const {
        return val;
    }
 
    mod_int& operator+=(const mod_int &other) {
        val += other.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
 
    mod_int& operator-=(const mod_int &other) {
        val -= other.val;
        if (val < 0) val += MOD;
        return *this;
    }
 
    static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
           #if !defined(_WIN32) || defined(_WIN64)
                return x % m;
           #endif
           unsigned x_high = x >> 32, x_low = (unsigned) x;
           unsigned quot, rem;
           asm("divl %4\n"
            : "=a" (quot), "=d" (rem)
            : "d" (x_high), "a" (x_low), "r" (m));
           return rem;
    }
 
    mod_int& operator*=(const mod_int &other) {
        val = fast_mod((uint64_t) val * other.val);
        return *this;
    }
 
    mod_int& operator/=(const mod_int &other) {
        return *this *= other.inv();
    }
 
    friend mod_int operator+(const mod_int &a, const mod_int &b) { return mod_int(a) += b; }
    friend mod_int operator-(const mod_int &a, const mod_int &b) { return mod_int(a) -= b; }
    friend mod_int operator*(const mod_int &a, const mod_int &b) { return mod_int(a) *= b; }
    friend mod_int operator/(const mod_int &a, const mod_int &b) { return mod_int(a) /= b; }
 
    mod_int& operator++() {
        val = val == MOD - 1 ? 0 : val + 1;
        return *this;
    }
 
    mod_int& operator--() {
        val = val == 0 ? MOD - 1 : val - 1;
        return *this;
    }
 
    mod_int operator++(int32_t) { mod_int before = *this; ++*this; return before; }
    mod_int operator--(int32_t) { mod_int before = *this; --*this; return before; }
 
    mod_int operator-() const {
        return val == 0 ? 0 : MOD - val;
    }
 
    bool operator==(const mod_int &other) const { return val == other.val; }
    bool operator!=(const mod_int &other) const { return val != other.val; }
 
    mod_int inv() const {
        return mod_inv(val);
    }
 
    mod_int pow(long long p) const {
        assert(p >= 0);
        mod_int a = *this, result = 1;
 
        while (p > 0) {
            if (p & 1)
                result *= a;
 
            a *= a;
            p >>= 1;
        }
 
        return result;
    }
 
    friend ostream& operator<<(ostream &stream, const mod_int &m) {
        return stream << m.val;
    }
    friend istream& operator >> (istream &stream, mod_int &m) {
        return stream>>m.val;   
    }
};

// -------------------- Input Checker Start --------------------
 
long long readInt(long long l, long long r, char endd)
{
    long long x = 0;
    int cnt = 0, fi = -1;
    bool is_neg = false;
    while(true)
    {
        char g = getchar();
        if(g == '-')
        {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if('0' <= g && g <= '9')
        {
            x *= 10;
            x += g - '0';
            if(cnt == 0)
                fi = g - '0';
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);
            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if(g == endd)
        {
            if(is_neg)
                x = -x;
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(false);
            }
            return x;
        }
        else
        {
            assert(false);
        }
    }
}
 
string readString(int l, int r, char endd)
{
    string ret = "";
    int cnt = 0;
    while(true)
    {
        char g = getchar();
        assert(g != -1);
        if(g == endd)
            break;
        cnt++;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}
 
long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
void readEOF() { assert(getchar() == EOF); }
 
vector<int> readVectorInt(int n, long long l, long long r)
{
    vector<int> a(n);
    for(int i = 0; i < n - 1; i++)
        a[i] = readIntSp(l, r);
    a[n - 1] = readIntLn(l, r);
    return a;
}
 

int sum = 0;

void checkInput(auto a,auto b){
	assert(a[0] != '0');
	assert(b[0] != '0');
	for(auto i:a){
		assert(i >= '0' and i <= '9');
	}
	for(auto i:b){
		assert(i >= '0' and i <= '9');
	}
	assert(b.length() > a.length() or (b.length() == a.length() and b > a));
}


const int N = 1e3 + 5;

mod_int dp[N][6*N][2];
mod_int ways[N][6*N][2];
bool vis[N][6*N][2];
const int shift = 3e3 + 1;


void fun(int &num_5,int &num_2,int j){
	if(j == 2)num_2++;
	if(j == 4)num_2 += 2;
	if(j == 5)num_5++;
	if(j == 6)num_2++;
	if(j == 8)num_2 += 3;
}

array<mod_int,2> zeno(int i,int cnt,int f,string &s){
	if(i == s.length()){
		return {0,1};
	}
	if(vis[i][cnt + shift][f]){
		return {dp[i][cnt + shift][f],ways[i][cnt + shift][f]};
	}
	vis[i][cnt + shift][f] = true;
	mod_int &ans = dp[i][cnt + shift][f];
	mod_int &way = ways[i][cnt + shift][f];
	for(int j = 1; j <= 9; j++){
		if(f == 1 and s[i] - '0' < j)break;
		bool ff = (f && (s[i] - '0') == j);
		int num_5 = (cnt < 0) ? -cnt : 0;
		int num_2 = (cnt > 0) ? cnt : 0;
		
		fun(num_5,num_2,j);

		int mn = min(num_5,num_2);
		int ccnt = 0;
		if(num_5 > mn){
			ccnt = mn - num_5;
		}else if(num_2 > mn){
			ccnt = num_2 - mn;
		}
		auto [ans2, way2] = zeno(i + 1,ccnt,ff,s);
		ans += ans2 + way2*mn;
		way += way2;
	}
	return {ans,way};
}


mod_int solve_zeroes(string s){
	mod_int ans = 0;



	bool zero = 0;
	int n = s.length();

	for(int len = 1; len < n; len++){
		ans += 9*mod_int(10).pow(len - 1) - mod_int(9).pow(len);
	}

	for(int i = 0; i < n; i++){
		for(int j = 0; j < s[i] - '0'; j++){
			if(i == 0 and j == 0)continue;
			bool z = (zero | (j == 0));
			if(z){
				ans += mod_int(10).pow(n - i - 1);
			}else{
				ans += (mod_int(10).pow(n - i - 1) - mod_int(9).pow(n - i - 1));
			}
		}
		zero |= (s[i] == '0');
	}
	ans += zero;
	return ans;
}


mod_int solve_25(string &s){
	memset(vis,0,sizeof(vis));
	memset(ways,0,sizeof(ways));
	memset(dp,0,sizeof(dp));
	mod_int ans = zeno(0,0,1,s)[0];
	int n = s.length();

	vector<vector<mod_int>> f(n + 1,vector<mod_int>(6*n + 1));
	vector<vector<mod_int>> f2(n + 1,vector<mod_int>(6*n + 1));

	f2[0][0 + 3*n] = 1;

	for(int i = 0; i < n; i++){
		for(int j = 1; j <= 9; j++){
			for(int cnt = -3*i; cnt <= 3*i; cnt++){
				mod_int ans2 = f[i][cnt + 3*n]; 
				mod_int way = f2[i][cnt + 3*n];

				int num_5 = (cnt < 0) ? -cnt : 0;
				int num_2 = (cnt > 0) ? cnt : 0;

				fun(num_5,num_2,j);
				int mn = min(num_5,num_2);
				int ccnt = 0;
				if(num_5 > mn){
					ccnt = mn - num_5;
				}else if(num_2 > mn){
					ccnt = num_2 - mn;
				}
				ccnt += 3*n;
				f[i + 1][ccnt] += mn*way + ans2;
				f2[i + 1][ccnt] += way;
			}
		}
	}

	for(int len = 1; len < n; len++){
		for(int cnt = -3*n; cnt <= 3*n; cnt++){
			ans += f[len][cnt + 3*n];
		}
	}
	return ans;
		

}

int brute(int a,int b){
	int res = 0;
	for(int i = a + 1; i <= b; i++){
		int x = i;
		int cnt_2 = 0,cnt_5 = 0, cnt_0 = 0;
		while(x){
			int z = x % 10;
			if(z == 0)cnt_0++;
			x /= 10;
			fun(cnt_5,cnt_2,z);
		}
		if(cnt_0)res++;
		else res += min(cnt_5,cnt_2);
	}
	return res;
}


int solve(){
 		string a = readStringSp(1,1000);
 		string b = readStringLn(1,1000);
 		checkInput(a,b);

 		mod_int ans = solve_25(b) - solve_25(a);

 		ans += (solve_zeroes(b) - solve_zeroes(a));

 		cout << ans  << endl;



 return 0;
}
signed main(){
    ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    //freopen("input.txt", "r", stdin);
    //freopen("output.txt", "w", stdout);
    #ifdef SIEVE
    sieve();
    #endif
    #ifdef NCR
    init();
    #endif
    int t = 1;
    while(t--){
        solve();
    }
    return 0;
}
Tester (satyam_343)'s code (C++)
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std; 
#define ll long long 
#define pb push_back               
#define mp make_pair        
#define nline "\n"                         
#define f first                                          
#define s second                                              
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()   
#define vl vector<ll>         
#define vvl vector<vector<ll>>    
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);  
#endif    
void _print(ll x){cerr<<x;}  
void _print(char x){cerr<<x;} 
void _print(string x){cerr<<x;}     
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count()); 
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=1e9+7;         
const ll MAX=300300;
ll calc(string s){
    ll n=s.size();
    s=" "+s;
    assert(n>=1 and n<=1000);
    vector<vector<ll>> ways(n+5,vector<ll>(2,0));
    vector<vector<ll>> aways(n+5,vector<ll>(2,0));
    vector<ll> power(n+5,1);
    vector<ll> apower(n+5,1);
    for(ll i=1;i<=n;i++){
        power[i]=(power[i-1]*10)%MOD;
        apower[i]=(apower[i-1]*9)%MOD;

    }
    for(ll i=1;i<=n;i++){
        ways[i][1]=1;
        ways[i][0]=power[n-i];
        aways[i][0]=apower[n-i];
        aways[i][1]=1;
        for(ll j=i+1;j<=n;j++){
            ll d=s[j]-'0';
            aways[i][1]&=(d!=0);
            ways[i][1]=(ways[i][1]+d*power[n-j])%MOD; 
        }
        for(ll j=i+1;j<=n;j++){
            ll d=s[j]-'0';  
            if(d==0){
                break;
            }  
            aways[i][1]=(aways[i][1]+(d-1)*apower[n-j])%MOD; 
        }
    }
    vector<vector<ll>> dp(2,vector<ll>(4*n+5,0));
    dp[1][n]=1;
    ll value=0;
    for(ll i=1;i<=n;i++){
        vector<vector<ll>> now(2,vector<ll>(4*n+5,0));
        ll d=s[i]-'0';
        if(i>1){
            dp[0][n]++; 
        }
        for(ll j=0;j<=1;j++){
            ll pos=-n;
            for(auto it:dp[j]){
                if(pos>3*n){
                    break;
                }
                if(pos==-n){
                    pos++; 
                    continue; 
                }
                for(ll cur=1;cur<=9;cur++){
                    if(j==1){
                        if(cur==d){
                            if(cur==5){
                                ll lft=pos;
                                now[1][lft-1+n]=(now[1][lft-1+n]+it)%MOD;
                                if(lft>0){
                                    value=(value+it*aways[i][1])%MOD;
                                }
                            }
                            else if(cur%2==0){
                                ll pw=1;
                                if(cur%4==0){
                                    pw++;
                                }
                                if(cur%8==0){
                                    pw++; 
                                }
                                ll lft=pos;
                                now[1][lft+pw+n]=(now[1][n+lft+pw]+it)%MOD;
                                value=(value+it*min(max(0LL,-lft),pw)*aways[i][1])%MOD;
                            }
                            else{
                                now[1][pos+n]=(now[1][pos+n]+it)%MOD;
                            }

                        }
                        else if(cur<d){
                            if(cur==5){
                                ll lft=pos;
                                now[0][lft-1+n]=(now[0][lft-1+n]+it)%MOD;
                                if(lft>0){
                                    value=(value+it*aways[i][0])%MOD;
                                }
                            }
                            else if(cur%2==0){
                                ll pw=1;
                                if(cur%4==0){
                                    pw++;
                                }
                                if(cur%8==0){
                                    pw++; 
                                }
                                now[0][pos+n+pw]=(now[0][pos+n+pw]+it)%MOD;
                                value=(value+it*min(max(0LL,-pos),pw)*aways[i][0])%MOD;
                            }
                            else{
                                now[0][pos+n]=(now[0][pos+n]+it)%MOD;
                            }

                        }

                    }  
                    else{
                        if(cur==5){
                            ll lft=pos;
                            now[0][lft-1+n]=(now[0][lft-1+n]+it)%MOD;
                            if(lft>0){
                                value=(value+it*aways[i][0])%MOD;
                            }
                        }
                        else if(cur%2==0){
                            ll pw=1+(cur%4==0)+(cur%8==0);
                            now[0][pos+n+pw]=(now[0][pos+n+pw]+it)%MOD;
                            value=(value+it*min(max(0LL,-pos),pw)*aways[i][0])%MOD;
                        }
                        else{
                            now[0][pos+n]=(now[0][pos+n]+it)%MOD;
                        }

                    }
                }  
                pos++;
            }    
        } 
        swap(dp,now);  
    }
    vector<ll> anot(2,0);    
    anot[1]=1;  
    for(ll i=1;i<=n;i++){   
        ll d=s[i]-'0';   
        vector<ll> now(2,0);  
        for(ll j=0;j<=1;j++){  
            value=(value+(i!=1)*anot[j]*ways[i][j&(d==0)])%MOD;
        }
        for(ll j=0;j<=1;j++){
            for(ll cur=1;cur<=9;cur++){
                if(j==1){
                    if(cur==d){
                        now[1]=(now[1]+anot[1])%MOD;
                    }
                    else if(cur<d){
                        now[0]=(now[0]+anot[1])%MOD;
                    }
                }
                else{
                    now[0]=(now[0]+anot[j])%MOD;
                }
            }
        }
        if(i>1){
            now[0]=(now[0]+9)%MOD; 
        }
        swap(anot,now);
    }
    value=(value+MOD)%MOD;
    return value;
}
ll sz(string s){
    ll n=s.size();
    return n;
}
void solve(){                    
    string a,b; cin>>a>>b;
    for(auto i:a){
        assert(i>='0' and i<='9');
    }
    for(auto i:b){
        assert(i>='0' and i<='9');
    }
    if(sz(a)==sz(b)){
        assert(a<b);
    }
    else if(sz(a)>sz(b)){
        assert(0);
    }
    ll ans=calc(b)-calc(a)+MOD;
    ans%=MOD;
    cout<<ans;
    return;
}                 
int main()                                                                               
{                     
    ios_base::sync_with_stdio(false);                           
    cin.tie(NULL);                          
    #ifndef ONLINE_JUDGE                 
    freopen("input.txt", "r", stdin);                                           
    freopen("output.txt", "w", stdout);  
    freopen("error.txt", "w", stderr);                        
    #endif    
    ll test_cases=1;                   
    //cin>>test_cases;
    while(test_cases--){   
        solve();     
    }
    cout<<fixed<<setprecision(10);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
} 
2 Likes

Just want to share my idea.

The basic idea is digit DP of course. While iterating from highest digit to lowest digit, we keep the number of 2s (denoted as p) and 5s (denoted as q) in the prefix of the given number. When we are at the i-th digit from the right:

  • If it is 0.
    No need to go forward anymore. All the numbers with this prefix will contribute 1. So just calculate the suffix modulo 998244353, add to answer and break.

  • Otherwise it is larger than 0.
    Suppose it is d. If we choose 0 at this digit, the contribution will be 10^i. If we choose larger than 0, update p and q. Notice that now we can choose any number of length i from 0 to 10^i-1 as suffix. The contribution is only related to the length of suffix i, the number of 2s p, and the number of 5s q. We denote this contribution as a function f(len, p, q).

The last part we haven’t counted is the contribution of numbers with length smaller than N. We can brute force each length from len(N) -1 to 1, and calculate the contribution. It’s similar with the previous case, however it does not allow 0 as prefix. We denote this contribution as g(len).

So, now the left problem is how to calculate f and g?
First, we find that g(len)=f(len,0,0) if we don’t allow prefix 0 in calculation of f. So we can only consider about f.
There are two parts of f(len,p,q).

  • The numbers contain 0. If we allow prefix 0, it’s 10^i-9^i, otherwise it is 9\times(10^{i-1}-9^{i-1}).

  • For the other cases, we brute force j from 0 to len as additionally choose j numbers of 5. Then the total number of 5 will become q'=q+j. Now let’s consider how to choose the number p' of 2.

    • If we choose p' >= q', each number’s contribution to the answer will be q'. The choices of such numbers are q'\Sigma_{k=q'-p}^{len-j}dp(len-j,k), where dp(len, k) means the number of choices if we choose extra k 2s in len numbers.

    • If we choose p'<q', each number’s contribution to the answer will be p'. The choices of such number are \Sigma_{k=0}^{q'-p-1}dp(len-j,k)\times (p+k).

How to precalculate dp:
The transition function is dp[len][p] = 4dp[len-1][p]+2dp[len-1][p-1]+dp[len-1][p-2]+dp[len-1][p-3], and dp[0][0]=1. That’s the contribution of [1,3,7,9],[2,6],[4],[8], respectively. Notice here we don’t consider about 5 because we have counted 5 for q.

Looks like we are done. However, the time complexity of above formulas is O(N^2) so the overall time complexity is O(N^3) which is not acceptable.
Now think about how to optimize it. We can find that, p'\Sigma_{k=q'-p}^{len-j}dp(len-j,k) is a range sum of dp(len-j, k), while \Sigma_{k=0}^{q'-p-1}dp(len-j,k)\times (p+k)=p\times \Sigma_{k=0}^{q'-p-1}dp(len-j,k)+\Sigma_{k=0}^{q'-p-1}k\times dp(len-j,k) is a combination of range sum of dp(len-j, k) and range sum of k\times dp(len-j,k). So we can get them in O(1) time with O(N^2) precalculation. Problem solved.

The overall time complexity is O(N^2).

(I’ve almost completed writing this but I see that aging1986 has a similar solution. Now this comment is pointless but I’ll post it anyway.)

Here’s my solution for calculating f(B) for the case where number of digits are equal (a similar but simpler procedure can be done when the number has fewer digits than that of B):

Iterate on the number of digits of the common prefix with B (call it d), the next digit, and the number occurrences of 5 in the suffix (call it c_5). The number of ways to assign the positions of 5 in the suffix is just a binomial coefficient. Precompute dp[i][j], the count of i digit numbers without 0 and 5 such that there are j 2s in the factorization of the product all digits.

There are two cases based on the total number of occurrences of 2 in the prime factorisation of the product (call it c_2):

  • If c_2 \ge c_5, the expression will involve a range sum query on dp[n - d- 1][j] where j belongs to some continuous range.
  • Otherwise, the expression will involve a range sum query on dp[n - d - 1][j]*j and a range sum query on dp[n - d - 1][j].

We can precompute for both of these types of range queries using prefix sums. Can also save some memory by not storing a row up to something like 3000 but up to only like 1000 by allowing 1000 to represent everything bigger than 1000 as well.