Skip to content

Commit

Permalink
Merge pull request #32 from JelmerBot/dev/branches
Browse files Browse the repository at this point in the history
add branch detection functionality
  • Loading branch information
lmcinnes authored Jan 7, 2025
2 parents 5ca7be0 + 1c5ff02 commit 2f83774
Show file tree
Hide file tree
Showing 14 changed files with 2,749 additions and 106 deletions.
585 changes: 585 additions & 0 deletions doc/detecting_branches.ipynb

Large diffs are not rendered by default.

445 changes: 445 additions & 0 deletions doc/for_developers.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ User Guide
basic_usage
benchmarks
comparable_clusterings
detecting_branches
for_developers


----------
Expand Down
3 changes: 2 additions & 1 deletion fast_hdbscan/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .hdbscan import HDBSCAN, fast_hdbscan
from .branches import BranchDetector, find_branch_sub_clusters

# Force JIT compilation on import
import numpy as np
Expand All @@ -7,4 +8,4 @@
HDBSCAN(allow_single_cluster=True).fit(random_data)
HDBSCAN(cluster_selection_method="leaf").fit(random_data)

__all__ = ["HDBSCAN", "fast_hdbscan"]
__all__ = ["HDBSCAN", "fast_hdbscan", "BranchDetector", "find_branch_sub_clusters"]
58 changes: 34 additions & 24 deletions fast_hdbscan/boruvka.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,33 @@
from .disjoint_set import ds_rank_create, ds_find, ds_union_by_rank
from .numba_kdtree import parallel_tree_query, rdist, point_to_node_lower_bound_rdist

@numba.njit(locals={"i": numba.types.int64})
def merge_components(disjoint_set, candidate_neighbors, candidate_neighbor_distances, point_components):
component_edges = {np.int64(0): (np.int64(0), np.int64(1), np.float32(0.0)) for i in range(0)}

@numba.njit(locals={"parent": numba.types.int32})
def select_components(candidate_distances, candidate_neighbors, point_components):
component_edges = {np.int64(0): (np.int32(0), np.int32(1), np.float32(0.0)) for i in range(0)}

# Find the best edges from each component
for i in range(candidate_neighbors.shape[0]):
from_component = np.int64(point_components[i])
for parent, (distance, neighbor, from_component) in enumerate(
zip(candidate_distances, candidate_neighbors, point_components)
):
if from_component in component_edges:
if candidate_neighbor_distances[i] < component_edges[from_component][2]:
component_edges[from_component] = (np.int64(i), np.int64(candidate_neighbors[i]), candidate_neighbor_distances[i])
if distance < component_edges[from_component][2]:
component_edges[from_component] = (parent, neighbor, distance)
else:
component_edges[from_component] = (np.int64(i), np.int64(candidate_neighbors[i]), candidate_neighbor_distances[i])
component_edges[from_component] = (parent, neighbor, distance)

return component_edges


@numba.njit()
def merge_components(disjoint_set, component_edges):
result = np.empty((len(component_edges), 3), dtype=np.float64)
result_idx = 0

# Add the best edges to the edge set and merge the relevant components
for edge in component_edges.values():
from_component = ds_find(disjoint_set, np.int32(edge[0]))
to_component = ds_find(disjoint_set, np.int32(edge[1]))
from_component = ds_find(disjoint_set, edge[0])
to_component = ds_find(disjoint_set, edge[1])
if from_component != to_component:
result[result_idx] = (np.float64(edge[0]), np.float64(edge[1]), np.float64(edge[2]))
result_idx += 1
Expand All @@ -34,10 +41,13 @@ def merge_components(disjoint_set, candidate_neighbors, candidate_neighbor_dista


@numba.njit(parallel=True)
def update_component_vectors(tree, disjoint_set, node_components, point_components):
def update_point_components(disjoint_set, point_components):
for i in numba.prange(point_components.shape[0]):
point_components[i] = ds_find(disjoint_set, np.int32(i))


@numba.njit()
def update_node_components(tree, node_components, point_components):
for i in range(tree.node_data.shape[0] - 1, -1, -1):
node_info = tree.node_data[i]

Expand Down Expand Up @@ -272,28 +282,28 @@ def parallel_boruvka(tree, min_samples=10, sample_weights=None):
expected_neighbors = min_samples / mean_sample_weight
distances, neighbors = parallel_tree_query(tree, tree.data, k=int(2 * expected_neighbors))
core_distances = sample_weight_core_distance(distances, neighbors, sample_weights, min_samples)
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
else:
if min_samples > 1:
distances, neighbors = parallel_tree_query(tree, tree.data, k=min_samples + 1, output_rdist=True)
core_distances = distances.T[-1]
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
else:
core_distances = np.zeros(tree.data.shape[0], dtype=np.float32)
distances, neighbors = parallel_tree_query(tree, tree.data, k=2)
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
update_component_vectors(tree, components_disjoint_set, node_components, point_components)

while n_components > 1:
edges = [np.empty((0, 3), dtype=np.float64) for _ in range(0)]
new_edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
while True:
edges.append(new_edges)
n_components -= new_edges.shape[0]
if n_components == 1:
break
update_point_components(components_disjoint_set, point_components)
update_node_components(tree, node_components, point_components)
candidate_distances, candidate_indices = boruvka_tree_query(tree, node_components, point_components,
core_distances)
new_edges = merge_components(components_disjoint_set, candidate_indices, candidate_distances, point_components)
update_component_vectors(tree, components_disjoint_set, node_components, point_components)

edges = np.vstack((edges, new_edges))
n_components = np.unique(point_components).shape[0]
component_edges = select_components(candidate_distances, candidate_indices, point_components)
new_edges = merge_components(components_disjoint_set, component_edges)

edges = np.vstack(edges)
edges[:, 2] = np.sqrt(edges.T[2])
return edges
return edges, neighbors[:, 1:], np.sqrt(core_distances)
148 changes: 148 additions & 0 deletions fast_hdbscan/branches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import numpy as np
from .sub_clusters import SubClusterDetector, find_sub_clusters


def compute_centrality(data, probabilities, *args):
points = args[-1]
cluster_data = data[points, :]
centroid = np.average(cluster_data, weights=probabilities[points], axis=0)
return 1 / np.linalg.norm(cluster_data - centroid[None, :], axis=1)


def apply_branch_threshold(
labels,
branch_labels,
probabilities,
cluster_probabilities,
cluster_points,
linkage_trees,
label_sides_as_branches=False,
):
running_id = 0
min_branch_count = 1 if label_sides_as_branches else 2
for pts, tree in zip(cluster_points, linkage_trees):
unique_branch_labels = np.unique(branch_labels[pts])
has_noise = int(unique_branch_labels[0] == -1)
num_branches = len(unique_branch_labels) - has_noise
if num_branches <= min_branch_count and tree is not None:
labels[pts] = running_id
probabilities[pts] = cluster_probabilities[pts]
running_id += 1
else:
labels[pts] = branch_labels[pts] + has_noise + running_id
running_id += num_branches + has_noise


def find_branch_sub_clusters(
clusterer,
cluster_labels=None,
cluster_probabilities=None,
label_sides_as_branches=False,
min_cluster_size=None,
max_cluster_size=None,
allow_single_cluster=None,
cluster_selection_method=None,
cluster_selection_epsilon=0.0,
cluster_selection_persistence=0.0,
):
result = find_sub_clusters(
clusterer,
cluster_labels,
cluster_probabilities,
lens_callback=compute_centrality,
min_cluster_size=min_cluster_size,
max_cluster_size=max_cluster_size,
allow_single_cluster=allow_single_cluster,
cluster_selection_method=cluster_selection_method,
cluster_selection_epsilon=cluster_selection_epsilon,
cluster_selection_persistence=cluster_selection_persistence,
)
apply_branch_threshold(
result[0],
result[4],
result[1],
result[3],
result[-1],
label_sides_as_branches=label_sides_as_branches,
)
return result


class BranchDetector(SubClusterDetector):
"""
Performs a flare-detection post-processing step to detect branches within
clusters [1]_.
For each cluster, a graph is constructed connecting the data points based on
their mutual reachability distances. Each edge is given a centrality value
based on how far it lies from the cluster's center. Then, the edges are
clustered as if that centrality was a distance, progressively removing the
'center' of each cluster and seeing how many branches remain.
References
----------
.. [1] Bot, D. M., Peeters, J., Liesenborgs J., & Aerts, J. (2023, November).
FLASC: A Flare-Sensitive Clustering Algorithm: Extending HDBSCAN* for
Detecting Branches in Clusters. arXiv:2311.15887.
"""

def __init__(
self,
min_cluster_size=None,
max_cluster_size=None,
allow_single_cluster=None,
cluster_selection_method=None,
cluster_selection_epsilon=0.0,
cluster_selection_persistence=0.0,
propagate_labels=False,
label_sides_as_branches=False,
):
super().__init__(
min_cluster_size=min_cluster_size,
max_cluster_size=max_cluster_size,
allow_single_cluster=allow_single_cluster,
cluster_selection_method=cluster_selection_method,
cluster_selection_epsilon=cluster_selection_epsilon,
cluster_selection_persistence=cluster_selection_persistence,
propagate_labels=propagate_labels,
)
self.label_sides_as_branches = label_sides_as_branches

def fit(self, clusterer, labels=None, probabilities=None, sample_weight=None):
super().fit(clusterer, labels, probabilities, sample_weight, compute_centrality)
apply_branch_threshold(
self.labels_,
self.sub_cluster_labels_,
self.probabilities_,
self.cluster_probabilities_,
self.cluster_points_,
self._linkage_trees,
label_sides_as_branches=self.label_sides_as_branches,
)
self.branch_labels_ = self.sub_cluster_labels_
self.branch_probabilities_ = self.sub_cluster_probabilities_
self.centralities_ = self.lens_values_
return self

def propagated_labels(self, label_sides_as_branches=None):
if label_sides_as_branches is None:
label_sides_as_branches = self.label_sides_as_branches

labels, branch_labels = super().propagated_labels()
apply_branch_threshold(
labels,
branch_labels,
np.zeros_like(self.probabilities_),
np.zeros_like(self.probabilities_),
self.cluster_points_,
self._linkage_trees,
label_sides_as_branches=label_sides_as_branches,
)
return labels, branch_labels

@property
def approximation_graph_(self):
"""See :class:`~hdbscan.plots.ApproximationGraph` for documentation."""
return super()._make_approximation_graph(
lens_name="centrality", sub_cluster_name="branch"
)
Loading

0 comments on commit 2f83774

Please sign in to comment.