PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: amir_quantom
Testers: tabr, yash_daga
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Greedy algorithms
PROBLEM:
Given N and K, find the smallest integer X such that N+X has at most K distinct digits when written in decimal.
EXPLANATION:
First, if N already has at most K distinct digits the answer is of course 0.
Now, instead of finding X, let’s directly find N+X; we can subtract N from it in the end.
Let’s treat N as a string of digits.
Let the length of this string be L.
Simiiarly, look at N+X as a string of length L (it’s obvious that an optimal N+X will also have length L, since \underbrace{999\ldots 99}_{\text{L times}} has only one distinct digit and is \geq N).
Since N+X\gt N, there must exist an index i such that:
- N_j = (N+X)_j for 1 \leq j \lt i
- N_i \lt (N+X)_i
That is, N and N+X will match on some prefix; and when they differ for the first time, N+X will have a higher digit.
Note that if we fix the position at which N and N+X first differ, then all further positions are free: we can place whatever digit we like there, and N+X \gt N will still hold.
We’ll use this to our advantage.
Fix the position i where N and N+X differ.
This fixes the values of (N+X)_j for 1 \leq j \lt i.
Now, we want N_i \lt (N+X)_i; but we also want N+X to be as small as possible. A little case analysis tells us the following:
- If N_i = 9 then there’s no valid choice of (N+X)_i, so ignore this i.
- Otherwise, ideally we’d like to choose (N+X)_i = N_i+1.
- This is almost always possible. The only time it isn’t is when it would make N+X have more than K distinct digits, i.e, the set \{N_1, N_2, N_3, \ldots, N_{i-1}, N_i+1\} has size \gt K.
- If this happens to be the case, then (N+X)_i should be the smallest element of \{N_1, N_2, N_3, \ldots, N_{i-1}\} that’s \gt N_i. If no such element exists, ignore this i; otherwise, it can be found with brute force.
- Now we’ve fixed (N+X)_i, and all the other characters to its right are completely free: we just need to make sure that we still have \leq K distinct digits.
- So, simply choose d to be the lowest digit that makes N+X have at most K distinct digits. Once again, this can be simply bruteforced, since we already know which digits are in the prefix.
This way, we obtain (at most) one value of N+X for every prefix, each of which is valid.
Take the smallest among them as the final value of N+X, and thus compute X.
TIME COMPLEXITY
\mathcal{O}(D\log N\log D) or \mathcal{O}(D^2\log N) per test case, where D = 10.
CODE:
Setter's code (C++)
#include<bits/stdc++.h>
//#pragma GCC optimize("O2")
using namespace std;
using ll = long long;
using ld = long double;
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define sz(x) (int)x.size()
//#define endl '\n'
const int mod = 1e9 + 7;
const int inf = 2e9 + 5;
const ll linf = 9e18 + 5;
int n;
int k;
void init() {
}
void input() {
cin >> n >> k;
}
void solve() {
string s = to_string(n);
int ans = 0;
if (n == 1000 * 1000 * 1000) {
ans = 1111111111;
}
else {
for (int i = 0; i < sz(s); i++) {
ans = ans * 10 + 9;
}
}
set<int> digits;
int now = 0;
for (int i = 0; i < sz(s); i++) {
if (sz(digits) > k) {
break;
}
int d = s[i] - '0';
if (d == 9) {
now *= 10;
now += d;
digits.insert(d);
continue;
}
int fore = 0;
int exten = 0;
if (sz(digits) == k) {
if (digits.upper_bound(d) == digits.end()) {
now *= 10;
now += d;
digits.insert(d);
continue;
}
fore = *digits.upper_bound(d);
exten = min(fore, *digits.begin());
}
else if (sz(digits) == k - 1) {
if (digits.find(d + 1) != digits.end()) {
fore = d + 1;
exten = 0;
}
else {
fore = d + 1;
exten = fore;
if (!digits.empty()) {
exten = min(exten, *digits.begin());
}
}
}
else {
fore = d + 1;
exten = 0;
}
int ans2 = now;
ans2 = ans2 * 10 + fore;
// cout << " " << ans2 << endl;
for (int j = i + 1; j < sz(s); j++) {
ans2 = ans2 * 10 + exten;
// cout << " " << i << ' ' << j << endl;
}
ans = ans2;
//cout << ans << ' ' << ans2 << ' ' << now << ' ' << fore << ' ' << exten << endl;
now *= 10;
now += d;
digits.insert(d);
}
if (sz(digits) <= k) {
ans = n;
}
cout << ans - n << '\n';
}
void output() {
}
int main() {
// freopen("parsadox2.txt","r+",stdin);
// freopen("parsadox.txt","w+",stdout);
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int number_of_testcases = 1;
cin >> number_of_testcases;
while (number_of_testcases--) {
init();
input();
solve();
output();
}
return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}
int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
now++;
}
return now;
}
string readOne() {
assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
// cerr << res << endl;
return res;
}
string readString(int minl, int maxl, const string &pattern = "") {
assert(minl <= maxl);
string res = readOne();
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int minv, int maxv) {
assert(minv <= maxv);
int res = stoi(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
long long res = stoll(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
};
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
input_checker in;
int tt = in.readInt(1, 1e5);
in.readEoln();
while (tt--) {
long long n = in.readInt(1, 1e9);
in.readSpace();
int k = in.readInt(1, 10);
in.readEoln();
auto ss = to_string(n);
sort(ss.begin(), ss.end());
ss.resize(unique(ss.begin(), ss.end()) - ss.begin());
if ((int) ss.size() <= k) {
cout << 0 << '\n';
continue;
}
long long t = 1;
long long u = 0;
long long ans = 1e18;
for (int i = 0; i < 10; i++) {
long long x = n / t;
for (int j = 0; j < 12; j++) {
x++;
ss = to_string(x);
sort(ss.begin(), ss.end());
ss.resize(unique(ss.begin(), ss.end()) - ss.begin());
if ((int) ss.size() <= k) {
ans = min(ans, (j + 1) * t - n % t + (ss[0] - '0') * u);
}
}
u += t;
t *= 10;
}
cout << ans << '\n';
}
in.readEof();
return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
s, k = input().split()
n, k = int(s), int(k)
mark = [0]*10
ans = '9'*len(s)
for i in range(len(s)):
d = ord(s[i]) - ord('0')
d1, d2 = d + 1, 0
if sum(mark) < k:
if sum(mark) == k-1 and d1 < 10 and mark[d1] == 0:
mark[d1] = 1
while mark[d2] == 0: d2 += 1
mark[d1] = 0
else:
while d1 < 10 and mark[d1] == 0: d1 += 1
while d2 < 10 and mark[d2] == 0: d2 += 1
if d1 < 10:
ans = min(ans, s[:i] + chr(ord('0') + d1) + chr(ord('0') + d2)*(len(s)-i-1))
mark[d] = 1
if sum(mark) > k: break
if sum(mark) <= k: ans = s
print(int(ans) - n)