MNDIGSM2 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3

Author: Srikkanth R and Daanish Mahajan
Tester: Miten Shah
Editorialist: Aman Dwivedi

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Maths, Observation

PROBLEM:

Let f(n, B) be the sum of digits of the integer n when written in base B.

More formally, if n = \sum\limits_{i=0}^{\infty} a_i B^i where a_i is an integer in the range [0, B-1], then f(n, B) = \sum\limits_{i=0}^{\infty} a_i.

Given Q queries, each consisting of two integers n and r. Find the value of B corresponding to which f(n, B) is minimum for all {\bf{2}} \leq B \leq r. If there are multiple such values, you can print any of them.

EXPLANATION:

In the previous version of the problem we have seen that we can solve the problem in O(\sqrt{N} *C) time complexity. But this complexity is not good enough to solve this version and hence we will get TLE.

The only thing that differs in the problem is constraints. Notice that l is fixed in this problem as 2, which gives us a hint that our final answer will be bounded by log(N). Since it only takes log(N) bits to represent any number in the binary system and hence, so our final answer will never exceed log(N).

As in this version, we know that our answer is bounded by log(N), we can fix the three least significant bits of the base B by using a log^3(N) loop and check whether aB^2+bB+c=N has some B which satisfy it.

This can be easily done by solving the quadratic expression and see if there is some B which satisfies the above equation and is less than r. Finally, we can take the minimum value of (a+b+c) over all the possible values of B which satisfies the equation.

This covers all the bases B whose value is greater than \sqrt[3]{N}, and then for the bases whose value is smaller than \sqrt[3]{N}, we can simply brute as we did in the previous version. Finally, we can output the base B which gives us the minimum possible digit sum (a+b+c) among all possible choices of B.

TIME COMPLEXITY:

O(\sqrt[3]{N} + log^3(N)) per query

SOLUTIONS:

Author's
#include <bits/stdc++.h>

#define LL long long
using namespace std;

clock_t start = clock();

LL readInt(LL l, LL r, char endd) {
    LL x = 0;
    char ch = getchar();
    bool first = true, neg = false;
    while (true) {
        if (ch == endd) {
            break;
        } else if (ch == '-') {
            assert(first);
            neg = true;
        } else if (ch >= '0' && ch <= '9') {
            x = (x << 1) + (x << 3) + ch - '0';
        } else {
            assert(false);
        }
        first = false;
        ch = getchar();
    }
    if (neg) x = -x;
    if (x < l || x > r) {
        cerr << l << " " << r << " " << x << " failed\n";
    }
    assert(l <= x && x <= r);
    return x;
}
string readString(int l, int r, char endd) {
    string ret = "";
    int cnt = 0;
    while (true) {
        char g = getchar();
        assert (g != -1);
        if (g == endd) {
            break;
        }
        ++cnt;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}
LL readIntSp(LL l, LL r) {
    return readInt(l, r, ' ');
}
LL readIntLn(LL l, LL r) {
    return readInt(l, r, '\n');
}
string readStringSp(int l, int r) {
    return readString(l, r, ' ');
}
string readStringLn(int l, int r) {
    return readString(l, r, '\n');
}

void solve() {
    LL n = readIntSp(2, (LL)1e12), r = readIntLn(2, (LL)1e12);
    LL i, ret = -1, have = (LL)1e18;
    auto update_base = [&](LL base) {
        LL m = n, ans = 0;
        while (m > 0) {
            ans += m % base;
            m /= base;
        }
        if (ans < have) {
            have = ans;
            ret = base;
        }
    };
    update_base(2);
    for (i=3;i*i*i<=n&&i<=r;++i) {
        update_base(i);
    }
    LL lim = have;
    cerr << ret << ' ';
    for (LL a = 0; a <= lim; ++a) {
        for (LL b = 0; b <= lim; ++b) {
            for (LL C = 0; C <= lim; ++C) {
                LL c = C - n;
                if (a == 0) {
                    if (b == 0) {
                        continue;
                    }
                    LL base = -c / b;
                    if (base >= 2 && base <= r) {
                        update_base(base);
                    }
                    continue;
                }
                LL disc = b * b - 4 * a * c;
                LL go = sqrtl(disc);
                while (go * go < disc) ++go;
                while (go * go > disc) --go;
                if (go * go == disc) {
                    LL base = (go - b) / (2 * a);
                    if (base >= 2 && base <= r) {
                        update_base(base);
                    }
                }
            }
        }
    }
    assert(ret >= 2 && ret <= r);
    cout << ret << '\n';
}

int main() {
// Start solution here use readIntLn, readIntSp and readStringSp and readStringLn
// for reading input
    int T = readIntLn(1, 300);
    while (T--) {
        solve();
    }
// End solution here
    assert(getchar() == EOF);
    
    cerr << fixed << setprecision(10);
    cerr << "Time taken = " << (clock() - start) / ((double)CLOCKS_PER_SEC) << " s\n";
    return 0;
}
Tester
// created by mtnshh

#include<bits/stdc++.h>
using namespace std;
#define ll long long int
#define pb push_back
#define rb pop_back
#define ti tuple<int, int, int>
#define pii pair<int, int>
#define pli pair<ll, int>
#define pll pair<ll, ll>
#define mp make_pair
#define mt make_tuple
 
#define rep(i,a,b) for(ll i=a;i<b;i++)
#define repb(i,a,b) for(ll i=a;i>=b;i--)
 
#define err() cout<<"--------------------------"<<endl; 
#define errA(A) for(auto i:A)   cout<<i<<" ";cout<<endl;
#define err1(a) cout<<#a<<" "<<a<<endl
#define err2(a,b) cout<<#a<<" "<<a<<" "<<#b<<" "<<b<<endl
#define err3(a,b,c) cout<<#a<<" "<<a<<" "<<#b<<" "<<b<<" "<<#c<<" "<<c<<endl
#define err4(a,b,c,d) cout<<#a<<" "<<a<<" "<<#b<<" "<<b<<" "<<#c<<" "<<c<<" "<<#d<<" "<<d<<endl

#define all(A)  A.begin(),A.end()
#define allr(A)    A.rbegin(),A.rend()
#define ft first
#define sd second

#define V vector<ll>
#define S set<ll>
#define VV vector<V>
#define Vpll vector<pll>
 
#define endl "\n"

long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        // char g = getc(fp);
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
            
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
            // cerr << x << " " << l << " " << r << endl;
            assert(l<=x && x<=r);
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        // char g=getc(fp);
        assert(g != -1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}

const ll max_q = 300;
const ll max_r = 1e12;
const ll max_n = 1e12;

const ll N = 200005;
const ll INF = 1e15;

ll solve(ll n, ll b){
    ll sum = 0;
    while(n){
        sum += n % b;
        n /= b;
    }
    return sum;
}

void solve(){
    ll n = readIntSp(2, max_n), r = readIntLn(2, max_r);
    ll mn = INF, ans = -1;
    if(r >= n){
        ans = n;
        mn = 1;
    }
    ll x = solve(n, 2);
    if(x < mn){
        mn = x;
        ans = 2;
    }
    for(ll i=2; i*i*i<=n; i++){
        if(i>r) break;
        ll x = solve(n, i);
        if(x < mn){
            mn = x;
            ans = i;
        }
    }
    ll y = mn;
    rep(a,0,y)  rep(b,0,y)   rep(c,0,y){
        if(a == 0){
            if(b == 0){
                continue;
            }
            ll B = (n - c) / b;
            if(B < 2 or B > r)   continue;
            ll x = solve(n, B);
            if(x < mn){
                mn = x;
                ans = B;
            }
            continue;
        }
        ll det = b*b - 4*a*(c-n);
        ll det_sqrt = sqrt(det);
        if(det_sqrt * det_sqrt < det){
            det_sqrt += 1;
        }
        ll B = (det_sqrt - b) / (2 * a);
        if(B < 2 or B > r)   continue;
        ll x = solve(n, B);
        if(x < mn){
            mn = x;
            ans = B;
        }
    }
    cout << ans << endl;
}

int main(){
    ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    ll q = readIntLn(1, max_q);
    while(q--){
        solve();
    }
    assert(getchar()==-1);
}
2 Likes

Why don’t we check for exact solutions of the quadratic solution? I mean there might be the case that the discriminant is not a perfect square. In that case shouldn’t the answer not exist for this triplet of a, b and c.

With a little further observation I was able to find a solution which was in principle approximately O(sqrt N), but usually less because many of the tests can be eliminated as they cannot lead to a better solution than has been found already. See CodeChef: Practical coding for everyone which passed in 0.49 seconds. My code for MinDigSum is essentially the same, but with ‘int’ instead of ‘long’.

In function DigitsSumMinimum it stops calculating the sum of the digits as soon as they exceed the minimum found in a previous call.

Before starting, note that if Right >= N, the answer is 1 with B = N. No further work is required in this case.

At the start check the right limit, in case it is an unusual best solution.

Then work back through the 2-digit solutions, stopping as soon as we reach a point where no better solution is possible. The code (in C#) is

                // right and 2-digit numbers
                long sum_min = long.MaxValue;
                long k = n / right;
                long b = right;
                do {
                    if (DigitsSumMinimum(n, b, ref sum_min))
                        base_best = b;
                    b = n / (++k);
                } while (b > left && b > k && k < sum_min);

We then calculate with the left limit, again in case it is an unusual, better solution.

Next work from left + 1 up to 5th root of N, for all numbers with 6 or more digits. There cannot be many of these, because the 5th root of N is less than 300 even for N = 1e12.

Then work through each of 5, 4 and 3-digit numbers in a similar way, stopping the search as soon as we cannot do better than the best solution so far.

1 Like