MYPROBLEM - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: amir_quantom
Testers: tabr, yash_daga
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Greedy algorithms

PROBLEM:

Given N and K, find the smallest integer X such that N+X has at most K distinct digits when written in decimal.

EXPLANATION:

First, if N already has at most K distinct digits the answer is of course 0.

Now, instead of finding X, let’s directly find N+X; we can subtract N from it in the end.

Let’s treat N as a string of digits.
Let the length of this string be L.
Simiiarly, look at N+X as a string of length L (it’s obvious that an optimal N+X will also have length L, since \underbrace{999\ldots 99}_{\text{L times}} has only one distinct digit and is \geq N).

Since N+X\gt N, there must exist an index i such that:

  • N_j = (N+X)_j for 1 \leq j \lt i
  • N_i \lt (N+X)_i

That is, N and N+X will match on some prefix; and when they differ for the first time, N+X will have a higher digit.

Note that if we fix the position at which N and N+X first differ, then all further positions are free: we can place whatever digit we like there, and N+X \gt N will still hold.
We’ll use this to our advantage.

Fix the position i where N and N+X differ.
This fixes the values of (N+X)_j for 1 \leq j \lt i.

Now, we want N_i \lt (N+X)_i; but we also want N+X to be as small as possible. A little case analysis tells us the following:

  • If N_i = 9 then there’s no valid choice of (N+X)_i, so ignore this i.
  • Otherwise, ideally we’d like to choose (N+X)_i = N_i+1.
    • This is almost always possible. The only time it isn’t is when it would make N+X have more than K distinct digits, i.e, the set \{N_1, N_2, N_3, \ldots, N_{i-1}, N_i+1\} has size \gt K.
    • If this happens to be the case, then (N+X)_i should be the smallest element of \{N_1, N_2, N_3, \ldots, N_{i-1}\} that’s \gt N_i. If no such element exists, ignore this i; otherwise, it can be found with brute force.
  • Now we’ve fixed (N+X)_i, and all the other characters to its right are completely free: we just need to make sure that we still have \leq K distinct digits.
  • So, simply choose d to be the lowest digit that makes N+X have at most K distinct digits. Once again, this can be simply bruteforced, since we already know which digits are in the prefix.

This way, we obtain (at most) one value of N+X for every prefix, each of which is valid.

Take the smallest among them as the final value of N+X, and thus compute X.

TIME COMPLEXITY

\mathcal{O}(D\log N\log D) or \mathcal{O}(D^2\log N) per test case, where D = 10.

CODE:

Setter's code (C++)
#include<bits/stdc++.h>
//#pragma GCC optimize("O2")
using namespace std;
using ll = long long;
using ld = long double;
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define sz(x) (int)x.size()
//#define endl '\n'
const int mod = 1e9 + 7;
const int inf = 2e9 + 5;
const ll linf = 9e18 + 5;


int n;
int k;

void init() {
}

void input() {
    cin >> n >> k;
}

void solve() {
    string s = to_string(n);
    
    int ans = 0;
    if (n == 1000 * 1000 * 1000) {
        ans = 1111111111;
    }
    else {
        for (int i = 0; i < sz(s); i++) {
            ans = ans * 10 + 9;
        }
    }

    set<int> digits;
    int now = 0;
    for (int i = 0; i < sz(s); i++) {
        if (sz(digits) > k) {
            break;
        }

        int d = s[i] - '0';
        if (d == 9) {
            now *= 10;
            now += d;
            digits.insert(d);
            continue;
        }


        int fore = 0;
        int exten = 0;
        if (sz(digits) == k) {
            if (digits.upper_bound(d) == digits.end()) {
                now *= 10;
                now += d;
                digits.insert(d);
                continue;
            }

            fore = *digits.upper_bound(d);
            exten = min(fore, *digits.begin());
        }
        else if (sz(digits) == k - 1) {
            if (digits.find(d + 1) != digits.end()) {
                fore = d + 1;
                exten = 0;
            }
            else {
                fore = d + 1;

                exten = fore;
                if (!digits.empty()) {
                    exten = min(exten, *digits.begin());
                }
            }
        }
        else {
            fore = d + 1;
            exten = 0;
        }


        int ans2 = now;

        ans2 = ans2 * 10 + fore;
//        cout << "        " << ans2 << endl;
        for (int j = i + 1; j < sz(s); j++) {
            ans2 = ans2 * 10 + exten;
        //    cout << "        " << i << ' ' << j << endl;
        }

        ans = ans2;
        //cout << ans << ' ' << ans2 << ' ' << now << ' ' << fore << ' ' << exten << endl;

        now *= 10;
        now += d;
        digits.insert(d);
    }

    if (sz(digits) <= k) {
        ans = n;
    }

    cout << ans - n << '\n';
}

void output() {
}

int main() {
    // freopen("parsadox2.txt","r+",stdin);
    // freopen("parsadox.txt","w+",stdout);
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);

    int number_of_testcases = 1;
    cin >> number_of_testcases;
    while (number_of_testcases--) {
        init();

        input();

        solve();

        output();
    }

    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        // cerr << res << endl;
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    input_checker in;
    int tt = in.readInt(1, 1e5);
    in.readEoln();
    while (tt--) {
        long long n = in.readInt(1, 1e9);
        in.readSpace();
        int k = in.readInt(1, 10);
        in.readEoln();
        auto ss = to_string(n);
        sort(ss.begin(), ss.end());
        ss.resize(unique(ss.begin(), ss.end()) - ss.begin());
        if ((int) ss.size() <= k) {
            cout << 0 << '\n';
            continue;
        }
        long long t = 1;
        long long u = 0;
        long long ans = 1e18;
        for (int i = 0; i < 10; i++) {
            long long x = n / t;
            for (int j = 0; j < 12; j++) {
                x++;
                ss = to_string(x);
                sort(ss.begin(), ss.end());
                ss.resize(unique(ss.begin(), ss.end()) - ss.begin());
                if ((int) ss.size() <= k) {
                    ans = min(ans, (j + 1) * t - n % t + (ss[0] - '0') * u);
                }
            }
            u += t;
            t *= 10;
        }
        cout << ans << '\n';
    }
    in.readEof();
    return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
	s, k = input().split()
	n, k = int(s), int(k)
	mark = [0]*10
	ans = '9'*len(s)
	for i in range(len(s)):
		d = ord(s[i]) - ord('0')
		d1, d2 = d + 1, 0
		if sum(mark) < k:
			if sum(mark) == k-1 and d1 < 10 and mark[d1] == 0:
				mark[d1] = 1
				while mark[d2] == 0: d2 += 1
				mark[d1] = 0
		else:
			while d1 < 10 and mark[d1] == 0: d1 += 1
			while d2 < 10 and mark[d2] == 0: d2 += 1
		
		if d1 < 10:
			ans = min(ans, s[:i] + chr(ord('0') + d1) + chr(ord('0') + d2)*(len(s)-i-1))
		
		mark[d] = 1
		if sum(mark) > k: break
	if sum(mark) <= k: ans = s
	print(int(ans) - n)