SQUARESORT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Ritul Kumar Singh
Testers: Nishank Suresh, Satyam
Editorialist: Nishank Suresh

DIFFICULTY:

2728

PREREQUISITES:

BFS, Dynamic programming

PROBLEM:

You have N integers. In one move, you can choose A_i and set it to either A_i^2 or \left\lfloor \sqrt{A_i}\right\rfloor.

Find the minimum number of moves required to sort A.

EXPLANATION:

The single most important observation here is that a given index can’t take too many values: for the given limits, there are \leq 40 possibilities for each index.

Proof

We can make a couple of observations for any positive integer x.

  • \left\lfloor \sqrt{x^2}\right\rfloor = x. This means that any sequence of operations on x can be replaced by an equivalent sequence of operations in which every square root operation is before every squaring operation.
  • If we don’t use the square root operation, the integers we can reach are x, x^2, x^4, x^8, \ldots, i.e, integers of the form x^{2^k} for some k \geq 0.

If x = 1 the only integer that can be reached is 1, so we only look at x \gt 1.
If x \gt 1, then note that x^{2^6} = x^{64}\gt 10^{18} no matter what x is. So, repeated squaring only allows us to reach at most 6 values: x, x^2, x^4, x^8, x^{16}, x^{32}.

Similar reasoning should tell you that repeatedly taking the square root of x will give you at most 6 distinct values before you hit 1.
Since any valid integer can be created by taking the square root some times and then squaring some times, this gives us an upper bound of 6\times 6 = 36 reachable values.

Further, while it isn’t needed, a slight modification of the argument allows us to prove a tighter bound: 6 + 5 + 4 + 3 + 2 + 1 = 21, which along with x = 1 gives us 22. Do you see how?

With this in mind, let’s first find, for each 1 \leq i \leq N, the possible values that can occur at this position (and the minimum number of steps needed to reach each value).
Let S_i denote the set of values that can occur at position i, and cost(i, x) denote the number of steps needed to turn A_i into x.

How to compute these?

We use bfs!

More specifically, consider the (infinite) directed graph on the set of positive integers, where there’s an edge from x to y if and only if y = x^2 or y = \left\lfloor \sqrt{x}\right\rfloor.

Then, note that S_i is exactly the set of nodes that can be reached from A_i in this graph, and cost(i, x) is simply the shortest path from A_i to x.
So, simply start a bfs from A_i to compute distances: while the graph itself is impossible to hold in memory, we only care about those vertices that can be reached from A_i, which as noted above is pretty small. Edges are defined implicitly and can be computed as we do the bfs.

This bfs will take \mathcal{O}(|S_i|) time for a given index i, and |S_i| \leq 40 so this is \leq 40N vertices visited in total, which is perfectly fine.

Once we have this information, the rest of the problem turns into a relatively simply dp.

Let dp_{i, x} denote the minimum number of moves needed to sort the first i elements of the array, such that the value at position i is x. The final answer is the minimum value of dp_{N, x} across all x \in S_N.

The transitions are simple:

dp_{i, x} = cost(i, x) + \min_{\substack{y \in S_{i-1} \\ y \leq x}} (dp_{i-1, y})

Or in simpler words: the minimum cost to sort the first i elements and have x end up at position i equals the cost of turning A_i into x, plus the minimum cost of sorting the first i-1 elements such that position i-1 contains an element not larger than x.
The first part is cost(i, x) which we precomputed, and the second part is exactly what our dp contains so we obtain a recursion.

Let M be the maximum number of possibilities for a single index. Note that M \leq 22.
The algorithm above, if implemented directly as described, uses \mathcal{O}(NM) space and has a worst-case complexity of \mathcal{O}(NM^2).

It is possible to improve both of these:

  • Note that computing dp_{i, x} depends on only the values of dp_{i-1}. So, we can discard all previous rows of the dp table, which brings the memory down to \mathcal{O}(N+M).
  • Computing the transitions itself can be improved to \mathcal{O}(NM) by sorting the sets S_i and iterating in increasing order to compute dp_i, and then maintaining two pointers (or with binary search for an additional \log factor, which should still be fast enough).

The second step is explained more in-depth in the editorial of DIVSORT (which is a pretty similar problem, and has essentially the same dp).

If your solution matches the above and you’re getting WA, it might be due to use of the sqrt function: read through this blog to see how to fix it.

TIME COMPLEXITY

\mathcal{O}(NM) or \mathcal{O}(NM\log M) per test case, where M \leq 22 is the maximum possible number of values an index can have.

CODE:

Setter's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int int64_t
#define sp << ' ' <<
#define nl << '\n'

int root(int x) {
	int y = sqrtl(x) + 2;
	while(y * y > x) --y;
	return y;
}

signed main() {
	cin.tie(0)->sync_with_stdio(0);
	
	int T; cin >> T;
	while(T--) {
		int N; cin >> N;

		map<int, int> dp {{0, 0}}, cur;

		while(N--) {
			int a; cin >> a;

			cur[a] = 0;
			queue<int> q;
			q.push(a);

			while(!empty(q)) {
				int u = q.front(); q.pop();

				vector<int> vs;
				if(1 < u) vs.push_back(root(u));
				if(u <= (int)1e9) vs.push_back(u * u);

				for(int v : vs)
					if(cur.find(v) == end(cur))
						cur[v] = cur[u] + 1, q.push(v);
			}

			int prefMin = 1e18;

			for(auto &[i, j] : cur) {
				auto k = dp.upper_bound(i);
				j += prev(k)->second;

				j = min(j, prefMin);
				prefMin = min(prefMin, j);
			}

			swap(cur, dp);
			cur.clear();
		}

		int ans = 1e18;
		for(auto [i, j] : dp)
			ans = min(ans, j);

		cout << ans nl;
	}
}
Tester's code (C++)
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;   
using namespace std;  
#define ll long long  
const ll INF_MUL=1e13;
const ll INF_ADD=1e18;    
#define pb push_back                 
#define mp make_pair          
#define nline "\n"                           
#define f first                                          
#define s second                                             
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()     
#define vl vector<ll>           
#define vvl vector<vector<ll>>    
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);  
#endif       
void _print(ll x){cerr<<x;}  
void _print(char x){cerr<<x;}   
void _print(string x){cerr<<x;}    
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());   
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=1e9+7;   
const ll MAX=1000100; 
ll overflow(ll l,ll r){
    if(l>(2*INF_ADD/r))
        return 1;
    if(l*r>INF_ADD){
        return 1; 
    }
    return 0; 
}
ll sqr(ll x){
    ll y=sqrt(x);
    y+=5;
    while(1){
        if(y*y<=x){
            return y; 
        }
        y--;
    }
}
vector<pair<ll,ll>> getv(ll x){
    map<ll,ll> visited,cost; 
    queue<ll> track; track.push(x); 
    visited[x]=1,cost[x]=0; 
    while(!track.empty()){
        auto it=track.front();
        track.pop();
        if(!overflow(it,it)){
            ll y=it*it;
            if(!visited[y]){
                visited[y]=1;
                cost[y]=cost[it]+1;
                track.push(y);  
            }
        } 
        ll z=sqr(it);
        if(!visited[z]){
            visited[z]=1;
            cost[z]=cost[it]+1;
            track.push(z);
        }
    }
    vector<pair<ll,ll>> anot;
    for(auto it:cost){
        anot.push_back(it);
    }
    return anot; 
}
void solve(){
    ll n; cin>>n;
    map<ll,ll> dp; 
    dp[0]=0;
    for(ll i=1;i<=n;i++){
        ll x; cin>>x;
        vector<pair<ll,ll>> v=getv(x); 
        map<ll,ll> adp,visited;
        sort(all(v));
        reverse(all(v));
        for(auto it:dp){
            for(auto j:v){
                if(j.f<it.f){
                    break;   
                }
                if(visited[j.f]==0){
                    adp[j.f]=j.s+it.s; 
                }
                visited[j.f]=1; 
                adp[j.f]=min(adp[j.f],j.s+it.s); 
            }
        }
        swap(dp,adp);
    } 
    ll ans=INF_ADD;
    for(auto it:dp){
        ans=min(ans,it.s);
    }
    cout<<ans<<nline; 
    return;                          
}                                    
int main()                                                                           
{    
    ios_base::sync_with_stdio(false);                         
    cin.tie(NULL);  
    #ifndef ONLINE_JUDGE                 
    freopen("input.txt", "r", stdin);                                              
    freopen("output.txt", "w", stdout);  
    freopen("error.txt", "w", stderr);                        
    #endif                          
    ll test_cases=1;               
    cin>>test_cases;
    while(test_cases--){
        solve();
    }
    cout<<fixed<<setprecision(10);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
}  
Editorialist's code (Python)
from math import sqrt

def getsqrt(x):
    s = int(sqrt(x))
    while (s+1)*(s+1) <= x: s += 1
    while s*s > x: s -= 1
    return s

def genall(x):
	queue = [[x, 0]]
	reached = set([x])
	for u, d in queue:
		if u <= 10**9 and u*u not in reached:
			reached.add(u*u)
			queue.append([u*u, d+1])
		y = getsqrt(u)
		if y not in reached:
			reached.add(y)
			queue.append([y, d+1])
	return sorted(queue)

for _ in range(int(input())):
	n = int(input())
	a = list(map(int, input().split()))
	prv = [[1, 0]]
	for x in a:
		costs = genall(x)
		dp = []
		ptr, mn = 0, 10**18
		for y, d in costs:
			while ptr < len(prv) and prv[ptr][0] <= y:
				mn = min(mn, prv[ptr][1])
				ptr += 1
			dp.append([y, d + mn])
		prv = dp
	ans = 10 ** 18
	for x, d in prv: ans = min(ans, d)
	print(ans)
1 Like

There is a similar kind of question on codeforces
make the array strictly increasing by making the minimum possible number of operations

Can we use similar concept?
https://codeforces.com/blog/entry/47821
https://codeforces.com/blog/entry/47094

Here’s code I tried but could not pass all test cases

#include <bits/stdc++.h>
#define int long long int

void solve(){
    int n;
    std::cin >> n;
    std::vector <int> v(n);
    for(auto &i:v){
        std::cin >> i;
    }
    std::priority_queue <int> pq;

    int last = -1, ans = 0;
    for(int i=0; i<n; i++){
        if(v[i] == 1)
            last = i;
    }

    for(int i=0; i<=last; i++){
        if(v[i] != 1){
            int c = 0;
            while(v[i] != 1){
                v[i] = std::sqrt(v[i]);
                c += 1;
            }
            ans += c;
        } 
    }
    
    pq.push(v[0]);
    for(int i=1; i<n; i++){
        pq.push(v[i]);
        if(pq.top() > v[i]){
            int c = 0, t = pq.top();
            while(t > v[i]){
                t = std::sqrt(t);
                c += 1;
            }
            ans += c;
            pq.pop();
            pq.push(v[i]);
        }
    }

    std::cout << ans << "\n";
}
     
signed main() {

    std::ios::sync_with_stdio(false);
    std::cin.tie(0);

    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif // ONLINE_JUDGE

    int t = 1;
    std::cin >> t;
    while(t--){
        solve();
    }
}

Thanks

dp={}

def solve(i, l, num):
    if i == n:
        return 0
        
    if (i, l) in dp:
        return dp[(i, l)]
    
    res = float('inf')
    p=0
    x = a[i]
    while True:
        if x >= num:
            res = min(res, p + solve(i+1, p, x))
        if x > 10**9 or x == 1:
            break
        p += 1
        x *= x
    
    p=0
    x = a[i]
    while True:
        if x >= num:
            res = min(res, p + solve(i+1, -p, x))
        if x == 1:
            break
        p += 1
        x = int(x**0.5)
    
    dp[(i, l)] = res
    return res

t = int(input())
for _ in range(t):
    n = int(input())
    a = list(map(int, input().split()))
    print(solve(0, 0, 0))
    dp.clear()
    
    

why is this code not working?

The slope trick isn’t applicable here at all.

You need the function at each index to be piecewise linear and either convex or concave. The function in this problem is not piecewise linear, and is neither convex nor concave: it jumps randomly.

For example, if you start at A_i = 17, then you need 1 move to reach 4, 2 moves to reach 16, and 2 moves to reach 2.
2 \to 1 \to 2 \to 0 disproves convexity and concavity.

As for being piecewise linear, the function at a given index is only defined at \leq 22 integer points! It’s not possible to extrapolate between those points at all due to its nature, there’s absolutely nothing you can do to make it work.

I recommend reading through this blog for a better understanding of when you can apply it.

1 Like

It’s not hard to think of DP,but it takes some time to implement and optimize.

BTW,I got stuck in the accuracy prolem of sqrt() function in .I believe there are some people like me :smiling_face_with_tear:

1 Like

Thanks a lot, appreciate it :slight_smile:

Hey there are people who aren’t getting this even after contest is over: My solution link:
CodeChef: Practical coding for everyone Plz help me debug this. I have resolved the sqrt problem already. If my solution is not optimised I should get a TLE , why WA ?

Consider test case:
1
7
302273 761858 483972 671329 790140 749685 522992
Your answer is 8 but the correct answer is 7 .
The scheme is:
302273 -> 549
761858 -> 872
790140 -> 888 -> 29 -> 841 -> 707281
522992 -> 273520630000
The final array is:
549 872 483972 671329 707281 749685 273520630000

1 Like

Hey , thanks for the testcase. I wrongly considered that squaring a number after taking floor(square-root) would give the same number, but here its different because we are taking floor of the square root and it may not be a perfect square.

yes i was also getting WA in 1 test case. Btw now understood the cause.