PRIMESETMUL- Editorial

PROBLEM LINK:

Practice
Div-2 Contest
Div-1 Contest

Author: Vishesh Saraswat
Tester: Istvan Nagy
Editorialist: Aujasvit Datta

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Math, Sorting, Binary Search, Brute force

PROBLEM:

Given set S of N prime numbers and a number M. Let D be a number less than M such that the prime factors of D are only from S. We have to find the number of possible values of D.

QUICK EXPLANATION:

We can divide the set of primes into two disjoint sets of N/2 primes each. After that, for each set we will calculate all the numbers he likes using brute force. Now we can sort one of those lists of numbers and for each element A in the other list, do a binary search to find how many numbers are \leq floor(M/A) in the sorted list and add that to the final answer.

EXPLANATION:

We sort the set S and create two disjoint sets A_1 and A_2 with N/2 primes in each. The optimal way to do this is to take alternate elements of S in A_1 and the rest in A_2 (for reasons that are detailed in the subsequent parts of the editorial). So, for example, A_1 should contain all elements at even places in set S and A_2 should contain all elements at odd places in set S.

Now, for each set A_1 and A_2, we generate the numbers that can be formed from numbers in only that set and satisfy the constraints given in the problem. In other words, we solve our original problem on both sets A_1 and A_2 separately with the only difference that we also store the numbers that satisfy the given constraints in sets L_1 and L_2 respectively. This can be done using brute force. The code given below performs this operation:

void generate_mul(int curr = 1, int pos = 0) {
	if(curr > m) return; // if curr is greater than m then it can't satisfy the required constraints and will not be part of the solution

	if(pos >= a.size()) { // we have gone through the vector a so curr is added to l
		l.push_back(curr);
		return;
	}
	
	int temp = curr;

	generate_mul(curr, pos + 1); //if we don't multiply curr with the number at position pos in a

	while(temp <= m) {
		temp *= a[pos];
		generate_mul(temp, pos + 1);
	}
}

By running this code twice, we can generate L_1 and L_2. Please note that in the above code, the original value of curr is 1, so 1 will be present in both L_1 and L_2.

After this, we sort one of the lists out of L_1 and L_2 (in this editorial, list L_1 is sorted). Now, for each element D_2 in L_2, we count the number of elements D_1 in list L_1 such that D_1 \leq floor(M/D_2) using binary search. This is because the number D_1 * D_2 will satisfy the constraints of the problem (D_1*D_2 \leq M and since D_1 and D_2 are formed from elements of set S, the prime factorization D_1 * D_2 will also contain numbers that form a subset of S). This following code performs this task:

int number_of_possible_values() {
	sort(l1.begin(), l1.end());
	int ans = 0;
	for(auto i: l2) {
		ans += upper_bound(l1.begin(), l1.end(), (int) (floor(m/i) ) ) - l1.begin(); //returns index of first number > m/i
	}
	return ans; 
}

In the above code, since 1 will be present in both L_1 and L_2, all the numbers in L_1 and L_2 (which already satisfy the constraints of the problem) will also get counted in ans. As a result, ans will contain the final answer for the test case.

The time complexity of this solution will be approximately O(P log P + Q log P), with P and Q being the lengths of L_1 and L_2 respectively. According to the constraints of the problem, max(P, Q) \leq 2*10^6. By using the strategy outlined at the starting of the editorial(taking alternate elements from S for A_1 and A_2), we ensure that the values of P and Q are comparable and the value P log P + Q log Q is minimized.

SOLUTIONS:

Setter's Solution
#include "bits/stdc++.h"
using namespace std;
/*
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using ordered_set = tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update>;
*/

#define all(x) begin(x), end(x)
#define rall(x) rbegin(x), rend(x)
#define sz(x) (int)(x).size()

using ll = long long;
#define int ll
const int mod = 1e9+7;

int mx = 0;
vector<int> p, primes;
int cnt = 0;

void generate_mul(int cur, int pos) {
    if (pos == sz(primes)) {
        p.push_back(cur);
        ++cnt;
 //     if (cnt%10000==0) cerr << cnt << '\n';
        return;
    }
    int acur = cur;
    int curp = 1;
    for (int pwr = 0; acur * curp <= mx; ++pwr) {
        cur = acur * curp;
        curp *= primes[pos];
        generate_mul(cur, pos+1);
    }
}

void solve(int tc) {
    int n;
    int m;
    cin >> n >> m;
    mx = m;
    vector<int> pr(n);
    for (auto &x : pr)
        cin >> x;
    sort(all(pr));
    for (int x : pr)
        cerr << x << " ";
    cerr << '\n';
    for (int i = 0; i < n; i+=2)
        primes.push_back(pr[i]);
    for (int x : primes)
        cerr << x << " ";
    cerr << '\n';
    generate_mul(1, 0);
    auto p1 = p;
    sort(all(p1));
    p.clear(); primes.clear();
    for (int i = 1; i < n; i+=2)
        primes.push_back(pr[i]);
    generate_mul(1, 0);
    auto p2 = p;
    sort(all(p2));
    p.clear(); primes.clear();
    int ans = 0;
    for (int x : p1)
        ans += upper_bound(all(p2), m/x) - begin(p2);
    cout << ans << '\n';
}

signed main() {
    cin.tie(0)->sync_with_stdio(0);
    int tc = 1;
    cin >> tc;
    for (int i = 1; i <= tc; ++i) solve(i);
    return 0;
}
Tester's Solution
#include <iostream>
#include <algorithm>
#include <string>
#include <cassert>
#include <vector>
#include <set>
using namespace std;

#ifdef HOME
#define NOMINMAX
#include <windows.h>
#endif

long long readInt(long long l, long long r, char endd) {
	long long x = 0;
	int cnt = 0;
	int fi = -1;
	bool is_neg = false;
	while (true) {
		char g = getchar();
		if (g == '-') {
			assert(fi == -1);
			is_neg = true;
			continue;
		}
		if ('0' <= g && g <= '9') {
			x *= 10;
			x += g - '0';
			if (cnt == 0) {
				fi = g - '0';
			}
			cnt++;
			assert(fi != 0 || cnt == 1);
			assert(fi != 0 || is_neg == false);

			assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
		}
		else if (g == endd) {
			assert(cnt > 0);
			if (is_neg) {
				x = -x;
			}
			assert(l <= x && x <= r);
			return x;
		}
		else {
			//assert(false);
		}
	}
}

string readString(int l, int r, char endd) {
	string ret = "";
	int cnt = 0;
	while (true) {
		char g = getchar();
		assert(g != -1);
		if (g == endd) {
			break;
		}
		cnt++;
		ret += g;
	}
	assert(l <= cnt && cnt <= r);
	return ret;
}
long long readIntSp(long long l, long long r) {
	return readInt(l, r, ' ');
}
long long readIntLn(long long l, long long r) {
	return readInt(l, r, '\n');
}
string readStringLn(int l, int r) {
	return readString(l, r, '\n');
}
string readStringSp(int l, int r) {
	return readString(l, r, ' ');
}

bool isPrime(uint64_t v)
{
	for (uint64_t i = 2; i * i <= v; ++i)
	{
		if (v % i == 0)
			return false;
	}
	return true;
}

int main() {
#ifdef HOME
	if (IsDebuggerPresent())
	{
		freopen("../in.txt", "rb", stdin);
		freopen("../out.txt", "wb", stdout);
	}
#endif
	int T = readIntLn(1, 2);
	for (int tc = 0; tc < T; ++tc)
	{
		int N = readIntSp(1, 20);
		uint64_t M = readIntLn(1, 100'000'000'000'000'000ull);
		std::vector<uint64_t> v1({ 1 }), v2({1});
		std::set<uint64_t> ss;
		for (uint32_t i = 0; i < N; ++i)
		{
			uint64_t si = 0;
			if (i + 1 != N)
				si = readIntSp(2, 1000);
			else
				si = readIntLn(2, 1000);
			//assert(ss.count(si) == 0);
			//assert(isPrime(si));
			//ss.insert(si);
			std::vector<uint64_t>& v = i & 1 ? v1 : v2;
			for (size_t j = 0; j < v.size(); ++j)
			{
				uint64_t tmp = v[j] * si;
				if( tmp > M)
					continue;
				v.push_back(tmp);
			}
			sort(v.begin(), v.end());
		}
		for (auto& v2i : v2)
			v2i = M / v2i;
		reverse(v2.begin(), v2.end());
		uint64_t res = 0;
		for (size_t i = 0, j = 0; i < v1.size() && j < v2.size();++i)
		{
			while (j < v2.size() && v1[i] > v2[j])
				++j;
			res += v2.size() - j;
		}
		printf("%llu\n", res);
	}
	assert(getchar() == -1);
}
Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define pb push_back
#define mp make_pair
#define pii pair<int, int> 
#define fr first
#define sc second
#define all(a) a.begin(),a.end()


int n, m;
vector <int> a1, a2, l1, l2, s;

void generate_mul_1(int curr = 1, int pos = 0) {
	if(curr > m) return; // if curr is greater than m then it can't satisfy the required constraints and will not be part of the solution

	if(pos >= a1.size()) { // we have gone through the vector a so curr is added to l
		l1.push_back(curr);
		return;
	}
	
	int temp = curr;

	generate_mul_1(curr, pos + 1); //if we don't multiply curr with the number at position pos in a

	while(temp <= m) {
		temp *= a1[pos];
		generate_mul_1(temp, pos + 1);
	}
}

void generate_mul_2(int curr = 1, int pos = 0) {
	if(curr > m) return; // if curr is greater than m then it can't satisfy the required constraints and will not be part of the solution

	if(pos >= a2.size()) { // we have gone through the vector a so curr is added to l
		l2.push_back(curr);
		return;
	}
	
	int temp = curr;

	generate_mul_2(curr, pos + 1); //if we don't multiply curr with the number at position pos in a

	while(temp <= m) {
		temp *= a2[pos];
		generate_mul_2(temp, pos + 1);
	}
}


int number_of_possible_values() {
	sort(l1.begin(), l1.end());
	int ans = 0;
	for(auto i: l2) {
		ans += upper_bound(l1.begin(), l1.end(), (int) (floor(m/i) ) ) - l1.begin(); //returns index of first number > m/i

		//since 1 will be present in both l1 and l2, all the numbers in l1 and l2 will also get counted in ans
	}


	return ans; 
}





signed main() {
	ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
	int t; cin >> t;

	while(t--) {
		cin >> n >> m;
		a1.clear(); a2.clear(); l1.clear(); l2.clear(); s.clear();

		for(int i = 0; i < n; i++) {
			int k; cin >> k;
			s.push_back(k);
		}

		sort(s.begin(), s.end());

		for(int i = 0; i < n; i += 2) {
			a1.pb(s[i]);
		}

		for(int i = 1; i < n; i += 2) {
			a2.pb(s[i]);
		}

		generate_mul_1(); // find all possible values which satisfy condition in a1
		generate_mul_2(); // find all possible values which satisfy condition in a2

		cout << number_of_possible_values() << endl;
	}	
	return 0;
}
5 Likes

In my approach of generating all the values in the valueSet (implemented in solve()), it is already sorted, so we do not need to sort it afterwards as discussed in the video editorial.

Approach is similar to one discussed in method 2 here :

#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define ff first
#define ss second
#define ll unsigned long long
const ll mod = 1e9 + 7;
const ll N = 1e6 + 10;

vector dq, valueSet1, valueSet2;
ll m;

vector solve(vector<pair<ll, ll>> &s)
{
dq = vector (1, 1);
ll n = s.size();

while (*dq.rbegin() <= m)
{
    ll mn = LLONG_MAX;
    for (int i = 0; i < n; i++)
    {
        mn = min(mn, dq[s[i].ss] * s[i].ff);
    }

    dq.push_back(mn);

    for (ll i = 0; i < n; i++)
    {
        if (mn == dq[s[i].ss] * s[i].ff)
        {
            s[i].ss++;
        }
    }
}

while (*dq.rbegin() > m)
    dq.pop_back();

return dq;

}

int main()
{
ios_base::sync_with_stdio(0);
cin.tie(0);

ll t;
cin >> t;

while (t--)
{
    ll n;
    cin >> n >> m;

    vector<pair<ll, ll>> s(n), set1, set2;

    for (ll i = 0; i < n; i++)
    {
        cin >> s[i].ff; // prime
        s[i].ss = 0;    // curr_index
    }

    sort(s.begin(), s.end());

    for(ll i = 0; i < n; i += 2)
        set1.pb(s[i]);

    for(ll i = 1; i < n; i += 2)
        set2.pb(s[i]);    

    valueSet1 = solve(set1);  
    valueSet2 = solve(set2);

    ll start = 0, end = valueSet2.size()-1;
    ll ans = 0;

    while (start < valueSet1.size() && end >= 0)
    {
        if(valueSet1[start]*valueSet2[end] <= m)
        {
            ans += (end + 1);
            start++;
        }

        else
            end--;
    }
      
    cout << ans << "\n";
}

return 0;

}

can you provide the proof for max(P, Q) <= 2e6 part :sweat_smile:

Great problem, loved it.