Skip to content

Commit

Permalink
expose batched sampling option in diagnostics (#1321)
Browse files Browse the repository at this point in the history
* expose batched sampling option; error handling

* further improvements

* undo batch_size option
  • Loading branch information
janfb authored Dec 13, 2024
1 parent 06890eb commit d3f22b5
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 16 deletions.
3 changes: 1 addition & 2 deletions sbi/diagnostics/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ def run_sbc(
num_workers: number of CPU cores to use in parallel for running
`num_sbc_samples` inferences.
show_progress_bar: whether to display a progress over sbc runs.
use_batched_sampling: whether to use batched sampling for posterior
samples.
use_batched_sampling: whether to use batched sampling for posterior samples.
Returns:
ranks: ranks of the ground truth parameters under the inferred
Expand Down
3 changes: 3 additions & 0 deletions sbi/diagnostics/tarp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def run_tarp(
distance: Callable = l2,
num_bins: Optional[int] = 30,
z_score_theta: bool = True,
use_batched_sampling: bool = True,
) -> Tuple[Tensor, Tensor]:
"""
Estimates coverage of samples given true values thetas with the TARP method.
Expand All @@ -54,6 +55,7 @@ def run_tarp(
num_bins: number of bins to use for the credibility values.
If ``None``, then ``num_sims // 10`` bins are used.
z_score_theta : whether to normalize parameters before coverage test.
use_batched_sampling: whether to use batched sampling for posterior samples.
Returns:
ecp: Expected coverage probability (``ecp``), see equation 4 of the paper
Expand All @@ -67,6 +69,7 @@ def run_tarp(
(num_posterior_samples,),
num_workers,
show_progress_bar=show_progress_bar,
use_batched_sampling=use_batched_sampling,
)
assert posterior_samples.shape == (
num_posterior_samples,
Expand Down
44 changes: 30 additions & 14 deletions sbi/utils/diagnostics_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import warnings

import torch
from joblib import Parallel, delayed
from torch import Tensor
from tqdm import tqdm

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior
from sbi.inference.posteriors.vi_posterior import VIPosterior
from sbi.sbi_types import Shape

Expand All @@ -29,18 +32,23 @@ def get_posterior_samples_on_batch(
Returns:
posterior_samples: of shape (num_samples, batch_size, dim_parameters).
"""
batch_size = len(xs)
num_xs = len(xs)

# Try using batched sampling when implemented.
try:
# has shape (num_samples, batch_size, dim_parameters)
if use_batched_sampling:
if use_batched_sampling:
try:
# has shape (num_samples, num_xs, dim_parameters)
posterior_samples = posterior.sample_batched(
sample_shape, x=xs, show_progress_bars=show_progress_bar
)
else:
raise NotImplementedError
except NotImplementedError:
except (NotImplementedError, AssertionError):
warnings.warn(
"Batched sampling not implemented for this posterior. "
"Falling back to non-batched sampling.",
stacklevel=2,
)
use_batched_sampling = False

if not use_batched_sampling:
# We need a function with extra training step for new x for VIPosterior.
def sample_fun(
posterior: NeuralPosterior, sample_shape: Shape, x: Tensor, seed: int = 0
Expand All @@ -51,8 +59,16 @@ def sample_fun(
torch.manual_seed(seed)
return posterior.sample(sample_shape, x=x, show_progress_bars=False)

if isinstance(posterior, (VIPosterior, MCMCPosterior)):
warnings.warn(
"Using non-batched sampling. Depending on the number of different xs "
f"( {num_xs}) and the number of parallel workers {num_workers}, "
"this might take a lot of time.",
stacklevel=2,
)

# Run in parallel with progress bar.
seeds = torch.randint(0, 2**32, (batch_size,))
seeds = torch.randint(0, 2**32, (num_xs,))
outputs = list(
tqdm(
Parallel(return_as="generator", n_jobs=num_workers)(
Expand All @@ -61,7 +77,7 @@ def sample_fun(
),
disable=not show_progress_bar,
total=len(xs),
desc=f"Sampling {batch_size} times {sample_shape} posterior samples.",
desc=f"Sampling {num_xs} times {sample_shape} posterior samples.",
)
) # (batch_size, num_samples, dim_parameters)
# Transpose to shape convention: (sample_shape, batch_size, dim_parameters)
Expand All @@ -70,8 +86,8 @@ def sample_fun(
).permute(1, 0, 2)

assert posterior_samples.shape[:2] == sample_shape + (
batch_size,
), f"""Expected batched posterior samples of shape {
sample_shape + (batch_size,)
} got {posterior_samples.shape[:2]}."""
num_xs,
), f"""Expected batched posterior samples of shape {sample_shape + (num_xs,)} got {
posterior_samples.shape[:2]
}."""
return posterior_samples

0 comments on commit d3f22b5

Please sign in to comment.