POWPM - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: munch_01
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

None

PROBLEM:

You’re given an array A. Find the number of pairs of indices (i, j) such that A_i^j \leq A_j.

EXPLANATION:

Powers of positive integers generally grow pretty fast: in particular, if x \geq 2, then x^{30} \gt 10^9.
This means that any pair (i, j) such that A_i \geq 2 and j \geq 30 is automatically invalid!

In other words, every valid pair should have either A_i = 1, or j \lt 30.
Looking at these two cases:

  • If A_i = 1, then it doesn’t matter what j is: A_i^j = 1^j = 1 will always be \leq A_j.
    So, in this case we can simply add N to the answer, since every j is valid.
  • If A_i \gt 1, then our observation tells us that it’s enough to check for only j \lt 30, which can be done using a brute force.

We check at most 30\cdot N pairs of indices this way, which is fast enough.

Note that depending on your implementation, you will have to deal with overflow issues appropriately.
A simple way of doing this is to store the current power in a 64-bit variable, and break out once it exceeds 10^9 - this way you will never encounter overflow since at any point you only multiply something \leq 10^9 with something else \leq 10^9.

TIME COMPLEXITY:

\mathcal{O}(N\log 10^9) per testcase.

CODE:

Editorialist's code (Python)
for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    ans = 0
    for i in range(n):
        if a[i] == 1: ans += n
        else:
            pw = 1
            for j in range(n):
                pw *= a[i]
                if pw > 10**9: break
                ans += pw <= a[j]
    print(ans)
1 Like

hey problem tester

why this code fails on last test case

def binary_search(arr, x):
    left, right = 0, len(arr) - 1
    while left <= right:
        mid = (left + right) // 2
        if arr[mid] < x:
            left = mid + 1
        else:
            right = mid - 1
    return len(arr) - left

for _ in range(int(input())):
    a = int(input())
    arr = [int(i) for i in input().split()]
    # ai<=aj^(1/(j+1))
    st = sorted([j ** (1 / (i + 1)) for i, j in enumerate(arr)])
    ans = 0
    for i, j in enumerate(arr):
        ans += binary_search(st, j)
    print(ans)

1 Like

I used,
A[i]<= (jth root of (A[j]))
Now for each j ,calculate no of i ,which satisfy it,but it is giving wrong answer?Can anyone please tell ,why is this approach wrong?

1 Like

i also got stuck at there not able to find the logic

Can you explain the problem in detail?

I used the same approach but it gave wrong answer on last testcase.
Here’s link to my solution:- CodeChef: Practical coding for everyone

Try this test case:
Input:

1
5
2 8 16777216 256 1024

Expected Output:

7

3^{rd} root of 16777216 (viz., 2^{24}) is 256, but 16777216^{\frac{1}{3}} will be erroneous when computed. It’ll be computed as 255.99999999999991. When converted to int, the fractional part is truncated, resulting in 255. All solutions based on computing k^{th} roots will miss the pair (4, 3), which is a valid pair.

2 Likes

I used binary search to find the nth root of an element and used lower_bound (again binary search) on a sorted version of the array to determine the number of elements that are lesser than or equal to the root.

Like the editorial observes, nthRoot always returns 1 for all n > 30 so the answer will be the same.

int nthRoot(int x, int n) {
  if (x == 0 || x == 1)
    return x;
  ll low = 1, high = x;
  ll ans = 1;
  while (low <= high) {
    ll mid = (low + high) / 2;
    int c = 0;
    int tmp = x;
    if (mid == 1)
      c = n;
    else {
      while (tmp) {
        tmp /= mid;
        c++;
      }
      c--;
    }
    if (c >= n) {
      ans = mid;
      low = mid + 1;
    } else 
      high = mid - 1;
  };
  return ans;
}

void solve() {
  int n;
  cin >> n;
  vector<int> v(n);
  vector<int> a(n);
  for (int i = 0; i < n; i++) {
    cin >> v[i];
    a[i] = v[i];
  }
  sort(a.begin(), a.end());
  
  ll ans = 0;
  
  for (int i = 0; i < n; i++) {
    int root = nthRoot(v[i], i + 1);
    auto it = upper_bound(a.begin(), a.end(), root);
    int dist = it - a.begin();
    ans += dist;
  }
  
  cout<<ans<<"\n";
}

why my code is Failing. Please Help me deduce

include <bits/stdc++.h>
using namespace std;

int main() {
// your code goes here
int t;
cin>>t;
while(t–){
int n;
cin>>n;
int arr[n];
for(int i=0;i<n;i++){
cin>>arr[i];
}
long long ans = 0;
int x = min(30,n);
for(int i=0;i<n;i++){
if(arr[i] == 1){
ans+=n;
continue;
}
for(int j = 1;j<x;j++){
if(pow(arr[i],j) <= arr[j-1])
ans++;
}
}
cout<<ans<<endl;
}
}

Why I have to stuck everytime on 4th problem.
Just not able to think logic, it was straight forward.

I was thinkig of discarading the pairs. But still I failed miserably.

but when i am printing each element after taking root the value for 16777216 root 1/3, comes out to be 256, this you are right it is not including the pair (256,16777216)
why is this happening?

#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define mod 1000000007
#define py cout << "YES" << endl;
#define pn cout << "NO" << endl;
#define nl cout << endl;
#define all(p) p.begin(), p.end()
#define pyr                    \
    {                          \
        cout << "YES" << endl; \
        return;                \
    }
#define pnr                   \
    {                         \
        cout << "NO" << endl; \
        return;               \
    }

void breathe()
{
    double n;
    cin >> n;
    vector<ll> arr(n);
    for (auto &&i : arr)
    {
        cin >> i;
    }

    vector<ll> v = arr;
    sort(all(v));

    ll ans = 0;
    for (ll i = 0; i < n; i++)
    {
        double ele = pow(arr[i], 1.0 / (i + 1.0));
        cout << ele << ' ';
        ll index = upper_bound(all(v), ele) - v.begin();
        // cout << index << ' ';

        ans += index;
    }

    cout << ans << endl;

    return;
}

int main()
{
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    ll tt = 1;
    cin >> tt;
    while (tt--)
    {
        breathe();
    }
    return 0;
}

Basically ,if you are using double ,you may get into precision error,so use slightly tweaked approach as mentioned below in comments