Prim's implementation help!

Can anyone point out what could be wrong in this simple implementation of Prim’s algorithm?

void prims (unordered_map<int,set<pair<int,int>>> g, int n) {
        
    vector<int> mstSet; // in MST
    vector<bool> visited(n,false);
    unordered_map<int,int> V; // vertex key array
    int u = 0;
    V[u] = 0;
    mstSet.push_back(u);
    visited[u]=true;

    while(mstSet.size()!=n) {
        for(auto it:g[u]) {
            // w < v.key then v.key = w
            if(it.ff<V[it.ss]) V[it.ss]=it.ff;
        }
        int mi = inf;
        for(auto it:V) {
            if(it.second<mi && !visited[it.first]) {
                u = it.first;
                mi=it.second;
            }
        }
        mstSet.push_back(u);
        visited[u] = true;
    }

    for(auto i:mstSet) cout << i << " ";

}

Here’s the main method.

int main () {

    unordered_map<int,set<pair<int,int>>> e;
    struct construct con;

    con.make_weighted(e,4,0,1);
    con.make_weighted(e,8,0,7);
    con.make_weighted(e,11,1,7);
    con.make_weighted(e,8,1,2);
    con.make_weighted(e,7,7,8);
    con.make_weighted(e,1,7,6);
    con.make_weighted(e,2,2,8);
    con.make_weighted(e,6,8,6);
    con.make_weighted(e,2,6,5);
    con.make_weighted(e,7,2,3);
    con.make_weighted(e,4,2,5);
    con.make_weighted(e,14,3,5);
    con.make_weighted(e,9,3,4);
    con.make_weighted(e,10,5,4);

    prims(e,9);

    return 0;
}

And here’s the “struct construct”.

struct construct {

    void make_weighted (unordered_map<int,set<pair<int,int>>>& g, int c, int u, int v) {
        g[u].insert({c,v});
        g[v].insert({c,u});
    }

};

Your idea is correct.
The implementation has a subtle bug.

This part is not quite doing what you want it to do.
In c++ if some (key, pair) is not present in unordered_map then a default value of 0 is assigned to that key instead of throwing an error. In that for loop, you are accessing some keys in the unordered_map without initializing them and hence a default value of 0 is assigned to them and it is never changed.

In this line you should add a check like this:

for(auto it : g[u] ){
      // w < v.key then v.key = w 
      if( visited[it.ss] == false ) {
           // visiting this for the first time so initialize with current value
          V[it.ss] = it.ff
      }else{
         // already visited before so minimize now
        V[it.ss] = min( it.ff, V[it.ss]);
}

rest is ok.

3 Likes

Thank you so much for the response! I tried the above change (changed code below) but the “mstSet” output wasn’t 0,1,7,6,5,2,8,3 as given here. Instead, it comes out to be 0,7,8,2,3,4,5,1,6.

                //          u               cost,v        no of vertices
void prims (unordered_map<int,set<pair<int,int>>> g, int n) {
    
    vector<int> mstSet; // in MST
    vector<bool> visited(n,false);
    unordered_map<int,int> V; // vertex key array
    int u = 0;
    V[u] = 0;
    mstSet.push_back(u);
    visited[u]=true;

    while(mstSet.size()!=n) {
        for(auto it:g[u]) {
            // w < v.key then v.key = w
            if(it.ff<V[it.ss]) {
                if(!visited[it.ss]) V[it.ss]=it.ff;
                else V[it.ss] = min(V[it.ss],it.ff);
            }
        }
        int mi = inf;
        for(auto it:V) {
            if(it.second<mi && !visited[it.first]) {
                u = it.first;
                mi=it.second;
            }
        }
        mstSet.push_back(u);
        visited[u] = true;
    }

    for(auto i:mstSet) cout << i << " ";

}

//If you want to refer prims algorithm code, coincidentally I just wrote it :blush:, the pair holds the (value, node index ) info
#include <bits/stdc++.h>
#define FAST ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
typedef long long ll;
typedef long double ld;
#define pb push_back
#define mp make_pair

using namespace std;

ll prim(vector<pair<ll,ll>> v[], ll n)
{
bool vis[n];
memset(vis,0,sizeof(vis));
//vis[0] = 1;
set<pair<ll,ll>> st;
st.insert(mp(0,0));
for(ll i=1;i<n;i++)
st.insert(mp(LLONG_MAX,i));
ll vali[n];
vali[0] = 0;
for(ll i=1;i<n;i++)
vali[i] = LLONG_MAX;
ll cnt = 0;
while(st.size()>0)
{
auto curr = *st.begin();
st.erase(st.begin());
vis[curr.second] = 1;
cnt++;

    for(auto it = v[curr.second].begin();it!=v[curr.second].end();it++)
    {
        if(vis[it->second])
            continue;
        ll v1 = it->first;
        if(v1<vali[it->second])
        {
            st.erase(mp(vali[it->second],it->second));
            vali[it->second] = v1;
            st.insert(mp(v1,it->second));
        }
    }
}

ll ans = 0;
for(ll i=0;i<n;i++)
    ans+=vali[i];
return ans;    

}

1 Like

Thanks all! I fixed mine as well, the problem was unordered map storing 0 at places that weren’t given any values so I had to explicitly change those to INF to keep track of the minimum key values. Feels good! :slight_smile: