Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add branch detection functionality #32

Merged
merged 11 commits into from
Jan 7, 2025
585 changes: 585 additions & 0 deletions doc/detecting_branches.ipynb

Large diffs are not rendered by default.

445 changes: 445 additions & 0 deletions doc/for_developers.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ User Guide
basic_usage
benchmarks
comparable_clusterings
detecting_branches
for_developers


----------
Expand Down
3 changes: 2 additions & 1 deletion fast_hdbscan/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .hdbscan import HDBSCAN, fast_hdbscan
from .branches import BranchDetector, find_branch_sub_clusters

# Force JIT compilation on import
import numpy as np
Expand All @@ -7,4 +8,4 @@
HDBSCAN(allow_single_cluster=True).fit(random_data)
HDBSCAN(cluster_selection_method="leaf").fit(random_data)

__all__ = ["HDBSCAN", "fast_hdbscan"]
__all__ = ["HDBSCAN", "fast_hdbscan", "BranchDetector", "find_branch_sub_clusters"]
58 changes: 34 additions & 24 deletions fast_hdbscan/boruvka.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,33 @@
from .disjoint_set import ds_rank_create, ds_find, ds_union_by_rank
from .numba_kdtree import parallel_tree_query, rdist, point_to_node_lower_bound_rdist

@numba.njit(locals={"i": numba.types.int64})
def merge_components(disjoint_set, candidate_neighbors, candidate_neighbor_distances, point_components):
component_edges = {np.int64(0): (np.int64(0), np.int64(1), np.float32(0.0)) for i in range(0)}

@numba.njit(locals={"parent": numba.types.int32})
def select_components(candidate_distances, candidate_neighbors, point_components):
component_edges = {np.int64(0): (np.int32(0), np.int32(1), np.float32(0.0)) for i in range(0)}

# Find the best edges from each component
for i in range(candidate_neighbors.shape[0]):
from_component = np.int64(point_components[i])
for parent, (distance, neighbor, from_component) in enumerate(
zip(candidate_distances, candidate_neighbors, point_components)
):
if from_component in component_edges:
if candidate_neighbor_distances[i] < component_edges[from_component][2]:
component_edges[from_component] = (np.int64(i), np.int64(candidate_neighbors[i]), candidate_neighbor_distances[i])
if distance < component_edges[from_component][2]:
component_edges[from_component] = (parent, neighbor, distance)
else:
component_edges[from_component] = (np.int64(i), np.int64(candidate_neighbors[i]), candidate_neighbor_distances[i])
component_edges[from_component] = (parent, neighbor, distance)

return component_edges
Comment on lines -7 to +22
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I split merge_components into select_components and merge_components so I can re-use the latter part to compute the minimum spanning tree of a knn--mst union graph. That process needs to reject invalid neighbour values within the select_components part.



@numba.njit()
def merge_components(disjoint_set, component_edges):
result = np.empty((len(component_edges), 3), dtype=np.float64)
result_idx = 0

# Add the best edges to the edge set and merge the relevant components
for edge in component_edges.values():
from_component = ds_find(disjoint_set, np.int32(edge[0]))
to_component = ds_find(disjoint_set, np.int32(edge[1]))
from_component = ds_find(disjoint_set, edge[0])
to_component = ds_find(disjoint_set, edge[1])
if from_component != to_component:
result[result_idx] = (np.float64(edge[0]), np.float64(edge[1]), np.float64(edge[2]))
result_idx += 1
Expand All @@ -34,10 +41,13 @@ def merge_components(disjoint_set, candidate_neighbors, candidate_neighbor_dista


@numba.njit(parallel=True)
def update_component_vectors(tree, disjoint_set, node_components, point_components):
def update_point_components(disjoint_set, point_components):
for i in numba.prange(point_components.shape[0]):
point_components[i] = ds_find(disjoint_set, np.int32(i))


@numba.njit()
def update_node_components(tree, node_components, point_components):
Comment on lines -37 to +50
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, I split update_component_vectors into update_point_components and update_node_components so I can re-use the former part.

for i in range(tree.node_data.shape[0] - 1, -1, -1):
node_info = tree.node_data[i]

Expand Down Expand Up @@ -272,28 +282,28 @@ def parallel_boruvka(tree, min_samples=10, sample_weights=None):
expected_neighbors = min_samples / mean_sample_weight
distances, neighbors = parallel_tree_query(tree, tree.data, k=int(2 * expected_neighbors))
core_distances = sample_weight_core_distance(distances, neighbors, sample_weights, min_samples)
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
else:
if min_samples > 1:
distances, neighbors = parallel_tree_query(tree, tree.data, k=min_samples + 1, output_rdist=True)
core_distances = distances.T[-1]
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
else:
core_distances = np.zeros(tree.data.shape[0], dtype=np.float32)
distances, neighbors = parallel_tree_query(tree, tree.data, k=2)
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
update_component_vectors(tree, components_disjoint_set, node_components, point_components)

while n_components > 1:
edges = [np.empty((0, 3), dtype=np.float64) for _ in range(0)]
new_edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
while True:
edges.append(new_edges)
n_components -= new_edges.shape[0]
if n_components == 1:
break
update_point_components(components_disjoint_set, point_components)
update_node_components(tree, node_components, point_components)
candidate_distances, candidate_indices = boruvka_tree_query(tree, node_components, point_components,
core_distances)
new_edges = merge_components(components_disjoint_set, candidate_indices, candidate_distances, point_components)
update_component_vectors(tree, components_disjoint_set, node_components, point_components)

edges = np.vstack((edges, new_edges))
n_components = np.unique(point_components).shape[0]
component_edges = select_components(candidate_distances, candidate_indices, point_components)
new_edges = merge_components(components_disjoint_set, component_edges)

edges = np.vstack(edges)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor to call vstack only once, avoiding multiple allocations.

edges[:, 2] = np.sqrt(edges.T[2])
return edges
return edges, neighbors[:, 1:], np.sqrt(core_distances)
148 changes: 148 additions & 0 deletions fast_hdbscan/branches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import numpy as np
from .sub_clusters import SubClusterDetector, find_sub_clusters


def compute_centrality(data, probabilities, *args):
points = args[-1]
cluster_data = data[points, :]
centroid = np.average(cluster_data, weights=probabilities[points], axis=0)
return 1 / np.linalg.norm(cluster_data - centroid[None, :], axis=1)


def apply_branch_threshold(
labels,
branch_labels,
probabilities,
cluster_probabilities,
cluster_points,
linkage_trees,
label_sides_as_branches=False,
):
running_id = 0
min_branch_count = 1 if label_sides_as_branches else 2
for pts, tree in zip(cluster_points, linkage_trees):
unique_branch_labels = np.unique(branch_labels[pts])
has_noise = int(unique_branch_labels[0] == -1)
num_branches = len(unique_branch_labels) - has_noise
if num_branches <= min_branch_count and tree is not None:
labels[pts] = running_id
probabilities[pts] = cluster_probabilities[pts]
running_id += 1
else:
labels[pts] = branch_labels[pts] + has_noise + running_id
running_id += num_branches + has_noise


def find_branch_sub_clusters(
clusterer,
cluster_labels=None,
cluster_probabilities=None,
label_sides_as_branches=False,
min_cluster_size=None,
max_cluster_size=None,
allow_single_cluster=None,
cluster_selection_method=None,
cluster_selection_epsilon=0.0,
cluster_selection_persistence=0.0,
):
result = find_sub_clusters(
clusterer,
cluster_labels,
cluster_probabilities,
lens_callback=compute_centrality,
min_cluster_size=min_cluster_size,
max_cluster_size=max_cluster_size,
allow_single_cluster=allow_single_cluster,
cluster_selection_method=cluster_selection_method,
cluster_selection_epsilon=cluster_selection_epsilon,
cluster_selection_persistence=cluster_selection_persistence,
)
apply_branch_threshold(
result[0],
result[4],
result[1],
result[3],
result[-1],
label_sides_as_branches=label_sides_as_branches,
)
return result


class BranchDetector(SubClusterDetector):
"""
Performs a flare-detection post-processing step to detect branches within
clusters [1]_.

For each cluster, a graph is constructed connecting the data points based on
their mutual reachability distances. Each edge is given a centrality value
based on how far it lies from the cluster's center. Then, the edges are
clustered as if that centrality was a distance, progressively removing the
'center' of each cluster and seeing how many branches remain.

References
----------
.. [1] Bot, D. M., Peeters, J., Liesenborgs J., & Aerts, J. (2023, November).
FLASC: A Flare-Sensitive Clustering Algorithm: Extending HDBSCAN* for
Detecting Branches in Clusters. arXiv:2311.15887.
"""

def __init__(
self,
min_cluster_size=None,
max_cluster_size=None,
allow_single_cluster=None,
cluster_selection_method=None,
cluster_selection_epsilon=0.0,
cluster_selection_persistence=0.0,
propagate_labels=False,
label_sides_as_branches=False,
):
super().__init__(
min_cluster_size=min_cluster_size,
max_cluster_size=max_cluster_size,
allow_single_cluster=allow_single_cluster,
cluster_selection_method=cluster_selection_method,
cluster_selection_epsilon=cluster_selection_epsilon,
cluster_selection_persistence=cluster_selection_persistence,
propagate_labels=propagate_labels,
)
self.label_sides_as_branches = label_sides_as_branches

def fit(self, clusterer, labels=None, probabilities=None, sample_weight=None):
super().fit(clusterer, labels, probabilities, sample_weight, compute_centrality)
apply_branch_threshold(
self.labels_,
self.sub_cluster_labels_,
self.probabilities_,
self.cluster_probabilities_,
self.cluster_points_,
self._linkage_trees,
label_sides_as_branches=self.label_sides_as_branches,
)
self.branch_labels_ = self.sub_cluster_labels_
self.branch_probabilities_ = self.sub_cluster_probabilities_
self.centralities_ = self.lens_values_
return self

def propagated_labels(self, label_sides_as_branches=None):
if label_sides_as_branches is None:
label_sides_as_branches = self.label_sides_as_branches

labels, branch_labels = super().propagated_labels()
apply_branch_threshold(
labels,
branch_labels,
np.zeros_like(self.probabilities_),
np.zeros_like(self.probabilities_),
self.cluster_points_,
self._linkage_trees,
label_sides_as_branches=label_sides_as_branches,
)
return labels, branch_labels

@property
def approximation_graph_(self):
"""See :class:`~hdbscan.plots.ApproximationGraph` for documentation."""
return super()._make_approximation_graph(
lens_name="centrality", sub_cluster_name="branch"
)
Loading