Skip to content

Commit

Permalink
fix eom recursion; add fit_predict
Browse files Browse the repository at this point in the history
  • Loading branch information
JelmerBot committed Dec 29, 2024
1 parent ff0b3a5 commit a3a5756
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
19 changes: 7 additions & 12 deletions fast_hdbscan/cluster_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,21 +450,16 @@ def extract_clusters_bcubed(condensed_tree, cluster_tree, data_labels, allow_vir

@numba.njit()
def score_condensed_tree_nodes(condensed_tree):
result = {0: np.float32(0.0) for i in range(0)}
root = condensed_tree.parent[0]
result = {root: np.float32(0.0)}

for i in range(condensed_tree.parent.shape[0]):
parent = condensed_tree.parent[i]
if parent in result:
result[parent] += condensed_tree.lambda_val[i] * condensed_tree.child_size[i]
else:
result[parent] = condensed_tree.lambda_val[i] * condensed_tree.child_size[i]

if condensed_tree.child_size[i] > 1:
child = condensed_tree.child[i]
if child in result:
result[child] -= condensed_tree.lambda_val[i] * condensed_tree.child_size[i]
else:
result[child] = -condensed_tree.lambda_val[i] * condensed_tree.child_size[i]
result[child] = -condensed_tree.lambda_val[i] * condensed_tree.child_size[i]

parent = condensed_tree.parent[i]
result[parent] += condensed_tree.lambda_val[i] * condensed_tree.child_size[i]

return result

Expand Down Expand Up @@ -493,7 +488,7 @@ def unselect_below_node(node, cluster_tree, selected_clusters):

@numba.njit(fastmath=True)
def eom_recursion(node, cluster_tree, node_scores, node_sizes, selected_clusters, max_cluster_size):
current_score = node_scores[node]
current_score = max(node_scores[node], 0.0) # floating point errors can make score negative!
current_size = node_sizes[node]

children = cluster_tree.child[cluster_tree.parent == node]
Expand Down
4 changes: 4 additions & 0 deletions fast_hdbscan/hdbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,10 @@ def fit(self, X, y=None, sample_weight=None, **fit_params):

return self

def fit_predict(self, X, y=None, sample_weight=None, **fit_params):
self.fit(X, y, sample_weight, **fit_params)
return self.labels_

def dbscan_clustering(self, epsilon):
check_is_fitted(
self,
Expand Down

0 comments on commit a3a5756

Please sign in to comment.