MAKEIT1 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Vikram Singla
Preparer: Abhinav Sharma
Testers: Nishank Suresh, Tejas Pandey
Editorialist: Nishank Suresh

DIFFICULTY:

2875

PREREQUISITES:

Depth first search, dynamic programming, sieve of Eratosthenes/some fast prime factorization method.

PROBLEM:

You have an integer N, that you want to make 1. You also have M pairs of the form (a, b), where a is prime. In one move, you can:

  • Pick a prime p that divides N
  • Divide N by p
  • Multiply N by b for all pairs of the form (p, b).

Find out whether N can be made 1, and if it can, the minimum number of moves to do so.

EXPLANATION:

First, note that we can think of this problem entirely in terms of primes. A pair of the form (a, b) essentially removes a from the prime factorization, and adds in all the prime factors of b.

Next, note that there are not actually any choices to be made in terms of number of moves. So, if there is a way to reduce N to 1, then no matter how it is done the cost will be the same.

Thus, we have to find two things:

  • Is it possible to reduce N to 1?
  • If it is, find the cost of doing so.

Let’s look at a specific prime, say p. What is the cost of removing p from the prime factorization?
Well, based on the pairs we have, removing one copy of p will add in some other primes to the factorization. Let these primes be p_1, p_2, \ldots, p_k. Then, each of these primes must be removed, and so on.

So, if cost_p denotes the minimum number of moves to remove p from the factorization, we have

cost_p = 1 + \sum_{i=1}^k cost_{p_i}

This can be implemented to run quickly (in \mathcal{O}(V+E) where V is the number of vertices and is \leq 10^6, and E is the number of edges in the graph) with the help of dynamic programming, since each cost_p only needs to be computed once. The final answer is the sum of cost_p for all primes in the prime factorization of N. For example, for N = 12 the answer would be cost_2 + cost_2 + cost_3.

The only time we run into issues is if we run into a cycle - that is, if removing p from N ends up with adding p back to N. In this case, bringing N down to 1 is of course impossible, making the answer -1.

Note that you have to be careful about one thing: it is possible that the graph has a cycle somewhere in it, but this cycle is not reachable from any of the primes in the factorization of N. In this case, the answer won’t be -1. A trivial example of this is when N = 1 and the graph can be whatever, with the answer being 0.

Note that the above solution requires computing the prime factorization of a number quickly. This can be done with the help of the sieve of Eratosthenes as follows:

  • Let spf_i denote the smallest prime factor of i
  • spf_i can be computed with a sieve
  • To factorize x, do the following:
    • If x = 1, stop
    • Otherwise, let p = spf_x: \ \ p is one factor of x.
    • Divide x by p and continue.

N has at most \log N prime factors, so the above process takes \mathcal{O}(\log N) time to prime factorize N. Note that this also gives us a bound on the number of edges on in the graph: each of the M pairs adds at most \log(10^6) = 20 edges to the graph, so there are 20\cdot M edges at worst.

TIME COMPLEXITY

\mathcal{O}(N\log N) per test case, where N = 10^6.

CODE:

Preparer's code (C++)
#include <bits/stdc++.h>
using namespace std;
 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/
 
const int MAX_T = 1e5;
const int MAX_N = 1e5;
const int MAX_SUM_LEN = 1e5;
 
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define ff first
#define ss second
#define mp make_pair
#define ll long long
#define rep(i,n) for(int i=0;i<n;i++)
#define rev(i,n) for(int i=n;i>=0;i--)
#define rep_a(i,a,n) for(int i=a;i<n;i++)
#define pb push_back
#define int ll
 
int sum_n = 0, sum_m = 0;
int max_n = 0, max_m = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;
ll mod = 998244353;

const ll MX=1000005;
ll lp[MX]={0};
vector<int> pr;
 
void pre(){
    rep_a(i,2,MX){
        if(lp[i]==0){
            lp[i]=i;
            pr.pb(i);
 
        }
        for(int j=0;j<(int)pr.size() && pr[j]<=lp[i] && i*pr[j]<MX;j++){
            lp[i*pr[j]]=pr[j];
        }
    }
}

vector<vector<pair<int,int> > >g;
vector<int> dp, col;
bool poss;
void dfs(int curr){
    col[curr]=1;
    dp[curr] = 1;
    for(auto h:g[curr]){
        if(col[h.ff]==1){
            poss = 0;
            return;
        }
        else if(!col[h.ff]){
            dfs(h.ff);
        }
        dp[curr] += (h.ss*dp[h.ff])%mod;
        dp[curr] %= mod;
    }
    col[curr]=2;

    if(curr==1) dp[curr] = 0;
}

void solve(){
    int n = readIntSp(1, 1e6);
    int m = readIntLn(0, 1e5);

    g.assign(MX, vector<pair<int,int> >());
    dp.assign(MX,0);
    col.assign(MX,0);

    int cnt, tmp;
    

    int a,b;
    rep(i,m){
        a = readIntSp(1,1e6);
        b = readIntLn(1,1e6);

        while(b>1){
            cnt = 0;
            tmp = lp[b];
            while(b%tmp==0){
                b/=tmp;
                cnt++;
            }

            g[a].pb(mp(tmp, cnt));
        }
    }

    poss = 1;
    int ans = 0;
   // for(auto h:g[13]) cout<<h.ff<<" "<<h.ss<<'\n';
    while(n>1){
        cnt = 0;
        tmp = lp[n];
        while(n%tmp==0){
            n/=tmp;
            cnt++;
        }

        dfs(tmp);
        if(!poss) break;

        //cout<<tmp<<" "<<dp[tmp]<<'\n';
        ans += (cnt*dp[tmp])%mod;
        ans%=mod;
    }

   // rep(i, 1e6) cout<<dp[i]<<'\n';

    if(poss) cout<<ans<<'\n';
    else cout<<-1<<'\n';



}
 
signed main()
{

    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r" , stdin);
    freopen("output.txt", "w" , stdout);
    #endif
    fast;
    
    pre();
    int t = 1;
    
    for(int i=1;i<=t;i++)
    {    
       solve();
    }
   
    assert(getchar() == -1);

    cerr<<"SUCCESS\n";


    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Tester's code (C++)
#include <bits/stdc++.h>
#define mod 998244353
#define maxn 100007
using namespace std;

bool isp[maxn];
vector<int> primes;

void seive() {
    for(int j = 2; j < maxn; j++) {
        if(isp[j]) continue;
        primes.push_back(j);
        for(int k = 2*j; k < maxn; k += j)
            isp[k] = 1;
    }
}

vector<int> factors(int n) {
    vector<int> res;
    int now = 0;
	long long int p = primes[now];
	while(p*p <= n) {
	    while(n%p == 0) n /= p, res.push_back(p);
	    now++;
	    p = primes[now];
	}
	if(n > 1) res.push_back(n);
	return res;
}

int main() {
    seive();
	int n, m;
	cin >> n >> m;
	map<int, int> mp;
	int now = 0;
	vector<int> vals;
	long long int p = primes[now];
	while(p*p <= n) {
	    int cnt = 0;
	    while(n%p == 0) n /= p, cnt++;
	    mp[p] = cnt;
	    if(cnt) vals.push_back(p);
	    now++;
	    p = primes[now];
	}
	if(n > 1) mp[n] = 1, vals.push_back(n);
	vector<pair<int, int>> valid;
	for(int j = 0; j < m; j++) {
	    int a, b;
	    cin >> a >> b;
	    vector<int> fc = factors(b);
	        vals.push_back(a);
	        for(int k = 0; k < fc.size(); k++) {
	            valid.push_back({a, fc[k]});
	            vals.push_back(fc[k]);
	        }
	}
	sort(vals.begin(), vals.end());
	vals.resize(distance(vals.begin(), unique(vals.begin(), vals.end())));
	vector<int> ed[vals.size() + 1];
	for(int i = 0; i < valid.size(); i++) {
	    int a = valid[i].first, b = valid[i].second;
	    a = upper_bound(vals.begin(), vals.end(), a) - vals.begin();
	    b = upper_bound(vals.begin(), vals.end(), b) - vals.begin();
	    ed[a].push_back(b);
	}
	vector<int> topo;
	int flag = 0;
	int vis[vals.size() + 1];
	memset(vis, 0, sizeof(vis));
	function<void(int)> dfs = [&](int cur){
		vis[cur] = 1;
	for(int i = 0;i < ed[cur].size();i++)
		if(vis[ed[cur][i]]){
			if(vis[ed[cur][i]] - 1);
			else flag = 1;
		}
		else dfs(ed[cur][i]);
	vis[cur] = 2;
	topo.push_back(cur);
  };
  for(int i = 1;i <= vals.size();i++) if(!vis[i]) dfs(i);
  reverse(topo.begin(), topo.end());
      memset(vis, 0, sizeof(vis));
      long long int ans = 0;
      int bad = 0;
      for(int i = 0; i < topo.size(); i++) {
          long long int v1 = mp[vals[topo[i] - 1]];
          vis[topo[i]] = 1;
          ans += v1;
          ans %= mod;
          for(int j = 0; j < ed[topo[i]].size(); j++) {
             if(v1 && vis[ed[topo[i]][j]]) bad = 1;
             int now = vals[ed[topo[i]][j] - 1];
             mp[now] += v1;
             mp[now] %= mod;
          }
      }
      if(topo.size() != vals.size() || bad) cout << "-1\n";
      else cout << ans << "\n";
  
	return 0;
}
Editorialist's code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	const int LIM = 1e6 + 5, mod = 998244353;
	vector<int> spf(LIM);
	for (int i = LIM-1; i >= 2; --i) {
		for (int j = i; j < LIM; j += i) spf[j] = i;
	}
	
	int n, m; cin >> n >> m;
	vector<vector<int>> adj(LIM);
	for (int i = 0; i < m; ++i) {
		int a, b; cin >> a >> b;
		while (b > 1) {
			int p = spf[b];
			b /= p;
			adj[a].push_back(p);
		}
	}
	int ans = 0, bad = 0;
	vector<int> cost(LIM, -1), mark(LIM);
	auto calc = [&] (const auto &self, int u) -> ll {
		if (cost[u] != -1) return cost[u];
		if (mark[u] == 1 or bad) {
			bad = 1;
			return 0;
		}
		mark[u] = 1;
		ll ret = 1;
		for (int v : adj[u]) {
			ret += self(self, v);
			ret %= mod;
		}
		mark[u] = 2;
		return cost[u] = ret;
	};

	while (n > 1) {
		int p = spf[n];
		ans += calc(calc, p); ans %= mod;
		n /= p;
	}
	if (bad) cout << -1 << '\n';
	else cout << ans << '\n';
}

This editorial is for the wrong question… If you go on practice and go to Rocket Pack, this is the editorial it links to.

3 Likes