Editorial - PLIND

PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Setter: Trung Nguyen

Tester: Radoslav Dimitrov

Editorialist: Raja Vardhan Reddy

DIFFICULTY:

Medium

PREREQUISITES:

Dynamic programming

PROBLEM:

A string is almost palindromic if we can rearrange its characters in such a way that it becomes a palindrome. A positive integer is almost palindromic if its decimal representation without leading zeros is an almost palindromic string.

You are given two integers L and R. Find the number of almost palindromic integers between L and R (inclusive). Since this number could be very large, compute it modulo 10^9+7.

EXPLANATION

“A string is almost palindromic if we can rearrange its characters in such a way that it becomes a palindrome”, implies that for a string to be almost palindromic,if we count the number of occurances of each digit in the string, then atmost one digit occurs odd number of times and all other digits occur even number of times. So, the problem is equivalent to finding the number of integers between L and R which have atmost one digit occuring odd number of times.

Instead of solving the problem for (L,R), let’s solve for (1,R) and (1,L-1) and subtract them.

Solving for (1,R) :
Let length of an integer be the number of digits in its decimal representation without leading zeros. Let us represent length of integer X with len_X.

  • Let ans_l be number of integers with length l having atmost one digit occuring odd number of times.
  • Let res be the number of integers with length len_R which are <=R and have atmost one digit occuring odd number of times.

Let us call the required answer as req_R.
It can be calculated as:
req_R = \sum_{l=1} ^ { l< {len_R }} ans_l + res , because, \sum_{l=1} ^ { l< {len_R }} ans_l gives the number of almost palindromic integers which are of smaller length than R, and res give the number of almost palindromic integers which are <=R with length =len_R.

Computation of res:
Let S=“???..?”, having same length as R. We need to find number of ways to replace ‘?’ with with a digit, such that resulting S<=R and S has atmost one digit which occurs odd number of times.

Let’s fill the digits from most significant position, i.e from S_1
Let’s define dp[i][j][k] as number of ways to fill the substring S_iS_{i+1}..S_{len_R} such that atmost one digit occurs odd number of times in S, when we have already filled S_1,...S_{i-1} and there are j digits occuring odd number of times in S_1,...S_{i-1}, and k represents if there if any position where we filled a digit smaller than the corresponding digit in R ( If yes, k=1 , else k=0).

Base Case:

  • dp[len_R+1][j][k] = 1 if j<=1 :- we filled all the digits, and number of digits occuring odd number of times is <=1.
  • dp[len_R+1][j][k]=0 if j>1 :- we filled all the digits, and number of digits occuring odd number of times is >1.

This dp[i][j][k] can be calculated as:

  • If k=0, that means, all the digits we have filled till now are same as in R, so the current position can be filled with 0,1,...,R_i
    Let us define odd_i as number of digits occuring odd number of times in the substring R_1,...R_i , and small_i as number of digits which are smaller than R_i and are occuring odd number of times in the substring R_1,...R_{i-1}.
    If we fill with R_i, number of ways = dp[i+1][odd_i][k].
    If we fill one of the small_i digits, number of ways = small_i*dp[i+1][odd_{i-1}-1][1].
    If we fill any of the remaining digits, number of ways : (R_i-small_i)*dp[i+1][odd_{i-1}+1][1].
    Hence dp[i][j][k]=dp[i+1][odd_i][k]+small_i*dp[i+1][odd_{i-1}-1][1]+(R_i-small_i)*dp[i+1][odd_{i-1}+1][1].

  • If k=1, that means, there if a position where we filled a digit which is smaller that the corresponding one in R, so the current position can be filled with 0,1,...,9.
    And j digits out of these are occuring odd number of times, (10-j) are occuring even number of times.
    If we fill it with one of the j digits, then that digit now occurs even number of times, hence number ways = dp[i+1][j-1][k]*j.
    And if we fill with one of the (10-j) digits, then that digit now occurs odd number of times, hence number of ways= dp[i+1][j+1][k]*(10-j).
    Therefore, dp[i][j][k]= dp[i+1][j+1][k]*(10-j)+dp[i+1][j-1][k]*j.

Now res can be calculated as:
First digit in S can be filled with 1,2,...R_1.
If it is filled with R_1, number of ways= dp[2][1][0].
If it is filled with other digits, number of ways = (R_1-1)*dp[2][1][1].
Hence, res = dp[2][1][0]+(R_1-1)*dp[2][1][1].

Computation of ans_l:
We have S=“???..??” of length l, we need to fill it with digits such that atmost one digit occurs odd number of times. For this, the above dp can be used, with k=1 since, this number is always smaller than R.
First digit of S can be 1,2,...9. Therefore, ans_l =9*dp[len_R+2-l][1][1].

Therefore,
Total number of almost palindromic integers between 1 to R = (\sum_{l=1}^{l<len_R} 9*dp[len_R+2-l][1][1])+dp[2][1][0]+(R_1-1)*dp[2][1][1].
Similarly number of almost palindromic integers between 1 and L-1 can be calculated.

TIME COMPLEXITY:

Let d be the total number of digits ( d=10 ).

Solving for 1 to R:
Computation of odd_i ,small_i : O(d*len_R)
Computation of dp : There are (len_R)*d*2 dp states, and value of each state is computed in O(1). Therefore, Total time :O(len_R*d)
Computation of ans_l:O(1) for each l, Therefore, Total time: O(len_R) for all l.

Solving for 1 to L-1:
Computation of odd_i ,small_i : O(d*len_{L-1})
Computation of dp : There are (len_{L-1})*d*2 dp states, and value of each state is computed in O(1). Therefore, Total time :O(len_{L-1}*d)
Computation of ans_l:O(1) for each l, Therefore, Total time: O(len_{L-1}) for all l.

Total time complexity : O(d*len_R)+O(d*len_{L-1}) =O(d*len_R) (Assuming len_{L-1}=len_R) for each test case.

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
using namespace std;
//#pragma GCC optimize("Ofast")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
 
#define ms(s, n) memset(s, n, sizeof(s))
#define FOR(i, a, b) for (int i = (a); i < (b); ++i)
#define FORd(i, a, b) for (int i = (a) - 1; i >= (b); --i)
#define FORall(it, a) for (__typeof((a).begin()) it = (a).begin(); it != (a).end(); it++)
#define sz(a) int((a).size())
#define present(t, x) (t.find(x) != t.end())
#define all(a) (a).begin(), (a).end()
#define uni(a) (a).erase(unique(all(a)), (a).end())
#define pb push_back
#define pf push_front
#define mp make_pair
#define fi first
#define se second
#define prec(n) fixed<<setprecision(n)
#define bit(n, i) (((n) >> (i)) & 1)
#define bitcount(n) __builtin_popcountll(n)
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair<int, int> pi;
typedef vector<int> vi;
typedef vector<pi> vii;
const int MOD = (int) 1e9 + 7;
const int FFTMOD = 119 << 23 | 1;
const int INF = (int) 1e9 + 23111992;
const ll LINF = (ll) 1e18 + 23111992;
const ld PI = acos((ld) -1);
const ld EPS = 1e-9;
inline ll gcd(ll a, ll b) {ll r; while (b) {r = a % b; a = b; b = r;} return a;}
inline ll lcm(ll a, ll b) {return a / gcd(a, b) * b;}
inline ll fpow(ll n, ll k, int p = MOD) {ll r = 1; for (; k; k >>= 1) {if (k & 1) r = r * n % p; n = n * n % p;} return r;}
template<class T> inline int chkmin(T& a, const T& val) {return val < a ? a = val, 1 : 0;}
template<class T> inline int chkmax(T& a, const T& val) {return a < val ? a = val, 1 : 0;}
inline ull isqrt(ull k) {ull r = sqrt(k) + 1; while (r * r > k) r--; return r;}
inline ll icbrt(ll k) {ll r = cbrt(k) + 1; while (r * r * r > k) r--; return r;}
inline void addmod(int& a, int val, int p = MOD) {if ((a = (a + val)) >= p) a -= p;}
inline void submod(int& a, int val, int p = MOD) {if ((a = (a - val)) < 0) a += p;}
inline int mult(int a, int b, int p = MOD) {return (ll) a * b % p;}
inline int inv(int a, int p = MOD) {return fpow(a, p - 2, p);}
inline int sign(ld x) {return x < -EPS ? -1 : x > +EPS;}
inline int sign(ld x, ld y) {return sign(x - y);}
mt19937 mt(chrono::high_resolution_clock::now().time_since_epoch().count());
inline int mrand() {return abs((int) mt());}
#define db(x) cerr << "[" << #x << ": " << (x) << "] ";
#define endln cerr << "\n";
 
const int maxn = 1e6 + 5;
int dp[maxn][10 + 1];
 
int check(int k) {
    int s = 0;
    while (k) {
        s ^= 1 << k % 10;
        k /= 10;
    }
    return bitcount(s) <= 1;
}
 
int bruteforce(string s) {
    int n = 0;
    for (char c : s) {
        n = n * 10 + c - '0';
    }
    int res = 0;
    FOR(i, 0, n + 1) {
        res += check(i);
    }
    return res;
}
 
void chemthan() {
    dp[0][0] = 1;
    FOR(i, 1, maxn) {
        FOR(j, 0, 10 + 1) {
            if (j) {
                addmod(dp[i][j], mult(dp[i - 1][j - 1], 11 - j));
            }
            if (j < 10) {
                addmod(dp[i][j], mult(dp[i - 1][j + 1], j + 1));
            }
        }
    }
    static int c[20][20];
    FOR(i, 0, 20) c[0][i] = 1;
    FOR(i, 1, 20) FOR(j, 1, 20) c[i][j] = (c[i][j - 1] + c[i - 1][j - 1]) % MOD;
    FOR(i, 0, 20) FOR(j, 0, 20) c[i][j] = inv(c[i][j]);
    FOR(i, 0, maxn) {
        FOR(j, 0, 10 + 1) {
            dp[i][j] = mult(dp[i][j], c[j][10]);
        }
    }
    auto calc = [&] (string s) {
        int res = 0;
        int n = sz(s), t = 0;
        FOR(i, 0, n) {
            int c = s[i] - '0';
            FOR(j, 0, c) {
                int tt = t;
                if (i + j) tt ^= 1 << j;
                if (!(i + j)) {
                    if (n == 1) {
                        addmod(res, 1);
                    }
                    FOR(k, 1, n) {
                        if (k == 1 && bitcount(tt ^ (1 << 0)) <= 1) {
                            addmod(res, 1);
                        }
                        FOR(l, 1, 9 + 1) {
                            int d = bitcount(tt ^ (1 << l));
                            addmod(res, dp[k - 1][d]);
                            if (d) {
                                addmod(res, mult(dp[k - 1][d - 1], d));
                            }
                            if (d < 10) {
                                addmod(res, mult(dp[k - 1][d + 1], 10 - d));
                            }
                        }
                    }
                }
                else {
                    int d = bitcount(tt);
                    int k = n - i - 1;
                    addmod(res, dp[k][d]);
                    if (d) {
                        addmod(res, mult(dp[k][d - 1], d));
                    }
                    if (d < 10) {
                        addmod(res, mult(dp[k][d + 1], 10 - d));
                    }
                }
            }
            t ^= 1 << c;
        }
        if (bitcount(t) <= 1) {
            addmod(res, 1);
        }
        if (sz(s) <= 6) {
            assert(res == bruteforce(s));
        }
        return res;
    };
    auto normalize = [&] (string& s) {
        reverse(all(s));
        while (1 < sz(s) && s.back() == '0') s.pop_back();
        reverse(all(s));
    };
    auto sub = [&] (string& s) {
        FORd(i, sz(s), 0) {
            int d = s[i] - '0';
            if (d) {
                s[i] = '0' + d - 1;
                break;
            }
            else {
                s[i] = '9';
            }
        }
    };
    int test; cin >> test;
    while (test--) {
        string l, r; cin >> l >> r;
        sub(l);
        normalize(l), normalize(r);
        int res = calc(r);
        submod(res, calc(l));
        cout << res << "\n";
    }
}
 
int main(int argc, char* argv[]) {
    ios_base::sync_with_stdio(0), cin.tie(0);
    if (argc > 1) {
        assert(freopen(argv[1], "r", stdin));
    }
    if (argc > 2) {
        assert(freopen(argv[2], "wb", stdout));
    }
    chemthan();
    cerr << "\nTime elapsed: " << 1000 * clock() / CLOCKS_PER_SEC << "ms\n";
    return 0;
}
Tester's Solution
#include <bits/stdc++.h>
#define endl '\n'
 
#define SZ(x) ((int)x.size())
#define ALL(V) V.begin(), V.end()
#define L_B lower_bound
#define U_B upper_bound
#define pb push_back
 
using namespace std;
template<class T, class T1> int chkmin(T &x, const T1 &y) { return x > y ? x = y, 1 : 0; }
template<class T, class T1> int chkmax(T &x, const T1 &y) { return x < y ? x = y, 1 : 0; }
const int MAXN = (1 << 20);
const int mod = (int)1e9 + 7;
 
int cnt[MAXN];
int rec(int n, int k);
 
string _L, _R;
vector<int> L, R;
int pw10[MAXN], ans[MAXN];
 
vector<int> create_vec(string s) {
	vector<int> ret;
	for(char c: s) {
		ret.pb(c - '0');
	}
	return ret;
}
 
void read() {
	cin >> _L >> _R;
	L = create_vec(_L);
	R = create_vec(_R);
}
 
int popcnt[1 << 10];
 
inline void fix(int &x) { 
	if(x >= mod) x -= mod;
}
 
int eval(vector<int> X) {
	if(X == vector<int>(1, 0)) {
		return 0;
	}
 
	int ret = ans[SZ(X) - 1];
	int mask = 0;
	for(int i = 0; i < SZ(X); i++) {
		for(int d = (i == 0); d < X[i]; d++) {
			ret += rec(SZ(X) - i - 1, popcnt[mask ^ (1 << d)]);
			fix(ret);
		}
	
		mask ^= (1 << X[i]);
	}
 
	ret += rec(0, popcnt[mask]);
	fix(ret);
	return ret;
}
 
void solve() {
	L[SZ(L) - 1]--;
	int pos = SZ(L) - 1;
	while(L[pos] < 0) {
		L[pos - 1]--;
		L[pos] += 10;
		pos--;
	}
 
	reverse(ALL(L));
	while(SZ(L) > 1 && L.back() == 0) { 
		L.pop_back();
	}
 
	reverse(ALL(L));
 
	cout << (eval(R) - eval(L) + mod) % mod << endl;
}
 
int dp[MAXN][11];
 
int rec(int n, int odd_cnt) {
	if(n == 0) {
		return odd_cnt <= 1;
	}
 
	int &memo = dp[n][odd_cnt];
	if(memo != -1) {
		return memo;
	}
 
	memo = 0;
	if(odd_cnt) {
		memo = (memo + odd_cnt * 1ll * rec(n - 1, odd_cnt - 1)) % mod;
	}
 
	if(odd_cnt < 10) {
		memo = (memo + (10 - odd_cnt) * 1ll * rec(n - 1, odd_cnt + 1)) % mod; 
	}
 
	return memo;
}
 
void precompute(int MX) {
	memset(dp, -1, sizeof(dp));
 
	pw10[0] = 1;
	ans[0] = 0; 
	for(int i = 1; i <= MX; i++) {
		pw10[i] = pw10[i - 1] * 10ll % mod;
		ans[i] = (ans[i - 1] + 9ll * rec(i - 1, 1)) % mod;
	}
 
	popcnt[0] = 0;
	for(int i = 1; i < (1 << 10); i++) {
		popcnt[i] = popcnt[i >> 1] + (i & 1);
	}
}
 
int main() {
	ios_base::sync_with_stdio(false);
	cin.tie(NULL);
 
	precompute((int)1e6 + 1);
 
	int T;
	cin >> T;
	while(T--) {
		read();
		solve();
	}
 
	return 0;
}
Editorialist's Solution
//raja1999
 
//#pragma comment(linker, "/stack:200000000")
//#pragma GCC optimize("Ofast")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,avx,avx2")
 
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string> 
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip> 
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp> 
//setbase - cout << setbase (16)a; cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 <<endl;prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val
 
using namespace std;
using namespace __gnu_pbds;
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define int ll
 
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;
 
 
//std::ios::sync_with_stdio(false);
 
int cnt[10];
int check(string s){
	int i,c=0;
	for(i=0;i<10;i++){
		cnt[i]=0;
	}
	for(i=0;i<s.length();i++){
		cnt[s[i]-'0']++;
	}
	for(i=0;i<10;i++){
		if(cnt[i]%2){
			c++;
		}
	}
	if(c>1){
		return 0;
	}
	return 1;
}
 
int small[1000005],odd[1000005],dp[1000005][2][11],len;
int solve(string s){
	int i,j,k;
	for(i=0;i<10;i++){
		cnt[i]=0;
	}
	len=s.length();
	rep(i,s.length()){
		odd[i+1]=0;
		small[i]=0;
		for(j=0;j<10;j++){
			if(j<(s[i]-'0')){
				small[i]+=(cnt[j]%2);
			}
		}
		cnt[s[i]-'0']++;
		for(j=0;j<10;j++){
			odd[i+1]+=(cnt[j]%2);
		}
	}
	for(i=0;i<2;i++){
		for(j=0;j<=10;j++){
			dp[len][i][j]=0;
			if(j<=1){ 		
				dp[len][i][j]=1;
			}
		}
	}
	ll val=0;
	for(i=len-1;i>=0;i--){
		for(j=0;j<2;j++){
			for(k=0;k<11;k++){
				val=0;
				if(j==0){
					val=dp[i+1][j][odd[i+1]];
					if(odd[i]!=10)
						val+=(s[i]-'0'-small[i])*dp[i+1][1][odd[i]+1];
					if(odd[i]!=0)
						val+=(small[i])*dp[i+1][1][odd[i]-1];
				}
				else{
					if(k<10){
						val=(10-k)*dp[i+1][1][k+1];
					}
					if(k!=0){
						val+=k*dp[i+1][1][k-1];
					}
				}
				val%=mod;
				dp[i][j][k]=val;
			}
		}
	}
	int c=s[0]-'0';
	val=dp[1][1][1]*(c-1);
	val+=dp[1][0][1];
	for(i=1;i<s.length();i++){
		val+=dp[i+1][1][1]*9LL;
	}
	val%=mod;
	return val;
}
main(){
	std::ios::sync_with_stdio(false); cin.tie(NULL);
	int t,t1;
	cin>>t;
	// t=1;
	t1=t;
	while(t--){
		string l,r;
		cin>>l>>r;
		int ansr,ansl,checkl,res;
		ansr=solve(r);
		ansl=solve(l);
		checkl=check(l);
		res=ansr-ansl+checkl;
		res%=mod;
		res+=mod;
		res%=mod;
		cout<<res<<endl;
	}
	return 0;
} 

Feel free to Share your approach, If it differs. Suggestions are always welcomed. :slight_smile:

I did the exactly same thing as given in editorial but i’m getting tle . I have used an extra state in dp which is st to keep track of leading zeroes. Pls somebody suggest something to remove tle . I 'm stuck.
Link to my code— CodeChef: Practical coding for everyone