PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: Ritul Kumar Singh
Testers: Nishank Suresh, Satyam
Editorialist: Nishank Suresh
DIFFICULTY:
2728
PREREQUISITES:
BFS, Dynamic programming
PROBLEM:
You have N integers. In one move, you can choose A_i and set it to either A_i^2 or \left\lfloor \sqrt{A_i}\right\rfloor.
Find the minimum number of moves required to sort A.
EXPLANATION:
The single most important observation here is that a given index can’t take too many values: for the given limits, there are \leq 40 possibilities for each index.
Proof
We can make a couple of observations for any positive integer x.
- \left\lfloor \sqrt{x^2}\right\rfloor = x. This means that any sequence of operations on x can be replaced by an equivalent sequence of operations in which every square root operation is before every squaring operation.
- If we don’t use the square root operation, the integers we can reach are x, x^2, x^4, x^8, \ldots, i.e, integers of the form x^{2^k} for some k \geq 0.
If x = 1 the only integer that can be reached is 1, so we only look at x \gt 1.
If x \gt 1, then note that x^{2^6} = x^{64}\gt 10^{18} no matter what x is. So, repeated squaring only allows us to reach at most 6 values: x, x^2, x^4, x^8, x^{16}, x^{32}.
Similar reasoning should tell you that repeatedly taking the square root of x will give you at most 6 distinct values before you hit 1.
Since any valid integer can be created by taking the square root some times and then squaring some times, this gives us an upper bound of 6\times 6 = 36 reachable values.
Further, while it isn’t needed, a slight modification of the argument allows us to prove a tighter bound: 6 + 5 + 4 + 3 + 2 + 1 = 21, which along with x = 1 gives us 22. Do you see how?
With this in mind, let’s first find, for each 1 \leq i \leq N, the possible values that can occur at this position (and the minimum number of steps needed to reach each value).
Let S_i denote the set of values that can occur at position i, and cost(i, x) denote the number of steps needed to turn A_i into x.
How to compute these?
We use bfs!
More specifically, consider the (infinite) directed graph on the set of positive integers, where there’s an edge from x to y if and only if y = x^2 or y = \left\lfloor \sqrt{x}\right\rfloor.
Then, note that S_i is exactly the set of nodes that can be reached from A_i in this graph, and cost(i, x) is simply the shortest path from A_i to x.
So, simply start a bfs from A_i to compute distances: while the graph itself is impossible to hold in memory, we only care about those vertices that can be reached from A_i, which as noted above is pretty small. Edges are defined implicitly and can be computed as we do the bfs.
This bfs will take \mathcal{O}(|S_i|) time for a given index i, and |S_i| \leq 40 so this is \leq 40N vertices visited in total, which is perfectly fine.
Once we have this information, the rest of the problem turns into a relatively simply dp.
Let dp_{i, x} denote the minimum number of moves needed to sort the first i elements of the array, such that the value at position i is x. The final answer is the minimum value of dp_{N, x} across all x \in S_N.
The transitions are simple:
Or in simpler words: the minimum cost to sort the first i elements and have x end up at position i equals the cost of turning A_i into x, plus the minimum cost of sorting the first i-1 elements such that position i-1 contains an element not larger than x.
The first part is cost(i, x) which we precomputed, and the second part is exactly what our dp contains so we obtain a recursion.
Let M be the maximum number of possibilities for a single index. Note that M \leq 22.
The algorithm above, if implemented directly as described, uses \mathcal{O}(NM) space and has a worst-case complexity of \mathcal{O}(NM^2).
It is possible to improve both of these:
- Note that computing dp_{i, x} depends on only the values of dp_{i-1}. So, we can discard all previous rows of the dp table, which brings the memory down to \mathcal{O}(N+M).
- Computing the transitions itself can be improved to \mathcal{O}(NM) by sorting the sets S_i and iterating in increasing order to compute dp_i, and then maintaining two pointers (or with binary search for an additional \log factor, which should still be fast enough).
The second step is explained more in-depth in the editorial of DIVSORT (which is a pretty similar problem, and has essentially the same dp).
If your solution matches the above and you’re getting WA, it might be due to use of the sqrt
function: read through this blog to see how to fix it.
TIME COMPLEXITY
\mathcal{O}(NM) or \mathcal{O}(NM\log M) per test case, where M \leq 22 is the maximum possible number of values an index can have.
CODE:
Setter's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int int64_t
#define sp << ' ' <<
#define nl << '\n'
int root(int x) {
int y = sqrtl(x) + 2;
while(y * y > x) --y;
return y;
}
signed main() {
cin.tie(0)->sync_with_stdio(0);
int T; cin >> T;
while(T--) {
int N; cin >> N;
map<int, int> dp {{0, 0}}, cur;
while(N--) {
int a; cin >> a;
cur[a] = 0;
queue<int> q;
q.push(a);
while(!empty(q)) {
int u = q.front(); q.pop();
vector<int> vs;
if(1 < u) vs.push_back(root(u));
if(u <= (int)1e9) vs.push_back(u * u);
for(int v : vs)
if(cur.find(v) == end(cur))
cur[v] = cur[u] + 1, q.push(v);
}
int prefMin = 1e18;
for(auto &[i, j] : cur) {
auto k = dp.upper_bound(i);
j += prev(k)->second;
j = min(j, prefMin);
prefMin = min(prefMin, j);
}
swap(cur, dp);
cur.clear();
}
int ans = 1e18;
for(auto [i, j] : dp)
ans = min(ans, j);
cout << ans nl;
}
}
Tester's code (C++)
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast,unroll-loops")
#include <bits/stdc++.h>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
#define ll long long
const ll INF_MUL=1e13;
const ll INF_ADD=1e18;
#define pb push_back
#define mp make_pair
#define nline "\n"
#define f first
#define s second
#define pll pair<ll,ll>
#define all(x) x.begin(),x.end()
#define vl vector<ll>
#define vvl vector<vector<ll>>
#define vvvl vector<vector<vector<ll>>>
#ifndef ONLINE_JUDGE
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);
#endif
void _print(ll x){cerr<<x;}
void _print(char x){cerr<<x;}
void _print(string x){cerr<<x;}
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";}
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=1e9+7;
const ll MAX=1000100;
ll overflow(ll l,ll r){
if(l>(2*INF_ADD/r))
return 1;
if(l*r>INF_ADD){
return 1;
}
return 0;
}
ll sqr(ll x){
ll y=sqrt(x);
y+=5;
while(1){
if(y*y<=x){
return y;
}
y--;
}
}
vector<pair<ll,ll>> getv(ll x){
map<ll,ll> visited,cost;
queue<ll> track; track.push(x);
visited[x]=1,cost[x]=0;
while(!track.empty()){
auto it=track.front();
track.pop();
if(!overflow(it,it)){
ll y=it*it;
if(!visited[y]){
visited[y]=1;
cost[y]=cost[it]+1;
track.push(y);
}
}
ll z=sqr(it);
if(!visited[z]){
visited[z]=1;
cost[z]=cost[it]+1;
track.push(z);
}
}
vector<pair<ll,ll>> anot;
for(auto it:cost){
anot.push_back(it);
}
return anot;
}
void solve(){
ll n; cin>>n;
map<ll,ll> dp;
dp[0]=0;
for(ll i=1;i<=n;i++){
ll x; cin>>x;
vector<pair<ll,ll>> v=getv(x);
map<ll,ll> adp,visited;
sort(all(v));
reverse(all(v));
for(auto it:dp){
for(auto j:v){
if(j.f<it.f){
break;
}
if(visited[j.f]==0){
adp[j.f]=j.s+it.s;
}
visited[j.f]=1;
adp[j.f]=min(adp[j.f],j.s+it.s);
}
}
swap(dp,adp);
}
ll ans=INF_ADD;
for(auto it:dp){
ans=min(ans,it.s);
}
cout<<ans<<nline;
return;
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
freopen("error.txt", "w", stderr);
#endif
ll test_cases=1;
cin>>test_cases;
while(test_cases--){
solve();
}
cout<<fixed<<setprecision(10);
cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n";
}
Editorialist's code (Python)
from math import sqrt
def getsqrt(x):
s = int(sqrt(x))
while (s+1)*(s+1) <= x: s += 1
while s*s > x: s -= 1
return s
def genall(x):
queue = [[x, 0]]
reached = set([x])
for u, d in queue:
if u <= 10**9 and u*u not in reached:
reached.add(u*u)
queue.append([u*u, d+1])
y = getsqrt(u)
if y not in reached:
reached.add(y)
queue.append([y, d+1])
return sorted(queue)
for _ in range(int(input())):
n = int(input())
a = list(map(int, input().split()))
prv = [[1, 0]]
for x in a:
costs = genall(x)
dp = []
ptr, mn = 0, 10**18
for y, d in costs:
while ptr < len(prv) and prv[ptr][0] <= y:
mn = min(mn, prv[ptr][1])
ptr += 1
dp.append([y, d + mn])
prv = dp
ans = 10 ** 18
for x, d in prv: ans = min(ans, d)
print(ans)