Skip to content

Commit 0fd9b70

Browse files
authored
expose chunk size, bump version (#82)
* expose chunk size * add rel note
1 parent 84d3a77 commit 0fd9b70

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ and this project adheres to [Semantic Versioning][].
88
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
99
[semantic versioning]: https://semver.org/spec/v2.0.0.html
1010

11+
## 0.3.1 (2022-02-16)
12+
13+
- Expose chunk size for silhouette ([#82][])
14+
15+
[#82]: https://github.com/YosefLab/scib-metrics/pull/82
16+
1117
## 0.3.0 (2022-02-16)
1218

1319
- Rename `KmeansJax` to `Kmeans` and fix ++ initialization, use Kmeans as default in benchmarker instead of Leiden ([#81][])

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ requires = ["hatchling"]
55

66
[project]
77
name = "scib-metrics"
8-
version = "0.3.0"
8+
version = "0.3.1"
99
description = "Accelerated and Python-only scIB metrics"
1010
readme = "README.md"
1111
requires-python = ">=3.8"

src/scib_metrics/_silhouette.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from scib_metrics.utils import silhouette_samples
55

66

7-
def silhouette_label(X: np.ndarray, labels: np.ndarray, rescale: bool = True) -> float:
7+
def silhouette_label(X: np.ndarray, labels: np.ndarray, rescale: bool = True, chunk_size: int = 256) -> float:
88
"""Average silhouette width (ASW) :cite:p:`luecken2022benchmarking`.
99
1010
Parameters
@@ -15,18 +15,22 @@ def silhouette_label(X: np.ndarray, labels: np.ndarray, rescale: bool = True) ->
1515
Array of shape (n_cells,) representing label values
1616
rescale
1717
Scale asw into the range [0, 1].
18+
chunk_size
19+
Size of chunks to process at a time for distance computation.
1820
1921
Returns
2022
-------
2123
silhouette score
2224
"""
23-
asw = np.mean(silhouette_samples(X, labels))
25+
asw = np.mean(silhouette_samples(X, labels, chunk_size=chunk_size))
2426
if rescale:
2527
asw = (asw + 1) / 2
2628
return np.mean(asw)
2729

2830

29-
def silhouette_batch(X: np.ndarray, labels: np.ndarray, batch: np.ndarray, rescale: bool = True) -> float:
31+
def silhouette_batch(
32+
X: np.ndarray, labels: np.ndarray, batch: np.ndarray, rescale: bool = True, chunk_size: int = 256
33+
) -> float:
3034
"""Average silhouette width (ASW) with respect to batch ids within each label :cite:p:`luecken2022benchmarking`.
3135
3236
Parameters
@@ -39,6 +43,8 @@ def silhouette_batch(X: np.ndarray, labels: np.ndarray, batch: np.ndarray, resca
3943
Array of shape (n_cells,) representing batch values
4044
rescale
4145
Scale asw into the range [0, 1]. If True, higher values are better.
46+
chunk_size
47+
Size of chunks to process at a time for distance computation.
4248
4349
Returns
4450
-------
@@ -55,7 +61,7 @@ def silhouette_batch(X: np.ndarray, labels: np.ndarray, batch: np.ndarray, resca
5561
if (n_batches == 1) or (n_batches == X_subset.shape[0]):
5662
continue
5763

58-
sil_per_group = silhouette_samples(X_subset, batch_subset)
64+
sil_per_group = silhouette_samples(X_subset, batch_subset, chunk_size=chunk_size)
5965

6066
# take only absolute value
6167
sil_per_group = np.abs(sil_per_group)

0 commit comments

Comments
 (0)