From 2ad4767ae80dd9982140a5763ed7b5da2c9360a1 Mon Sep 17 00:00:00 2001 From: MeeseeksMachine <39504233+meeseeksmachine@users.noreply.github.com> Date: Tue, 20 Sep 2022 19:26:10 +0200 Subject: [PATCH] Backport PR #1702: Quick fix in poisson sample() function for vae (#1703) Co-authored-by: ricomnl --- scvi/module/_vae.py | 2 +- tests/models/test_models.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/scvi/module/_vae.py b/scvi/module/_vae.py index 9699acb8fe..a0ae08f9fc 100644 --- a/scvi/module/_vae.py +++ b/scvi/module/_vae.py @@ -461,7 +461,7 @@ def sample( dist = generative_outputs["px"] if self.gene_likelihood == "poisson": - l_train = generative_outputs["px"].mu + l_train = generative_outputs["px"].rate l_train = torch.clamp(l_train, max=1e8) dist = torch.distributions.Poisson( l_train diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 6ce3f112d4..aaa7444d80 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -362,6 +362,14 @@ def test_scvi(save_path): model = SCVI(adata, gene_likelihood="nb") model.get_likelihood_parameters() + # test different gene_likelihoods + for gene_likelihood in ["zinb", "nb", "poisson"]: + model = SCVI(adata, gene_likelihood=gene_likelihood) + model.train(1, check_val_every_n_epoch=1, train_size=0.5) + model.posterior_predictive_sample() + model.get_latent_representation() + model.get_normalized_expression() + # test train callbacks work a = synthetic_iid() SCVI.setup_anndata(