|
17 | 17 | SNRE_C,
|
18 | 18 | DirectPosterior,
|
19 | 19 | )
|
20 |
| -from sbi.simulators.linear_gaussian import diagonal_linear_gaussian |
| 20 | +from sbi.simulators.linear_gaussian import ( |
| 21 | + diagonal_linear_gaussian, |
| 22 | + linear_gaussian, |
| 23 | + true_posterior_linear_gaussian_mvn_prior, |
| 24 | +) |
| 25 | +from sbi.utils.diagnostics_utils import get_posterior_samples_on_batch |
| 26 | +from tests.test_utils import check_c2st |
21 | 27 |
|
22 | 28 |
|
23 | 29 | @pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C])
|
@@ -204,3 +210,70 @@ def test_batched_mcmc_sample_log_prob_with_different_x(
|
204 | 210 | assert torch.allclose(
|
205 | 211 | samples_m, samples_sep_m, atol=0.2, rtol=0.2
|
206 | 212 | ), "Batched sampling is not consistent with separate sampling."
|
| 213 | + |
| 214 | + |
| 215 | +@pytest.mark.slow |
| 216 | +@pytest.mark.parametrize( |
| 217 | + "density_estimator", |
| 218 | + [ |
| 219 | + pytest.param( |
| 220 | + "mdn", |
| 221 | + marks=pytest.mark.xfail( |
| 222 | + raises=AssertionError, reason="Due to MDN bug in pyknos", strict=True |
| 223 | + ), |
| 224 | + ), |
| 225 | + "maf", |
| 226 | + "zuko_nsf", |
| 227 | + ], |
| 228 | +) |
| 229 | +def test_batched_sampling_and_logprob_accuracy(density_estimator: str): |
| 230 | + """Test with two different observations and compare to sequential methods.""" |
| 231 | + num_dim = 2 |
| 232 | + num_simulations = 2000 |
| 233 | + num_samples = 1000 |
| 234 | + sample_shape = (num_samples,) |
| 235 | + xos = torch.stack((-1.0 * ones(num_dim), 1.0 * ones(num_dim))) |
| 236 | + num_xos = xos.shape[0] |
| 237 | + |
| 238 | + prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim)) |
| 239 | + likelihood_shift = -1.0 * ones(num_dim) |
| 240 | + likelihood_cov = 0.3 * eye(num_dim) |
| 241 | + prior_mean = zeros(num_dim) |
| 242 | + prior_cov = eye(num_dim) |
| 243 | + prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) |
| 244 | + |
| 245 | + def simulator(theta): |
| 246 | + return linear_gaussian(theta, likelihood_shift, likelihood_cov) |
| 247 | + |
| 248 | + inference = SNPE_C( |
| 249 | + prior=prior, show_progress_bars=False, density_estimator=density_estimator |
| 250 | + ) |
| 251 | + theta = prior.sample((num_simulations,)) |
| 252 | + x = simulator(theta) |
| 253 | + posterior_estimator = inference.append_simulations(theta, x).train() |
| 254 | + |
| 255 | + posterior = DirectPosterior(posterior_estimator=posterior_estimator, prior=prior) |
| 256 | + |
| 257 | + samples_batched = get_posterior_samples_on_batch( |
| 258 | + xos, posterior, sample_shape, use_batched_sampling=True |
| 259 | + ) |
| 260 | + log_probs_batched = posterior.log_prob_batched(samples_batched, xos) |
| 261 | + |
| 262 | + # check c2st for each xos |
| 263 | + for idx in range(0, num_xos): |
| 264 | + gt_posterior = true_posterior_linear_gaussian_mvn_prior( |
| 265 | + xos[idx], likelihood_shift, likelihood_cov, prior_mean, prior_cov |
| 266 | + ) |
| 267 | + target_samples = gt_posterior.sample((num_samples,)) |
| 268 | + check_c2st( |
| 269 | + target_samples, |
| 270 | + samples_batched[:, idx], |
| 271 | + alg=f"c2st-batch-vs-non-batch-{density_estimator}-x-idx{idx}", |
| 272 | + ) |
| 273 | + |
| 274 | + target_log_probs = gt_posterior.log_prob(samples_batched[:, idx]) |
| 275 | + log_probs = posterior.log_prob(samples_batched[:, idx], xos[idx]) |
| 276 | + assert torch.allclose(log_probs, log_probs_batched[:, idx]) |
| 277 | + assert torch.allclose( |
| 278 | + target_log_probs.exp(), log_probs.exp(), atol=0.4, rtol=0.4 |
| 279 | + ), "Batched log probs are not consistent with non-batched log probs." |
0 commit comments