From 72b3a7dd05643056ee4364d0af25b1301de07333 Mon Sep 17 00:00:00 2001 From: Jan Boelts Date: Tue, 20 Aug 2024 13:19:58 +0200 Subject: [PATCH] fix #1224: upgrade pyknos, remove xfail test. --- pyproject.toml | 2 +- tests/posterior_nn_test.py | 14 +------------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b61e0bd12..890a02203 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "matplotlib", "numpy", "pillow", - "pyknos>=0.15.1", + "pyknos>=0.16.0", "pyro-ppl>=1.3.1", "scikit-learn", "scipy", diff --git a/tests/posterior_nn_test.py b/tests/posterior_nn_test.py index 6ed3cba47..10ecf5490 100644 --- a/tests/posterior_nn_test.py +++ b/tests/posterior_nn_test.py @@ -213,19 +213,7 @@ def test_batched_mcmc_sample_log_prob_with_different_x( @pytest.mark.slow -@pytest.mark.parametrize( - "density_estimator", - [ - pytest.param( - "mdn", - marks=pytest.mark.xfail( - raises=AssertionError, reason="Due to MDN bug in pyknos", strict=True - ), - ), - "maf", - "zuko_nsf", - ], -) +@pytest.mark.parametrize("density_estimator", ["mdn", "maf", "zuko_nsf"]) def test_batched_sampling_and_logprob_accuracy(density_estimator: str): """Test with two different observations and compare to sequential methods.""" num_dim = 2