From 9db9e9e3ae3fd45ee23812eee8286369ffeb8ea4 Mon Sep 17 00:00:00 2001 From: Jan Date: Tue, 3 Dec 2024 17:59:04 +0100 Subject: [PATCH 1/3] expose batched sampling option; error handling --- sbi/diagnostics/sbc.py | 3 +-- sbi/diagnostics/tarp.py | 3 +++ sbi/utils/diagnostics_utils.py | 13 ++++++++++++- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/sbi/diagnostics/sbc.py b/sbi/diagnostics/sbc.py index 776727735..ba01fb0a8 100644 --- a/sbi/diagnostics/sbc.py +++ b/sbi/diagnostics/sbc.py @@ -48,8 +48,7 @@ def run_sbc( num_workers: number of CPU cores to use in parallel for running `num_sbc_samples` inferences. show_progress_bar: whether to display a progress over sbc runs. - use_batched_sampling: whether to use batched sampling for posterior - samples. + use_batched_sampling: whether to use batched sampling for posterior samples. Returns: ranks: ranks of the ground truth parameters under the inferred diff --git a/sbi/diagnostics/tarp.py b/sbi/diagnostics/tarp.py index 3fe2f1e52..44ff114f3 100644 --- a/sbi/diagnostics/tarp.py +++ b/sbi/diagnostics/tarp.py @@ -28,6 +28,7 @@ def run_tarp( distance: Callable = l2, num_bins: Optional[int] = 30, z_score_theta: bool = True, + use_batched_sampling: bool = True, ) -> Tuple[Tensor, Tensor]: """ Estimates coverage of samples given true values thetas with the TARP method. @@ -54,6 +55,7 @@ def run_tarp( num_bins: number of bins to use for the credibility values. If ``None``, then ``num_sims // 10`` bins are used. z_score_theta : whether to normalize parameters before coverage test. + use_batched_sampling: whether to use batched sampling for posterior samples. Returns: ecp: Expected coverage probability (``ecp``), see equation 4 of the paper @@ -67,6 +69,7 @@ def run_tarp( (num_posterior_samples,), num_workers, show_progress_bar=show_progress_bar, + use_batched_sampling=use_batched_sampling, ) assert posterior_samples.shape == ( num_posterior_samples, diff --git a/sbi/utils/diagnostics_utils.py b/sbi/utils/diagnostics_utils.py index e68e1ab79..c783316fd 100644 --- a/sbi/utils/diagnostics_utils.py +++ b/sbi/utils/diagnostics_utils.py @@ -1,9 +1,12 @@ +import warnings + import torch from joblib import Parallel, delayed from torch import Tensor from tqdm import tqdm from sbi.inference.posteriors.base_posterior import NeuralPosterior +from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior from sbi.inference.posteriors.vi_posterior import VIPosterior from sbi.sbi_types import Shape @@ -40,7 +43,7 @@ def get_posterior_samples_on_batch( ) else: raise NotImplementedError - except NotImplementedError: + except (NotImplementedError, AssertionError): # We need a function with extra training step for new x for VIPosterior. def sample_fun( posterior: NeuralPosterior, sample_shape: Shape, x: Tensor, seed: int = 0 @@ -51,6 +54,14 @@ def sample_fun( torch.manual_seed(seed) return posterior.sample(sample_shape, x=x, show_progress_bars=False) + if isinstance(posterior, (VIPosterior, MCMCPosterior)): + warnings.warn( + "Using non-batched sampling. Depending on the number of different xs " + f"( {batch_size}) and the number of parallel workers {num_workers}, " + "this might be slow.", + stacklevel=2, + ) + # Run in parallel with progress bar. seeds = torch.randint(0, 2**32, (batch_size,)) outputs = list( From 6dd4a22ff99f5fe7f2c3fe0a3044daaf0e73c865 Mon Sep 17 00:00:00 2001 From: Jan Date: Thu, 5 Dec 2024 13:59:10 +0100 Subject: [PATCH 2/3] further improvements --- sbi/diagnostics/sbc.py | 6 +++- sbi/diagnostics/tarp.py | 4 +++ sbi/utils/diagnostics_utils.py | 55 ++++++++++++++++++++++------------ 3 files changed, 45 insertions(+), 20 deletions(-) diff --git a/sbi/diagnostics/sbc.py b/sbi/diagnostics/sbc.py index ba01fb0a8..59db21450 100644 --- a/sbi/diagnostics/sbc.py +++ b/sbi/diagnostics/sbc.py @@ -2,7 +2,7 @@ # under the Apache License Version 2.0, see import warnings -from typing import Callable, Dict, List, Union +from typing import Callable, Dict, List, Optional, Union import torch from scipy.stats import kstest, uniform @@ -26,6 +26,7 @@ def run_sbc( num_workers: int = 1, show_progress_bar: bool = True, use_batched_sampling: bool = True, + batch_size: Optional[int] = None, **kwargs, ): """Run simulation-based calibration (SBC) (parallelized across sbc runs). @@ -49,6 +50,8 @@ def run_sbc( `num_sbc_samples` inferences. show_progress_bar: whether to display a progress over sbc runs. use_batched_sampling: whether to use batched sampling for posterior samples. + batch_size: batch size for batched sampling. Useful for batched sampling with + large batches of xs for avoiding memory overflow. Returns: ranks: ranks of the ground truth parameters under the inferred @@ -89,6 +92,7 @@ def run_sbc( num_workers, show_progress_bar, use_batched_sampling=use_batched_sampling, + batch_size=batch_size, ) # take a random draw from each posterior to get data averaged posterior samples. diff --git a/sbi/diagnostics/tarp.py b/sbi/diagnostics/tarp.py index 44ff114f3..d9226fb8f 100644 --- a/sbi/diagnostics/tarp.py +++ b/sbi/diagnostics/tarp.py @@ -29,6 +29,7 @@ def run_tarp( num_bins: Optional[int] = 30, z_score_theta: bool = True, use_batched_sampling: bool = True, + batch_size: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: """ Estimates coverage of samples given true values thetas with the TARP method. @@ -56,6 +57,8 @@ def run_tarp( If ``None``, then ``num_sims // 10`` bins are used. z_score_theta : whether to normalize parameters before coverage test. use_batched_sampling: whether to use batched sampling for posterior samples. + batch_size: batch size for batched sampling. Useful for batched sampling with + large batches of xs for avoiding memory overflow. Returns: ecp: Expected coverage probability (``ecp``), see equation 4 of the paper @@ -70,6 +73,7 @@ def run_tarp( num_workers, show_progress_bar=show_progress_bar, use_batched_sampling=use_batched_sampling, + batch_size=batch_size, ) assert posterior_samples.shape == ( num_posterior_samples, diff --git a/sbi/utils/diagnostics_utils.py b/sbi/utils/diagnostics_utils.py index c783316fd..b3ad038c1 100644 --- a/sbi/utils/diagnostics_utils.py +++ b/sbi/utils/diagnostics_utils.py @@ -1,4 +1,5 @@ import warnings +from typing import Optional import torch from joblib import Parallel, delayed @@ -18,6 +19,7 @@ def get_posterior_samples_on_batch( num_workers: int = 1, show_progress_bar: bool = False, use_batched_sampling: bool = True, + batch_size: Optional[int] = None, ) -> Tensor: """Get posterior samples for a batch of xs. @@ -28,22 +30,37 @@ def get_posterior_samples_on_batch( num_workers: number of workers to use for parallelization. show_progress_bars: whether to show progress bars. use_batched_sampling: whether to use batched sampling if possible. - + batch_size: batch size for batched sampling. Useful for batched sampling with + large batches of xs for avoiding memory overflow. Returns: posterior_samples: of shape (num_samples, batch_size, dim_parameters). """ - batch_size = len(xs) + num_xs = len(xs) + if batch_size is None: + batch_size = num_xs - # Try using batched sampling when implemented. - try: - # has shape (num_samples, batch_size, dim_parameters) - if use_batched_sampling: - posterior_samples = posterior.sample_batched( - sample_shape, x=xs, show_progress_bars=show_progress_bar + if use_batched_sampling: + try: + # distribute the batch of xs into smaller batches + batched_xs = xs.split(batch_size) + posterior_samples = torch.cat( + [ # has shape (num_samples, num_xs, dim_parameters) + posterior.sample_batched( + sample_shape, x=xs_batch, show_progress_bars=show_progress_bar + ) + for xs_batch in batched_xs + ], + dim=1, ) - else: - raise NotImplementedError - except (NotImplementedError, AssertionError): + except (NotImplementedError, AssertionError): + warnings.warn( + "Batched sampling not implemented for this posterior. " + "Falling back to non-batched sampling.", + stacklevel=2, + ) + use_batched_sampling = False + + if not use_batched_sampling: # We need a function with extra training step for new x for VIPosterior. def sample_fun( posterior: NeuralPosterior, sample_shape: Shape, x: Tensor, seed: int = 0 @@ -57,13 +74,13 @@ def sample_fun( if isinstance(posterior, (VIPosterior, MCMCPosterior)): warnings.warn( "Using non-batched sampling. Depending on the number of different xs " - f"( {batch_size}) and the number of parallel workers {num_workers}, " - "this might be slow.", + f"( {num_xs}) and the number of parallel workers {num_workers}, " + "this might take a lot of time.", stacklevel=2, ) # Run in parallel with progress bar. - seeds = torch.randint(0, 2**32, (batch_size,)) + seeds = torch.randint(0, 2**32, (num_xs,)) outputs = list( tqdm( Parallel(return_as="generator", n_jobs=num_workers)( @@ -72,7 +89,7 @@ def sample_fun( ), disable=not show_progress_bar, total=len(xs), - desc=f"Sampling {batch_size} times {sample_shape} posterior samples.", + desc=f"Sampling {num_xs} times {sample_shape} posterior samples.", ) ) # (batch_size, num_samples, dim_parameters) # Transpose to shape convention: (sample_shape, batch_size, dim_parameters) @@ -81,8 +98,8 @@ def sample_fun( ).permute(1, 0, 2) assert posterior_samples.shape[:2] == sample_shape + ( - batch_size, - ), f"""Expected batched posterior samples of shape { - sample_shape + (batch_size,) - } got {posterior_samples.shape[:2]}.""" + num_xs, + ), f"""Expected batched posterior samples of shape {sample_shape + (num_xs,)} got { + posterior_samples.shape[:2] + }.""" return posterior_samples From 4c69fa455b3108da688b31d801d039572d5c2bfb Mon Sep 17 00:00:00 2001 From: Jan Date: Tue, 10 Dec 2024 18:06:59 +0100 Subject: [PATCH 3/3] undo batch_size option --- sbi/diagnostics/sbc.py | 6 +----- sbi/diagnostics/tarp.py | 4 ---- sbi/utils/diagnostics_utils.py | 20 ++++---------------- 3 files changed, 5 insertions(+), 25 deletions(-) diff --git a/sbi/diagnostics/sbc.py b/sbi/diagnostics/sbc.py index 59db21450..ba01fb0a8 100644 --- a/sbi/diagnostics/sbc.py +++ b/sbi/diagnostics/sbc.py @@ -2,7 +2,7 @@ # under the Apache License Version 2.0, see import warnings -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Union import torch from scipy.stats import kstest, uniform @@ -26,7 +26,6 @@ def run_sbc( num_workers: int = 1, show_progress_bar: bool = True, use_batched_sampling: bool = True, - batch_size: Optional[int] = None, **kwargs, ): """Run simulation-based calibration (SBC) (parallelized across sbc runs). @@ -50,8 +49,6 @@ def run_sbc( `num_sbc_samples` inferences. show_progress_bar: whether to display a progress over sbc runs. use_batched_sampling: whether to use batched sampling for posterior samples. - batch_size: batch size for batched sampling. Useful for batched sampling with - large batches of xs for avoiding memory overflow. Returns: ranks: ranks of the ground truth parameters under the inferred @@ -92,7 +89,6 @@ def run_sbc( num_workers, show_progress_bar, use_batched_sampling=use_batched_sampling, - batch_size=batch_size, ) # take a random draw from each posterior to get data averaged posterior samples. diff --git a/sbi/diagnostics/tarp.py b/sbi/diagnostics/tarp.py index d9226fb8f..44ff114f3 100644 --- a/sbi/diagnostics/tarp.py +++ b/sbi/diagnostics/tarp.py @@ -29,7 +29,6 @@ def run_tarp( num_bins: Optional[int] = 30, z_score_theta: bool = True, use_batched_sampling: bool = True, - batch_size: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: """ Estimates coverage of samples given true values thetas with the TARP method. @@ -57,8 +56,6 @@ def run_tarp( If ``None``, then ``num_sims // 10`` bins are used. z_score_theta : whether to normalize parameters before coverage test. use_batched_sampling: whether to use batched sampling for posterior samples. - batch_size: batch size for batched sampling. Useful for batched sampling with - large batches of xs for avoiding memory overflow. Returns: ecp: Expected coverage probability (``ecp``), see equation 4 of the paper @@ -73,7 +70,6 @@ def run_tarp( num_workers, show_progress_bar=show_progress_bar, use_batched_sampling=use_batched_sampling, - batch_size=batch_size, ) assert posterior_samples.shape == ( num_posterior_samples, diff --git a/sbi/utils/diagnostics_utils.py b/sbi/utils/diagnostics_utils.py index b3ad038c1..d53cd536c 100644 --- a/sbi/utils/diagnostics_utils.py +++ b/sbi/utils/diagnostics_utils.py @@ -1,5 +1,4 @@ import warnings -from typing import Optional import torch from joblib import Parallel, delayed @@ -19,7 +18,6 @@ def get_posterior_samples_on_batch( num_workers: int = 1, show_progress_bar: bool = False, use_batched_sampling: bool = True, - batch_size: Optional[int] = None, ) -> Tensor: """Get posterior samples for a batch of xs. @@ -30,27 +28,17 @@ def get_posterior_samples_on_batch( num_workers: number of workers to use for parallelization. show_progress_bars: whether to show progress bars. use_batched_sampling: whether to use batched sampling if possible. - batch_size: batch size for batched sampling. Useful for batched sampling with - large batches of xs for avoiding memory overflow. + Returns: posterior_samples: of shape (num_samples, batch_size, dim_parameters). """ num_xs = len(xs) - if batch_size is None: - batch_size = num_xs if use_batched_sampling: try: - # distribute the batch of xs into smaller batches - batched_xs = xs.split(batch_size) - posterior_samples = torch.cat( - [ # has shape (num_samples, num_xs, dim_parameters) - posterior.sample_batched( - sample_shape, x=xs_batch, show_progress_bars=show_progress_bar - ) - for xs_batch in batched_xs - ], - dim=1, + # has shape (num_samples, num_xs, dim_parameters) + posterior_samples = posterior.sample_batched( + sample_shape, x=xs, show_progress_bars=show_progress_bar ) except (NotImplementedError, AssertionError): warnings.warn(