GCD_QUERIES - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Authors: d_k_7386
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

2422

PREREQUISITES:

Sieve of Eratosthenes

PROBLEM:

You’re given an array of N integers, each between 1 and 10^6.
Answer Q queries of the following form on this array:

  • Given an integer x, find the smallest integer y present in A such that \gcd(x, y) \gt 1, print it, and delete it from A.
  • If no such y exists, then print and delete the smallest element of A instead.

EXPLANATION:

The main observation required to solve this problem is the fact that, if \gcd(x, y) \gt 1, then there must exist a prime number p such that p divides both x and y.

So, to answer a given query x, we can do the following:

  • Quickly all prime factors of x, say they’re p_1, p_2, \ldots, p_k.
  • For each p_i, find the smallest element of the array that’s divisible by p_i, say y_i.
  • The answer is then simply \min(y_1, y_2, \ldots, y_k).
  • Of course, if none of the y_i exist simply take the smallest element of the array.

The main hurdle is now how we would do this quickly. Prime factorization in \mathcal{O}(\sqrt{x}) each time is almost certainly too slow, since we need to factorize several numbers (all the array elements and the queries).
Notice the constraint A_i \leq 10^6, which we’ll use to our advantage.

Prime factorizing small numbers fast

Notice that for this problem, we don’t really need all the factors of a number: we only need its prime factors.
In particular, a number x can only have \leq \log{x} distinct prime factors, since 2^{\log x} = x.
This gives us hope that maybe we can find them all quickly.

Since A_i \leq 10^6, let’s do some preprocessing.
Using a (slightly modified) sieve of Eratosthenes, one can find, for each 1 \leq x \leq 10^6, a list of all its prime factors.
This can be done in \mathcal{O}(M\log\log M), where M = 10^6 here.

Once these lists are stored, prime factorizing a number is simple: just iterate through the list!

Now that we’re done with the prime factorization, let’s move on to answering queries.
We need to support the following:

  • For a given prime p, find the smallest element of A that’s divisible by p.
  • Delete an element from A, which also requires us to update the information of all primes it’s divisible by.

So, we need a data structure that supports quick insertion/deletion, and quickly finding the minimum.
The easiest way to achieve this is by using std::set/std::multiset in C++ or TreeSet in Java. Coding in Python will require you to be a bit more clever, though it’s still doable — see the editorialist’s code linked below.

At any rate, pick your data structure of choice, and assign one to each prime, using say a map.
Then,

  • For each i, quickly prime factorize A_i using the approach discussed above. For each prime p that divides A_i, insert A_i into its corresponding set.
  • Notice that this uses \mathcal{O}(N\log{10^6}) memory, since each A_i has \leq \log{A_i} prime factors and so will only be present in that many lists.
  • Then, for each query x,
    • Prime factorize x as discussed.
    • For each prime factor, find the smallest remaining element; after which you’ll know which value needs to be deleted from the array
  • To delete an element, simply reverse the insertion process: find each prime factor, and delete A_i from the appropriate list.

Each element is inserted into at most \log{A_i} lists and deleted at most once from each, so the complexity of this part is \mathcal{O}(N\log{10^6}\log N); perhaps with an extra \log depending on implementation.
Each query requires us to iterate through \leq \log x lists; and query each one for the minimum which can be done in \mathcal{O}(1). So, the complexity of this part is \mathcal{O}(Q\log{10^6}), again with maybe an extra \log from a map.

TIME COMPLEXITY

\mathcal{O}(M\log\log M) precomputation, followed by \mathcal{O}((N+Q)\log N\log M) per test case, where M = 10^6.

CODE:

Setter's code (C++)
#define ll long long int
#include<bits/stdc++.h>
#define loop(i,a,b) for(ll i=a;i<b;++i)
#define rloop(i,a,b) for(ll i=a;i>=b;i--)
#define in(a,n) for(ll i=0;i<n;++i) cin>>a[i];
#define pb push_back
#define mk make_pair
#define all(v) v.begin(),v.end()
#define dis(v) for(auto i:v)cout<<i<<" ";cout<<endl;
#define display(arr,n) for(int i=0; i<n; i++)cout<<arr[i]<<" ";cout<<endl;
#define fast ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);srand(time(NULL));
#define l(a) a.length()
#define s(a) (ll)a.size()
#define fr first
#define sc second
#define mod 1000000007
#define endl '\n'
#define yes cout<<"Yes"<<endl;
#define no cout<<"No"<<endl;
using namespace std;
#define debug(x) cerr << #x<<" "; _print(x); cerr << endl;
void _print(ll t) {cerr << t;}
void _print(int t) {cerr << t;}
void _print(string t) {cerr << t;}
void _print(char t) {cerr << t;}
void _print(double t) {cerr << t;}
template <class T, class V> void _print(pair <T, V> p);
template <class T> void _print(vector <T> v);
template <class T> void _print(set <T> v);
template <class T, class V> void _print(map <T, V> v);
template <class T> void _print(multiset <T> v);
template <class T, class V> void _print(pair <T, V> p) {cerr << "{"; _print(p.fr); cerr << ","; _print(p.sc); cerr << "}";}
template <class T> void _print(vector <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template <class T> void _print(set <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template <class T> void _print(multiset <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template <class T, class V> void _print(map <T, V> v) {cerr << "[ "; for (auto i : v) {_print(i); cerr << " ";} cerr << "]";}

ll add(ll x,ll y)  {ll ans = x+y; return (ans>=mod ? ans - mod : ans);}
ll sub(ll x,ll y)  {ll ans = x-y; return (ans<0 ? ans + mod : ans);}
ll mul(ll x,ll y)  {ll ans = x*y; return (ans>=mod ? ans % mod : ans);}
vector<multiset<ll>> present(1e6+1);
vector<ll> spf;
void fn(){
    spf.assign(1e6+1,0);
    loop(i,0,1e6+1) spf[i] = i;
    loop(i,2,1e6+1){
        if(spf[i] == i) {
            for(int j = i;j<=1e6;j+=i)  spf[j] = min(spf[j],i);
        }
    }
}



vector<ll> prime_fact(ll n){
    vector<ll> ans;
    while(n>1)  {
        int j = spf[n];
        while(n%j == 0) n/=j;
        ans.pb(j);
    }
    return ans;
}

void solve(){
    ll n;   cin>>n;
    vector<ll> v(n);    in(v,n);
    present[1].clear();
    set<ll> used;
    loop(i,0,n){
        present[1].insert(v[i]);
        vector<ll> fact = prime_fact(v[i]);
        for(auto j:fact)    {
            present[j].insert(v[i]);
            used.insert(j);
        }
    }
    ll q;   cin>>q; 
    while(q--){
        ll x;   cin>>x;
        vector<ll> fact = prime_fact(x);
        ll mi = INT_MAX;
        for(auto i:fact)    
            if(present[i].size() > 0)    mi = min(mi,*present[i].begin());
        if(mi == INT_MAX){
            mi = *present[1].begin();
        }
        fact = prime_fact(mi);
        present[1].erase(present[1].find(mi));
        for(auto i:fact)    {
            present[i].erase(present[i].find(mi));
        }
        cout<<mi<<' ';
    }
    cout<<endl;
    for(auto i:used)    present[i].clear();
}


int main()
{
    fast
    fn();
    int t; cin>>t;
    while(t--) solve();
    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    input_checker in;
    const int N = 1e6 + 10;
    vector<vector<int>> f(N);
    vector<bool> ip(N, true);
    for (int i = 2; i < N; i++) {
        if (ip[i]) {
            for (int j = i; j < N; j += i) {
                f[j].emplace_back(i);
                ip[j] = false;
            }
        }
    }
    vector<multiset<int>> st(N);
    int tt = in.readInt(1, 100000);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 2e5);
        in.readEoln();
        auto a = in.readInts(n, 1, 1e6);
        in.readEoln();
        int q = in.readInt(1, n);
        in.readEoln();
        auto b = in.readInts(q, 1, 1e6);
        in.readEoln();
        for (int i = 0; i < n; i++) {
            st[1].emplace(a[i]);
            for (int t : f[a[i]]) {
                st[t].emplace(a[i]);
            }
        }
        for (int x : b) {
            int y = N;
            for (int t : f[x]) {
                if (!st[t].empty()) {
                    y = min(y, *st[t].begin());
                }
            }
            if (y == N) {
                y = *st[1].begin();
            }
            st[1].erase(st[1].find(y));
            cout << y << '\n';
            for (int t : f[y]) {
                st[t].erase(st[t].find(y));
            }
        }
        for (int i = 0; i < n; i++) {
            st[1].clear();
            for (int t : f[a[i]]) {
                st[t].clear();
            }
        }
    }
    assert(sn <= 2e5);
    in.readEof();
    return 0;
}
Editorialist's code (Python)
maxn = 10**6 + 10
prf = [0] * maxn
for i in range(2, maxn):
	if prf[i] > 0: continue
	for j in range(i, maxn, i):
		prf[j] = i

for _ in range(int(input())):
	n = int(input())
	mapper = {}
	id = 0
	a = list(map(int, input().split()))
	vals = [ ]
	mark = [0]*n
	for i in range(n):
		x = a[i]
		while x > 1:
			p = prf[x]
			if p not in mapper:
				mapper[p] = id
				vals.append([ ])
				id += 1
			vals[mapper[p]].append(i)
			
			while x%p == 0: x //= p
	for i in range(id):
		vals[i].sort(key = lambda x: -a[x])
	
	ord = list(range(n))
	ord.sort(key= lambda x: -a[x])


	q = int(input())
	queries = list(map(int, input().split()))

	for x in queries:
		primes = []
		while x > 1:
			p = prf[x]
			while x%p == 0: x //= p
			primes.append(p)
		
		choice = -1
		for p in primes:
			if p not in mapper: continue
			loc = mapper[p]
			while vals[loc]:
				if mark[vals[loc][-1]]: vals[loc].pop()
				else:
					who = vals[loc][-1]
					if choice == -1 or a[who] < a[choice]: choice = who
					break
		if choice == -1:
			while mark[ord[-1]] == 1: ord.pop()
			choice = ord[-1]
		
		print(a[choice], end = ' ')
		mark[choice] = 1
5 Likes

For editorialist
Please don’t include your entire template in the solution code. Only include that much part which is necessary. It will help in readability.

4 Likes

Hi, I’m the editorialist.

My code (linked at the bottom) will never have templates.
I can’t force the setter/tester to modify their coding styles, and I’m not going to sit and modify their code to remove templates because that’s a lot of effort.
The one thing I can guarantee is that my code won’t have excessive templates, can’t say anything about others’ code.

8 Likes

I really appreciate your effort but I use c++,
still I will try to understand your code

Setter has used array of multiset i find it little hard to handle , i am not sure why it not working without the clearing, as all ements should’ve been remove eitherway

i then tried unordered_map < ll , multiset > nums and found it to be easier
viewsolution/94381861

Why? You delete one element for each query, but the constraints don’t guarantee Q = N.

1 Like

kk got it , thanks bro :smiling_face_with_three_hearts:

I wanna ask instead of using vector<multiset < int > > , I used vector< vector<i nt > > to store the values in decreasing order and mapped with prime factors , and then for each query , used the pop_back( ) method of vector to remove the values ,

we start with last values which are the smallest one for each prime factor of query x
if that value is not valid , already exhausted in previous queries I remove it using pop_back( ) .
else I consider it to be potential answer or smallest value available in array .

since pop_back( ) operation runtime is O(1) why does this give TLE + WA ?

My Code :- CodeChef: Practical coding for everyone

Don’t know about WA, but the TLE is because you’re declaring vector<vector<int>>avail(lim+1); for each test case.
Here lim = 1e6 and there can be upto 10^5 testcases, of course that’s too slow.
The easiest way to get around is is to do map<int, vector<int>> lim; which creates exactly the memory it needs (which is what my Python code does since there’s no inbuilt sorted set).

1 Like

Ok that barely passed but WA ?
https://www.codechef.com/viewsolution/94401046

can anyone point out the error or share test case?

Your sieve was buggy.

for(int i=2;i*i<=lim;++i){
    if(!prime)continue;
    for(int j=2*i;j<=lim;j+=i){
        prime[j] = false;
        all_prime[j].push_back(i);
    }
    all_prime[i].push_back(i);
}

There are 2 mistakes here:

  • if (!prime) continue; should be if (!prime[i]) continue;, you weren’t actually checking if i was a prime!
    if(!prime) essentially means if(!prime[0]), and that’s always true in your code. Because of this, all_prime[i] instead contained all the factors of i, not just the prime factors (well, not quite all factors but that’s related to the next point). This is what made it slow.
  • i*i <= lim means you were only considering i upto 1000, any larger primes weren’t being considered at all.

Fix these two errors and you should get AC with a runtime of less than a second.

2 Likes

oh thanks a lot , I messed up in sieve !!

Hey, I have applied similar logic in JAVA, but it is giving TLE in 3-cases, kindly help me
https://www.codechef.com/viewsolution/94301520

Hey, I’ve implemented this solution but I’m getting WA. Could anyone point out what I’m missing here?

Code: CodeChef: Practical coding for everyone

Thanks in advance!

a.erase(ans); should be a.erase(a.find(ans));

@iceknight1093 Could you please clarify why using a.erase( ans ) is incorrect?

TIA!

erase will delete all instance of a number even if that number appear multiple time in the array, erase(find) will only delete one.

Snippet is part of the code and many beginners learn new things via template. I don’t think deleting template would be great. Main code is clearly visible, what’s the problem.

@iceknight1093 man i really appreciate your efforts, no doubt you’re great.

1 Like

Thanks so much for your efforts