Problem Statement: Given a weighted, undirected, and connected graph of V vertices and E edges. The task is to find the sum of weights of the edges of the Minimum Spanning Tree.
Definition: A minimum spanning tree consists of N nodes and N-1 edges connecting all the nodes which have the minimum cost(sum of edge weights).
Note: It is known as a tree since a tree doesn’t have cycles involved. A graph with N nodes and N-1 edges is equivalent to a tree.
Example:
Input:
Output:
Explanation: Given the following graph, the minimum spanning tree is the one given in the output. It consists of 5 nodes and 4 edges and this configuration gives us the minimum weight(cost = 16)
Solution:
Disclaimer: Don’t jump directly to the solution, try it out yourself first.
Solution 1(Brute Force):
Intuition: Let’s start with anyone node in our graph. As the first step, we find out all the adjacent edges connected to this node and then pick up the minimum one. Now we have 2 nodes. We further continue this process but now we would consider all the edges connected to these two nodes and pickup the minimum one. We then continue this process until all the nodes are covered.
Note: It might happen that while picking up an edge we might end up forming a cycle. In that case, we would pick up the next lowest edge that doesn’t form a cycle.
Approach: We would have 3 arrays(key, mst and parent).
Significance of all three arrays:
- Key : This array holds the weight/ cost of the MST(intialized to INT_MAX except the index 0 which is set with value of zero)
- MST : This is a boolean array which indicates whether a node is already a part of MST or not(initialized to false except the index 0 which is true)
- Parent : This indicates the parent of a particular node in the MST(initialized to -1)
Steps:
- Assuming we start with node 0, the index 0 in the key array is initialized to zero(because it is the first node in the MST). We find the index/node in the key array which has the minimum weight. We then find all its adjacent edges and pickup the the one with minimum weight.
- Also at the same time we mark this node as true(indicating that it is now a part of the MST) and also set it’s parent as node ‘0’.
- After this, we would continue to find the one with minimum weight in the key array that is not a part of the MST(Notice that this is where we ensure that we pickup the node with minimum weight and we do not choose an edge that might cause a cycle)
- We continue this process until all nodes become a part of the MST
Code:
C++ Code
#include<bits/stdc++.h>
using namespace std;
int main(){
int N=5,m=6;
vector<pair<int,int> > adj[N];
adj[0].push_back({1,2});
adj[0].push_back({3,6});
adj[1].push_back({0,2});
adj[1].push_back({2,3});
adj[1].push_back({3,8});
adj[1].push_back({4,5});
adj[2].push_back({1,3});
adj[2].push_back({4,7});
adj[3].push_back({0,6});
adj[3].push_back({1,8});
adj[4].push_back({1,5});
adj[4].push_back({2,7});
int parent[N];
int key[N];
bool mstSet[N];
for (int i = 0; i < N; i++)
key[i] = INT_MAX, mstSet[i] = false;
key[0] = 0;
parent[0] = -1;
int ansWeight = 0;
for (int count = 0; count < N - 1; count++)
{
int mini = INT_MAX, u;
for (int v = 0; v < N; v++)
{
if (mstSet[v] == false && key[v] < mini)
mini = key[v], u = v;
}
mstSet[u] = true;
for (auto it : adj[u]) {
int v = it.first;
int weight = it.second;
if (mstSet[v] == false && weight < key[v])
parent[v] = u, key[v] = weight;
}
}
for (int i = 1; i < N; i++)
cout << parent[i] << " - " << i <<" \n";
return 0;
}
Output:
Time Complexity: O(N^2). Going through N Node for N-1 times
Space Complexity: O(N). 3 arrays of size N
Java Code
import java.util.*;
class Node
{
private int v;
private int weight;
Node(int _v, int _w) { v = _v; weight = _w; }
Node() {}
int getV() { return v; }
int getWeight() { return weight; }
}
class Main
{
void primsAlgo(ArrayList<ArrayList<Node>> adj, int N)
{
int key[] = new int[N];
int parent[] = new int[N];
boolean mstSet[] = new boolean[N];
for(int i = 0;i<N;i++) {
key[i] = 100000000;
mstSet[i] = false;
}
key[0] = 0;
parent[0] = -1;
for(int i = 0;i<N-1;i++) {
int mini = 100000000, u = 0;
for(int v = 0;v<N;v++) {
if(mstSet[v] == false && key[v] < mini) {
mini = key[v];
u = v;
}
}
mstSet[u] = true;
for(Node it: adj.get(u)) {
if(mstSet[it.getV()] == false && it.getWeight() < key[it.getV()]) {
parent[it.getV()] = u;
key[it.getV()] = it.getWeight();
}
}
}
for(int i = 1;i<N;i++) {
System.out.println(parent[i] + " - " + i);
}
}
public static void main(String args[])
{
int n = 5;
ArrayList<ArrayList<Node> > adj = new ArrayList<ArrayList<Node> >();
for (int i = 0; i < n; i++)
adj.add(new ArrayList<Node>());
adj.get(0).add(new Node(1, 2));
adj.get(1).add(new Node(0, 2));
adj.get(1).add(new Node(2, 3));
adj.get(2).add(new Node(1, 3));
adj.get(0).add(new Node(3, 6));
adj.get(3).add(new Node(0, 6));
adj.get(1).add(new Node(3, 8));
adj.get(3).add(new Node(1, 8));
adj.get(1).add(new Node(4, 5));
adj.get(4).add(new Node(1, 5));
adj.get(2).add(new Node(4, 7));
adj.get(4).add(new Node(2, 7));
Main obj = new Main();
obj.primsAlgo(adj, n);
}
}
Output:
0 – 1
1 – 2
0 – 3
1 – 4
Time Complexity: O(N^2). Going through N Node for N-1 times
Space Complexity: O(N). 3 arrays of size N
Solution 2(Optimized Approach):
Intuition: While trying to optimize our code we must look out for repetitions. At the first glance, we see that we are going through the key-value again and again to find the minimum edge weight that is not part of the MST. Can we think of a data structure that stores the minimum weight at the top at any given instance?
Approach: We would use a minimum heap to facilitate this task. Min-heap would contain the weight required to reach a node along with its index. Here is a demonstration of the same.

Code:
C++ Code
#include<bits/stdc++.h>
using namespace std;
int main(){
int N=5,m=6;
vector<pair<int,int> > adj[N];
adj[0].push_back({1,2});
adj[0].push_back({3,6});
adj[1].push_back({0,2});
adj[1].push_back({2,3});
adj[1].push_back({3,8});
adj[1].push_back({4,5});
adj[2].push_back({1,3});
adj[2].push_back({4,7});
adj[3].push_back({0,6});
adj[3].push_back({1,8});
adj[4].push_back({1,5});
adj[4].push_back({2,7});
int parent[N];
int key[N];
bool mstSet[N];
for (int i = 0; i < N; i++)
key[i] = INT_MAX, mstSet[i] = false;
priority_queue< pair<int,int>, vector <pair<int,int>> , greater<pair<int,int>> > pq;
key[0] = 0;
parent[0] = -1;
pq.push({0, 0});
while(!pq.empty())
{
int u = pq.top().second;
pq.pop();
mstSet[u] = true;
for (auto it : adj[u]) {
int v = it.first;
int weight = it.second;
if (mstSet[v] == false && weight < key[v]) {
parent[v] = u;
key[v] = weight;
pq.push({key[v], v});
}
}
}
for (int i = 1; i < N; i++)
cout << parent[i] << " - " << i <<" \n";
return 0;
}
Output:
0 – 1
1 – 2
0 – 3
1 – 4
Time Complexity: O(NlogN). N iterations and logN for priority queue
Space Complexity: O(N). Three arrays and priority queue
Java Code
import java.util.*;
class Node implements Comparator < Node > {
private int v;
private int weight;
Node(int _v, int _w) {
v = _v;
weight = _w;
}
Node() {}
int getV() {
return v;
}
int getWeight() {
return weight;
}
@Override
public int compare(Node node1, Node node2) {
if (node1.weight < node2.weight)
return -1;
if (node1.weight > node2.weight)
return 1;
return 0;
}
}
class Main {
void primsAlgo(ArrayList < ArrayList < Node >> adj, int N) {
int key[] = new int[N];
int parent[] = new int[N];
boolean mstSet[] = new boolean[N];
for (int i = 0; i < N; i++) {
key[i] = 100000000;
mstSet[i] = false;
}
PriorityQueue < Node > pq = new PriorityQueue < Node > (N, new Node());
key[0] = 0;
parent[0] = -1;
pq.add(new Node(key[0], 0));
// Run the loop till all the nodes have been visited
// because in the brute code we checked for mstSet[node] == false while computing the minimum
// but here we simply take the minimal from the priority queue, so a lot of times a node might be taken twice
// hence its better to keep running till all the nodes have been taken.
// try the following case:
// Credits: Srejan Bera
// 6 7
// 0 1 5
// 0 2 10
// 0 3 100
// 1 3 50
// 1 4 200
// 3 4 250
// 4 5 50
while (!pq.isEmpty()) {
int u = pq.poll().getV();
mstSet[u] = true;
for (Node it: adj.get(u)) {
if (mstSet[it.getV()] == false && it.getWeight() < key[it.getV()]) {
parent[it.getV()] = u;
key[it.getV()] = it.getWeight();
pq.add(new Node(it.getV(), key[it.getV()]));
}
}
}
for (int i = 1; i < N; i++) {
System.out.println(parent[i] + " - " + i);
}
}
public static void main(String args[]) {
int n = 5;
ArrayList < ArrayList < Node > > adj = new ArrayList < ArrayList < Node > > ();
for (int i = 0; i < n; i++)
adj.add(new ArrayList < Node > ());
adj.get(0).add(new Node(1, 2));
adj.get(1).add(new Node(0, 2));
adj.get(1).add(new Node(2, 3));
adj.get(2).add(new Node(1, 3));
adj.get(0).add(new Node(3, 6));
adj.get(3).add(new Node(0, 6));
adj.get(1).add(new Node(3, 8));
adj.get(3).add(new Node(1, 8));
adj.get(1).add(new Node(4, 5));
adj.get(4).add(new Node(1, 5));
adj.get(2).add(new Node(4, 7));
adj.get(4).add(new Node(2, 7));
Main obj = new Main();
obj.primsAlgo(adj, n);
}
}
Output:
0 – 1
1 – 2
0 – 3
1 – 4
Time Complexity: O(NlogN). N iterations and logN for priority queue
Space Complexity: O(N). Three arrays and priority queue
Special thanks to Naman Daga for contributing to this article on takeUforward. If you also wish to share your knowledge with the takeUforward fam, please check out this article