From d3f22b5cdad83b4d9c7cfa24ec70634a74301276 Mon Sep 17 00:00:00 2001 From: Jan Date: Fri, 13 Dec 2024 09:21:52 +0100 Subject: [PATCH] expose batched sampling option in diagnostics (#1321) * expose batched sampling option; error handling * further improvements * undo batch_size option --- sbi/diagnostics/sbc.py | 3 +-- sbi/diagnostics/tarp.py | 3 +++ sbi/utils/diagnostics_utils.py | 44 +++++++++++++++++++++++----------- 3 files changed, 34 insertions(+), 16 deletions(-) diff --git a/sbi/diagnostics/sbc.py b/sbi/diagnostics/sbc.py index 776727735..ba01fb0a8 100644 --- a/sbi/diagnostics/sbc.py +++ b/sbi/diagnostics/sbc.py @@ -48,8 +48,7 @@ def run_sbc( num_workers: number of CPU cores to use in parallel for running `num_sbc_samples` inferences. show_progress_bar: whether to display a progress over sbc runs. - use_batched_sampling: whether to use batched sampling for posterior - samples. + use_batched_sampling: whether to use batched sampling for posterior samples. Returns: ranks: ranks of the ground truth parameters under the inferred diff --git a/sbi/diagnostics/tarp.py b/sbi/diagnostics/tarp.py index 3fe2f1e52..44ff114f3 100644 --- a/sbi/diagnostics/tarp.py +++ b/sbi/diagnostics/tarp.py @@ -28,6 +28,7 @@ def run_tarp( distance: Callable = l2, num_bins: Optional[int] = 30, z_score_theta: bool = True, + use_batched_sampling: bool = True, ) -> Tuple[Tensor, Tensor]: """ Estimates coverage of samples given true values thetas with the TARP method. @@ -54,6 +55,7 @@ def run_tarp( num_bins: number of bins to use for the credibility values. If ``None``, then ``num_sims // 10`` bins are used. z_score_theta : whether to normalize parameters before coverage test. + use_batched_sampling: whether to use batched sampling for posterior samples. Returns: ecp: Expected coverage probability (``ecp``), see equation 4 of the paper @@ -67,6 +69,7 @@ def run_tarp( (num_posterior_samples,), num_workers, show_progress_bar=show_progress_bar, + use_batched_sampling=use_batched_sampling, ) assert posterior_samples.shape == ( num_posterior_samples, diff --git a/sbi/utils/diagnostics_utils.py b/sbi/utils/diagnostics_utils.py index e68e1ab79..d53cd536c 100644 --- a/sbi/utils/diagnostics_utils.py +++ b/sbi/utils/diagnostics_utils.py @@ -1,9 +1,12 @@ +import warnings + import torch from joblib import Parallel, delayed from torch import Tensor from tqdm import tqdm from sbi.inference.posteriors.base_posterior import NeuralPosterior +from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior from sbi.inference.posteriors.vi_posterior import VIPosterior from sbi.sbi_types import Shape @@ -29,18 +32,23 @@ def get_posterior_samples_on_batch( Returns: posterior_samples: of shape (num_samples, batch_size, dim_parameters). """ - batch_size = len(xs) + num_xs = len(xs) - # Try using batched sampling when implemented. - try: - # has shape (num_samples, batch_size, dim_parameters) - if use_batched_sampling: + if use_batched_sampling: + try: + # has shape (num_samples, num_xs, dim_parameters) posterior_samples = posterior.sample_batched( sample_shape, x=xs, show_progress_bars=show_progress_bar ) - else: - raise NotImplementedError - except NotImplementedError: + except (NotImplementedError, AssertionError): + warnings.warn( + "Batched sampling not implemented for this posterior. " + "Falling back to non-batched sampling.", + stacklevel=2, + ) + use_batched_sampling = False + + if not use_batched_sampling: # We need a function with extra training step for new x for VIPosterior. def sample_fun( posterior: NeuralPosterior, sample_shape: Shape, x: Tensor, seed: int = 0 @@ -51,8 +59,16 @@ def sample_fun( torch.manual_seed(seed) return posterior.sample(sample_shape, x=x, show_progress_bars=False) + if isinstance(posterior, (VIPosterior, MCMCPosterior)): + warnings.warn( + "Using non-batched sampling. Depending on the number of different xs " + f"( {num_xs}) and the number of parallel workers {num_workers}, " + "this might take a lot of time.", + stacklevel=2, + ) + # Run in parallel with progress bar. - seeds = torch.randint(0, 2**32, (batch_size,)) + seeds = torch.randint(0, 2**32, (num_xs,)) outputs = list( tqdm( Parallel(return_as="generator", n_jobs=num_workers)( @@ -61,7 +77,7 @@ def sample_fun( ), disable=not show_progress_bar, total=len(xs), - desc=f"Sampling {batch_size} times {sample_shape} posterior samples.", + desc=f"Sampling {num_xs} times {sample_shape} posterior samples.", ) ) # (batch_size, num_samples, dim_parameters) # Transpose to shape convention: (sample_shape, batch_size, dim_parameters) @@ -70,8 +86,8 @@ def sample_fun( ).permute(1, 0, 2) assert posterior_samples.shape[:2] == sample_shape + ( - batch_size, - ), f"""Expected batched posterior samples of shape { - sample_shape + (batch_size,) - } got {posterior_samples.shape[:2]}.""" + num_xs, + ), f"""Expected batched posterior samples of shape {sample_shape + (num_xs,)} got { + posterior_samples.shape[:2] + }.""" return posterior_samples