PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: a_18o3
Tester: tabr
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Combinatorics, inclusion-exclusion or dynamic programming
PROBLEM:
You’re given two integers L and R in base B, both of which have length N representations.
Count the number of integers between L and R such that they have exactly K digits in base B.
EXPLANATION:
Let’s simplify the setup slightly: we’ll only count numbers \lt R that are N digits in base B and have K distinct digits.
Let A = [A_1, A_2, \ldots, A_N] be an N-digit (in base B) integer that’s \lt R.
Then, there must exist an index i such that:
- A_j = R_j for j \lt i
- A_i \lt R_i
Let’s instead fix this first differing index i, and count the number of valid A.
First, we know that the first i-1 digits of A match those of R, so certainly all of those digits will exist.
Then, we also need to fix A_i to be something strictly less than R_i. There are \mathcal{O}(B) choices for what A_i is.
Once A_i is fixed, note that all the digits at positions \gt i can be freely chosen: the only constraint is that there are K distinct digits overall.
So, suppose there are d distinct digits once A_i is fixed.
We then need to choose another K-d digits from the B-d unused ones, and arrange them in the remaining N-i positions such that the newly chosen digits all occur at least once.
The number of such arrangements can be found in \mathcal{O}(B) time using inclusion-exclusion.
How?
First, we choose the new digits: there are \binom{K-d}{B-d} choices from the unused ones.
If we didn’t care about using the new digits at least once, we’d simply have K choices for each of the N-i positions, for K^{N-i} choices in total.
From here, we want to subtract the number of configurations in which some of the new K-d digits don’t appear.
For convenience, let’s label the new digits 1, 2, 3, \ldots, K-d.
Let S_i denote the set of configurations that don’t contain digit i.
We want to compute the size of the union of all the S_i, that is, |S_1\cup S_2\cup\ldots\cup S_{K-d}|
The inclusion-exclusion principle tells us that
This follows from the fact that for a fixed set of x of these digits, the number of configurations that don’t contain at least this set is exactly (K-x)^{N-i}, and there are \binom{K-d}{x} ways to choose this subset of size x.
Since K-d \leq B, this sum is easily found in \mathcal{O}(B \log N) time, or even \mathcal{O}(B) if you precompute powers.
It’s also possible to use dynamic programming, as can be seen in the tester’s code below.
We now have a solution in \mathcal{O}(NB^2): fix the prefix, fix the smaller digit, and compute the number of arrangements of the remaining part using inclusion-exclusion.
To further improve this, observe that while there are \mathcal{O}(B) choices for the smaller digit A_i, its actual value doesn’t affect the following computation: the only thing that matters is whether it has occurred in the prefix before or not (so whether it increases d by 1 or not).
So, instead of trying every choice of digit, we compute the number of configurations for both the cases when A_i has and hasn’t appeared before, and then multiply them by the appropriate counts of digits.
This knocks off a factor of B from the complexity, and we have \mathcal{O}(N\cdot B), which is fast enough.
To finish, simply apply the above solution twice: once to count valid numbers \lt R, and once to count valid numbers \lt L.
Subtract the latter from the former, and add 1 if R itself has exactly K digits.
TIME COMPLEXITY:
\mathcal{O}(N\cdot B) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
const int facN = 1e6 + 5;
const int mod = 1e9 + 7; // 998244353
int ff[facN], iff[facN];
bool facinit = false;
int power(int x, int y){
if (y == 0) return 1;
int v = power(x, y / 2);
v = 1LL * v * v % mod;
if (y & 1) return 1LL * v * x % mod;
else return v;
}
void factorialinit(){
facinit = true;
ff[0] = iff[0] = 1;
for (int i = 1; i < facN; i++){
ff[i] = 1LL * ff[i - 1] * i % mod;
}
iff[facN - 1] = power(ff[facN - 1], mod - 2);
for (int i = facN - 2; i >= 1; i--){
iff[i] = 1LL * iff[i + 1] * (i + 1) % mod;
}
}
int C(int n, int r){
if (!facinit) factorialinit();
if (n == r) return 1;
if (r < 0 || r > n) return 0;
return 1LL * ff[n] * iff[r] % mod * iff[n - r] % mod;
}
int P(int n, int r){
if (!facinit) factorialinit();
assert(0 <= r && r <= n);
return 1LL * ff[n] * iff[n - r] % mod;
}
int Solutions(int n, int r){
//solutions to x1 + ... + xn = r
//xi >= 0
return C(n + r - 1, n - 1);
}
void Solve()
{
int n, b, k; cin >> n >> b >> k;
vector <int> l(n), r(n);
for (auto &x : l) cin >> x;
for (auto &x : r) cin >> x;
int f = n;
for (int i = 0; i < n; i++){
if (l[i] != r[i]){
f = i;
break;
}
}
int ans = 0;
for (int i = f - 1; i < n; i++){
// cout << "STEP " << i << "\n";
// 0...i same as l
vector <bool> hv(b, false);
for (int j = 0; j <= i; j++){
hv[l[j]] = true;
}
int u = 0;
for (int j = 0; j < b; j++){
if (hv[j]){
u++;
}
}
// cout << u << "\n";
if (i == n - 1){
if (u == k)
ans++;
continue;
}
int w1 = 0, w0 = 0;
for (int j = 0; j < b; j++){
// needs to be > l[i + 1]
if (j <= l[i + 1]) continue;
if (i == f - 1 && j >= r[i + 1]) continue;
if (hv[j]) w0++;
else w1++;
}
// cout << w0 << " " << w1 << "\n";
// you have u initially, how many ways to get x more?
// how many ways to get 1 more?
{
int un = b - u;
int pos = n - i - 2;
int okie = 0;
int need = k - u;
for (int t = 0; t <= need; t++){
int ways = C(need, t) * power(u + need - t, pos) % mod;
// cout << ways << " \n"[t == need];
if (t & 1) okie -= ways;
else okie += ways;
}
okie %= mod;
if (okie < 0) okie += mod;
okie *= C(un, need);
okie %= mod;
// cout << "OKIE " << okie << " " << w0 << "\n";
ans += okie * w0 % mod;
}
u++;
{
int un = b - u;
int pos = n - i - 2;
int okie = 0;
int need = k - u;
for (int t = 0; t <= need; t++){
int ways = C(need, t) * power(u + need - t, pos) % mod;
// cout << ways << " \n"[t == need];
if (t & 1) okie -= ways;
else okie += ways;
}
okie %= mod;
if (okie < 0) okie += mod;
okie *= C(un, need);
okie %= mod;
// cout << "OKIE " << okie << " " << w1 << "\n";
ans += okie * w1 % mod;
}
// cout << "ANS " << ans << "\n";
}
// cout << ans << "\n";
for (int i = f; i < n; i++){
vector <bool> hv(b, false);
for (int j = 0; j <= i; j++){
hv[r[j]] = true;
}
int u = 0;
for (int j = 0; j < b; j++){
if (hv[j]){
u++;
}
}
if (i == n - 1){
if (u == k)
ans++;
continue;
}
int w1 = 0, w0 = 0;
for (int j = 0; j < b; j++){
// needs to be > l[i + 1]
if (j >= r[i + 1]) continue;
if (hv[j]) w0++;
else w1++;
}
// you have u initially, how many ways to get x more?
// how many ways to get 1 more?
{
int un = b - u;
int pos = n - i - 2;
int okie = 0;
int need = k - u;
for (int t = 0; t <= need; t++){
int ways = C(need, t) * power(u + need - t, pos) % mod;
if (t & 1) okie -= ways;
else okie += ways;
}
okie %= mod;
if (okie < 0) okie += mod;
okie *= C(un, need);
okie %= mod;
ans += okie * w0 % mod;
}
u++;
{
int un = b - u;
int pos = n - i - 2;
int okie = 0;
int need = k - u;
for (int t = 0; t <= need; t++){
int ways = C(need, t) * power(u + need - t, pos) % mod;
if (t & 1) okie -= ways;
else okie += ways;
}
okie %= mod;
if (okie < 0) okie += mod;
okie *= C(un, need);
okie %= mod;
ans += okie * w1 % mod;
}
// cout << "ANS " << ans << "\n";
}
ans %= mod;
cout << ans << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
constexpr long long mod = (int) 1e9 + 7;
long long Calc(vector<int> a, int n, int b, int k) {
// dp[seen first i digits][number of distinct digits is j]
vector dp(n + 1, vector(n + 1, 0LL));
vector digit_flag(b + 1, false);
int digit_count = 0;
for (int i = 0; i < n; i++) {
for (int d = 0; d < a[i]; d++) {
if (digit_flag[d]) {
dp[i + 1][digit_count] += 1;
dp[i + 1][digit_count] %= mod;
} else {
dp[i + 1][digit_count + 1] += 1;
dp[i + 1][digit_count + 1] %= mod;
}
}
for (int j = 0; j < n; j++) {
dp[i + 1][j] += dp[i][j] * j;
dp[i + 1][j + 1] += dp[i][j] * (b - j);
dp[i + 1][j] %= mod;
dp[i + 1][j + 1] %= mod;
}
if (!digit_flag[a[i]]) {
digit_flag[a[i]] = true;
digit_count++;
}
}
return dp[n][k];
}
int main() {
int tt;
cin >> tt;
while (tt--) {
int n, b, k;
cin >> n >> b >> k;
vector<int> l(n), r(n);
for (int i = 0; i < n; i++) {
cin >> l[i];
}
for (int i = 0; i < n; i++) {
cin >> r[i];
}
long long ans = Calc(r, n, b, k) - Calc(l, n, b, k);
if ((int) set<int>(r.begin(), r.end()).size() == k) {
ans++;
}
ans = (ans % mod + mod) % mod;
cout << ans << '\n';
}
return 0;
}
Editorialist's code (Python)
mod = 10**9 + 7
maxN = 2005
C = [ [0 for _ in range(maxN)] for _ in range(maxN)]
pw = [ [0 for _ in range(maxN)] for _ in range(maxN)]
def f(n, x, y): # length n, x choices per spot, y of them should definitely occur
res, sgn = 0, 1
for i in range(y+1):
res += sgn * C[y][i] * pw[x-i][n] % mod
sgn *= -1
return res % mod
def calc(N, b, k):
if N[0] == 0: return 0
n = len(N)
ans = distinct = 0
mark = [0]*b
for i in range(n): # go down here
used, unused = 0, 0
for d in range(N[i]):
if i > 0 or d > 0:
used += mark[d]
unused += 1 - mark[d]
ans += used * C[b-distinct][k-distinct] * f(n-1-i, k, k-distinct) % mod
if distinct < k:
ans += unused * C[b-distinct-1][k-distinct-1] * f(n-1-i, k, k-distinct-1) % mod
if mark[N[i]] == 0: distinct += 1
mark[N[i]] = 1
if distinct > k: break
return ans % mod
for n in range(maxN):
C[n][0] = 1
for r in range(1, n+1): C[n][r] = (C[n-1][r] + C[n-1][r-1]) % mod
for x in range(1, maxN):
pw[x][0] = 1
for i in range(1, maxN): pw[x][i] = pw[x][i-1] * x % mod
for _ in range(int(input())):
n, b, k = map(int, input().split())
L = list(map(int, input().split()))
R = list(map(int, input().split()))
res = calc(R, b, k) - calc(L, b, k)
if len(set(R)) == k: res += 1
print(res % mod)