KGCD - Editorial

PROBLEM LINK: The Chosen One

Setter: Abhishek Jugdar
Tester: Aman Singhal
Editorialist: Abhishek Jugdar

DIFFICULTY:

MEDIUM

PREREQUISITES:

Sieve, Principle of Inclusion-Exclusion

PROBLEM:

Given N numbers, A1, A2, A3..., AN ranging from 1 to N, find the Kth smallest GCD and print any pair of indices with the given GCD.

EXPLANATION:

We can modify sieve to compute res[i] = no. of pairs having GCD = i.
Compute freq[ ] = frequency of elements in the array beforehand.

for(i = n; i > 0; i–) {

x = \sum_{j | i}^{N} freq[j]

res[i] = \frac{x*(x-1)}{2} - \sum_{j | i}^{N} res[j];
}

Let g = Kth smallest GCD
For finding any pair having GCD = g, find all numbers x divisible by g, divide them by g, thus reducing the problem to finding a co-prime pair in this new array. One of the ways is: For each number x in the new array, find if there is a number which is not divisible by any of the prime factors of x(can be done using inclusion-exclusion, pre-compute prime factors using another sieve), if there is such a number, say x, we can set x as the first number and iterate back to find the second number, say y.

Time Complexity - nlog(n) + n*(2^x)*x
Here x denotes maximum number of distinct prime factors in any number from 1 to 10^6, which is 7. But with appropriate checks it works much faster in practice. You can refer to the solutions below for implementation details.

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
using namespace std;

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);
    int t;
    cin >> t;

    while (t--)
    {
        int n;
        long long k;
        cin >> n >> k;

        vector <int> arr(n);
        for (int &x : arr)
            cin >> x;

        vector <int> cnt(n + 1, 0);   // freq count
        vector <long long> res(n + 1, 0);

        for (int x : arr)
            cnt[x]++;

        for (int i = n; i > 0; i--) {
            int div_count = 0;
            long long rem = 0;
            for (int j = i; j <= n; j += i) {
                div_count += cnt[j];
                rem += res[j];
            }
            res[i] = (1LL * div_count * (div_count - 1)) / 2 - rem;
        }

        int ans = 0;
        for (int i = 1; i <= n; i++) {
            if (k <= res[i]) {
                ans = i;
                break;
            }
            k -= res[i];
        }
        cout << ans << '\n';

        // Part to find two indices having k-th gcd.

        vector <int> spf(n + 1, 0);   // smallest prime factor
        cnt.assign(n + 1, 0);

        for (int i = 2; i <= n; i++) {
            if (!spf[i]) {
                spf[i] = i;
                if (n / i < i) continue;

                for (int j = i * i; j <= n; j += i) {
                    if (!spf[j])
                        spf[j] = i;
                }
            }
        }

        vector <pair <int, int> > req;
        map <int, int> same_cnt;  // Only 1 duplicate for each number is needed, rest copies can be discarded

        for (int i = 0; i < n; i++) {
            if (arr[i] % ans == 0) {
                if (same_cnt[arr[i]] > 1) continue;

                same_cnt[arr[i]]++;
                req.push_back(make_pair(arr[i] / ans, i));
            }
        }

        int first_index, second_index;
        // Now we have to search for a number such that atleast one number before it isn't divisible by any prime factors of our number.

        int first_num = -1;
        for (int i = 0; i < req.size(); i++) {
            int x = req[i].first;
            vector <int> primes;

            while (x > 1) {
                int small = spf[x];
                primes.push_back(small);
                while (x % small == 0)
                    x /= small;
            }

            int sz = primes.size();
            int div_count = 0;
            for (int j = 1; j < (1 << sz); j++) {
                int curr_number = 1;
                for (int l = 0; l < sz; l++) {
                    if (j >> l & 1)
                        curr_number *= primes[l];
                }
                int bits = __builtin_popcount(j);
                div_count += (bits & 1 ? cnt[curr_number] : -cnt[curr_number]);
                cnt[curr_number]++;
            }

            if (div_count < i) {
                first_num = i;
                break;
            }
        }

        assert(first_num != -1);
        first_index = req[first_num].second;

        for (int i = first_num - 1; i >= 0; i--) {
            if (__gcd(req[first_num].first, req[i].first) == 1) {
                second_index = req[i].second;
                break;
            }
        }
        int curr_gcd = __gcd(arr[first_index], arr[second_index]);
        assert(ans == curr_gcd);
        cout << first_index + 1 << ' ' << second_index + 1 << '\n';
    }
}


Tester's Solution (Python)
import sys,collections
import math as mt 
 
input=sys.stdin.readline
MAXN = 1000001
 
  
spf = [0 for i in range(MAXN)] 
 
#sieve for storing prime numbers
def sieve(): 
    spf[1] = 1
    for i in range(2, MAXN): 
        spf[i] = i 
 
    for i in range(4, MAXN, 2): 
        spf[i] = 2
  
    for i in range(3, mt.ceil(mt.sqrt(MAXN))): 
        if (spf[i] == i): 
            for j in range(i * i, MAXN, i):  
                if (spf[j] == j): 
                    spf[j] = i 
 
sieve()
 
#For getting factorization of numbers
def getFactorization(x): 
    ret = list() 
    while (x != 1): 
        ret.append(spf[x]) 
        x = x // spf[x] 
  
    return ret
 
def bitsoncount(x): 
    return bin(x).count('1')
 
#function to find any coprime pair in an array
def coprimepair(C):
    high=max(C)
    n=len(C)
    count= [0 for i in range(high+1)]
    for i in range(0,n):
        count[C[i]]=count[C[i]]+1
 
    cal= [0 for i in range(high+1)] 
    for i in range(high,1,-1):
        counter=0
        j=i
        while(j<=high):
            counter=counter+count[j]
            j=j+i
            
        cal[i]=counter
 
    v1=0
    ind=0
 
    for i in range(n):
        PA=getFactorization(C[i])
        PA=list(set(PA))
        m=len(PA)
        odd=0
        even=0
        ps=1<<m
        for counter in range(1,ps):
            p=1
            for j in range(m):
                if (counter & (1<<j)):
                    p*=PA[j]
            if (bitsoncount(counter) & 1):
                odd+=cal[p]
            else:
                even+=cal[p]
        
        if ((odd-even)!=n):
            v1=C[i]
            ind=i
            break
 
    v2=0
    for i in range(n):
        if (ind!=i and mt.gcd(v1,C[i])==1):
            v2=C[i]
            break
    return v1,v2
 
#print(getFactorization(24))
 
 
def kgcd(n,A,K):
    high=max(A)
    count = [0 for i in range(high+1)] 
    for i in range(0, n) : 
        count[A[i]]=count[A[i]]+1
 
    counter=0
    ans=0
    vis=[0 for i in range(high+1)]  #vis to avoid multiple count of same values
    for i in range(high,0,-1):
        j=i
        rv=0
        while(j<=high):
            counter=counter+count[j]
            rv=rv+vis[j]
            j=j+i
            
        n1=(counter*(counter-1))//2
        ans=ans+n1-rv
        vis[i]=n1-rv
        if (ans>=K):
            return i
        
        counter=0
 
    return 1   
    
def main():
    T=int(input())
    for _ in range(T):
        N,K=map(int,input().split())
        A=list(map(int,input().split()))
        val=(N*(N-1))//2
        ans=kgcd(N,A,val-K+1)
        print(ans)
 
        G={}
        for i in range(N+1):
            G[i]=0
        C=[]
        for i in range(N):
            if (A[i]%ans==0 and G[A[i]//ans]<=2 ):
                G[A[i]//ans]=G[A[i]//ans]+1
                C.append(A[i]//ans)
 
                
        v1,v2=coprimepair(C)      
        v1,v2=v1*ans,v2*ans        
        #print(v1,v2)
        ind1,ind2=-1,-1
        for i in range(N):
            if (A[i]==v1):
                ind1=i+1
                break
            
        for i in range(N):
            if (i!=(ind1-1) and A[i]==v2):
                ind2=i+1
                break
        print(ind1,ind2)
        
main()


For doubts, please leave them in the comment section, I’ll address them.

4 Likes

Could someone explain the setters implementation of the 2nd part of the question more clearly especially the part where bitwise operations have been used ?.

1 Like

Why does the second part of the solution work? Could you please explain or share some relevant links/articles related to it?.

I hope you have understood the part that after we make the new array, our problem is just reduced to finding 1 co-prime pair in the array. So, consider any number x in the new array, to check whether it is co-prime to any other element in the new array, we can do this -
Suppose x has 4 prime factors, a,b,c,d. Now I can calculate the number of elements divisible by atleast one of a,b,c,d by using inclusion-exclusion. Let this number be cnt. If the index of x (in new array) is i (0-based indexing), and cnt < i, that means that there is atleast 1 number before i, which is not divisible by any prime factors of x, and hence we could iterate back and find that second number. This is what has been implemented in the setter’s solution.

3 Likes

I have tried to explain the 2nd part of the question in more detail in my comment above. The part with the bitwise operations is nothing but the inclusion-exclusion part. You might want to read more about inclusion-exclusion if you haven’t before.

My WA Code
#include <bits/stdc++.h>
using namespace std;

#define ll int64_t
#define ull unsigned long long
#define lld long double
#define FIO ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
#define pb push_back
#define eb emplace_back
#define ff first
#define ss second
#define vt vector
#define vll vt<ll>
#define pll pair<ll,ll>
#define vpll vt<pll>
#define vvll vt<vll>
#define all(v) v.begin(),v.end()
#define rall(v) v.rbegin(), v.rend()
#define FOR(i,n) for(ll i=0;i<n;i++)
#define ffo(i,a,b) for(ll i=a;i<=b;i++)
#define rfo(i,a,b) for(ll i=a;i>=b;i--)
#define space cout<<"\n\n";
#define endl '\n' // comment this line in interactive prob
template <typename T> using mxpq = priority_queue<T>;
template <typename T> using mnpq = priority_queue<T, vt<T>, greater<T>>;
#define fps(x,y) fixed<<setprecision(y)<<x
#define merg(a,b,c) set_union(a.begin(),a.end(),b.begin(),b.end(),inserter(c,c.begin()))
#define mmss(arr,v) memset(arr,v,sizeof(arr))

const ll mod = 1e9 + 7;
const ll N = 1e6 + 6;
const ll maxN = 1e5 + 15;
const ll MAX_SIZE = 2e6 + 6;
const ll INF = 0x3f3f3f3f3f3f3f3fll;
const lld PI = 3.14159265359;

int dx[4] = { -1, 0, 1, 0};
int dy[4] = {0, 1, 0, -1};
// up, right, down, left
//int dx[] = {+1,-1,+0,+0,-1,-1,+1,+1}; // Eight Directions
//int dy[] = {+0,+0,+1,-1,+1,-1,-1,+1}; // Eight Directions
//int dx[]= {-2,-2,-1,1,-1,1,2,2}; // Knight moves
//int dy[]= {1,-1,-2,-2,2,2,-1,1}; // Knight moves
// For taking a complete line as input: string s; getline(cin, s);
// For calculating inverse modulo, raise to the power mod-2.
// For (a^b)%mod, where b is large, replace b by b%(mod-1).

ll powerM(ll x, ll y, ll M = mod) { // default argument
	ll v = 1; x = x % M; while (y > 0) {if (y & 1)v = (v * x) % M; y = y >> 1; x = (x * x) % M;} return v;
}

ll power(ll x, ll y) {
	ll v = 1; while (y > 0) {if (y & 1)v = v * x; y = y >> 1; x = x * x;} return v;
}

int largest_bit(long long x) { // based on 0-indexing
	return x == 0 ? -1 : 63 - __builtin_clzll(x);
}

void solve() {
	ll n, k;
	cin >> n >> k;
	ll kk = 1 + (n * (n - 1)) / 2 - k;
	vll a(n);
	vll myhash(n + 1, 0);
	vt<ll> howMany(n + 1, 0);
	vvll position(n + 1);
	FOR(i, n) {
		cin >> a[i];
		myhash[a[i]]++;
		position[a[i]].pb(i + 1);
	}
	for (ll i = n; i >= 1; --i) {
		ll c = 0, rm = 0;
		for (ll j = i; j <= n; j += i) {
			c += myhash[j];
			rm += howMany[j];
		}
		ll tmp = ((c * (c - 1)) / 2 - rm);
		howMany[i] = tmp;
		kk -= tmp;
		if (kk <= 0) {
			cout << i << endl;
			if (position[i].size() >= 2) {
				cout << position[i][0] << " " << position[i][1] << endl;
				return;
			}
			else {
				ll el = 0;
				for (ll j = i; j <= n; j += i) {
					if (myhash[j] > 0) {
						el = j;
						myhash[j]--;
						break;
					}
				}
				for (ll j = i; j <= n; j += i) {
					if (myhash[j] > 0 && __gcd(j, el) == i) {
						cout << position[el][0] << " " << position[j][0] << endl; return;
					}
				}
			}
			return;
		}
	}
}

int main()
{
#ifdef LOCAL
	freopen("in1.txt", "r", stdin);
	freopen("out1.txt", "w", stdout);
#endif
	FIO;
	int testcases = 1;
	cin >> testcases;
	for (int caseno = 1; caseno <= testcases; ++caseno) {
		// cout << "Case #" << caseno << ": ";
		solve();
	}
	return 0;
}

@darshancool25

I have written this program after understanding the editorial and my way of printing indices is just a little different.

I am picking up two elements that are divisible by the gcd and checking if their gcd is equal to what we require.

My code is going wrong in way of printing indices, but I am unable to come up with a counter test case.