-
Notifications
You must be signed in to change notification settings - Fork 1k
/
Copy pathkruskal_algorithm2.py
140 lines (112 loc) · 4.46 KB
/
kruskal_algorithm2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#! /usr/bin/env python
#coding:utf-8
"""
以下代码参考http://www.ics.uci.edu/~eppstein/PADS/的源码
"""
class UnionFind:
"""
UnionFind的实例:
Each unionFind instance X maintains a family of disjoint sets of
hashable objects, supporting the following two methods:
- X[item] returns a name for the set containing the given item.
Each set is named by an arbitrarily-chosen one of its members; as
long as the set remains unchanged it will keep the same name. If
the item is not yet part of a set in X, a new singleton set is
created for it.
- X.union(item1, item2, ...) merges the sets containing each item
into a single larger set. If any item is not yet part of a set
in X, it is added to X as one of the members of the merged set.
"""
def __init__(self):
"""Create a new empty union-find structure."""
self.weights = {}
self.parents = {}
def __getitem__(self, object):
"""Find and return the name of the set containing the object."""
# check for previously unknown object
if object not in self.parents:
self.parents[object] = object
self.weights[object] = 1
return object
# find path of objects leading to the root
path = [object]
root = self.parents[object]
while root != path[-1]:
path.append(root)
root = self.parents[root]
# compress the path and return
for ancestor in path:
self.parents[ancestor] = root
return root
def __iter__(self):
"""Iterate through all items ever found or unioned by this structure."""
return iter(self.parents)
def union(self, *objects):
"""Find the sets containing the objects and merge them all."""
roots = [self[x] for x in objects]
heaviest = max([(self.weights[r],r) for r in roots])[1]
for r in roots:
if r != heaviest:
self.weights[heaviest] += self.weights[r]
self.parents[r] = heaviest
"""
Various simple functions for graph input.
Each function's input graph G should be represented in such a way that "for v in G" loops through the vertices, and "G[v]" produces a list of the neighbors of v; for instance, G may be a dictionary mapping each vertex to its neighbor set.
D. Eppstein, April 2004.
"""
def isUndirected(G):
"""Check that G represents a simple undirected graph."""
for v in G:
if v in G[v]:
return False
for w in G[v]:
if v not in G[w]:
return False
return True
def union(*graphs):
"""Return a graph having all edges from the argument graphs."""
out = {}
for G in graphs:
for v in G:
out.setdefault(v,set()).update(list(G[v]))
return out
"""
Kruskal's algorithm for minimum spanning trees. D. Eppstein, April 2006.
"""
import unittest
def MinimumSpanningTree(G):
"""
Return the minimum spanning tree of an undirected graph G.
G should be represented in such a way that iter(G) lists its
vertices, iter(G[u]) lists the neighbors of u, G[u][v] gives the
length of edge u,v, and G[u][v] should always equal G[v][u].
The tree is returned as a list of edges.
"""
if not isUndirected(G):
raise ValueError("MinimumSpanningTree: input is not undirected")
for u in G:
for v in G[u]:
if G[u][v] != G[v][u]:
raise ValueError("MinimumSpanningTree: asymmetric weights")
# Kruskal's algorithm: sort edges by weight, and add them one at a time.
# We use Kruskal's algorithm, first because it is very simple to
# implement once UnionFind exists, and second, because the only slow
# part (the sort) is sped up by being built in to Python.
subtrees = UnionFind()
tree = []
for W,u,v in sorted((G[u][v],u,v) for u in G for v in G[u]):
if subtrees[u] != subtrees[v]:
tree.append((u,v))
subtrees.union(u,v)
return tree
# If run standalone, perform unit tests
class MSTTest(unittest.TestCase):
def testMST(self):
"""Check that MinimumSpanningTree returns the correct answer."""
G = {0:{1:11,2:13,3:12},1:{0:11,3:14},2:{0:13,3:10},3:{0:12,1:14,2:10}}
T = [(2,3),(0,1),(0,3)]
for e,f in zip(MinimumSpanningTree(G),T):
self.assertEqual(min(e),min(f))
self.assertEqual(max(e),max(f))
if __name__ == "__main__":
unittest.main()