Skip to content

Commit

Permalink
expose batched sampling option; error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Dec 3, 2024
1 parent 3bd8aa9 commit 9db9e9e
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
3 changes: 1 addition & 2 deletions sbi/diagnostics/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ def run_sbc(
num_workers: number of CPU cores to use in parallel for running
`num_sbc_samples` inferences.
show_progress_bar: whether to display a progress over sbc runs.
use_batched_sampling: whether to use batched sampling for posterior
samples.
use_batched_sampling: whether to use batched sampling for posterior samples.
Returns:
ranks: ranks of the ground truth parameters under the inferred
Expand Down
3 changes: 3 additions & 0 deletions sbi/diagnostics/tarp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def run_tarp(
distance: Callable = l2,
num_bins: Optional[int] = 30,
z_score_theta: bool = True,
use_batched_sampling: bool = True,
) -> Tuple[Tensor, Tensor]:
"""
Estimates coverage of samples given true values thetas with the TARP method.
Expand All @@ -54,6 +55,7 @@ def run_tarp(
num_bins: number of bins to use for the credibility values.
If ``None``, then ``num_sims // 10`` bins are used.
z_score_theta : whether to normalize parameters before coverage test.
use_batched_sampling: whether to use batched sampling for posterior samples.
Returns:
ecp: Expected coverage probability (``ecp``), see equation 4 of the paper
Expand All @@ -67,6 +69,7 @@ def run_tarp(
(num_posterior_samples,),
num_workers,
show_progress_bar=show_progress_bar,
use_batched_sampling=use_batched_sampling,
)
assert posterior_samples.shape == (
num_posterior_samples,
Expand Down
13 changes: 12 additions & 1 deletion sbi/utils/diagnostics_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import warnings

import torch
from joblib import Parallel, delayed
from torch import Tensor
from tqdm import tqdm

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior
from sbi.inference.posteriors.vi_posterior import VIPosterior
from sbi.sbi_types import Shape

Expand Down Expand Up @@ -40,7 +43,7 @@ def get_posterior_samples_on_batch(
)
else:
raise NotImplementedError
except NotImplementedError:
except (NotImplementedError, AssertionError):
# We need a function with extra training step for new x for VIPosterior.
def sample_fun(
posterior: NeuralPosterior, sample_shape: Shape, x: Tensor, seed: int = 0
Expand All @@ -51,6 +54,14 @@ def sample_fun(
torch.manual_seed(seed)
return posterior.sample(sample_shape, x=x, show_progress_bars=False)

if isinstance(posterior, (VIPosterior, MCMCPosterior)):
warnings.warn(
"Using non-batched sampling. Depending on the number of different xs "
f"( {batch_size}) and the number of parallel workers {num_workers}, "
"this might be slow.",
stacklevel=2,
)

# Run in parallel with progress bar.
seeds = torch.randint(0, 2**32, (batch_size,))
outputs = list(
Expand Down

0 comments on commit 9db9e9e

Please sign in to comment.