-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add branch detection functionality #32
Changes from all commits
821ff98
f411d5c
658948e
d40ab87
70dff00
ff0b3a5
a3a5756
c66f5c1
a4b846d
f932878
1c5ff02
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,26 +4,33 @@ | |
from .disjoint_set import ds_rank_create, ds_find, ds_union_by_rank | ||
from .numba_kdtree import parallel_tree_query, rdist, point_to_node_lower_bound_rdist | ||
|
||
@numba.njit(locals={"i": numba.types.int64}) | ||
def merge_components(disjoint_set, candidate_neighbors, candidate_neighbor_distances, point_components): | ||
component_edges = {np.int64(0): (np.int64(0), np.int64(1), np.float32(0.0)) for i in range(0)} | ||
|
||
@numba.njit(locals={"parent": numba.types.int32}) | ||
def select_components(candidate_distances, candidate_neighbors, point_components): | ||
component_edges = {np.int64(0): (np.int32(0), np.int32(1), np.float32(0.0)) for i in range(0)} | ||
|
||
# Find the best edges from each component | ||
for i in range(candidate_neighbors.shape[0]): | ||
from_component = np.int64(point_components[i]) | ||
for parent, (distance, neighbor, from_component) in enumerate( | ||
zip(candidate_distances, candidate_neighbors, point_components) | ||
): | ||
if from_component in component_edges: | ||
if candidate_neighbor_distances[i] < component_edges[from_component][2]: | ||
component_edges[from_component] = (np.int64(i), np.int64(candidate_neighbors[i]), candidate_neighbor_distances[i]) | ||
if distance < component_edges[from_component][2]: | ||
component_edges[from_component] = (parent, neighbor, distance) | ||
else: | ||
component_edges[from_component] = (np.int64(i), np.int64(candidate_neighbors[i]), candidate_neighbor_distances[i]) | ||
component_edges[from_component] = (parent, neighbor, distance) | ||
|
||
return component_edges | ||
|
||
|
||
@numba.njit() | ||
def merge_components(disjoint_set, component_edges): | ||
result = np.empty((len(component_edges), 3), dtype=np.float64) | ||
result_idx = 0 | ||
|
||
# Add the best edges to the edge set and merge the relevant components | ||
for edge in component_edges.values(): | ||
from_component = ds_find(disjoint_set, np.int32(edge[0])) | ||
to_component = ds_find(disjoint_set, np.int32(edge[1])) | ||
from_component = ds_find(disjoint_set, edge[0]) | ||
to_component = ds_find(disjoint_set, edge[1]) | ||
if from_component != to_component: | ||
result[result_idx] = (np.float64(edge[0]), np.float64(edge[1]), np.float64(edge[2])) | ||
result_idx += 1 | ||
|
@@ -34,10 +41,13 @@ def merge_components(disjoint_set, candidate_neighbors, candidate_neighbor_dista | |
|
||
|
||
@numba.njit(parallel=True) | ||
def update_component_vectors(tree, disjoint_set, node_components, point_components): | ||
def update_point_components(disjoint_set, point_components): | ||
for i in numba.prange(point_components.shape[0]): | ||
point_components[i] = ds_find(disjoint_set, np.int32(i)) | ||
|
||
|
||
@numba.njit() | ||
def update_node_components(tree, node_components, point_components): | ||
Comment on lines
-37
to
+50
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly, I split |
||
for i in range(tree.node_data.shape[0] - 1, -1, -1): | ||
node_info = tree.node_data[i] | ||
|
||
|
@@ -272,28 +282,28 @@ def parallel_boruvka(tree, min_samples=10, sample_weights=None): | |
expected_neighbors = min_samples / mean_sample_weight | ||
distances, neighbors = parallel_tree_query(tree, tree.data, k=int(2 * expected_neighbors)) | ||
core_distances = sample_weight_core_distance(distances, neighbors, sample_weights, min_samples) | ||
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set) | ||
update_component_vectors(tree, components_disjoint_set, node_components, point_components) | ||
else: | ||
if min_samples > 1: | ||
distances, neighbors = parallel_tree_query(tree, tree.data, k=min_samples + 1, output_rdist=True) | ||
core_distances = distances.T[-1] | ||
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set) | ||
update_component_vectors(tree, components_disjoint_set, node_components, point_components) | ||
else: | ||
core_distances = np.zeros(tree.data.shape[0], dtype=np.float32) | ||
distances, neighbors = parallel_tree_query(tree, tree.data, k=2) | ||
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set) | ||
update_component_vectors(tree, components_disjoint_set, node_components, point_components) | ||
|
||
while n_components > 1: | ||
edges = [np.empty((0, 3), dtype=np.float64) for _ in range(0)] | ||
new_edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set) | ||
while True: | ||
edges.append(new_edges) | ||
n_components -= new_edges.shape[0] | ||
if n_components == 1: | ||
break | ||
update_point_components(components_disjoint_set, point_components) | ||
update_node_components(tree, node_components, point_components) | ||
candidate_distances, candidate_indices = boruvka_tree_query(tree, node_components, point_components, | ||
core_distances) | ||
new_edges = merge_components(components_disjoint_set, candidate_indices, candidate_distances, point_components) | ||
update_component_vectors(tree, components_disjoint_set, node_components, point_components) | ||
|
||
edges = np.vstack((edges, new_edges)) | ||
n_components = np.unique(point_components).shape[0] | ||
component_edges = select_components(candidate_distances, candidate_indices, point_components) | ||
new_edges = merge_components(components_disjoint_set, component_edges) | ||
|
||
edges = np.vstack(edges) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactor to call |
||
edges[:, 2] = np.sqrt(edges.T[2]) | ||
return edges | ||
return edges, neighbors[:, 1:], np.sqrt(core_distances) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import numpy as np | ||
from .sub_clusters import SubClusterDetector, find_sub_clusters | ||
|
||
|
||
def compute_centrality(data, probabilities, *args): | ||
points = args[-1] | ||
cluster_data = data[points, :] | ||
centroid = np.average(cluster_data, weights=probabilities[points], axis=0) | ||
return 1 / np.linalg.norm(cluster_data - centroid[None, :], axis=1) | ||
|
||
|
||
def apply_branch_threshold( | ||
labels, | ||
branch_labels, | ||
probabilities, | ||
cluster_probabilities, | ||
cluster_points, | ||
linkage_trees, | ||
label_sides_as_branches=False, | ||
): | ||
running_id = 0 | ||
min_branch_count = 1 if label_sides_as_branches else 2 | ||
for pts, tree in zip(cluster_points, linkage_trees): | ||
unique_branch_labels = np.unique(branch_labels[pts]) | ||
has_noise = int(unique_branch_labels[0] == -1) | ||
num_branches = len(unique_branch_labels) - has_noise | ||
if num_branches <= min_branch_count and tree is not None: | ||
labels[pts] = running_id | ||
probabilities[pts] = cluster_probabilities[pts] | ||
running_id += 1 | ||
else: | ||
labels[pts] = branch_labels[pts] + has_noise + running_id | ||
running_id += num_branches + has_noise | ||
|
||
|
||
def find_branch_sub_clusters( | ||
clusterer, | ||
cluster_labels=None, | ||
cluster_probabilities=None, | ||
label_sides_as_branches=False, | ||
min_cluster_size=None, | ||
max_cluster_size=None, | ||
allow_single_cluster=None, | ||
cluster_selection_method=None, | ||
cluster_selection_epsilon=0.0, | ||
cluster_selection_persistence=0.0, | ||
): | ||
result = find_sub_clusters( | ||
clusterer, | ||
cluster_labels, | ||
cluster_probabilities, | ||
lens_callback=compute_centrality, | ||
min_cluster_size=min_cluster_size, | ||
max_cluster_size=max_cluster_size, | ||
allow_single_cluster=allow_single_cluster, | ||
cluster_selection_method=cluster_selection_method, | ||
cluster_selection_epsilon=cluster_selection_epsilon, | ||
cluster_selection_persistence=cluster_selection_persistence, | ||
) | ||
apply_branch_threshold( | ||
result[0], | ||
result[4], | ||
result[1], | ||
result[3], | ||
result[-1], | ||
label_sides_as_branches=label_sides_as_branches, | ||
) | ||
return result | ||
|
||
|
||
class BranchDetector(SubClusterDetector): | ||
""" | ||
Performs a flare-detection post-processing step to detect branches within | ||
clusters [1]_. | ||
|
||
For each cluster, a graph is constructed connecting the data points based on | ||
their mutual reachability distances. Each edge is given a centrality value | ||
based on how far it lies from the cluster's center. Then, the edges are | ||
clustered as if that centrality was a distance, progressively removing the | ||
'center' of each cluster and seeing how many branches remain. | ||
|
||
References | ||
---------- | ||
.. [1] Bot, D. M., Peeters, J., Liesenborgs J., & Aerts, J. (2023, November). | ||
FLASC: A Flare-Sensitive Clustering Algorithm: Extending HDBSCAN* for | ||
Detecting Branches in Clusters. arXiv:2311.15887. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
min_cluster_size=None, | ||
max_cluster_size=None, | ||
allow_single_cluster=None, | ||
cluster_selection_method=None, | ||
cluster_selection_epsilon=0.0, | ||
cluster_selection_persistence=0.0, | ||
propagate_labels=False, | ||
label_sides_as_branches=False, | ||
): | ||
super().__init__( | ||
min_cluster_size=min_cluster_size, | ||
max_cluster_size=max_cluster_size, | ||
allow_single_cluster=allow_single_cluster, | ||
cluster_selection_method=cluster_selection_method, | ||
cluster_selection_epsilon=cluster_selection_epsilon, | ||
cluster_selection_persistence=cluster_selection_persistence, | ||
propagate_labels=propagate_labels, | ||
) | ||
self.label_sides_as_branches = label_sides_as_branches | ||
|
||
def fit(self, clusterer, labels=None, probabilities=None, sample_weight=None): | ||
super().fit(clusterer, labels, probabilities, sample_weight, compute_centrality) | ||
apply_branch_threshold( | ||
self.labels_, | ||
self.sub_cluster_labels_, | ||
self.probabilities_, | ||
self.cluster_probabilities_, | ||
self.cluster_points_, | ||
self._linkage_trees, | ||
label_sides_as_branches=self.label_sides_as_branches, | ||
) | ||
self.branch_labels_ = self.sub_cluster_labels_ | ||
self.branch_probabilities_ = self.sub_cluster_probabilities_ | ||
self.centralities_ = self.lens_values_ | ||
return self | ||
|
||
def propagated_labels(self, label_sides_as_branches=None): | ||
if label_sides_as_branches is None: | ||
label_sides_as_branches = self.label_sides_as_branches | ||
|
||
labels, branch_labels = super().propagated_labels() | ||
apply_branch_threshold( | ||
labels, | ||
branch_labels, | ||
np.zeros_like(self.probabilities_), | ||
np.zeros_like(self.probabilities_), | ||
self.cluster_points_, | ||
self._linkage_trees, | ||
label_sides_as_branches=label_sides_as_branches, | ||
) | ||
return labels, branch_labels | ||
|
||
@property | ||
def approximation_graph_(self): | ||
"""See :class:`~hdbscan.plots.ApproximationGraph` for documentation.""" | ||
return super()._make_approximation_graph( | ||
lens_name="centrality", sub_cluster_name="branch" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I split
merge_components
intoselect_components
andmerge_components
so I can re-use the latter part to compute the minimum spanning tree of a knn--mst union graph. That process needs to reject invalid neighbour values within theselect_components
part.