 # MNDIGSM2 - Editorial

Author: Srikkanth R and Daanish Mahajan
Tester: Miten Shah
Editorialist: Aman Dwivedi

Easy-Medium

# PREREQUISITES:

Maths, Observation

# PROBLEM:

Let f(n, B) be the sum of digits of the integer n when written in base B.

More formally, if n = \sum\limits_{i=0}^{\infty} a_i B^i where a_i is an integer in the range [0, B-1], then f(n, B) = \sum\limits_{i=0}^{\infty} a_i.

Given Q queries, each consisting of two integers n and r. Find the value of B corresponding to which f(n, B) is minimum for all {\bf{2}} \leq B \leq r. If there are multiple such values, you can print any of them.

# EXPLANATION:

In the previous version of the problem we have seen that we can solve the problem in O(\sqrt{N} *C) time complexity. But this complexity is not good enough to solve this version and hence we will get TLE.

The only thing that differs in the problem is constraints. Notice that l is fixed in this problem as 2, which gives us a hint that our final answer will be bounded by log(N). Since it only takes log(N) bits to represent any number in the binary system and hence, so our final answer will never exceed log(N).

As in this version, we know that our answer is bounded by log(N), we can fix the three least significant bits of the base B by using a log^3(N) loop and check whether aB^2+bB+c=N has some B which satisfy it.

This can be easily done by solving the quadratic expression and see if there is some B which satisfies the above equation and is less than r. Finally, we can take the minimum value of (a+b+c) over all the possible values of B which satisfies the equation.

This covers all the bases B whose value is greater than \sqrt{N}, and then for the bases whose value is smaller than \sqrt{N}, we can simply brute as we did in the previous version. Finally, we can output the base B which gives us the minimum possible digit sum (a+b+c) among all possible choices of B.

# TIME COMPLEXITY:

O(\sqrt{N} + log^3(N)) per query

# SOLUTIONS:

Author's
#include <bits/stdc++.h>

#define LL long long
using namespace std;

clock_t start = clock();

LL readInt(LL l, LL r, char endd) {
LL x = 0;
char ch = getchar();
bool first = true, neg = false;
while (true) {
if (ch == endd) {
break;
} else if (ch == '-') {
assert(first);
neg = true;
} else if (ch >= '0' && ch <= '9') {
x = (x << 1) + (x << 3) + ch - '0';
} else {
assert(false);
}
first = false;
ch = getchar();
}
if (neg) x = -x;
if (x < l || x > r) {
cerr << l << " " << r << " " << x << " failed\n";
}
assert(l <= x && x <= r);
return x;
}
string readString(int l, int r, char endd) {
string ret = "";
int cnt = 0;
while (true) {
char g = getchar();
assert (g != -1);
if (g == endd) {
break;
}
++cnt;
ret += g;
}
assert(l <= cnt && cnt <= r);
return ret;
}
LL readIntSp(LL l, LL r) {
}
LL readIntLn(LL l, LL r) {
}
string readStringSp(int l, int r) {
}
string readStringLn(int l, int r) {
}

void solve() {
LL i, ret = -1, have = (LL)1e18;
auto update_base = [&](LL base) {
LL m = n, ans = 0;
while (m > 0) {
ans += m % base;
m /= base;
}
if (ans < have) {
have = ans;
ret = base;
}
};
update_base(2);
for (i=3;i*i*i<=n&&i<=r;++i) {
update_base(i);
}
LL lim = have;
cerr << ret << ' ';
for (LL a = 0; a <= lim; ++a) {
for (LL b = 0; b <= lim; ++b) {
for (LL C = 0; C <= lim; ++C) {
LL c = C - n;
if (a == 0) {
if (b == 0) {
continue;
}
LL base = -c / b;
if (base >= 2 && base <= r) {
update_base(base);
}
continue;
}
LL disc = b * b - 4 * a * c;
LL go = sqrtl(disc);
while (go * go < disc) ++go;
while (go * go > disc) --go;
if (go * go == disc) {
LL base = (go - b) / (2 * a);
if (base >= 2 && base <= r) {
update_base(base);
}
}
}
}
}
assert(ret >= 2 && ret <= r);
cout << ret << '\n';
}

int main() {
while (T--) {
solve();
}
// End solution here
assert(getchar() == EOF);

cerr << fixed << setprecision(10);
cerr << "Time taken = " << (clock() - start) / ((double)CLOCKS_PER_SEC) << " s\n";
return 0;
}

Tester
// created by mtnshh

#include<bits/stdc++.h>
using namespace std;
#define ll long long int
#define pb push_back
#define rb pop_back
#define ti tuple<int, int, int>
#define pii pair<int, int>
#define pli pair<ll, int>
#define pll pair<ll, ll>
#define mp make_pair
#define mt make_tuple

#define rep(i,a,b) for(ll i=a;i<b;i++)
#define repb(i,a,b) for(ll i=a;i>=b;i--)

#define err() cout<<"--------------------------"<<endl;
#define errA(A) for(auto i:A)   cout<<i<<" ";cout<<endl;
#define err1(a) cout<<#a<<" "<<a<<endl
#define err2(a,b) cout<<#a<<" "<<a<<" "<<#b<<" "<<b<<endl
#define err3(a,b,c) cout<<#a<<" "<<a<<" "<<#b<<" "<<b<<" "<<#c<<" "<<c<<endl
#define err4(a,b,c,d) cout<<#a<<" "<<a<<" "<<#b<<" "<<b<<" "<<#c<<" "<<c<<" "<<#d<<" "<<d<<endl

#define all(A)  A.begin(),A.end()
#define allr(A)    A.rbegin(),A.rend()
#define ft first
#define sd second

#define V vector<ll>
#define S set<ll>
#define VV vector<V>
#define Vpll vector<pll>

#define endl "\n"

long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
// char g = getc(fp);
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);

assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
// cerr << x << " " << l << " " << r << endl;
assert(l<=x && x<=r);
return x;
} else {
assert(false);
}
}
}
string ret="";
int cnt=0;
while(true){
char g=getchar();
// char g=getc(fp);
assert(g != -1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
}
long long readIntLn(long long l,long long r){
}
}
}

const ll max_q = 300;
const ll max_r = 1e12;
const ll max_n = 1e12;

const ll N = 200005;
const ll INF = 1e15;

ll solve(ll n, ll b){
ll sum = 0;
while(n){
sum += n % b;
n /= b;
}
return sum;
}

void solve(){
ll mn = INF, ans = -1;
if(r >= n){
ans = n;
mn = 1;
}
ll x = solve(n, 2);
if(x < mn){
mn = x;
ans = 2;
}
for(ll i=2; i*i*i<=n; i++){
if(i>r) break;
ll x = solve(n, i);
if(x < mn){
mn = x;
ans = i;
}
}
ll y = mn;
rep(a,0,y)  rep(b,0,y)   rep(c,0,y){
if(a == 0){
if(b == 0){
continue;
}
ll B = (n - c) / b;
if(B < 2 or B > r)   continue;
ll x = solve(n, B);
if(x < mn){
mn = x;
ans = B;
}
continue;
}
ll det = b*b - 4*a*(c-n);
ll det_sqrt = sqrt(det);
if(det_sqrt * det_sqrt < det){
det_sqrt += 1;
}
ll B = (det_sqrt - b) / (2 * a);
if(B < 2 or B > r)   continue;
ll x = solve(n, B);
if(x < mn){
mn = x;
ans = B;
}
}
cout << ans << endl;
}

int main(){
ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
while(q--){
solve();
}
assert(getchar()==-1);
}

2 Likes

Why don’t we check for exact solutions of the quadratic solution? I mean there might be the case that the discriminant is not a perfect square. In that case shouldn’t the answer not exist for this triplet of a, b and c.

With a little further observation I was able to find a solution which was in principle approximately O(sqrt N), but usually less because many of the tests can be eliminated as they cannot lead to a better solution than has been found already. See Solution: 50722837 | CodeChef which passed in 0.49 seconds. My code for MinDigSum is essentially the same, but with ‘int’ instead of ‘long’.

In function DigitsSumMinimum it stops calculating the sum of the digits as soon as they exceed the minimum found in a previous call.

Before starting, note that if Right >= N, the answer is 1 with B = N. No further work is required in this case.

At the start check the right limit, in case it is an unusual best solution.

Then work back through the 2-digit solutions, stopping as soon as we reach a point where no better solution is possible. The code (in C#) is

                // right and 2-digit numbers
long sum_min = long.MaxValue;
long k = n / right;
long b = right;
do {
if (DigitsSumMinimum(n, b, ref sum_min))
base_best = b;
b = n / (++k);
} while (b > left && b > k && k < sum_min);


We then calculate with the left limit, again in case it is an unusual, better solution.

Next work from left + 1 up to 5th root of N, for all numbers with 6 or more digits. There cannot be many of these, because the 5th root of N is less than 300 even for N = 1e12.

Then work through each of 5, 4 and 3-digit numbers in a similar way, stopping the search as soon as we cannot do better than the best solution so far.

1 Like