diff --git a/sbi/diagnostics/tarp.py b/sbi/diagnostics/tarp.py new file mode 100644 index 000000000..e606bd01a --- /dev/null +++ b/sbi/diagnostics/tarp.py @@ -0,0 +1,406 @@ +""" +Implementation taken from Lemos et al, 'Sampling-Based Accuracy Testing of +Posterior Estimators for General Inference' https://arxiv.org/abs/2302.03026 + +The TARP diagnostic is a global diagnostic which can be used to check a +trained posterior against a set of true values of theta. +""" + +from typing import Callable, Optional, Tuple + +import matplotlib.pyplot as plt +import torch +from matplotlib.axes import Axes +from matplotlib.figure import Figure +from scipy.stats import kstest +from torch import Tensor +from tqdm.auto import tqdm + +from sbi.inference.posteriors.base_posterior import NeuralPosterior +from sbi.inference.posteriors.vi_posterior import VIPosterior +from sbi.utils.metrics import l2 + + +# TODO: can be replaced by batched sampling for DirectPosterior. +def _infer_posterior_on_batch( + xs: Tensor, + posterior: NeuralPosterior, + num_posterior_samples: int = 1000, +) -> Tensor: + """ + Infer samples of a posterior distribution on a batch of inputs. + + Parameters: + ---------- + xs : Tensor + The input data batch. + posterior : NeuralPosterior + The neural posterior to use for inference. + num_posterior_samples : int, optional + The number of posterior samples to draw for each input, by default 1000. + + Returns: + ------- + Tensor + A tensor of shape (num_posterior_samples, N, P) where N is the number of + samples given by xs and P is the output dimension of the neural + posterior estimator. + """ + + samples = [] + + for idx in range(xs.shape[0]): + # unsqueeze for potential higher-dimensional data. + xo = xs[idx].unsqueeze(0) + # VI posterior needs to be trained on the current xo. + if isinstance(posterior, VIPosterior): + posterior.set_default_x(xo) + posterior.train() + + # Draw posterior samples and save one for the data average posterior. + ths = posterior.sample((num_posterior_samples,), x=xo, show_progress_bars=False) + # Note: one could calculate coverage values here + + samples.append(ths.unsqueeze(1)) + + return torch.cat(samples, dim=1) + + +# this function currently does not perform any TARP related operation +# the purpose of the function is (a) to align with the sbc interface and +# (b) to provide the data which is required to run TARP +# NOTE: this function needs to be removed once better alternatives exist. +# TODO: Can be integrated into tarp method loop, and into sbc function. +def _prepare_estimates( + xs: Tensor, + posterior: NeuralPosterior, + num_posterior_samples: int = 1000, + num_workers: int = 1, + infer_batch_size: int = 1, + show_progress_bar: bool = True, +) -> Tensor: + """ + Perform inference on batched x values using the provided posterior. + the purpose of the function is (a) to align with the sbc interface and + (b) to provide the data which is required to run TARP. + + Args: + xs: observed data for tarp, simulated from thetas. + posterior: a posterior obtained from sbi. + num_posterior_samples: number of approximate posterior samples used + for ranking. + num_workers: number of CPU cores to use in parallel for running + infer_batch_size inferences. Currently throws an exception, will be + handled upstream. + infer_batch_size: batch size for workers. + show_progress_bar: whether to display a progress bar + + Returns: + samples: posterior samples obtained by performing inference on xs + under the posterior + + """ + num_sim_samples = xs.shape[0] + xs_batches = torch.split(xs, infer_batch_size, dim=0) + + if num_workers != 1: + raise NotImplementedError('parallel execution is currently not implemented') + else: + pbar = tqdm( + total=num_sim_samples, + disable=not show_progress_bar, + desc=f"Running {num_sim_samples} samples for tarp analysis.", + ) + + with pbar: + samples = [] + for xs_batch in xs_batches: + samples.append( + _infer_posterior_on_batch( + xs_batch, posterior, num_posterior_samples + ) + ) + pbar.update(infer_batch_size) + samples = torch.cat(samples, dim=1) + + return samples + + +def _check_references( + thetas: Tensor, references: Optional[Tensor] = None, rng_seed: Optional[int] = None +) -> Tensor: + """ + construct or correct references tensor required to perform TARP + + Args: + thetas: The ground truth theta values. + references: reference values for theta drawn from an arbitrary + distribution. According to the TARP paper, it is irrelevant + how these values are produced. The tensor needs to be of + the same shape as thetas. + rng_seed: Seed for the ``torch.random`` random number generator. + If rng_seed is None, no seed is set. + """ + + if not isinstance(references, Tensor): + if not isinstance(rng_seed, type(None)): + torch.random.manual_seed(rng_seed) + + # obtain min/max per dimension of theta + lo = thetas.min(dim=0).values # min for each theta dimension + hi = thetas.max(dim=0).values # max for each theta dimension + + refpdf = torch.distributions.Uniform(low=lo, high=hi) + # sample one reference point for each entry in theta + references = refpdf.sample(torch.Size([thetas.shape[0]])) + else: + if len(references.shape) == 2: + # add singleton dimension in front + references = references.unsqueeze(0) + + if len(references.shape) == 3 and references.shape[0] != 1: + raise ValueError( + f"""references must be a 2D array with a singular first + dimension, received {references.shape}""" + ) + + assert references.shape[-2:] == thetas.shape[-2:], f"""shape mismatch between + references {references.shape} and ground truth theta {thetas.shape}""" + + return references + + +def _run_tarp( + samples: Tensor, + theta: Tensor, + references: Optional[Tensor] = None, + distance: Callable = l2, + num_bins: Optional[int] = 30, + do_norm: bool = False, + rng_seed: Optional[int] = None, +) -> Tuple[Tensor, Tensor]: + """ + Estimates coverage of samples given true values theta with the TARP method. + Reference: `Lemos, Coogan et al 2023 `_ + + The TARP diagnostic is a global diagnostic which can be used to check a + trained posterior against a set of true values of theta. + + Args: + samples: The predicted parameter samples to compute the coverage of, + these samples are expected to have shape + ``(num_samples, num_sims, num_dims)``. These are obtained by + sampling a trained posterior `num_samples` times. Multiple + (posterior) samples for one observation are encouraged. + theta: The true parameter value theta. Theta is expected to + have shape ``(num_sims, num_dims)``. + references: the reference points to use for the coverage regions, with + shape ``(1, num_sims, num_dims)``, or ``None``. + If ``None``, then reference points are chosen randomly from + the unit hypercube over the parameter space given by theta. + In other words, reference samples are drawn from the + following ``Uniform(low=theta.min(dim=-1),high=theta.max(dim=-1))``. + distance: the distance metric to use when computing the distance. + Should be a callable function that accepts two tensors and + computes the distance between them, e.g. given two tensors + of shape ``(batch, 3)`` and ``(batch,3)``, this function should + return ``(batch,1)`` distance values. + Possible values: ``sbi.utils.metrics.l1`` or + ``sbi.utils.metrics.l2``. ``l2`` is the default. + num_bins: number of bins to use for the credibility values. + If ``None``, then ``num_sims // 10`` bins are used. + do_norm : whether to normalize parameters before coverage test + (Default = True) + rng_seed : whether to set the seed of torch.random, no seed is set + if None is received + (Default = None) + + Returns: + ecp: Expected coverage probability (``ecp``), see equation 4 of the paper + alpha: credibility values, see equation 2 of the paper + + """ + # TARP assumes that the predicted thetas are sampled from the "true" + # PDF num_samples times + assert ( + len(theta.shape) == 2 + ), f"theta must be of shape (num_sims, num_dims), received {theta.shape}" + assert ( + len(samples.shape) == 3 + ), f"""samples must be of shape (num_samples, num_sims, num_dims), + received {samples.shape}""" + + assert ( + theta.shape == samples.shape[1:] + ), f"shapes of theta {theta.shape} and samples {samples.shape[1:]} do not fit" + + num_samples, num_sims, num_dims = samples.shape # samples per simulation + + if num_bins is None: + num_bins = num_sims // 10 + + if do_norm: + lo = theta.min(dim=0, keepdim=True).values # min over batch + hi = theta.max(dim=0, keepdim=True).values # max over batch + samples = (samples - lo) / (hi - lo + 1e-10) + theta = (theta - lo) / (hi - lo + 1e-10) + + references = _check_references(theta, references) + assert ( + references.shape == samples.shape[1:] + ), f"""reference.shape must match the last two dimensions of samples: + {references.shape} vs {samples.shape}.""" + + # distances between references and samples + sample_dists = distance(references, samples) + + # distances between references and true values + theta_dists = distance(references, theta) + + # compute coverage, f in algorithm 2 + coverage_values = torch.sum(sample_dists < theta_dists, dim=0) / num_samples + hist, bin_edges = torch.histogram(coverage_values, density=True, bins=num_bins) + stepsize = bin_edges[1] - bin_edges[0] + ecp = torch.cumsum(hist, dim=0) * stepsize + + return torch.cat([Tensor([0]), ecp]), bin_edges + + +def run_tarp( + thetas: Tensor, + xs: Tensor, + posterior: NeuralPosterior, + num_posterior_samples: int = 1000, + num_workers: int = 1, + show_progress_bar: bool = True, + distance: Callable = l2, + num_bins: Optional[int] = 30, + do_norm: bool = True, + rng_seed: Optional[int] = None, +) -> Tuple[Tensor, Tensor]: + """ + Estimates coverage of samples given true values thetas with the TARP method. + Reference: `Lemos, Coogan et al 2023 `_ + + The TARP diagnostic is a global diagnostic which can be used to check a + trained posterior against a set of true values of theta. + + Args: + thetas: ground-truth parameters for tarp, simulated from the prior. + xs: observed data for tarp, simulated from thetas. + posterior: a posterior obtained from sbi. + num_posterior_samples: number of approximate posterior samples used for ranking. + num_workers: number of CPU cores to use in parallel for running num_sbc_samples + inferences. + show_progress_bar: whether to display a progress over sbc runs. + distance: the distance metric to use when computing the distance. + Should be a callable function that accepts two tensors and + computes the distance between them, e.g. given two tensors + of shape ``(batch, 3)`` and ``(batch,3)``, this function should + return ``(batch,1)`` distance values. + Possible values: ``sbi.utils.metrics.l1`` or + ``sbi.utils.metrics.l2``. ``l2`` is the default. + num_bins: number of bins to use for the credibility values. + If ``None``, then ``num_sims // 10`` bins are used. + do_norm : whether to normalize parameters before coverage test + (Default = True) + rng_seed : whether to set the seed of torch.random, no seed is set + if None is received + (Default = None) + + Returns: + ecp: Expected coverage probability (``ecp``), see equation 4 of the paper + alpha: credibility values, see equation 2 of the paper + """ + + samples = _prepare_estimates( + xs, + posterior, + num_posterior_samples, + num_workers, + show_progress_bar=show_progress_bar, + ) + + ecp, alpha = _run_tarp( + samples, + thetas, + distance=distance, + num_bins=num_bins, + do_norm=do_norm, + rng_seed=rng_seed, + ) + + return ecp, alpha + + +def check_tarp( + ecp: Tensor, + alpha: Tensor, +) -> Tuple[float, float]: + r"""check the obtained TARP credibitlity levels and + expected coverage probabilities. This will help to uncover underdispersed, + well covering or overdispersed posteriors. + + Args: + ecp: expected coverage probabilities computed with the TARP method, + i.e. first output of ``run_tarp``. + alpha: credibility levels $\alpha$, i.e. second output of ``run_tarp``. + + Returns: + atc: area to curve for large values of alpha, this number should be + close to ``0``. Values larger than ``0`` indicated overdispersed + distributions (i.e. the estimated posterior is too wide). Values + smaller than ``0`` indicate underdispersed distributions (i.e. + the estimated posterior is too narrow). Note, this property of + the ecp curve can also indicate if the posterior is biased, see + figure 2 of the paper for details + (https://arxiv.org/abs/2302.03026). + ks prob: p-value for a two sample Kolmogorov-Smirnov test. The null + hypothesis of this test is that the two distributions (ecp and + alpha) are identical, i.e. are produced by one common CDF. If + they were, the p-value should be close to ``1``. Commonly, + people reject the null if p-value is below 0.05! + """ + + nentries = alpha.shape[0] + midindex = nentries // 2 + atc = float((ecp[midindex:, ...] - alpha[midindex:, ...]).sum()) + + kstest_pvals = kstest(ecp.numpy(), alpha.numpy())[1] + + return atc, kstest_pvals + + +def plot_tarp(ecp: Tensor, alpha: Tensor, title="") -> Tuple[Figure, Axes]: + """ + Plots the expected coverage probability (ECP) against the credibility + level,alpha, for a given alpha grid. + + Args: + ecp : numpy.ndarray + Array of expected coverage probabilities. + alpha : numpy.ndarray + Array of credibility levels. + title : str, optional + Title for the plot. The default is "". + + Returns + fig : matplotlib.figure.Figure + The figure object. + ax : matplotlib.axes.Axes + The axes object. + + """ + + fig = plt.figure(figsize=(6, 6)) + ax: Axes = plt.gca() + + ax.plot(alpha, ecp, color="blue", label="TARP") + ax.plot(alpha, alpha, color="black", linestyle="--", label="ideal") + ax.set_xlabel(r"Credibility Level $\alpha$") + ax.set_ylabel(r"Expected Coverage Probility") + ax.set_xlim(0.0, 1.0) + ax.set_ylim(0.0, 1.0) + ax.set_title(title) + ax.legend() + return fig, ax # type: ignore diff --git a/sbi/utils/metrics.py b/sbi/utils/metrics.py index d33a8353d..2f82153f3 100644 --- a/sbi/utils/metrics.py +++ b/sbi/utils/metrics.py @@ -522,6 +522,44 @@ def _test(): # unbiased_mmd_squared_hypothesis_test(x, y) +def l2(x: Tensor, y: Tensor, axis=-1) -> Tensor: + """ + Calculates the L2 distance between two tensors. Note, we cannot use the + torch.nn.MSELoss function as this sums across the batch dimension AND the + dimension given by . For tarp, we only require to sum across + the dimension. + + Args: + x (Tensor): The first tensor. + y (Tensor): The second tensor. + axis (int, optional): The axis along which to calculate the L2 distance. + Defaults to -1. + Returns: + Tensor: A tensor containing the L2 distance between x and y along the + specified axis. + """ + return torch.sqrt(torch.sum((x - y) ** 2, dim=axis)) + + +def l1(x: Tensor, y: Tensor, axis=-1) -> Tensor: + """ + Calculates the L1 distance between two tensors. Note, we cannot use the + torch.nn.L1Loss function as this sums across the batch dimension AND the + dimension given by . For tarp, we only require to sum across + the dimension. + + Args: + x (Tensor): The first tensor. + y (Tensor): The second tensor. + axis (int, optional): The axis along which to calculate the L1 distance. + Defaults to -1. + Returns: + Tensor: A tensor containing the L1 distance between x and y along the + specified axis. + """ + return torch.sum(torch.abs(x - y), dim=axis) + + def main(): _test() diff --git a/tests/tarp_test.py b/tests/tarp_test.py new file mode 100644 index 000000000..09a8cea86 --- /dev/null +++ b/tests/tarp_test.py @@ -0,0 +1,385 @@ +import pytest +from scipy.stats import uniform +from torch import Tensor, allclose, exp, eye, ones +from torch.distributions import Normal, Uniform +from torch.nn import L1Loss + +from sbi.diagnostics.tarp import ( + _infer_posterior_on_batch, + _prepare_estimates, + _run_tarp, + check_tarp, + run_tarp, +) +from sbi.inference import SNPE, simulate_for_sbi +from sbi.simulators import linear_gaussian +from sbi.utils import BoxUniform +from sbi.utils.metrics import l1, l2 + + +def generate_toy_gaussian(nsamples=100, nsims=100, ndims=5, covfactor=1.0): + """adopted from the tarp paper page 7, section 4.1 Gaussian Toy Model + correct case""" + + base_mean = Uniform(-5, 5) + base_log_var = Uniform(-5, -1) + + locs = base_mean.sample((nsims, ndims)) + scales = exp(base_log_var.sample((nsims, ndims))) + + spdf = Normal(loc=locs, scale=covfactor * scales) + tpdf = Normal(loc=locs, scale=scales) + + samples = spdf.sample((nsamples,)) + theta_prime = tpdf.sample() + + return theta_prime, samples + + +def biased_toy_gaussian(nsamples=100, nsims=100, ndims=5, covfactor=1.0): + """adopted from the tarp paper page 7, section 4.1 Gaussian Toy Model + correct case""" + + base_mean = Uniform(-5, 5) + base_mean_ = uniform(-5, 5) + base_log_var = Uniform(-5, -1) + + locs_ = base_mean.sample((nsims, ndims)) + scales = exp(base_log_var.sample((nsims, ndims))) + locs = locs_ - locs_.sign() * base_mean_.isf(locs_) * scales + + spdf = Normal(loc=locs, scale=covfactor * scales) + tpdf = Normal(loc=locs, scale=scales) + + samples = spdf.sample((nsamples,)) + theta_prime = tpdf.sample() + + return theta_prime, samples + + +@pytest.fixture +def onsamples(): + nsamples = 100 # samples per simulation + nsims = 100 + ndims = 5 + + return generate_toy_gaussian(nsamples, nsims, ndims) + + +@pytest.fixture +def undersamples(): + # taken from the paper page 7, section 4.1 Gaussian Toy Model underconfident case + + nsamples = 100 # samples per simulation + nsims = 100 + ndims = 5 + + return generate_toy_gaussian(nsamples, nsims, ndims, covfactor=0.25) + + +@pytest.fixture +def oversamples(): + # taken from the paper page 7, section 4.1 Gaussian Toy Model overconfident case + + nsamples = 100 # samples per simulation + nsims = 100 + ndims = 5 + + return generate_toy_gaussian(nsamples, nsims, ndims, covfactor=4.0) + + +@pytest.fixture +def biased(): + nsamples = 100 # samples per simulation + nsims = 100 + ndims = 5 + + return biased_toy_gaussian(nsamples, nsims, ndims, covfactor=2.0) + + +def test_onsamples(onsamples): + theta, samples = onsamples + + assert theta.shape == (100, 5) or theta.shape == (1, 100, 5) + assert samples.shape == (100, 100, 5) + + +def test_undersamples(undersamples): + theta, samples = undersamples + + assert theta.shape == (100, 5) or theta.shape == (1, 100, 5) + assert samples.shape == (100, 100, 5) + + +def test_biased(biased): + theta, samples = biased + + assert theta.shape == (100, 5) or theta.shape == (1, 100, 5) + assert samples.shape == (100, 100, 5) + + +###################################################################### +## test TARP library + + +def test_distances(onsamples): + theta, samples = onsamples + + obs = l2(theta, samples) + + assert obs.shape == (100, 100) + + obs = l1(theta, samples) + + assert obs.shape == (100, 100) + + # difference in reductions + l1loss = L1Loss(reduction="sum") # sum across last axis AND batch + broadcasted_theta = theta.expand(samples.shape[0], -1, -1) + exp = l1loss(broadcasted_theta, samples) # sum across last axis + + assert obs.shape != exp.shape # gives the wrong shape + + # results including expansion + theta_ = theta.expand(samples.shape[0], -1, -1) + obs_ = l1(theta_, samples) + + assert allclose(obs, obs_) + + +###################################################################### +## Reproduce Toy Examples in paper, see Section 4.1 + + +def test_run_tarp_correct(onsamples): + theta, samples = onsamples + + ecp, alpha = _run_tarp(samples, theta, num_bins=30) + + assert allclose((ecp - alpha).abs().max(), Tensor([0.0]), atol=1e-1) + + ecp, alpha = _run_tarp(samples, theta, distance=l1, num_bins=30) + + assert allclose((ecp - alpha).abs().max(), Tensor([0.0]), atol=1e-1) + + +def test_run_tarp_correct_using_norm(onsamples): + theta, samples = onsamples + + ecp, alpha = _run_tarp(samples, theta, num_bins=30, do_norm=False) + + assert allclose((ecp - alpha).abs().max(), Tensor([0.0]), atol=1e-1) + assert ( + ecp - alpha + ).abs().sum() < 1.0 # integral of residuals should vanish, fig.2 in paper + + ecp, alpha = _run_tarp(samples, theta, num_bins=30, do_norm=False, distance=l1) + + # TARP detects that this is a correct representation of the posterior + assert allclose((ecp - alpha).abs().max(), Tensor([0.0]), atol=1e-1) + + +def test_run_tarp_detect_overdispersed(oversamples): + theta, samples = oversamples + + ecp, alpha = _run_tarp(samples, theta, num_bins=30, do_norm=True) + + assert not allclose((ecp - alpha).abs().max(), Tensor([0.0]), atol=1e-1) + assert (ecp - alpha).abs().sum() > 3.0 # integral is nonzero, fig.2 in paper + + ecp, alpha = _run_tarp(samples, theta, num_bins=30, do_norm=True, distance=l1) + + # TARP detects that this is NOT a correct representation of the posterior + # hence we test for not allclose + assert not allclose((ecp - alpha).abs().max(), Tensor([0.0]), atol=1e-1) + + +def test_run_tarp_detect_underdispersed(undersamples): + theta, samples = undersamples + + ecp, alpha = _run_tarp(samples, theta, num_bins=30, do_norm=True) + + assert not allclose((ecp - alpha).abs().max(), Tensor([0.0]), atol=1e-1) + assert (ecp - alpha).abs().sum() > 3.0 # integral is nonzero, fig.2 in paper + + ecp, alpha = _run_tarp(samples, theta, num_bins=30, do_norm=True, distance=l1) + + # TARP detects that this is NOT a correct representation of the posterior + # hence we test for not allclose + assert not allclose((ecp - alpha).abs().max(), Tensor([0.0]), atol=1e-1) + + +def test_run_tarp_detect_bias(biased): + theta, samples = biased + + ecp, alpha = _run_tarp(samples, theta, num_bins=30, do_norm=True) + + assert not allclose((ecp - alpha).abs().max(), Tensor([0.0]), atol=1e-1) + assert (ecp - alpha).abs().sum() > 3.0 # integral is nonzero, fig.2 in paper + + ecp, alpha = _run_tarp(samples, theta, num_bins=30, do_norm=True, distance=l1) + + # TARP detects that this is NOT a correct representation of the posterior + # hence we test for not allclose + assert not allclose((ecp - alpha).abs().max(), Tensor([0.0]), atol=1e-1) + + +def test_check_tarp_correct(onsamples): + theta, samples = onsamples + + ecp, alpha = _run_tarp(samples, theta, num_bins=30, do_norm=False) + print("onsamples") + print("tarp results", ecp, alpha) + atc, kspvals = check_tarp(ecp, alpha) + + print("tarp checks", atc, kspvals) + assert atc != 0.0 + assert atc < 1.0 + + assert kspvals > 0.05 # samples are likely from the same PDF + + +def test_check_tarp_underdispersed(undersamples): + theta, samples = undersamples + + ecp, alpha = _run_tarp(samples, theta, num_bins=30, do_norm=False) + print("underdispersed") + print("tarp results", ecp, alpha) + atc, kspvals = check_tarp(ecp, alpha) + + print("tarp checks", atc, kspvals) + + assert atc != 0.0 + assert atc < -2.0 + # assert atc < -1.0 # TODO: need to check why this breaks + + # TODO: need to check why this breaks + assert kspvals < 0.2 # samples are unlikely from the same PDF + + +def test_check_tarp_overdispersed(oversamples): + theta, samples = oversamples + + ecp, alpha = _run_tarp(samples, theta, num_bins=50, do_norm=False) + print("overdispersed") + print("tarp results", ecp, alpha) + atc, kspvals = check_tarp(ecp, alpha) + + print("tarp checks", atc, kspvals) + + assert atc != 0.0 + assert atc > 2.0 + + assert kspvals < 0.05 # samples are unlikely from the same PDF + + +def test_check_tarp_detect_bias(biased): + theta, samples = biased + + ecp, alpha = _run_tarp(samples, theta, num_bins=30, do_norm=True) + print("biased") + print("tarp results", ecp, alpha) + atc, kspvals = check_tarp(ecp, alpha) + + print("tarp checks", atc, kspvals) + assert atc != 0.0 + assert atc > 1.0 + + assert kspvals < 0.05 # samples are unlikely from the same PDF + + +###################################################################### +## Check TARP with SBI + + +@pytest.mark.slow +@pytest.mark.parametrize("method", [SNPE]) +def test_batched_prepare_estimates(method, model="mdn"): + """Tests running inference and checking samples with tarp.""" + + num_dim = 2 + prior = BoxUniform(-ones(num_dim), ones(num_dim)) + + num_simulations = 1000 + max_num_epochs = 20 + num_tarp_runs = 100 + + likelihood_shift = -1.0 * ones(num_dim) + likelihood_cov = 0.3 * eye(num_dim) + + def simulator(theta): + return linear_gaussian(theta, likelihood_shift, likelihood_cov) + + inferer = method(prior, show_progress_bars=False, density_estimator=model) + + theta, x = simulate_for_sbi(simulator, prior, num_simulations) + + _ = inferer.append_simulations(theta, x).train( + training_batch_size=100, max_num_epochs=max_num_epochs + ) + + posterior = inferer.build_posterior() + num_posterior_samples = 256 + thetas = prior.sample((num_tarp_runs,)) + xs = simulator(thetas) + + samples = _infer_posterior_on_batch(xs, posterior, num_posterior_samples) + + assert samples.shape != thetas.shape + assert samples.shape[1:] == thetas.shape + assert samples.shape[0] == num_posterior_samples + + samples_ = _prepare_estimates( + xs, posterior, num_posterior_samples, infer_batch_size=32 + ) + + assert samples_.shape != thetas.shape + assert samples_.shape[1:] == thetas.shape + assert samples_.shape[0] == num_posterior_samples + assert samples_.shape == samples.shape + + +@pytest.mark.slow +@pytest.mark.parametrize("method", [SNPE]) +def test_consistent_run_tarp_results_with_posterior(method, model="mdn"): + """Tests running inference and checking samples with tarp.""" + + num_dim = 2 + prior = BoxUniform(-ones(num_dim), ones(num_dim)) + + num_simulations = 6000 + num_tarp_sims = 1000 + num_posterior_samples = 1000 + + likelihood_shift = -1.0 * ones(num_dim) + likelihood_cov = 0.3 * eye(num_dim) + + def simulator(theta): + return linear_gaussian(theta, likelihood_shift, likelihood_cov) + + inferer = method(prior, show_progress_bars=True, density_estimator=model) + + theta = prior.sample((num_simulations,)) + x = simulator(theta) + + _ = inferer.append_simulations(theta, x).train(training_batch_size=1000) + + posterior = inferer.build_posterior() + + thetas = prior.sample((num_tarp_sims,)) + xs = simulator(thetas) + + ecp, alpha = run_tarp( + thetas, + xs, + posterior=posterior, + num_posterior_samples=num_posterior_samples, + num_bins=30, + do_norm=True, + rng_seed=41, + ) + + atc, kspvals = check_tarp(ecp, alpha) + print(atc, kspvals) + assert -0.5 < atc < 0.5 + assert kspvals > 0.05