DSP - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: kingmessi
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Sieve of Eratosthenes

PROBLEM:

The digit space of an integer x consists of all other integers that have the same multiset of digits as x, without leading zeros.
Define f(x, y) to be the largest prime number dividing both x and y, and 1 if no such prime exists.

Given x and y, compute the maximum possible value of f(a, b) across all pairs (a, b) such that a is in the digit space of x, and b is in the digit space of y.

EXPLANATION:

The key observation here is that since we’re dealing with integers that are \lt 10^7, they have at most 7 digits.
So, for any integer x, its digit space has size at most 7! = 5040, which is pretty small.

This means, for any x, we can always directly find every element of its digit space by just bruteforce.
For instance, extract the digits of x, sort them, and iterate across all possible permutations of these digits; then convert each permutation back to an integer.
Iterating across all permutations can generally be done using a library function: for example, next_permutation in C++ and itertools.permutations in Python.

This way, we can obtain two lists D_x and D_y, denoting the digit spaces of x and y respectively.
Our next task is to find the maximum common prime factor between two elements of these lists.

For this, we’ll need to be able to prime factorize integers quickly.
This can be done with the help of a sieve.

How?

Let \text{pr}_x be the smallest prime number dividing x.
This can be computed for every x from 1 to M using the following algorithm.

  1. Let \text{pr}_x = 0 for all x initially.
  2. Iterate x from 1 to M.
    • If \text{pr}_x \gt 0, x is not a prime (it has a prime factor), ignore it.
    • Otherwise, x is a prime. Iterate over all multiples y of x, and if \text{pr}_y = 0, set it to x.

This method of iterating across all only multiples of primes till M has a complexity of \mathcal{O}(M\log\log M) (see here for why).
In our case, M = 10^7, so this is pretty fast.


Further, note that any number \lt 10^7 can have at most 8 distinct prime factors, since the product of the first 9 primes exceeds 10^7.

So, let’s create a list P_x of prime factors of all elements of D_x, which has size bounded by 8\cdot |D_x| \leq 8!
In fact, it can be empirically verified that for x \lt 10^7, |P_x| \leq 3907.

Similarly, create a list P_y of prime factors of the elements of D_y, and then compute the maximum element of their intersection.
Computing the maximum element of the intersection can be done quickly in several ways: for example, you can just check whether each element of P_y exists in P_x (using, say, binary search - or even a map).
However, note that the ‘brute force’ method of directly checking every element of P_x against every element of P_y will be too slow, and can TLE.

TIME COMPLEXITY:

An upper bound is \mathcal{O}(d! \cdot 8 + 4000\cdot \log(4000)) per testcase, where d = \max(len(x), len(y)) \leq 7.

CODE:

Author's code (C++)
//Har Har Mahadev
#include<bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp> // Common file
#include <ext/pb_ds/tree_policy.hpp>
#define ll long long
#define int long long
#define rep(i,a,b) for(int i=a;i<b;i++)
#define rrep(i,a,b) for(int i=a;i>=b;i--)
#define repin rep(i,0,n)
#define di(a) int a;cin>>a;
#define precise(i) cout<<fixed<<setprecision(i)
#define vi vector<int>
#define si set<int>
#define mii map<int,int>
#define take(a,n) for(int j=0;j<n;j++) cin>>a[j];
#define give(a,n) for(int j=0;j<n;j++) cout<<a[j]<<' ';
#define vpii vector<pair<int,int>>
#define sis string s;
#define sin string s;cin>>s;
#define db double
#define be(x) x.begin(),x.end()
#define pii pair<int,int>
#define pb push_back
#define pob pop_back
#define ff first
#define ss second
#define lb lower_bound
#define ub upper_bound
#define bpc(x) __builtin_popcountll(x) 
#define btz(x) __builtin_ctz(x)
using namespace std;

using namespace __gnu_pbds;

typedef tree<int, null_type, less<int>, rb_tree_tag,tree_order_statistics_node_update> ordered_set;
typedef tree<pair<int, int>, null_type,less<pair<int, int> >, rb_tree_tag,tree_order_statistics_node_update> ordered_multiset;

const long long INF=1e18;
const long long M=1e9+7;
const long long MM=998244353;
  
int power( int N, int M){
    int power = N, sum = 1;
    if(N == 0) sum = 0;
    while(M > 0){if((M & 1) == 1){sum *= power;}
    power = power * power;M = M >> 1;}
    return sum;
}

const int N = 10000005;
vector<int> lp(N+1);
bool cnt[N+1];
vector<int> pr;

#define SIEVE

void sieve(){
    for (int i=2; i <= N; ++i) {
        if (lp[i] == 0) {
            lp[i] = i;
            pr.push_back(i);
        }
        for (int j = 0; i * pr[j] <= N; ++j) {
            lp[i * pr[j]] = pr[j];
            if (pr[j] == lp[i]) {
                break;
            }
        }
    }
}

vector<int> pf(int x)
{
    vector<int> ret;
    while (x != 1)
    {
        ret.push_back(lp[x]);
        x = x / lp[x];
    }
    return ret;
}

 
void solve()
{
    int X,Y;
    cin >> X >> Y;
    assert(X >= 1);
    assert(Y >= 1);
    assert(X < 1e7);
    assert(Y < 1e7);
    vi v1,v2;
    int tx = X;
    while(tx){
    	v1.pb(tx%10);
    	tx/=10;
    }
    int ty = Y;
    while(ty){
    	v2.pb(ty%10);
    	ty/=10;
    }
    sort(be(v1));
    sort(be(v2));
    vi v;

    do{
    	if(v1.back()){
    		int num = 0;
    		int t = 1;
    		for(auto x : v1){
    			num += x*t;
    			t *= 10;
    		}
    		for(auto x : pf(num)){
    			if(!cnt[x]){
    				cnt[x] = 1;
    				v.pb(x);
    			}
    		}
    	}

    }while(next_permutation(be(v1)));

    int ans = 1;

    do{
    	if(v2.back()){
    		int num = 0;
    		int t = 1;
    		for(auto x : v2){
    			num += x*t;
    			t *= 10;
    		}
    		for(auto x : pf(num)){
    			if(cnt[x])ans = max(ans,x);
    		}
    	}

    }while(next_permutation(be(v2)));

    cout << ans << "\n";

    for(auto x : v)cnt[x] = 0;

}

signed main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    #ifdef NCR
        init();
    #endif
    #ifdef SIEVE
        sieve();
    #endif
    di(t)
    assert(t <= 200);
    assert(t >= 1);
    while(t--)
        solve();
    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;

#define IGNORE_CR

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
#ifdef IGNORE_CR
            if (c == '\r') {
                continue;
            }
#endif
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            assert(!isspace(buffer[pos]));
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

struct sieve {
    vector<bool> is_prime;
    vector<int> min_factor;

    sieve(int MAX = 1e7 + 10) {
        is_prime = vector<bool>(MAX, true);
        min_factor = vector<int>(MAX);
        is_prime[0] = is_prime[1] = false;
        min_factor[0] = min_factor[1] = 1;
        for (int i = 2; i < MAX; i++) {
            if (!is_prime[i]) {
                continue;
            }
            min_factor[i] = i;
            if ((long long) i * i >= MAX) {
                continue;
            }
            for (int j = i * i; j < MAX; j += i) {
                if (is_prime[j]) {
                    is_prime[j] = false;
                    min_factor[j] = i;
                }
            }
        }
    }

    vector<pair<int, int>> factor(int n) const {
        vector<pair<int, int>> res;
        while (n != 1) {
            int p = min_factor[n];
            res.emplace_back(p, 0);
            while (p == min_factor[n]) {
                n /= p;
                res.back().second++;
            }
        }
        reverse(res.begin(), res.end());
        return res;
    }
};

vector<int> pf(int t, const sieve& si) {
    vector<int> p;
    while (t > 0) {
        p.emplace_back(t % 10);
        t /= 10;
    }
    sort(p.begin(), p.end());
    vector<int> res;
    do {
        if (p[0] == 0) {
            continue;
        }
        int k = 0;
        for (int i : p) {
            k *= 10;
            k += i;
        }
        for (auto d : si.factor(k)) {
            res.emplace_back(d.first);
        }
    } while (next_permutation(p.begin(), p.end()));
    sort(res.begin(), res.end());
    res.resize(unique(res.begin(), res.end()) - res.begin());
    return res;
}

int main() {
    input_checker in;
    int tt = in.readInt(1, 200);
    in.readEoln();
    sieve si;
    while (tt--) {
        int x = in.readInt(1, 1e7 - 1);
        in.readSpace();
        int y = in.readInt(1, 1e7 - 1);
        in.readEoln();
        auto a = pf(x, si);
        auto b = pf(y, si);
        int ans = 1;
        int i = 0, j = 0;
        int n = (int) a.size(), m = (int) b.size();
        while (i < n && j < m) {
            if (a[i] == b[j]) {
                ans = a[i];
                i++;
                j++;
            } else if (a[i] < b[j]) {
                i++;
            } else {
                j++;
            }
        }
        cout << ans << '\n';
    }
    in.readEof();
    return 0;
}
Editorialist's code (Python)
N = 10**7 + 10
lpf = [0]*N
for i in range(2, N):
    if lpf[i] > 0: continue
    for j in range(i, N, i):
        lpf[j] = i

def get(x):
    s = sorted(list(str(x)))
    facs = set()
    facs.add(1)
    
    import itertools
    for p in itertools.permutations(s):
        if p[0] == '0': continue
        num = int(''.join(c for c in p))
        while num > 1:
            facs.add(lpf[num])
            num //= lpf[num]
    return facs

for _ in range(int(input())):
    x, y = map(int, input().split())
    facsx = get(x)
    facsy = get(y)
    print(max(facsx.intersection(facsy)))

Doesn’t sieving in Python as in Editorialist’s code take too long? Is this problem even solvable in Python under the given time limit?

I guess when setting time limits on most of the problems python is rarely taken into consideration PyPy is assumed to be used by everybody
So I suggest you to use PyPy to submit at least on CodeChef

Thanks for the tip!

As was mentioned above, it’s almost always better to use PyPy over Python for competitive programming, the syntax is exactly the same but it’s generally much faster.

In this case, my code from the editorial runs in about 1.5s when submitted in PyPy3, which is well below the time limit.

However, submitting that exact code in Python3 will still get AC, actually.
On Codechef, several languages receive a multiplier to the time limit to account for them being slower on average - Java and PyPy receive a 2x multiplier, Python gets 5x.
So in this case, the stated time limit is 3 seconds, but for Java and PyPy your program only needs to run within twice that (6 seconds), and in Python 5 times that (so 15 seconds).

1 Like

Thanks for the explanation! I was not aware of these language-specific multipliers on Codechef.