Minimum Spanning Tree Algorithms

Minimum Spanning Tree Algorithms

Spanning Tree

Spanning Tree is a subgraph that connects all the nodes in the subgraph with the minimal number of edges possible. The minimum number of edges that form a spanning tree is n-1 where n is the number of nodes in the subgraph. A spanning tree does not have any cycles.

Properties of Spanning Tree

  • All the nodes in the tree are connected
  • Contains minimal number of edges
  • The edges that form the tree does not form any cycles

Minimum Spanning Tree

A minimum spanning tree is a spanning tree that has the minimal combined edge weights. The minimum spanning tree don’t have to be unique.

We will introduce 2 algorithms to calculate the minimum spanning tree of a graph

  • Prims Algorithm
  • Kruskal’s Algorithm

Prims Algorithm

Prims Algorithm starts by visiting a single node and incrementally grows the minimum spanning tree one node at a time by performing a breadth first search for the next smallest weighted node that connects any nodes already in the tree with an unvisited node(frontier nodes). Repeat till we reached all reachable nodes.

Algorithm

Step 1: Pick a starting node

Step 2: Mark the selected node as visited, then find all the nodes that are reachable from the selected node

Step 3: Pick the smallest edge that is reached from an visited nodes and repeat step 2

Step 4: Repeat till we run out of edges that can reach the unvisited nodes from the visited nodes.

Note! If the size of our tree is less than the number of nodes the tree is not fully connected.

from typing import Union, List, Tuple, Dict
from collections import defaultdict
import heapq

Node = Union[int, str]

class Prims:
    def __init__(self, edges: List[Tuple[Node, Node, int]]) -> None:
        self.adjacent_map = self.__initialize_adjacent_map(edges)

    def __initialize_adjacent_map(self, edges: List[Tuple[Node, Node, int]]) -> Dict[Node, List[Tuple[Node, int]]]:
        adjacent_map = defaultdict(list)
        for node1, node2, weight in edges:
            adjacent_map[node1].append((weight, node2))
            adjacent_map[node2].append((weight, node1))  # Assuming the graph is undirected
        return dict(adjacent_map)

    def find_minimum_weight(self, start_node: Node) -> int:
        # This method would implement Prims's algorithm to find the
        # sum of the weights of a minimum spanning tree. 
        # This method takes a starting node return a integer representing the sum of the weight of the mst.
        if start_node not in self.adjacent_map:
            raise ValueError("Node doees not exist in tree")
        
        min_heap = []
        visited = set()
        min_distance = 0
        
        heapq.heappush(min_heap, (0, start_node))

        while min_heap:
            current_node_distance, current_node = heapq.heappop(min_heap)
            if current_node not in visited:
                visited.add(current_node)
                min_distance += current_node_distance
                for adjacent_node_edge_length, adjacent_node in self.adjacent_map[current_node]:
                    if adjacent_node not in visited:
                        heapq.heappush(min_heap, (adjacent_node_edge_length, adjacent_node))

        if len(visited) != len(self.adjacent_map):
            raise ValueError("Tree is not fully connected.")
        return min_distance        
edges = [
    (0,2,2),
    (3,0,6),
    (2,1,5),
    (4,1,1),
    (3,2,2),
    (4,3,3)
]
prims = Prims(edges)        

Difference between Prims and Dijkstra

The Dijkstra and Prims algorithm look similar, because they both perform breadth first search and select the minimum distances of unvisited nodes. Lets look at the difference between them.

  • Prims Algorithm: Used to find the minimum spanning tree (MST) in a weighted, undirected graph. It starts from any node and expands the tree by adding the nearest unvisited node at each step, ensuring all nodes are connected with the minimum possible total edge weight.
  • Dijkstra Algorithm: Used to find the shortest path from a single source to all other nodes in a graph with non-negative weights. It calculates the shortest path by continuously updating the cumulative cost to reach each node from the source, selecting the node with the smallest known distance at each step.

Distance Calculation:

  • Prims: The distance tracked for each node is the minimum edge weight by which it can be connected to the growing MST. It only considers the weight of the connecting edge from the MST to a new node.
  • Dijkstra: The distance for each node is the sum of the weights from the source node along the shortest path found so far. It accumulates total path weights from the source to each node.

Prims algorithm gives the tree where all the nodes are connected and the sum of the weights of the edges that connect the node is minimum.

Dijkstra provide the shortest route from the source to each node.

Kruskal’s Algorithm

Kruskal’s algorithm starts by sorting all the edges by increasing weights. It then builds the minimum spanning tree by adding the edge one by one if adding the edge does not form a cycle. We continue this process till all the nodes are connected or when we run out of edges

Intuition

The core idea behind Kruskal’s algorithm is to build the MST by selecting the edges in ascending order of their weights, ensuring that no cycles are formed. By starting with the smallest edges, Kruskal’s ensures that the MST includes edges that contribute to minimizing the overall tree weight without closing any loops.

Algorithm

  1. Sort all edges in the graph.
  2. Start with the edge with the lowest weight.
  3. Add minimum weight edge to the existing tree t oensure there is no cycles.
  4. Repeat step 2.

Cycle Detection

We will use Union Find for cycle detection. You can read more about it using Union Find for cycle detection here.

Code Implementation

from typing import Union, List, Tuple, Dict, Set
from collections import defaultdict
import heapq

Node = Union[int, str]


class Kruskal:
    def __init__(self, edges: List[Tuple[Node, Node, int]]) -> None:
        self.nodes = self.__get_nodes(edges)
        self.minHeap = self.__get_min_heap(edges)

    def __get_nodes(self, edges: List[Tuple[Node, Node, int]]) -> Set[Node]:
        nodes = set()
        for node1, node2, _ in edges:
            nodes.add(node1)
            nodes.add(node2)
        return nodes

    def __get_min_heap(self, edges):
        minHeap = [(c, a, b) for a, b, c in edges]
        heapq.heapify(minHeap)
        return minHeap

    def find_minimum_weight(self):
        uf = UnionFind(self.nodes)
        minHeap = self.minHeap.copy()

        size = len(self.nodes)
        minWeight = 0
        edgeCount = 0
        while minHeap and edgeCount < size - 1:
            edgeWeight, node1, node2 = heapq.heappop(minHeap)
            if uf.find(node1) != uf.find(node2):
                minWeight += edgeWeight
                edgeCount += 1
                uf.union(node1, node2)
        
        return minWeight if edgeCount == size - 1 else None 
    
class UnionFind:
    def __init__(self, elements: Set[Node]) -> None:
        self.parent = self.__initialize_parent(elements)
        self.size = self.__initialize_size(elements)


    def __initialize_parent(self, elements: Set[Node]) -> Dict[Node, Node]:
        return {element: element for element in elements}

    def __initialize_size(self, elements: Set[Node])-> Dict[Node, int]:
        return {element: 1 for element in elements}

    def find(self, element: Node) -> Node:
        if element != self.parent[element]:
            return self.find(self.parent[element])
        return self.parent[element]

    def union(self, i: Node, j: Node) -> None:
        iRoot = self.find(i)
        jRoot = self.find(j)
        if iRoot == jRoot:
            return
    
        if self.size[iRoot] > self.size[jRoot]:
            self.parent[jRoot] = iRoot
            self.size[iRoot] += self.size[jRoot]
        else:
            self.parent[iRoot] = jRoot
            self.size[jRoot] += self.size[iRoot]        
edges = [
    # weight, node1, node2
    (0, 2, 2),
    (0, 3, 6),
    (1, 2, 5),
    (1, 4, 1),
    (2, 3, 2),
    (3, 4, 3),
]

kruskal = Kruskal(edges)
kruskal.find_minimum_weight()        

要查看或添加评论,请登录

Yi leng Yao的更多文章