SQRTCBRT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: kulyash
Testers: tabr, iceknight1093
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Binary search

PROBLEM:

F(N) denotes the difference between the number of squares and the number of cubes from 1 to N.

Given X, find the smallest value of N such that F(N) \geq X.

EXPLANATION:

Problems of the type “find the smallest N such that \ldots” are often solved with binary search.
Unfortunately, that requires the function to be monotonic, and F(N) here isn’t monotonic.

Analyzing how it changes, we can see that:

F(N) = \begin{cases} F(N-1), & \text{ if } N \text{ is neither a square nor a cube} \\ F(N-1) + 1, & \text{ if } N \text{ is a square but not a cube} \\ F(N-1) - 1, & \text{ if } N \text{ is a cube but not a square} \\ F(N-1), & \text{ if } N \text{ is both a square and a cube} \end{cases}

In particular, notice that F can only increase at a square number, so the final answer must be a square number.
Further, the fact that cubes are more ‘spread out’ than squares means that F is actually a monotonic function when evaluated on the squares!
That is, we consider the function g(N) = F(N^2) instead, which is increasing.

Proof

It’s enough to prove that there’s at most one cube lying between any two consecutive squares, since this is what makes F increasing on the squares.

So, suppose we have 1\leq x^2 \leq y^3 \leq (y+1)^3. We just need to prove (x+1)^2 \lt (y+1)^3.

This is not hard to do by simple algebraic manipulation.
Expand out (x+1)^2 and (y+1)^3, to obtain (x^2 + 2x + 1) and (y^3 + 3y + 3y^2 + 1).

Comparing their terms, we can see that:

  • y^3 \geq x_2
  • They both have a 1
  • 3y^2 \geq 3x \gt 2x. This follows from the fact that x^2 \leq y^3 means that y^2 \geq x^{4/3} \geq x.

So we have the required inequality.

With this information in hand, the solution is simple: use binary search to find the smallest value of N such that g(N) \geq X; the final answer is then N^2.

This requires us to quickly evaluate g(N) = F(N^2).

Note that the number of squares that are \leq N^2 is simply N, so we just need to count the number of cubes that are \leq N^2.
There are several ways to do this: for example, use std::cbrtl or yet another binary search.

Regarding precision

You may have encountered precision errors when computing cube roots; luckily, one of the samples was specifically designed to catch several common implementations.

In particular, cbrt in C++ or something like N ** (1 / 3) in Python have precision issues.

There are several ways to get around this, some being problem dependent:

  • If possible, the absolute best solution is to simply not deal with floating-point integers at all. In this problem, the number of cubes \leq N can be computed using a binary search across the integers, so this is the safest way.
  • cbrtl in C++ is precise enough to get AC in this task.
  • The cube root can be manually adjusted once you compute it: if c is the value you compute as the cube root of N (after rounding) and (c+1)^3 \leq N, then increase c to c+1 instead.
  • After computing the cube root, add a small value to it (like 10^{-10}) before rounding to an integer.

TIME COMPLEXITY

\mathcal{O}(\log^2 X) per testcase.

CODE:

Setter's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define ll long long
int main() {
    vector<ll>cubes;
    for(ll i=1;i<=1010000;i++)cubes.push_back(i*i*i);
    ll T;
    cin >> T;
    while(T--){
        ll x;
        cin >> x;
        ll l=1;
        ll r=2e9;
        ll ans;
        while(l<=r){
            ll mid=(r+l)/2;
            ll temp=upper_bound(cubes.begin(),cubes.end(),mid*mid)-cubes.begin();
            ll curr=mid-temp;
            if(curr>=x){
                ans=mid*mid;
                r=mid-1;
            }
            else{
                l=mid+1;
            }
        }
        cout << ans << endl;
    }
return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        // cerr << res << endl;
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    input_checker in;
    int tt = in.readInt(1, 1e5);
    in.readEoln();
    while (tt--) {
        int x = in.readInt(1, 1e9);
        in.readEoln();
        long long low = 0, high = 2e9;
        while (high - low > 1) {
            long long mid = (high + low) >> 1;
            long long cnt = mid;
            long long t = cbrt(mid * mid);
            while ((t + 1) * (t + 1) * (t + 1) <= mid * mid) {
                t++;
            }
            while (t * t * t > mid * mid) {
                t--;
            }
            cnt -= t;
            if (cnt >= x) {
                high = mid;
            } else {
                low = mid;
            }
        }
        cout << high * high << '\n';
    }
    in.readEof();
    return 0;
}
Editorialist's code (Python)
from bisect import bisect_left
cubes = [i**3 for i in range(1, 1010000)]

for _ in range(int(input())):
	x = int(input())
	lo, hi = 1, 2 * 10**9
	def g(n): # f(n*n)
		return n - bisect_left(cubes, n*n + 1)
	while lo < hi:
		mid = (lo + hi)//2
		if g(mid) >= x: hi = mid
		else: lo = mid+1
	print(lo * lo)
1 Like
#include<bits/stdc++.h>
using namespace std;
#define debug(x) ""
#define ll long long
#define int long long
#define ld long double
#define INF 998244353
#define MOD 1000000007
#define Case(test)cout<<"Case #"<<test<<": "; 
#define y second
#define x first
#define mem(arr,val) memset(arr,val,sizeof(arr));
#define all(x) x.begin(),x.end()
ll max(ll a,ll b){return ((a>=b)?a:b);}
int maxself(int &var1,int var2){var1=max(var1,var2);return var1;}
ll min(ll a,ll b){return ((a<=b)?a:b);}
int minself(int &var1,int var2){var1=min(var1,var2);return var1;}




int f(int n)
{
   int res=sqrt(n);
   int a=(cbrt(n));
   return res-a;
}


ll solve(int tt)
{
   int n,res=0;
   cin>>n;
   int l=1,r=sqrt(1ll*2e18);
   while(l<=r)
   {
      int mid=(l+r)>>1;
      int m=1ll*mid*mid;
      if(f(m)>=n)
      {
         res=m;
         r=(mid)-1;
      }
      else
      {
         l=mid+1;
      }
   }
   cout<<res<<endl;   
   return 0;
}





signed main()
{
   ios_base::sync_with_stdio(0);
   cin.tie(0);
   cout.tie(0);
   ll t=1;
   cin>>t;
   for(int i=1;i<=t;i++)
   {
      solve(i);
   }

   return 0;
}

why my code give different results for x=3151 on my local ide than codechef ide

It is because when you are declaring the function f(m) , you are simply taking the cubrt and sqrt of m and subracting it .
Instead of that, you should also check if I increase or decrease the value of a by 1{mentioned in defination of f(m)}, whether it lie within it or not.
Like for eg . cubrt(x) = 2.9 then its roundoff will give 3 , which is an error , so in this case if you have check as i mentioned above , then code will run perfectly fine.

In the Setter’s solution,

how are these upper limits decided.? (1010000, 2e9)

They aren’t specific bounds, just large enough numbers.

For a given N, f(N^2) is quite close to N, since the number of cubes grows much slower than the number of squares.
So it’s reasonable to expect that R = 2\cdot 10^9 will definitely be enough for f(N^2) to exceed 10^9, which is the limit we care about.
You could reasonably take any number that’s somewhat larger than 10^9 as well.

1010000 is chosen similarly, it’s a bit larger than 10^6 which is approximately how many cubes you care about.
You could also take 2\cdot 10^6 and it’d be correct, anything works as long as it’s large enough.

thanks