Skip to content

Commit 935b56b

Browse files
ori-kron-wisgithub-actions[bot]meeseeksmachinemartinkim0canergen
authored
ci: Manual Backport of PR 3128 to 1.2.x (#3130)
pop get bug fix data splitter backport 1.2.x --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: ori-kron-wis <[email protected]> Co-authored-by: Lumberbot (aka Jack) <[email protected]> Co-authored-by: Martin Kim <[email protected]> Co-authored-by: Can Ergen <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Justin Hong <[email protected]> Co-authored-by: Yishen Miao <[email protected]> Co-authored-by: Ramon Viñas <[email protected]> Co-authored-by: Martin Kim <[email protected]> Co-authored-by: Ethan Weinberger <[email protected]> Co-authored-by: Ethan Weinberger <[email protected]> Co-authored-by: access <[email protected]>
1 parent 32b0b5c commit 935b56b

File tree

5 files changed

+12
-13
lines changed

5 files changed

+12
-13
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ to [Semantic Versioning]. Full commit history is available in the
2020

2121
#### Fixed
2222

23+
- Fixed batch_size pop to get in {class}`scvi.dataloaders.DataSplitter` {pr}`3128`.
24+
2325
#### Changed
2426

2527
- Updated the CI workflow with internet, private and optional tests {pr}`3082`.

docs/tutorials/notebooks

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ requires = ["hatchling"]
44

55
[project]
66
name = "scvi-tools"
7-
version = "1.2.2.post1"
7+
version = "1.2.2.post2"
88
description = "Deep probabilistic analysis of single-cell omics data."
99
readme = "README.md"
1010
requires-python = ">=3.10"

src/scvi/dataloaders/_data_splitting.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -251,15 +251,15 @@ def __init__(
251251
self.n_train, self.n_val = validate_data_split_with_external_indexing(
252252
self.adata_manager.adata.n_obs,
253253
self.external_indexing,
254-
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
254+
self.data_loader_kwargs.get("batch_size", settings.batch_size),
255255
self.drop_last,
256256
)
257257
else:
258258
self.n_train, self.n_val = validate_data_split(
259259
self.adata_manager.adata.n_obs,
260260
self.train_size,
261261
self.validation_size,
262-
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
262+
self.data_loader_kwargs.get("batch_size", settings.batch_size),
263263
self.drop_last,
264264
self.train_size_is_none,
265265
)
@@ -434,15 +434,15 @@ def setup(self, stage: str | None = None):
434434
n_labeled_train, n_labeled_val = validate_data_split_with_external_indexing(
435435
n_labeled_idx,
436436
[labeled_idx_train, labeled_idx_val, labeled_idx_test],
437-
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
437+
self.data_loader_kwargs.get("batch_size", settings.batch_size),
438438
self.drop_last,
439439
)
440440
else:
441441
n_labeled_train, n_labeled_val = validate_data_split(
442442
n_labeled_idx,
443443
self.train_size,
444444
self.validation_size,
445-
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
445+
self.data_loader_kwargs.get("batch_size", settings.batch_size),
446446
self.drop_last,
447447
self.train_size_is_none,
448448
)
@@ -475,15 +475,15 @@ def setup(self, stage: str | None = None):
475475
n_unlabeled_train, n_unlabeled_val = validate_data_split_with_external_indexing(
476476
n_unlabeled_idx,
477477
[unlabeled_idx_train, unlabeled_idx_val, unlabeled_idx_test],
478-
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
478+
self.data_loader_kwargs.get("batch_size", settings.batch_size),
479479
self.drop_last,
480480
)
481481
else:
482482
n_unlabeled_train, n_unlabeled_val = validate_data_split(
483483
n_unlabeled_idx,
484484
self.train_size,
485485
self.validation_size,
486-
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
486+
self.data_loader_kwargs.get("batch_size", settings.batch_size),
487487
self.drop_last,
488488
self.train_size_is_none,
489489
)

tests/model/test_scvi.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -474,8 +474,7 @@ def test_scvi_n_obs_error(n_latent: int = 5):
474474
with pytest.raises(ValueError):
475475
model.train(1, train_size=1.0)
476476
with pytest.raises(ValueError):
477-
# Warning is emitted if last batch less than 3 cells + failure.
478-
model.train(1, train_size=1.0, batch_size=127)
477+
model.train(1, train_size=1.0, batch_size=128)
479478
model.train(1, train_size=1.0, datasplitter_kwargs={"drop_last": True})
480479

481480
adata = synthetic_iid()
@@ -484,9 +483,7 @@ def test_scvi_n_obs_error(n_latent: int = 5):
484483
model = SCVI(adata, n_latent=n_latent)
485484
with pytest.raises(ValueError):
486485
model.train(1, train_size=0.9) # np.ceil(n_cells * 0.9) % 128 == 1
487-
model.train(
488-
1, train_size=0.9, datasplitter_kwargs={"drop_last": True}
489-
) # np.ceil(n_cells * 0.9) % 128 == 1
486+
model.train(1, train_size=0.9, datasplitter_kwargs={"drop_last": True})
490487
model.train(1)
491488
assert model.is_trained is True
492489

0 commit comments

Comments
 (0)