Skip to content

Commit

Permalink
Merge pull request #270 from higra/fix_empty_bipartite_graph_matching
Browse files Browse the repository at this point in the history
Fix bipartite graph matching with empty graph
  • Loading branch information
PerretB authored Sep 18, 2023
2 parents 40110ee + 77a8b0e commit e141363
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 2 deletions.
10 changes: 10 additions & 0 deletions higra/algo/bipartite_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,22 @@ def bipartite_graph_matching(graph, edge_weights):
targets = graph[1]
num_vertices = graph[2]
num_edges = len(sources)

if len(targets) != num_edges:
raise ValueError("sources and targets must have the same length")

except Exception as e:
raise TypeError("graph must be an undirected graph or a tuple (sources, targets, num_vertices)", e)

if len(edge_weights) != num_edges:
raise ValueError("edge_weights must have the same length as the number of edges of the graph")

if not check:
raise ValueError("graph must be bipartite")

if num_edges == 0:
return np.array([], dtype=np.int64)

# if edge weights dtype is floating point, emit a warning
if np.issubdtype(edge_weights.dtype, np.floating):
print("Warning: possible loss of precision, edge weights are casted to int64 for bipartite graph matching")
Expand Down
7 changes: 7 additions & 0 deletions test/python/test_algo/test_bipartite_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ def test_is_bipartite_graph_false(self):
self.assertFalse(ans)
self.assertTrue(len(color) == 0)

def test_bipartite_graph_matching_empty(self):
g = hg.UndirectedGraph(6)
weights = np.array([])

edges = hg.bipartite_graph_matching(g, weights)
self.assertTrue(len(edges) == 0)

def test_bipartite_graph_matching(self):
g = hg.UndirectedGraph(6)
g.add_edges(np.array([0, 0, 1, 1, 2, 1]), np.array([3, 4, 3, 5, 5, 4]))
Expand Down
11 changes: 11 additions & 0 deletions test/python/test_image/test_graph_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,17 @@ def test_get_nd_regular_graph(self):

self.assertTrue(ref_edges == res_edges)

def test_match_pixels_image_2d_empty(self):
im1 = np.asarray([[1, 0, 0, 0],
[0, 0, 0, 0]])
im2 = np.asarray([[0, 0, 0, 1],
[0, 0, 0, 0]])

sources, targets = hg.match_pixels_image_2d(im1, im2, 1.3, "absolute")

self.assertTrue(len(sources) == 0)
self.assertTrue(len(targets) == 0)

def test_match_pixels_image_2d(self):
im1 = np.asarray([[1, 0, 0, 1],
[0, 0, 0, 1]])
Expand Down
2 changes: 1 addition & 1 deletion tools/azure-pipelines-linux-gcc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
- script: |
conda config --set always_yes yes --set changeps1 no
conda update -q conda
conda install tbb-devel==2019.9 scikit-learn==1.0.1 -c conda-forge
conda install tbb-devel==2019.9 scikit-learn -c conda-forge
displayName: Install python packages
- script: |
Expand Down
2 changes: 1 addition & 1 deletion tools/azure-pipelines-mac.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
- script: |
conda config --set always_yes yes --set changeps1 no
conda update -q conda
conda install cmake tbb-devel==2019.9 scikit-learn==1.0.1 -c conda-forge
conda install cmake tbb-devel==2019.9 scikit-learn -c conda-forge
displayName: Install python packages
- script: |
Expand Down

0 comments on commit e141363

Please sign in to comment.