Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expose batched sampling option; error handling #1321

Merged
merged 3 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions sbi/diagnostics/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ def run_sbc(
num_workers: number of CPU cores to use in parallel for running
`num_sbc_samples` inferences.
show_progress_bar: whether to display a progress over sbc runs.
use_batched_sampling: whether to use batched sampling for posterior
samples.
use_batched_sampling: whether to use batched sampling for posterior samples.

Returns:
ranks: ranks of the ground truth parameters under the inferred
Expand Down
3 changes: 3 additions & 0 deletions sbi/diagnostics/tarp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def run_tarp(
distance: Callable = l2,
num_bins: Optional[int] = 30,
z_score_theta: bool = True,
use_batched_sampling: bool = True,
) -> Tuple[Tensor, Tensor]:
"""
Estimates coverage of samples given true values thetas with the TARP method.
Expand All @@ -54,6 +55,7 @@ def run_tarp(
num_bins: number of bins to use for the credibility values.
If ``None``, then ``num_sims // 10`` bins are used.
z_score_theta : whether to normalize parameters before coverage test.
use_batched_sampling: whether to use batched sampling for posterior samples.

Returns:
ecp: Expected coverage probability (``ecp``), see equation 4 of the paper
Expand All @@ -67,6 +69,7 @@ def run_tarp(
(num_posterior_samples,),
num_workers,
show_progress_bar=show_progress_bar,
use_batched_sampling=use_batched_sampling,
)
assert posterior_samples.shape == (
num_posterior_samples,
Expand Down
44 changes: 30 additions & 14 deletions sbi/utils/diagnostics_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import warnings

import torch
from joblib import Parallel, delayed
from torch import Tensor
from tqdm import tqdm

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior
from sbi.inference.posteriors.vi_posterior import VIPosterior
from sbi.sbi_types import Shape

Expand All @@ -29,18 +32,23 @@ def get_posterior_samples_on_batch(
Returns:
posterior_samples: of shape (num_samples, batch_size, dim_parameters).
"""
batch_size = len(xs)
num_xs = len(xs)

# Try using batched sampling when implemented.
try:
# has shape (num_samples, batch_size, dim_parameters)
if use_batched_sampling:
if use_batched_sampling:
try:
# has shape (num_samples, num_xs, dim_parameters)
posterior_samples = posterior.sample_batched(
sample_shape, x=xs, show_progress_bars=show_progress_bar
)
else:
raise NotImplementedError
except NotImplementedError:
except (NotImplementedError, AssertionError):
warnings.warn(
"Batched sampling not implemented for this posterior. "
"Falling back to non-batched sampling.",
stacklevel=2,
)
use_batched_sampling = False

if not use_batched_sampling:
# We need a function with extra training step for new x for VIPosterior.
def sample_fun(
posterior: NeuralPosterior, sample_shape: Shape, x: Tensor, seed: int = 0
Expand All @@ -51,8 +59,16 @@ def sample_fun(
torch.manual_seed(seed)
return posterior.sample(sample_shape, x=x, show_progress_bars=False)

if isinstance(posterior, (VIPosterior, MCMCPosterior)):
warnings.warn(
"Using non-batched sampling. Depending on the number of different xs "
f"( {num_xs}) and the number of parallel workers {num_workers}, "
"this might take a lot of time.",
stacklevel=2,
)

# Run in parallel with progress bar.
seeds = torch.randint(0, 2**32, (batch_size,))
seeds = torch.randint(0, 2**32, (num_xs,))
outputs = list(
tqdm(
Parallel(return_as="generator", n_jobs=num_workers)(
Expand All @@ -61,7 +77,7 @@ def sample_fun(
),
disable=not show_progress_bar,
total=len(xs),
desc=f"Sampling {batch_size} times {sample_shape} posterior samples.",
desc=f"Sampling {num_xs} times {sample_shape} posterior samples.",
)
) # (batch_size, num_samples, dim_parameters)
# Transpose to shape convention: (sample_shape, batch_size, dim_parameters)
Expand All @@ -70,8 +86,8 @@ def sample_fun(
).permute(1, 0, 2)

assert posterior_samples.shape[:2] == sample_shape + (
batch_size,
), f"""Expected batched posterior samples of shape {
sample_shape + (batch_size,)
} got {posterior_samples.shape[:2]}."""
num_xs,
), f"""Expected batched posterior samples of shape {sample_shape + (num_xs,)} got {
posterior_samples.shape[:2]
}."""
return posterior_samples
Loading