Skip to content

Commit

Permalink
undo batch_size option
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Dec 10, 2024
1 parent 6dd4a22 commit 4c69fa4
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 25 deletions.
6 changes: 1 addition & 5 deletions sbi/diagnostics/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import warnings
from typing import Callable, Dict, List, Optional, Union
from typing import Callable, Dict, List, Union

import torch
from scipy.stats import kstest, uniform
Expand All @@ -26,7 +26,6 @@ def run_sbc(
num_workers: int = 1,
show_progress_bar: bool = True,
use_batched_sampling: bool = True,
batch_size: Optional[int] = None,
**kwargs,
):
"""Run simulation-based calibration (SBC) (parallelized across sbc runs).
Expand All @@ -50,8 +49,6 @@ def run_sbc(
`num_sbc_samples` inferences.
show_progress_bar: whether to display a progress over sbc runs.
use_batched_sampling: whether to use batched sampling for posterior samples.
batch_size: batch size for batched sampling. Useful for batched sampling with
large batches of xs for avoiding memory overflow.
Returns:
ranks: ranks of the ground truth parameters under the inferred
Expand Down Expand Up @@ -92,7 +89,6 @@ def run_sbc(
num_workers,
show_progress_bar,
use_batched_sampling=use_batched_sampling,
batch_size=batch_size,
)

# take a random draw from each posterior to get data averaged posterior samples.
Expand Down
4 changes: 0 additions & 4 deletions sbi/diagnostics/tarp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def run_tarp(
num_bins: Optional[int] = 30,
z_score_theta: bool = True,
use_batched_sampling: bool = True,
batch_size: Optional[int] = None,
) -> Tuple[Tensor, Tensor]:
"""
Estimates coverage of samples given true values thetas with the TARP method.
Expand Down Expand Up @@ -57,8 +56,6 @@ def run_tarp(
If ``None``, then ``num_sims // 10`` bins are used.
z_score_theta : whether to normalize parameters before coverage test.
use_batched_sampling: whether to use batched sampling for posterior samples.
batch_size: batch size for batched sampling. Useful for batched sampling with
large batches of xs for avoiding memory overflow.
Returns:
ecp: Expected coverage probability (``ecp``), see equation 4 of the paper
Expand All @@ -73,7 +70,6 @@ def run_tarp(
num_workers,
show_progress_bar=show_progress_bar,
use_batched_sampling=use_batched_sampling,
batch_size=batch_size,
)
assert posterior_samples.shape == (
num_posterior_samples,
Expand Down
20 changes: 4 additions & 16 deletions sbi/utils/diagnostics_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import warnings
from typing import Optional

import torch
from joblib import Parallel, delayed
Expand All @@ -19,7 +18,6 @@ def get_posterior_samples_on_batch(
num_workers: int = 1,
show_progress_bar: bool = False,
use_batched_sampling: bool = True,
batch_size: Optional[int] = None,
) -> Tensor:
"""Get posterior samples for a batch of xs.
Expand All @@ -30,27 +28,17 @@ def get_posterior_samples_on_batch(
num_workers: number of workers to use for parallelization.
show_progress_bars: whether to show progress bars.
use_batched_sampling: whether to use batched sampling if possible.
batch_size: batch size for batched sampling. Useful for batched sampling with
large batches of xs for avoiding memory overflow.
Returns:
posterior_samples: of shape (num_samples, batch_size, dim_parameters).
"""
num_xs = len(xs)
if batch_size is None:
batch_size = num_xs

if use_batched_sampling:
try:
# distribute the batch of xs into smaller batches
batched_xs = xs.split(batch_size)
posterior_samples = torch.cat(
[ # has shape (num_samples, num_xs, dim_parameters)
posterior.sample_batched(
sample_shape, x=xs_batch, show_progress_bars=show_progress_bar
)
for xs_batch in batched_xs
],
dim=1,
# has shape (num_samples, num_xs, dim_parameters)
posterior_samples = posterior.sample_batched(
sample_shape, x=xs, show_progress_bars=show_progress_bar
)
except (NotImplementedError, AssertionError):
warnings.warn(
Expand Down

0 comments on commit 4c69fa4

Please sign in to comment.