4
4
from scib_metrics .utils import silhouette_samples
5
5
6
6
7
- def silhouette_label (X : np .ndarray , labels : np .ndarray , rescale : bool = True ) -> float :
7
+ def silhouette_label (X : np .ndarray , labels : np .ndarray , rescale : bool = True , chunk_size : int = 256 ) -> float :
8
8
"""Average silhouette width (ASW) :cite:p:`luecken2022benchmarking`.
9
9
10
10
Parameters
@@ -15,18 +15,22 @@ def silhouette_label(X: np.ndarray, labels: np.ndarray, rescale: bool = True) ->
15
15
Array of shape (n_cells,) representing label values
16
16
rescale
17
17
Scale asw into the range [0, 1].
18
+ chunk_size
19
+ Size of chunks to process at a time for distance computation.
18
20
19
21
Returns
20
22
-------
21
23
silhouette score
22
24
"""
23
- asw = np .mean (silhouette_samples (X , labels ))
25
+ asw = np .mean (silhouette_samples (X , labels , chunk_size = chunk_size ))
24
26
if rescale :
25
27
asw = (asw + 1 ) / 2
26
28
return np .mean (asw )
27
29
28
30
29
- def silhouette_batch (X : np .ndarray , labels : np .ndarray , batch : np .ndarray , rescale : bool = True ) -> float :
31
+ def silhouette_batch (
32
+ X : np .ndarray , labels : np .ndarray , batch : np .ndarray , rescale : bool = True , chunk_size : int = 256
33
+ ) -> float :
30
34
"""Average silhouette width (ASW) with respect to batch ids within each label :cite:p:`luecken2022benchmarking`.
31
35
32
36
Parameters
@@ -39,6 +43,8 @@ def silhouette_batch(X: np.ndarray, labels: np.ndarray, batch: np.ndarray, resca
39
43
Array of shape (n_cells,) representing batch values
40
44
rescale
41
45
Scale asw into the range [0, 1]. If True, higher values are better.
46
+ chunk_size
47
+ Size of chunks to process at a time for distance computation.
42
48
43
49
Returns
44
50
-------
@@ -55,7 +61,7 @@ def silhouette_batch(X: np.ndarray, labels: np.ndarray, batch: np.ndarray, resca
55
61
if (n_batches == 1 ) or (n_batches == X_subset .shape [0 ]):
56
62
continue
57
63
58
- sil_per_group = silhouette_samples (X_subset , batch_subset )
64
+ sil_per_group = silhouette_samples (X_subset , batch_subset , chunk_size = chunk_size )
59
65
60
66
# take only absolute value
61
67
sil_per_group = np .abs (sil_per_group )
0 commit comments