COUNTBASEB - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: a_18o3
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Combinatorics, inclusion-exclusion or dynamic programming

PROBLEM:

You’re given two integers L and R in base B, both of which have length N representations.
Count the number of integers between L and R such that they have exactly K digits in base B.

EXPLANATION:

Let’s simplify the setup slightly: we’ll only count numbers \lt R that are N digits in base B and have K distinct digits.

Let A = [A_1, A_2, \ldots, A_N] be an N-digit (in base B) integer that’s \lt R.
Then, there must exist an index i such that:

  • A_j = R_j for j \lt i
  • A_i \lt R_i

Let’s instead fix this first differing index i, and count the number of valid A.
First, we know that the first i-1 digits of A match those of R, so certainly all of those digits will exist.
Then, we also need to fix A_i to be something strictly less than R_i. There are \mathcal{O}(B) choices for what A_i is.
Once A_i is fixed, note that all the digits at positions \gt i can be freely chosen: the only constraint is that there are K distinct digits overall.

So, suppose there are d distinct digits once A_i is fixed.
We then need to choose another K-d digits from the B-d unused ones, and arrange them in the remaining N-i positions such that the newly chosen digits all occur at least once.
The number of such arrangements can be found in \mathcal{O}(B) time using inclusion-exclusion.

How?

First, we choose the new digits: there are \binom{K-d}{B-d} choices from the unused ones.

If we didn’t care about using the new digits at least once, we’d simply have K choices for each of the N-i positions, for K^{N-i} choices in total.
From here, we want to subtract the number of configurations in which some of the new K-d digits don’t appear.

For convenience, let’s label the new digits 1, 2, 3, \ldots, K-d.
Let S_i denote the set of configurations that don’t contain digit i.
We want to compute the size of the union of all the S_i, that is, |S_1\cup S_2\cup\ldots\cup S_{K-d}|

The inclusion-exclusion principle tells us that

|S_1\cup S_2\cup\ldots\cup S_{K-d}| = \sum_{x=1}^{K-d} (-1)^{x+1} \binom{K-d}{x} (K-x)^{N-i}

This follows from the fact that for a fixed set of x of these digits, the number of configurations that don’t contain at least this set is exactly (K-x)^{N-i}, and there are \binom{K-d}{x} ways to choose this subset of size x.

Since K-d \leq B, this sum is easily found in \mathcal{O}(B \log N) time, or even \mathcal{O}(B) if you precompute powers.

It’s also possible to use dynamic programming, as can be seen in the tester’s code below.


We now have a solution in \mathcal{O}(NB^2): fix the prefix, fix the smaller digit, and compute the number of arrangements of the remaining part using inclusion-exclusion.

To further improve this, observe that while there are \mathcal{O}(B) choices for the smaller digit A_i, its actual value doesn’t affect the following computation: the only thing that matters is whether it has occurred in the prefix before or not (so whether it increases d by 1 or not).
So, instead of trying every choice of digit, we compute the number of configurations for both the cases when A_i has and hasn’t appeared before, and then multiply them by the appropriate counts of digits.

This knocks off a factor of B from the complexity, and we have \mathcal{O}(N\cdot B), which is fast enough.


To finish, simply apply the above solution twice: once to count valid numbers \lt R, and once to count valid numbers \lt L.
Subtract the latter from the former, and add 1 if R itself has exactly K digits.

TIME COMPLEXITY:

\mathcal{O}(N\cdot B) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

const int facN = 1e6 + 5;
const int mod = 1e9 + 7; // 998244353
int ff[facN], iff[facN];
bool facinit = false;

int power(int x, int y){
    if (y == 0) return 1;

    int v = power(x, y / 2);
    v = 1LL * v * v % mod;

    if (y & 1) return 1LL * v * x % mod;
    else return v;
}

void factorialinit(){
    facinit = true;
    ff[0] = iff[0] = 1;

    for (int i = 1; i < facN; i++){
        ff[i] = 1LL * ff[i - 1] * i % mod;
    }

    iff[facN - 1] = power(ff[facN - 1], mod - 2);
    for (int i = facN - 2; i >= 1; i--){
        iff[i] = 1LL * iff[i + 1] * (i + 1) % mod;
    }
}

int C(int n, int r){
    if (!facinit) factorialinit();

    if (n == r) return 1;

    if (r < 0 || r > n) return 0;
    return 1LL * ff[n] * iff[r] % mod * iff[n - r] % mod;
}

int P(int n, int r){
    if (!facinit) factorialinit();

    assert(0 <= r && r <= n);
    return 1LL * ff[n] * iff[n - r] % mod;
}

int Solutions(int n, int r){
    //solutions to x1 + ... + xn = r 
    //xi >= 0

    return C(n + r - 1, n - 1);
}

void Solve() 
{
    int n, b, k; cin >> n >> b >> k;

    vector <int> l(n), r(n);

    for (auto &x : l) cin >> x;
    for (auto &x : r) cin >> x;

    int f = n;
    for (int i = 0; i < n; i++){
        if (l[i] != r[i]){
            f = i;
            break;
        }
    }

    int ans = 0;

    for (int i = f - 1; i < n; i++){
        // cout << "STEP " << i << "\n";
        // 0...i same as l 
        vector <bool> hv(b, false);
        for (int j = 0; j <= i; j++){
            hv[l[j]] = true;
        }

        int u = 0;
        for (int j = 0; j < b; j++){
            if (hv[j]){
                u++;
            }
        }
        
        // cout << u << "\n";

        if (i == n - 1){
            if (u == k)
            ans++;
            continue;
        }

        int w1 = 0, w0 = 0;
        for (int j = 0; j < b; j++){
            // needs to be > l[i + 1] 
            if (j <= l[i + 1]) continue;
            if (i == f - 1 && j >= r[i + 1]) continue;
            if (hv[j]) w0++;
            else w1++;
        }
        
        // cout << w0 << " " << w1 << "\n";

        // you have u initially, how many ways to get x more? 
        // how many ways to get 1 more? 
        {
            int un = b - u;
            int pos = n - i - 2;
            int okie = 0;
            int need = k - u;
            
            for (int t = 0; t <= need; t++){
                int ways = C(need, t) * power(u + need - t, pos) % mod;
                // cout << ways << " \n"[t == need];
                if (t & 1) okie -= ways;
                else okie += ways;
            }

            okie %= mod;
            if (okie < 0) okie += mod;

            okie *= C(un, need);
            okie %= mod;
            
            // cout << "OKIE " << okie << " " << w0 << "\n";

            ans += okie * w0 % mod;
        }

        u++;

        {
            int un = b - u;
            int pos = n - i - 2;
            int okie = 0;
            int need = k - u;
            
            for (int t = 0; t <= need; t++){
                int ways = C(need, t) * power(u + need - t, pos) % mod;
                // cout << ways << " \n"[t == need];
                if (t & 1) okie -= ways;
                else okie += ways;
            }

            okie %= mod;
            if (okie < 0) okie += mod;

            okie *= C(un, need);
            okie %= mod;
            
            // cout << "OKIE " << okie << " " << w1 << "\n";

            ans += okie * w1 % mod;
        }
        // cout << "ANS " << ans << "\n";
    }
    
    // cout << ans << "\n";
    
    for (int i = f; i < n; i++){
        vector <bool> hv(b, false);
        for (int j = 0; j <= i; j++){
            hv[r[j]] = true;
        }

        int u = 0;
        for (int j = 0; j < b; j++){
            if (hv[j]){
                u++;
            }
        }

        if (i == n - 1){
            if (u == k)
            ans++;
            continue;
        }

        int w1 = 0, w0 = 0;
        for (int j = 0; j < b; j++){
            // needs to be > l[i + 1] 
            if (j >= r[i + 1]) continue;
            
            if (hv[j]) w0++;
            else w1++;
        }

        // you have u initially, how many ways to get x more? 
        // how many ways to get 1 more? 
        {
            int un = b - u;
            int pos = n - i - 2;
            int okie = 0;
            int need = k - u;
            
            for (int t = 0; t <= need; t++){
                int ways = C(need, t) * power(u + need - t, pos) % mod;
                if (t & 1) okie -= ways;
                else okie += ways;
            }

            okie %= mod;
            if (okie < 0) okie += mod;

            okie *= C(un, need);
            okie %= mod;

            ans += okie * w0 % mod;
        }

        u++;

        {
            int un = b - u;
            int pos = n - i - 2;
            int okie = 0;
            int need = k - u;
            
            for (int t = 0; t <= need; t++){
                int ways = C(need, t) * power(u + need - t, pos) % mod;
                if (t & 1) okie -= ways;
                else okie += ways;
            }

            okie %= mod;
            if (okie < 0) okie += mod;

            okie *= C(un, need);
            okie %= mod;

            ans += okie * w1 % mod;
        }
        
        // cout << "ANS " << ans << "\n";
    }

    ans %= mod;
    cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;

constexpr long long mod = (int) 1e9 + 7;

long long Calc(vector<int> a, int n, int b, int k) {
    // dp[seen first i digits][number of distinct digits is j]
    vector dp(n + 1, vector(n + 1, 0LL));
    vector digit_flag(b + 1, false);
    int digit_count = 0;
    for (int i = 0; i < n; i++) {
        for (int d = 0; d < a[i]; d++) {
            if (digit_flag[d]) {
                dp[i + 1][digit_count] += 1;
                dp[i + 1][digit_count] %= mod;
            } else {
                dp[i + 1][digit_count + 1] += 1;
                dp[i + 1][digit_count + 1] %= mod;
            }
        }
        for (int j = 0; j < n; j++) {
            dp[i + 1][j] += dp[i][j] * j;
            dp[i + 1][j + 1] += dp[i][j] * (b - j);
            dp[i + 1][j] %= mod;
            dp[i + 1][j + 1] %= mod;
        }
        if (!digit_flag[a[i]]) {
            digit_flag[a[i]] = true;
            digit_count++;
        }
    }
    return dp[n][k];
}

int main() {
    int tt;
    cin >> tt;
    while (tt--) {
        int n, b, k;
        cin >> n >> b >> k;
        vector<int> l(n), r(n);
        for (int i = 0; i < n; i++) {
            cin >> l[i];
        }
        for (int i = 0; i < n; i++) {
            cin >> r[i];
        }
        long long ans = Calc(r, n, b, k) - Calc(l, n, b, k);
        if ((int) set<int>(r.begin(), r.end()).size() == k) {
            ans++;
        }
        ans = (ans % mod + mod) % mod;
        cout << ans << '\n';
    }
    return 0;
}
Editorialist's code (Python)
mod = 10**9 + 7
maxN = 2005
C = [ [0 for _ in range(maxN)] for _ in range(maxN)]
pw =  [ [0 for _ in range(maxN)] for _ in range(maxN)]

def f(n, x, y): # length n, x choices per spot, y of them should definitely occur
    res, sgn = 0, 1
    for i in range(y+1):
        res += sgn * C[y][i] * pw[x-i][n] % mod
        sgn *= -1
    return res % mod

def calc(N, b, k):
    if N[0] == 0: return 0
    n = len(N)
    ans = distinct = 0
    mark = [0]*b
    for i in range(n): # go down here
        used, unused = 0, 0
        for d in range(N[i]):
            if i > 0 or d > 0:
                used += mark[d]
                unused += 1 - mark[d]
        
        ans += used * C[b-distinct][k-distinct] * f(n-1-i, k, k-distinct) % mod
        if distinct < k:
            ans += unused * C[b-distinct-1][k-distinct-1] * f(n-1-i, k, k-distinct-1) % mod
        if mark[N[i]] == 0: distinct += 1
        mark[N[i]] = 1
        if distinct > k: break
    return ans % mod

for n in range(maxN):
    C[n][0] = 1
    for r in range(1, n+1): C[n][r] = (C[n-1][r] + C[n-1][r-1]) % mod
for x in range(1, maxN):
    pw[x][0] = 1
    for i in range(1, maxN): pw[x][i] = pw[x][i-1] * x % mod

for _ in range(int(input())):
    n, b, k = map(int, input().split())
    L = list(map(int, input().split()))
    R = list(map(int, input().split()))
    res = calc(R, b, k) - calc(L, b, k)
    if len(set(R)) == k: res += 1
    print(res % mod)
1 Like