# MYPROBLEM - Editorial

Author: amir_quantom
Testers: tabr, yash_daga
Editorialist: iceknight1093

TBD

# PREREQUISITES:

Greedy algorithms

# PROBLEM:

Given N and K, find the smallest integer X such that N+X has at most K distinct digits when written in decimal.

# EXPLANATION:

First, if N already has at most K distinct digits the answer is of course 0.

Now, instead of finding X, let’s directly find N+X; we can subtract N from it in the end.

Let’s treat N as a string of digits.
Let the length of this string be L.
Simiiarly, look at N+X as a string of length L (it’s obvious that an optimal N+X will also have length L, since \underbrace{999\ldots 99}_{\text{L times}} has only one distinct digit and is \geq N).

Since N+X\gt N, there must exist an index i such that:

• N_j = (N+X)_j for 1 \leq j \lt i
• N_i \lt (N+X)_i

That is, N and N+X will match on some prefix; and when they differ for the first time, N+X will have a higher digit.

Note that if we fix the position at which N and N+X first differ, then all further positions are free: we can place whatever digit we like there, and N+X \gt N will still hold.
We’ll use this to our advantage.

Fix the position i where N and N+X differ.
This fixes the values of (N+X)_j for 1 \leq j \lt i.

Now, we want N_i \lt (N+X)_i; but we also want N+X to be as small as possible. A little case analysis tells us the following:

• If N_i = 9 then there’s no valid choice of (N+X)_i, so ignore this i.
• Otherwise, ideally we’d like to choose (N+X)_i = N_i+1.
• This is almost always possible. The only time it isn’t is when it would make N+X have more than K distinct digits, i.e, the set \{N_1, N_2, N_3, \ldots, N_{i-1}, N_i+1\} has size \gt K.
• If this happens to be the case, then (N+X)_i should be the smallest element of \{N_1, N_2, N_3, \ldots, N_{i-1}\} that’s \gt N_i. If no such element exists, ignore this i; otherwise, it can be found with brute force.
• Now we’ve fixed (N+X)_i, and all the other characters to its right are completely free: we just need to make sure that we still have \leq K distinct digits.
• So, simply choose d to be the lowest digit that makes N+X have at most K distinct digits. Once again, this can be simply bruteforced, since we already know which digits are in the prefix.

This way, we obtain (at most) one value of N+X for every prefix, each of which is valid.

Take the smallest among them as the final value of N+X, and thus compute X.

# TIME COMPLEXITY

\mathcal{O}(D\log N\log D) or \mathcal{O}(D^2\log N) per test case, where D = 10.

# CODE:

Setter's code (C++)
#include<bits/stdc++.h>
//#pragma GCC optimize("O2")
using namespace std;
using ll = long long;
using ld = long double;
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define sz(x) (int)x.size()
//#define endl '\n'
const int mod = 1e9 + 7;
const int inf = 2e9 + 5;
const ll linf = 9e18 + 5;

int n;
int k;

void init() {
}

void input() {
cin >> n >> k;
}

void solve() {
string s = to_string(n);

int ans = 0;
if (n == 1000 * 1000 * 1000) {
ans = 1111111111;
}
else {
for (int i = 0; i < sz(s); i++) {
ans = ans * 10 + 9;
}
}

set<int> digits;
int now = 0;
for (int i = 0; i < sz(s); i++) {
if (sz(digits) > k) {
break;
}

int d = s[i] - '0';
if (d == 9) {
now *= 10;
now += d;
digits.insert(d);
continue;
}

int fore = 0;
int exten = 0;
if (sz(digits) == k) {
if (digits.upper_bound(d) == digits.end()) {
now *= 10;
now += d;
digits.insert(d);
continue;
}

fore = *digits.upper_bound(d);
exten = min(fore, *digits.begin());
}
else if (sz(digits) == k - 1) {
if (digits.find(d + 1) != digits.end()) {
fore = d + 1;
exten = 0;
}
else {
fore = d + 1;

exten = fore;
if (!digits.empty()) {
exten = min(exten, *digits.begin());
}
}
}
else {
fore = d + 1;
exten = 0;
}

int ans2 = now;

ans2 = ans2 * 10 + fore;
//        cout << "        " << ans2 << endl;
for (int j = i + 1; j < sz(s); j++) {
ans2 = ans2 * 10 + exten;
//    cout << "        " << i << ' ' << j << endl;
}

ans = ans2;
//cout << ans << ' ' << ans2 << ' ' << now << ' ' << fore << ' ' << exten << endl;

now *= 10;
now += d;
digits.insert(d);
}

if (sz(digits) <= k) {
ans = n;
}

cout << ans - n << '\n';
}

void output() {
}

int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);

int number_of_testcases = 1;
cin >> number_of_testcases;
while (number_of_testcases--) {
init();

input();

solve();

output();
}

return 0;
}

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
string buffer;
int pos;

const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";

input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}

int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
now++;
}
return now;
}

assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
// cerr << res << endl;
return res;
}

string readString(int minl, int maxl, const string &pattern = "") {
assert(minl <= maxl);
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}

int readInt(int minv, int maxv) {
assert(minv <= maxv);
assert(minv <= res);
assert(res <= maxv);
return res;
}

long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
assert(minv <= res);
assert(res <= maxv);
return res;
}

assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}

assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}

assert((int) buffer.size() == pos);
}
};

int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
input_checker in;
while (tt--) {
long long n = in.readInt(1, 1e9);
auto ss = to_string(n);
sort(ss.begin(), ss.end());
ss.resize(unique(ss.begin(), ss.end()) - ss.begin());
if ((int) ss.size() <= k) {
cout << 0 << '\n';
continue;
}
long long t = 1;
long long u = 0;
long long ans = 1e18;
for (int i = 0; i < 10; i++) {
long long x = n / t;
for (int j = 0; j < 12; j++) {
x++;
ss = to_string(x);
sort(ss.begin(), ss.end());
ss.resize(unique(ss.begin(), ss.end()) - ss.begin());
if ((int) ss.size() <= k) {
ans = min(ans, (j + 1) * t - n % t + (ss[0] - '0') * u);
}
}
u += t;
t *= 10;
}
cout << ans << '\n';
}
return 0;
}

Editorialist's code (Python)
for _ in range(int(input())):
s, k = input().split()
n, k = int(s), int(k)
mark = [0]*10
ans = '9'*len(s)
for i in range(len(s)):
d = ord(s[i]) - ord('0')
d1, d2 = d + 1, 0
if sum(mark) < k:
if sum(mark) == k-1 and d1 < 10 and mark[d1] == 0:
mark[d1] = 1
while mark[d2] == 0: d2 += 1
mark[d1] = 0
else:
while d1 < 10 and mark[d1] == 0: d1 += 1
while d2 < 10 and mark[d2] == 0: d2 += 1

if d1 < 10:
ans = min(ans, s[:i] + chr(ord('0') + d1) + chr(ord('0') + d2)*(len(s)-i-1))

mark[d] = 1
if sum(mark) > k: break
if sum(mark) <= k: ans = s
print(int(ans) - n)