Skip to content

Commit ac29077

Browse files
committed
test: add xfailing test for MDN bug.
1 parent 54f9158 commit ac29077

File tree

1 file changed

+74
-1
lines changed

1 file changed

+74
-1
lines changed

tests/posterior_nn_test.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@
1717
SNRE_C,
1818
DirectPosterior,
1919
)
20-
from sbi.simulators.linear_gaussian import diagonal_linear_gaussian
20+
from sbi.simulators.linear_gaussian import (
21+
diagonal_linear_gaussian,
22+
linear_gaussian,
23+
true_posterior_linear_gaussian_mvn_prior,
24+
)
25+
from sbi.utils.diagnostics_utils import get_posterior_samples_on_batch
26+
from tests.test_utils import check_c2st
2127

2228

2329
@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C])
@@ -204,3 +210,70 @@ def test_batched_mcmc_sample_log_prob_with_different_x(
204210
assert torch.allclose(
205211
samples_m, samples_sep_m, atol=0.2, rtol=0.2
206212
), "Batched sampling is not consistent with separate sampling."
213+
214+
215+
@pytest.mark.slow
216+
@pytest.mark.parametrize(
217+
"density_estimator",
218+
[
219+
pytest.param(
220+
"mdn",
221+
marks=pytest.mark.xfail(
222+
raises=AssertionError, reason="Due to MDN bug in pyknos", strict=True
223+
),
224+
),
225+
"maf",
226+
"zuko_nsf",
227+
],
228+
)
229+
def test_batched_sampling_and_logprob_accuracy(density_estimator: str):
230+
"""Test with two different observations and compare to sequential methods."""
231+
num_dim = 2
232+
num_simulations = 2000
233+
num_samples = 1000
234+
sample_shape = (num_samples,)
235+
xos = torch.stack((-1.0 * ones(num_dim), 1.0 * ones(num_dim)))
236+
num_xos = xos.shape[0]
237+
238+
prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))
239+
likelihood_shift = -1.0 * ones(num_dim)
240+
likelihood_cov = 0.3 * eye(num_dim)
241+
prior_mean = zeros(num_dim)
242+
prior_cov = eye(num_dim)
243+
prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)
244+
245+
def simulator(theta):
246+
return linear_gaussian(theta, likelihood_shift, likelihood_cov)
247+
248+
inference = SNPE_C(
249+
prior=prior, show_progress_bars=False, density_estimator=density_estimator
250+
)
251+
theta = prior.sample((num_simulations,))
252+
x = simulator(theta)
253+
posterior_estimator = inference.append_simulations(theta, x).train()
254+
255+
posterior = DirectPosterior(posterior_estimator=posterior_estimator, prior=prior)
256+
257+
samples_batched = get_posterior_samples_on_batch(
258+
xos, posterior, sample_shape, use_batched_sampling=True
259+
)
260+
log_probs_batched = posterior.log_prob_batched(samples_batched, xos)
261+
262+
# check c2st for each xos
263+
for idx in range(0, num_xos):
264+
gt_posterior = true_posterior_linear_gaussian_mvn_prior(
265+
xos[idx], likelihood_shift, likelihood_cov, prior_mean, prior_cov
266+
)
267+
target_samples = gt_posterior.sample((num_samples,))
268+
check_c2st(
269+
target_samples,
270+
samples_batched[:, idx],
271+
alg=f"c2st-batch-vs-non-batch-{density_estimator}-x-idx{idx}",
272+
)
273+
274+
target_log_probs = gt_posterior.log_prob(samples_batched[:, idx])
275+
log_probs = posterior.log_prob(samples_batched[:, idx], xos[idx])
276+
assert torch.allclose(log_probs, log_probs_batched[:, idx])
277+
assert torch.allclose(
278+
target_log_probs.exp(), log_probs.exp(), atol=0.4, rtol=0.4
279+
), "Batched log probs are not consistent with non-batched log probs."

0 commit comments

Comments
 (0)