diff --git a/sbi/diagnostics/sbc.py b/sbi/diagnostics/sbc.py index 59db21450..ba01fb0a8 100644 --- a/sbi/diagnostics/sbc.py +++ b/sbi/diagnostics/sbc.py @@ -2,7 +2,7 @@ # under the Apache License Version 2.0, see import warnings -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Union import torch from scipy.stats import kstest, uniform @@ -26,7 +26,6 @@ def run_sbc( num_workers: int = 1, show_progress_bar: bool = True, use_batched_sampling: bool = True, - batch_size: Optional[int] = None, **kwargs, ): """Run simulation-based calibration (SBC) (parallelized across sbc runs). @@ -50,8 +49,6 @@ def run_sbc( `num_sbc_samples` inferences. show_progress_bar: whether to display a progress over sbc runs. use_batched_sampling: whether to use batched sampling for posterior samples. - batch_size: batch size for batched sampling. Useful for batched sampling with - large batches of xs for avoiding memory overflow. Returns: ranks: ranks of the ground truth parameters under the inferred @@ -92,7 +89,6 @@ def run_sbc( num_workers, show_progress_bar, use_batched_sampling=use_batched_sampling, - batch_size=batch_size, ) # take a random draw from each posterior to get data averaged posterior samples. diff --git a/sbi/diagnostics/tarp.py b/sbi/diagnostics/tarp.py index d9226fb8f..44ff114f3 100644 --- a/sbi/diagnostics/tarp.py +++ b/sbi/diagnostics/tarp.py @@ -29,7 +29,6 @@ def run_tarp( num_bins: Optional[int] = 30, z_score_theta: bool = True, use_batched_sampling: bool = True, - batch_size: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: """ Estimates coverage of samples given true values thetas with the TARP method. @@ -57,8 +56,6 @@ def run_tarp( If ``None``, then ``num_sims // 10`` bins are used. z_score_theta : whether to normalize parameters before coverage test. use_batched_sampling: whether to use batched sampling for posterior samples. - batch_size: batch size for batched sampling. Useful for batched sampling with - large batches of xs for avoiding memory overflow. Returns: ecp: Expected coverage probability (``ecp``), see equation 4 of the paper @@ -73,7 +70,6 @@ def run_tarp( num_workers, show_progress_bar=show_progress_bar, use_batched_sampling=use_batched_sampling, - batch_size=batch_size, ) assert posterior_samples.shape == ( num_posterior_samples, diff --git a/sbi/utils/diagnostics_utils.py b/sbi/utils/diagnostics_utils.py index b3ad038c1..d53cd536c 100644 --- a/sbi/utils/diagnostics_utils.py +++ b/sbi/utils/diagnostics_utils.py @@ -1,5 +1,4 @@ import warnings -from typing import Optional import torch from joblib import Parallel, delayed @@ -19,7 +18,6 @@ def get_posterior_samples_on_batch( num_workers: int = 1, show_progress_bar: bool = False, use_batched_sampling: bool = True, - batch_size: Optional[int] = None, ) -> Tensor: """Get posterior samples for a batch of xs. @@ -30,27 +28,17 @@ def get_posterior_samples_on_batch( num_workers: number of workers to use for parallelization. show_progress_bars: whether to show progress bars. use_batched_sampling: whether to use batched sampling if possible. - batch_size: batch size for batched sampling. Useful for batched sampling with - large batches of xs for avoiding memory overflow. + Returns: posterior_samples: of shape (num_samples, batch_size, dim_parameters). """ num_xs = len(xs) - if batch_size is None: - batch_size = num_xs if use_batched_sampling: try: - # distribute the batch of xs into smaller batches - batched_xs = xs.split(batch_size) - posterior_samples = torch.cat( - [ # has shape (num_samples, num_xs, dim_parameters) - posterior.sample_batched( - sample_shape, x=xs_batch, show_progress_bars=show_progress_bar - ) - for xs_batch in batched_xs - ], - dim=1, + # has shape (num_samples, num_xs, dim_parameters) + posterior_samples = posterior.sample_batched( + sample_shape, x=xs, show_progress_bars=show_progress_bar ) except (NotImplementedError, AssertionError): warnings.warn(