NONPRIME101 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

Simple

PREREQUISITES:

Primality testing

PROBLEM:

You’re given an array A. Find any pair of its elements whose sum is not prime.

EXPLANATION:

One of the easiest ways to obtain a sum that isn’t prime, is to try and make said sum even (and greater than 2, of course).

In particular, note that:

  • If A contains at least two even numbers, their sum is an even number larger than 2 and hence not prime.
  • If A contains at least two odd numbers, of which one is \geq 3, their sum is again an even number larger than 2.

Both above cases are easy to check.
If both checks fail, the elements of A are rather restricted in what they can be: there’s at most one even element, and the odd elements are either all 1, or there’s only odd element.
In particular, there’s at most one distinct even element, and one distinct odd element.

Let x be the even element and y be the odd element. (Note that if x doesn’t exist, we can immediately say that no valid pair exists.)
If x+y is not prime we’ve found our pair, otherwise no pair exists.

So, we only really need to perform a single primality check, and since A_i \leq 100 this will be done on a number that’s \leq 200 and hence pretty fast.


It’s possible to simplify implementation a fair bit by simply brute-forcing it.

From above, any valid pair must definitely include an element that’s \gt 1.
So,

  1. Fix an index i_1 such that A_{i_1} is even, and try every j from 1 to N.
  2. Fix an index i_2 such that A_{i_2} is odd and greater than 1, and try every j from 1 to N.

If a valid pair exists, it’ll definitely be found by checking just these two indices (to see why, go through the cases we had earlier).
This way, we need to perform at most 2N primality checks, which is not an issue since the numbers being tested are \leq 200.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

bool prime(int x){
    for (int i = 2; i < x; i++){
        if (x % i == 0) return false;
    }
    return true;
}

void Solve() 
{
    int n; cin >> n;
    vector <int> a(n);
    for (auto &x : a) cin >> x;
    
    for (int x = 0; x < n; x++) if (a[x] != 1){
        for (int y = 0; y < n; y++) if (y != x){
            if (!prime(a[x] + a[y])){
                cout << x + 1 << " " << y + 1 << "\n";
                return;
            }
        }
    }
    
    for (int x = 0; x < n; x++) if (a[x] == 1){
        for (int y = 0; y < n; y++) if (y != x){
            if (!prime(a[x] + a[y])){
                cout << x + 1 << " " << y + 1 << "\n";
                return;
            }
        }
        break;
    }
    
    cout << -1 << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Editorialist's code (Python)
def prime(n):
    i = 2
    while i*i <= n:
        if n%i == 0: return 0
        i += 1
    return 1

for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    evens, odds = [], []
    mxodd = 1
    for i in range(n):
        if a[i]%2 == 0: evens.append(i)
        else:
            odds.append(i)
            mxodd = max(mxodd, a[i])
    
    if len(evens) >= 2:
        print(evens[0]+1, evens[1]+1)
    elif len(odds) >= 2 and mxodd > 1:
        x, y = 0, 0
        while a[odds[x]] == 1: x += 1
        if x == 0: y = 1
        if x > y: x, y = y, x
        print(min(odds[x], odds[y])+1, max(odds[x], odds[y])+1)
    else:
        if len(evens) == 0: print(-1)
        else:
            x, y = evens[0], odds[0]
            if x > y: x, y = y, x
            if prime(a[x] + a[y]): print(-1)
            else: print(x+1, y+1)
1 Like
# cook your dish here
import math
def checkprime(k):
    if k==3: return 0
    else:
        for i in range(2,int(math.sqrt(k))+1):
            if k%i==0:
                return 1
        return 0


for _ in range(int(input())):
    n=int(input())
    a=list(map(int,input().split()))
    odd=[]
    even=[]
    one=-1
    
    
    for i in range(len(a)):
        if a[i]%2==0:
            even.append(i+1)
            if len(even)==2:
                print(*even)
                break
        elif a[i]%2!=0 :
            if a[i]==1 and one==-1:
                odd.append(i+1)
                one=1
            elif a[i]!=1:
                odd.append(i+1)
            if len(odd)==2:
                print(*odd)
                break
            
        
    else:
        if one==-1:
            if checkprime(a[0]+a[1]):
                print(1,2)
            else: print(-1)
        else:
            if even==[]: print(-1)
            else:
                if checkprime(a[even[0]-1]+1):
                    print(one,even[0])
                else: print(-1)
        Can you please tell why this solution does not work

can someone please analyse why I got TLE with this?

# cook your dish here
primes = set([2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199])

t = int(input())
for _ in range(t):
    n = int(input())
    an = list(map(int, input().split()))
    # anset = {}
    # for idx, i in enumerate(an):
    #     if i not in anset:
    #         anset[i] = []
    #     anset[i].append(idx)
    #     if len(anset[i]) >= 2 and i*2 not in primes:
    #         print(anset[i][0]+1, anset[i][1]+1)
    #         break
    if True:
        printed = False
        evens = []
        odds = []
        for idx, i in enumerate(an):
            if i%2 == 0:
                evens.append((idx, i))
            else:
                odds.append((idx, i))
        # evens = sorted(evens, key=lambda x: x[1])[::-1]
        # odds = sorted(odds, key=lambda x: x[1])[::-1]
        
        if len(evens) >= 2:
            # print("HERE 111")
            for idxx, (idx, i) in enumerate(evens):
                for jdxx, (jdx, j) in enumerate(evens[idxx+1:], idxx+1):
                    if (i + j) != 2:
                        print(idx+1, jdx+1)
                        printed = True
                        break
                if printed:
                    break
            if printed:
                continue

        elif len(odds) >= 2:
            # print("HERE 222")
            for idxx, (idx, i) in enumerate(odds):
                for jdxx, (jdx, j) in enumerate(odds[idxx+1:], idxx+1):
                    if (i + j) != 2:
                        print(idx+1, jdx+1)
                        printed = True
                        break
                if printed:
                   break
            if printed:
                continue
        
        else:
            # print("HERE 333")
            for idxx, (idx, i) in enumerate(evens):
                for jdxx, (jdx, j) in enumerate(odds):
                    # if i%2 == j%2:
                    #     continue
                    if i + j not in primes:
                        print(idx+1, jdx+1)
                        printed = True
                        break
                if printed:
                    break
            if printed:
                continue

        print(-1)

@sajangohil11
If A = [1, 1, 1, 1, \ldots], what do you think the complexity of this piece of code is?


@meldig23
Your code always has either one = -1 or one = 1, so this print is wrong when A_1 \neq 1.

3 Likes

how did you come up with o(n) time complexity , for finding prime its o(n)[worst case], if a!=1 for i=0 and lets say we got the ans at n then won’t it be o(n^2) and lets say first loop ends so o(n^2) and if a == 1 we check and worst case we get at n so again o(n^2) so worst is o(n^3) and if we optimise prime lets say till n/2 then probably o(n^2 root2)??

Please read the editorial carefully.
The first solution I presented requires exactly one primality test (and is otherwise linear time); the second one requires \mathcal{O}(N) primality tests.

It’s also important to be precise when talking about complexity, you can’t just use N everywhere.
Here, we’re testing for primes only up till 200 (which is completely unrelated to N, so the complexity of testing isn’t “\mathcal{O}(N)”).

1 Like

Thanks for the solution
But i seemed to have given the wrong code

import math
def checkprime(k):
    if k==3: return 0
    else:
        for i in range(2,int(math.sqrt(k))+1):
            if k%i==0:
                return 1
        return 0


for _ in range(int(input())):
    n=int(input())
    a=list(map(int,input().split()))
    odd=[]
    even=[]
    one=-1
    for i in range(len(a)):
        if a[i]%2==0:
            even.append(i+1)
            if len(even)==2:
                print(*even)
                break
        elif a[i]%2!=0 and a[i]!=1:
            odd.append(i+1)
            if len(odd)==2:
                print(*odd)
                break
            elif len(odd)==1 and one!=-1:
                print(*odd,one)
                break
        else:
            one=i+1
    else:
        if one==-1:
            if checkprime(a[0]+a[1]):
                print(1,2)
            else: print(-1)
        else:
            if even==[]: print(-1)
            else:
                if checkprime(a[even[0]-1]+1):
                    print(one,even[0])
                else: 
                    print(-1)

Here one stores the index of one.
Then why does it not work

I can’t understand why do you interchange x and y here

if x > y: x, y = y, x
print(min(odds[x], odds[y])+1, max(odds[x], odds[y])+1)

We can print (x, y) in any order right?

1
2
5 1

You print -1, which is wrong.


The problem initially required x \lt y so I wrote my code for that version; later it was updated to allow the pair in any order.
You’re right that in the current state, it isn’t required.