Skip to content

Commit

Permalink
fix eom recursion; add fix_predict
Browse files Browse the repository at this point in the history
  • Loading branch information
JelmerBot committed Dec 28, 2024
1 parent c33329b commit 33628aa
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
12 changes: 8 additions & 4 deletions fast_hdbscan/cluster_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,17 +497,20 @@ def eom_recursion(node, cluster_tree, node_scores, node_sizes, selected_clusters
current_size = node_sizes[node]

children = cluster_tree.child[cluster_tree.parent == node]
added_children_total = False
child_score_total = 0.0

for child_node in children:
child_score_total += eom_recursion(child_node, cluster_tree, node_scores, node_sizes, selected_clusters, max_cluster_size)
child_score, children_added = eom_recursion(child_node, cluster_tree, node_scores, node_sizes, selected_clusters, max_cluster_size)
added_children_total |= children_added
child_score_total += child_score

if child_score_total > current_score or current_size > max_cluster_size:
return child_score_total
return child_score_total, added_children_total
else:
selected_clusters[node] = True
unselect_below_node(node, cluster_tree, selected_clusters)
return current_score
return current_score, True


@numba.njit()
Expand All @@ -530,7 +533,8 @@ def extract_eom_clusters(condensed_tree, cluster_tree, max_cluster_size=np.inf,
elif len(node_scores) > 1:
root_children = cluster_tree.child[cluster_tree.parent == cluster_tree_root]
for child_node in root_children:
eom_recursion(child_node, cluster_tree, node_scores, node_sizes, selected_clusters, max_cluster_size)
if not eom_recursion(child_node, cluster_tree, node_scores, node_sizes, selected_clusters, max_cluster_size)[1]:
selected_clusters[child_node] = True

return np.asarray([node for node, selected in selected_clusters.items() if selected])

Expand Down
4 changes: 4 additions & 0 deletions fast_hdbscan/hdbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,10 @@ def fit(self, X, y=None, sample_weight=None, **fit_params):

return self

def fit_predict(self, X, y=None, sample_weight=None, **fit_params):
self.fit(X, y, sample_weight, **fit_params)
return self.labels_

def dbscan_clustering(self, epsilon):
check_is_fitted(
self,
Expand Down

0 comments on commit 33628aa

Please sign in to comment.