Minimum Spanning Tree Algorithms
Spanning Tree
Spanning Tree is a subgraph that connects all the nodes in the subgraph with the minimal number of edges possible. The minimum number of edges that form a spanning tree is n-1 where n is the number of nodes in the subgraph. A spanning tree does not have any cycles.
Properties of Spanning Tree
Minimum Spanning Tree
A minimum spanning tree is a spanning tree that has the minimal combined edge weights. The minimum spanning tree don’t have to be unique.
We will introduce 2 algorithms to calculate the minimum spanning tree of a graph
Prims Algorithm
Prims Algorithm starts by visiting a single node and incrementally grows the minimum spanning tree one node at a time by performing a breadth first search for the next smallest weighted node that connects any nodes already in the tree with an unvisited node(frontier nodes). Repeat till we reached all reachable nodes.
Algorithm
Step 1: Pick a starting node
Step 2: Mark the selected node as visited, then find all the nodes that are reachable from the selected node
Step 3: Pick the smallest edge that is reached from an visited nodes and repeat step 2
Step 4: Repeat till we run out of edges that can reach the unvisited nodes from the visited nodes.
Note! If the size of our tree is less than the number of nodes the tree is not fully connected.
from typing import Union, List, Tuple, Dict
from collections import defaultdict
import heapq
Node = Union[int, str]
class Prims:
def __init__(self, edges: List[Tuple[Node, Node, int]]) -> None:
self.adjacent_map = self.__initialize_adjacent_map(edges)
def __initialize_adjacent_map(self, edges: List[Tuple[Node, Node, int]]) -> Dict[Node, List[Tuple[Node, int]]]:
adjacent_map = defaultdict(list)
for node1, node2, weight in edges:
adjacent_map[node1].append((weight, node2))
adjacent_map[node2].append((weight, node1)) # Assuming the graph is undirected
return dict(adjacent_map)
def find_minimum_weight(self, start_node: Node) -> int:
# This method would implement Prims's algorithm to find the
# sum of the weights of a minimum spanning tree.
# This method takes a starting node return a integer representing the sum of the weight of the mst.
if start_node not in self.adjacent_map:
raise ValueError("Node doees not exist in tree")
min_heap = []
visited = set()
min_distance = 0
heapq.heappush(min_heap, (0, start_node))
while min_heap:
current_node_distance, current_node = heapq.heappop(min_heap)
if current_node not in visited:
visited.add(current_node)
min_distance += current_node_distance
for adjacent_node_edge_length, adjacent_node in self.adjacent_map[current_node]:
if adjacent_node not in visited:
heapq.heappush(min_heap, (adjacent_node_edge_length, adjacent_node))
if len(visited) != len(self.adjacent_map):
raise ValueError("Tree is not fully connected.")
return min_distance
edges = [
(0,2,2),
(3,0,6),
(2,1,5),
(4,1,1),
(3,2,2),
(4,3,3)
]
prims = Prims(edges)
Difference between Prims and Dijkstra
The Dijkstra and Prims algorithm look similar, because they both perform breadth first search and select the minimum distances of unvisited nodes. Lets look at the difference between them.
Distance Calculation:
Prims algorithm gives the tree where all the nodes are connected and the sum of the weights of the edges that connect the node is minimum.
Dijkstra provide the shortest route from the source to each node.
Kruskal’s Algorithm
Kruskal’s algorithm starts by sorting all the edges by increasing weights. It then builds the minimum spanning tree by adding the edge one by one if adding the edge does not form a cycle. We continue this process till all the nodes are connected or when we run out of edges
Intuition
The core idea behind Kruskal’s algorithm is to build the MST by selecting the edges in ascending order of their weights, ensuring that no cycles are formed. By starting with the smallest edges, Kruskal’s ensures that the MST includes edges that contribute to minimizing the overall tree weight without closing any loops.
Algorithm
Cycle Detection
We will use Union Find for cycle detection. You can read more about it using Union Find for cycle detection here.
Code Implementation
from typing import Union, List, Tuple, Dict, Set
from collections import defaultdict
import heapq
Node = Union[int, str]
class Kruskal:
def __init__(self, edges: List[Tuple[Node, Node, int]]) -> None:
self.nodes = self.__get_nodes(edges)
self.minHeap = self.__get_min_heap(edges)
def __get_nodes(self, edges: List[Tuple[Node, Node, int]]) -> Set[Node]:
nodes = set()
for node1, node2, _ in edges:
nodes.add(node1)
nodes.add(node2)
return nodes
def __get_min_heap(self, edges):
minHeap = [(c, a, b) for a, b, c in edges]
heapq.heapify(minHeap)
return minHeap
def find_minimum_weight(self):
uf = UnionFind(self.nodes)
minHeap = self.minHeap.copy()
size = len(self.nodes)
minWeight = 0
edgeCount = 0
while minHeap and edgeCount < size - 1:
edgeWeight, node1, node2 = heapq.heappop(minHeap)
if uf.find(node1) != uf.find(node2):
minWeight += edgeWeight
edgeCount += 1
uf.union(node1, node2)
return minWeight if edgeCount == size - 1 else None
class UnionFind:
def __init__(self, elements: Set[Node]) -> None:
self.parent = self.__initialize_parent(elements)
self.size = self.__initialize_size(elements)
def __initialize_parent(self, elements: Set[Node]) -> Dict[Node, Node]:
return {element: element for element in elements}
def __initialize_size(self, elements: Set[Node])-> Dict[Node, int]:
return {element: 1 for element in elements}
def find(self, element: Node) -> Node:
if element != self.parent[element]:
return self.find(self.parent[element])
return self.parent[element]
def union(self, i: Node, j: Node) -> None:
iRoot = self.find(i)
jRoot = self.find(j)
if iRoot == jRoot:
return
if self.size[iRoot] > self.size[jRoot]:
self.parent[jRoot] = iRoot
self.size[iRoot] += self.size[jRoot]
else:
self.parent[iRoot] = jRoot
self.size[jRoot] += self.size[iRoot]
edges = [
# weight, node1, node2
(0, 2, 2),
(0, 3, 6),
(1, 2, 5),
(1, 4, 1),
(2, 3, 2),
(3, 4, 3),
]
kruskal = Kruskal(edges)
kruskal.find_minimum_weight()