TRIPRI - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: aryan_sinha
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Sieve of Eratosthenes

PROBLEM:

Given an even integer N, find whether there exist three distinct primes whose squares to N.

EXPLANATION:

First, recall that every prime larger than 2 is odd.
Further, the square of an even number remains even, and the square of every odd number remains odd.

Now, if three integers are to sum up to the even number N, either exactly one of them should be even, or all three of them should be even.
However, since as we noted earlier that there’s only one even prime (2), it’s not possible in our case for all three to be even - meaning we’re forced into the case where only one of the numbers is even.

Let’s thus fix a = 2 to be the first number.
We now want to figure out whether there exist two distinct primes p and q such that 2^2 + p^2 + q^2 = N.
In particular, note that if we fix the value of p, the value of q is uniquely determined, since it should be \sqrt{N - 4 - p^2}.

This is enough to solve subtask 1: fix a value of p, check if it’s prime; and if it is, compute q using the above formula and check if that is also prime (and an integer, first!)
Checking if an integer x is prime is easily done in \mathcal{O}(\sqrt x) time, and in the first subtask N \leq 10^5 so this is pretty fast.
Note that we have p^2 \leq N, so it’s enough to check only all p \leq \sqrt N.

For each p, we do two primality tests (one for p and one for q), each of which takes \mathcal{O}(\sqrt N) time.
So, the overall complexity comes out to be \mathcal{O}(N) per test, which is fast enough here.


For the second subtask, we need a bit more optimization.
While N \leq 10^{10}, we’re again only interested in p \leq \sqrt{N} \leq 10^5, so looking at each of them is still fast enough as long as we’re able to do it quickly enough.

The slow part of our solution to subtask 1 was the primality check.
This can be sped up by utilizing the fact that p, q \leq 10^5.
For each integer from 1 to 10^5, we precompute whether it is prime or not - for example with the sieve of Eratosthenes.

Then, checking whether each of p and q is a prime can be done in \mathcal{O}(1) time, which brings our complexity down to \mathcal{O}(\sqrt N) overall - fast enough for the second subtask.

TIME COMPLEXITY:

\mathcal{O}(\sqrt N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp> // Common file
#include <ext/pb_ds/tree_policy.hpp>
#include <functional> // for less
#include <iostream>
#include <string>
using namespace __gnu_pbds;
using namespace std;

#define ll long long int
#define lld long double
#define all(vec) vec.begin(), vec.end()
#define endl "\n"
#define pb push_back
#define yes cout << "YES" << endl;
#define no cout << "NO" << endl;
#define ff first
#define ss second
#define flush cout << flush;
#define endl "\n"
// #define N 1e5 + 1
#define PI  3.141592653589793238462643383279
#define IOS                       \
    ios_base::sync_with_stdio(0); \
    cin.tie(0);                   \
    cout.tie(0);

typedef tree<pair<int, int>, null_type, less<pair<int, int>>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset; //store values as pairs with second element distinct
#define ordered_set tree<ll, null_type,less<ll>, rb_tree_tag,tree_order_statistics_node_update>
// find_by_order :- returns iterator to kth element starting from 0
// order_of_key:- elements less than current element

template <typename T>istream &operator>>(istream &in, vector<T> &a) {for (auto &x : a)in >> x; return in;};
template <typename T>ostream &operator<<(ostream &out, vector<T> &a) {for (auto &x : a)out << x << ' '; return out;};

template <typename T1, typename T2>ostream &operator<<(ostream &out, const pair<T1, T2> &x) { return out << x.ff << ' ' << x.ss; }
template <typename T1, typename T2>istream &operator>>(istream &in, pair<T1, T2> &x) { return in >> x.ff >> x.ss; }

const ll MOD = 1e9 + 7;
const ll mod = 998244353;
const ll NODE = 1e5 + 10;
const ll INF = 1e18;
const ll N = sqrt(1e13 + 5);
const ll NN = 1e18 + 1;
const ll MAX = 301;


vector<ll>vec(N + 1, 1);

vector<ll> prime()
{   // sieve
    vector<ll>ans;
    vector<ll>spf(N + 1);
    vec[0] = 0;
    vec[1] = 0;
    for (int i = 2; i * i <= N; i++)
    {
        if (vec[i] == 1)
        {
            spf[i] = i;
            for (int j = i * i; j <= N; j += i)
            {
                if (vec[j] == 1)
                {
                    vec[j] = 0;
                    spf[j] = i;
                }
            }
        }
    }
    for (int i = 0; i < N; i++)
    {
        if (i != 2 && vec[i] == 1) {
            ans.pb(i);
        }
    }
    return ans;
}
vector<ll>p;
void _segfault_()
{
    ll n; cin >> n;
    n -= 4;
    for (auto it : p) {
        ll val = n - it * it;
        double v = sqrt(val);
        val = v;
        if (val == v) {
            if (n - it * it > 0 && vec[val] && val!=2 && val != it) {
                cout << "Yes" << endl; return;
            }
        }
    }
    cout << "No" << endl;
}
int main(int argc, char const * argv[])
{
    // int32_t for returning val 32 bit integer always
    IOS
    clock_t z = clock();
    cout.setf(ios::fixed, ios::floatfield);
    cout.setf(ios::showpoint);
    cout << setprecision(10);
    //freopen("feast.in", "r", stdin); freopen("feast.out", "w", stdout);
    int t = 1;
    cin >> t;
    ll a = 1;
    p = prime();
    while (t--)
    {
        // cout << "Case #" << a << ": ";
        _segfault_();
        a++;
    }
    cerr << "Run Time : " << ((double)(clock() - z) / CLOCKS_PER_SEC);
    return 0;
}
Tester's code (C++)
// subtask 2
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

#define IGNORE_CR

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
#ifdef IGNORE_CR
            if (c == '\r') {
                continue;
            }
#endif
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            assert(!isspace(buffer[pos]));
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    input_checker in;
    const int MAX = 100000;
    vector<int> primes;
    vector<bool> is_prime(MAX, true);
    is_prime[0] = is_prime[1] = false;
    for (int i = 2; i < MAX; i++) {
        if (!is_prime[i]) {
            continue;
        }
        primes.emplace_back(i);
        if ((long long) i * i >= MAX) {
            continue;
        }
        for (int j = i * i; j < MAX; j += i) {
            if (is_prime[j]) {
                is_prime[j] = false;
            }
        }
    }
    int tt = in.readInt(1, 500);
    in.readEoln();
    while (tt--) {
        long long n = in.readLong(0, 1e10);
        in.readEoln();
        assert(n % 2 == 0);
        string ans = "No";
        for (int p : primes) {
            long long x = n - 4 - p * 1LL * p;
            if (x <= 0) {
                break;
            }
            long long y = llround(sqrtl(x));
            while (y * y > x) {
                y--;
            }
            while ((y + 1) * (y + 1) <= x) {
                y++;
            }
            if (x == y * y && p != y && is_prime[y]) {
                ans = "Yes";
            }
        }
        cout << ans << '\n';
    }
    in.readEof();
    return 0;
}
Editorialist's code (Python)
N = 10**5 + 42
sieve = [1]*N
sieve[0] = sieve[1] = 0
for p in range(2, N):
    if sieve[p] == 0: continue
    for n in range(2*p, N, p): sieve[n] = 0
primes = []
for i in range(3, N):
    if sieve[i] == 1: primes.append(i)

from math import sqrt
def isqrt(x):
    r = int(sqrt(x))
    while r*r > x: r -= 1
    while (r+1)*(r+1) <= x: r += 1
    return r

for _ in range(int(input())):
    n = int(input())
    ans = 'No'
    for p in primes:
        if 4 + p*p > n: break
        r = n - 4 - p*p
        s = isqrt(r)
        if s*s != r or s == p: continue
        if sieve[s]:
            ans = 'Yes'
            break
    print(ans)
1 Like

I initially missed that N is even and you can still solve it in O(N/log^2 N) precomputation plus O(sqrt N) per case.

please share how

My apologies, I missed a factor of \log N.

Anyway, there are O(\sqrt N / \log \sqrt N) = O(\sqrt N / \log N) primes up to \sqrt N. Precompute sums of squares of all pairs. there are O((\sqrt N / \log N)^2) = O(N / \log^2 N) of them. Sort them in O(N / \log^2 N \cdot \log (N / \log^2 N)) = O(N / \log N) time. Iterate over O(\sqrt N / \log N) candidates for the third prime and binary search for the sum of the remaining two squares in the sorted array in O(\log (N / \log^2 N)) = O(\log N). Hence the actual complexity is O(N / \log N) precomputation and O(\sqrt N / \log N \cdot \log N) = O(\sqrt N) per query.

You can also avoid sorting and use a hash table, but it’s harder to deal with the distinct condition this way, and the runtime in practice may be not that much better anyway.

1 Like

I find the time limit for this problem too tight to be honest. In case you just add a sanity check using Legendre’s three-square theorem, you get a TLE.

Checking if n is not of the form 4^a(8b + 7) with (a, b \in \mathbb{Z} and a, b \ge 0) takes precisely O(\frac{1}{2}log_2(n)) time.

Sadly, even if you take O(\sqrt{n} + \log_2(n)) time per test, you get TLE.

These are the two implementations. They are the exact same ones, just that the first one first checks if it is possible to represent n as a sum of squares of 3 integers, and the second one does not.

Code with Legendre check

Link: 1064741540
Time per test: O(\sqrt{n} + \log_2(n))

#include <bits/stdc++.h>

using namespace std;

inline vector<bool> get_primes(const int n) {
  vector<bool> prime(n + 1, true);
  for (int i = 4; i <= n; i += 2) {
    prime[i] = false;
  }
  for (int i = 3; i * i <= n; i += 2) {
    if (prime[i]) {
      for (int j = i * i; j <= n; j += i) {
        prime[j] = false;
      }
    }
  }
  return prime;
}

inline vector<int64_t> get_square_primes(const vector<bool> &prime) {
  vector<int64_t> primes;
  for (int i = 2; i < (int) prime.size(); i++) {
    if (prime[i]) {
      primes.push_back((int64_t) i * i);
    }
  }
  return primes;
}

inline bool is_possible(int64_t n) {
  while ((n & 3) == 0) {
    n >>= 2;
  }
  return (n & 7) != 7;
}

int main() {
  ios_base::sync_with_stdio(false);
  cin.tie(nullptr);
  cout.tie(nullptr);
  const vector<bool> is_prime = get_primes((int) 1e5);
  const vector<int64_t> sq = get_square_primes(is_prime);
  int test_cases;
  cin >> test_cases;
  while (test_cases-- > 0) {
    int64_t n;
    cin >> n;
    if (!is_possible(n)) {
      cout << "No\n";
      continue;
    }
    bool found = false;
    for (int i = 1; (sq[i] << 1) < n - sq[0]; i++) {
      const int64_t left = n - sq[0] - sq[i];
      const int64_t root = (int64_t) sqrt(left);
      if (root * root == left && is_prime[root]) {
        found = true;
        break;
      }
    }
    cout << (found ? "Yes\n" : "No\n");
  }
  return 0;
}
Code without Legendre check

Link: 1064741778
Time per test: O(\sqrt{n})

#include <bits/stdc++.h>

using namespace std;

inline vector<bool> get_primes(const int n) {
  vector<bool> prime(n + 1, true);
  for (int i = 4; i <= n; i += 2) {
    prime[i] = false;
  }
  for (int i = 3; i * i <= n; i += 2) {
    if (prime[i]) {
      for (int j = i * i; j <= n; j += i) {
        prime[j] = false;
      }
    }
  }
  return prime;
}

inline vector<int64_t> get_square_primes(const vector<bool> &prime) {
  vector<int64_t> primes;
  for (int i = 2; i < (int) prime.size(); i++) {
    if (prime[i]) {
      primes.push_back((int64_t) i * i);
    }
  }
  return primes;
}

int main() {
  ios_base::sync_with_stdio(false);
  cin.tie(nullptr);
  cout.tie(nullptr);
  const vector<bool> is_prime = get_primes((int) 1e5);
  const vector<int64_t> sq = get_square_primes(is_prime);
  int test_cases;
  cin >> test_cases;
  while (test_cases-- > 0) {
    int64_t n;
    cin >> n;
    bool found = false;
    for (int i = 1; (sq[i] << 1) < n - sq[0]; i++) {
      const int64_t left = n - sq[0] - sq[i];
      const int64_t root = (int64_t) sqrt(left);
      if (root * root == left && is_prime[root]) {
        found = true;
        break;
      }
    }
    cout << (found ? "Yes\n" : "No\n");
  }
  return 0;
}

I scratched my head all along and just couldn’t get this working under the TL during the contest. IMHO, the TL should have been considerate enough.

This would result in TLE, when we are calculating the sum of all pairs, cause that would take O(N/(log^2(n))
So per query it will take O(N) right?

Am I missing something here?

The reason why Legendre TLEs is N = 0. Here’s my submission with it that comfortably fits within the TL. It was rude of setters to include such edge cases, but oh well, gotta read the problem carefully.

2 Likes

The idea is meet-in-the-middle where we partition triplets (a, b, c) into pairs (a, b) and individual values of c. (a, b) pairs are precomputed once, and for each query, we only iterate over c. There are O(\sqrt N / \log N) values of c to check and checking each takes O(\log N) time, so the complexity per testcase is the same O(\sqrt N) as in the editorial.

The tight part is sorting in precomputation, but you can either:

  • not sort at all and store up to two (a, b) pairs per value of a^2 + b^2 to ensure a \ne c and b \ne c or;
  • use some heuristics to cut a big chunk of pairs as a, b \le \sqrt N doesn’t necessarily mean a^2 + b^2 \le N so you can drop all pairs with the sum of squares too big.
1 Like

@iceknight1093 the question does not say sum of test cases is bounded, then complexity would be O(t * root(n))

But when we find sqrt to get value of a, wont the complexity get an additional factor of

O(t * root(n) * log(root(n))

And how is this passing the time limit

So you would precompute all pairs (a,b) such that’s a^2 + b^2 \leq 10^{10} ? Even before we start answering test cases.

Okay this will work.

1 Like

Ah, I forgot to make a note of that.
Yes, you do have a binary search to compute the square root - however, this binary search is only done for each prime.
However, this is done not for every integer till \sqrt N, but for every prime till \sqrt N - and the prime number theorem tells us that there are \mathcal{O}\left(\frac{\sqrt N}{\log(\sqrt N)}\right) primes \leq \sqrt N; the two logs cancel out and the true complexity remains \mathcal{O}(\sqrt N).

2 Likes

I implemented the above approach and got AC. Solution

1 Like

Thanks a lot!!!

Hey @srinivas1999 ,
In your solution, the vector sumab is not sorted, right?

sumab.pb(a + b); // Line 210

How come you’re able to do binary_search on it?

bool check(ll c, ll n, int ind) // Line 183

{

  ll sum = n - c * c;

  return binary_search(sumab.begin(), sumab.begin() + ind, sum);
}

Sort them in O(\cfrac{N}{\log^2N}\times\log(\cfrac{N}{\log^2N})) = O(\cfrac{N}{\log N}) time.

Hey @nskybytskyi, are you sure this would pass in the given time limit (1 Second)?:thinking:

Technically speaking, std::lower_bound (and, by extension, std::binary_search) only requires the range to be partitioned, not necessarily sorted. Additionally, the set of pairs is generated in a roughly sorted manner: first ascending by a, then ascending by b. However, it is undeniable that some luck was involved in this case.

True, I forgot to sort. Maybe I just got lucky :joy:

Yeah :joy:

Your solution fails for this input:

1
54

Expected Output:

NO

Your Output:

YES

True, I forgot to sort. Maybe I just got lucky :joy:

Yeah. So Lucky. 54 seems to be fairly small number. At least one test file could’ve included all Even Integers in [0, 998].