Skip to content

Commit

Permalink
refactor: sbc and tarp with batched sampling (#1196)
Browse files Browse the repository at this point in the history
* refactor: sbc

* refactor: tarp and tests.

* refactor: diagnostics tutorial with tarp.

* use batched sampling when possible

* refactor sbc test

* refactor tarp tests

* move diagnostic_utils to utils; refactoring

* remove seeding of reference generation

* fix renaming
  • Loading branch information
janfb authored Aug 2, 2024
1 parent 0bdb9c5 commit 2fd89a8
Show file tree
Hide file tree
Showing 7 changed files with 645 additions and 726 deletions.
1 change: 1 addition & 0 deletions sbi/diagnostics/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from sbi.diagnostics.sbc import check_sbc, get_nltp, run_sbc
from sbi.diagnostics.tarp import check_tarp, run_tarp
180 changes: 64 additions & 116 deletions sbi/diagnostics/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import warnings
from typing import Callable, Dict, List, Sequence, Tuple, Union
from typing import Callable, Dict, List, Union

import torch
from joblib import Parallel, delayed
from scipy.stats import kstest, uniform
from torch import Tensor, ones, zeros
from torch.distributions import Uniform
Expand All @@ -14,7 +13,7 @@
from sbi.inference import DirectPosterior
from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.posteriors.vi_posterior import VIPosterior
from sbi.simulators.simutils import tqdm_joblib
from sbi.utils.diagnostics_utils import get_posterior_samples_on_batch
from sbi.utils.metrics import c2st


Expand All @@ -25,33 +24,32 @@ def run_sbc(
num_posterior_samples: int = 1000,
reduce_fns: Union[str, Callable, List[Callable]] = "marginals",
num_workers: int = 1,
sbc_batch_size: int = 1,
show_progress_bar: bool = True,
) -> Tuple[Tensor, Tensor]:
**kwargs,
):
"""Run simulation-based calibration (SBC) (parallelized across sbc runs).
Returns sbc ranks, log probs of the true parameters under the posterior and samples
from the data averaged posterior, one for each sbc run, respectively.
SBC is implemented as proposed in Talts et al., "Validating Bayesian Inference
Algorithms with Simulation-Based Calibration", https://arxiv.org/abs/1804.06788.
Note: This function implements two versions of coverage diagnostics:
- setting reduce_fns = "marginals" performs SBC as proposed in Talts et
al., https://arxiv.org/abs/1804.06788.
- setting reduce_fns = posterior.log_prob performs sample-based expected
coverage as proposed in Deistler et al., https://arxiv.org/abs/2210.04815.
Args:
thetas: ground-truth parameters for sbc, simulated from the prior.
xs: observed data for sbc, simulated from thetas.
posterior: a posterior obtained from sbi.
num_posterior_samples: number of approximate posterior samples used for ranking.
posterior: a posterior obtained from sbi. num_posterior_samples: number
of approximate posterior samples used for ranking.
reduce_fns: Function used to reduce the parameter space into 1D.
Simulation-based calibration can be recovered by setting this to the string
`marginals`. Sample-based expected coverage can be recovered by setting it
to `posterior.log_prob` (as a Callable).
num_workers: number of CPU cores to use in parallel for running num_sbc_samples
inferences.
sbc_batch_size: batch size for workers.
Simulation-based calibration can be recovered by setting this to the
string `marginals`. Sample-based expected coverage can be recovered
by setting it to `posterior.log_prob` (as a Callable).
num_workers: number of CPU cores to use in parallel for running
`num_sbc_samples` inferences.
show_progress_bar: whether to display a progress over sbc runs.
Returns:
ranks: ranks of the ground truth parameters under the inferred posterior.
ranks: ranks of the ground truth parameters under the inferred
dap_samples: samples from the data averaged posterior.
"""
num_sbc_samples = thetas.shape[0]
Expand All @@ -73,93 +71,45 @@ def run_sbc(
thetas.shape[0] == xs.shape[0]
), "Unequal number of parameters and observations."

thetas_batches = torch.split(thetas, sbc_batch_size, dim=0)
xs_batches = torch.split(xs, sbc_batch_size, dim=0)

if num_workers != 1:
# Parallelize the sequence of batches across workers.
# We use the solution proposed here: https://stackoverflow.com/a/61689175
# to update the pbar only after the workers finished a task.
with tqdm_joblib(
tqdm(
thetas_batches,
disable=not show_progress_bar,
desc=f"""Running {num_sbc_samples} sbc runs in {len(thetas_batches)}
batches.""",
total=len(thetas_batches),
)
) as _:
sbc_outputs: Sequence[Tuple[Tensor, Tensor]]
sbc_outputs = Parallel(n_jobs=num_workers)( # pyright: ignore[reportAssignmentType]
delayed(sbc_on_batch)(
thetas_batch, xs_batch, posterior, num_posterior_samples
)
for thetas_batch, xs_batch in zip(thetas_batches, xs_batches)
)
else:
pbar = tqdm(
total=num_sbc_samples,
disable=not show_progress_bar,
desc=f"Running {num_sbc_samples} sbc samples.",
if "sbc_batch_size" in kwargs:
warnings.warn(
"""`sbc_batch_size` is deprecated and will be removed in future versions.
Use `num_workers` instead.""",
DeprecationWarning,
stacklevel=2,
)

with pbar:
sbc_outputs = []
for thetas_batch, xs_batch in zip(thetas_batches, xs_batches):
sbc_outputs.append(
sbc_on_batch(
thetas_batch,
xs_batch,
posterior,
num_posterior_samples,
reduce_fns,
)
)
pbar.update(sbc_batch_size)

# Aggregate results.
ranks = []
dap_samples = []
for out in sbc_outputs:
ranks.append(out[0])
dap_samples.append(out[1])

ranks = torch.cat(ranks)
dap_samples = torch.cat(dap_samples)
# Get posterior samples, batched or parallelized.
posterior_samples = get_posterior_samples_on_batch(
xs, posterior, num_posterior_samples, num_workers, show_progress_bar
)
# for calibration methods its handy to have len(xs) in first dim.
posterior_samples = posterior_samples.transpose(0, 1)

# take a random draw from each posterior to get data averaged posterior samples.
dap_samples = posterior_samples[:, 0, :]
assert dap_samples.shape == (num_sbc_samples, thetas.shape[1]), "Wrong dap shape."

ranks = _run_sbc(
thetas, xs, posterior_samples, posterior, reduce_fns, show_progress_bar
)

return ranks, dap_samples


def sbc_on_batch(
def _run_sbc(
thetas: Tensor,
xs: Tensor,
posterior_samples: Tensor,
posterior: NeuralPosterior,
num_posterior_samples: int,
reduce_fns: Union[str, Callable, List[Callable]],
) -> Tuple[Tensor, Tensor]:
"""Return SBC results for a batch of SBC parameters and data from prior.
Args:
thetas: ground truth parameters.
xs: corresponding observations.
posterior: sbi posterior.
num_posterior_samples: number of samples to draw from the posterior in each sbc
run.
reduce_fns: Function that is used to reduce the parameter space into 1D.
Simulation-based calibration can be recovered by setting this to the string
`marginals`. Sample-based expected coverage can be recovered by setting it
to `posterior.log_prob` (as a Callable).
Returns
ranks: ranks of true parameters vs. posterior samples under the specified RV,
for each posterior dimension.
log_prob_thetas: log prob of true parameters under the approximate posterior.
Note that this is interpretable only for normalized log probs, i.e., when
using (S)NPE.
dap_samples: samples from the data averaged posterior for the current batch,
i.e., a single sample from each approximate posterior.
"""
reduce_fns: Union[str, Callable, List[Callable]] = "marginals",
show_progress_bar: bool = True,
) -> Tensor:
"""Calculate ranks for SBC or expected coverage."""
num_sbc_samples = thetas.shape[0]

# construct reduce functions for SBC or expected coverage
# For SBC, we simply take the marginals for each parameter dimension.
if isinstance(reduce_fns, str):
assert reduce_fns == "marginals", (
"`reduce_fn` must either be the string `marginals` or a Callable or a List "
Expand All @@ -169,35 +119,33 @@ def sbc_on_batch(
eval(f"lambda theta, x: theta[:, {i}]") for i in range(thetas.shape[1])
]

# For a Callable (e.g., expected coverage) we put it into a list for unified
# handling below.
if isinstance(reduce_fns, Callable):
reduce_fns = [reduce_fns]

dap_samples = torch.zeros_like(thetas)
ranks = torch.zeros((thetas.shape[0], len(reduce_fns)))

for idx in range(thetas.shape[0]):
# unsqueeze for potential higher-dimensional data.
xo = xs[idx].unsqueeze(0)
# VI posterior needs to be trained on the current xo.
ranks = torch.zeros((num_sbc_samples, len(reduce_fns)))
# Iterate over all sbc samples and calculate ranks.
for sbc_idx, (ths, theta_i, x_i) in tqdm(
enumerate(zip(posterior_samples, thetas, xs)),
total=num_sbc_samples,
disable=not show_progress_bar,
desc=f"Calculating ranks for {num_sbc_samples} sbc samples.",
):
# For VIPosteriors, we need to train on each x.
if isinstance(posterior, VIPosterior):
posterior.set_default_x(xo)
posterior.train()

# Draw posterior samples and save one for the data average posterior.
ths = posterior.sample((num_posterior_samples,), x=xo, show_progress_bars=False)
posterior.set_default_x(x_i)
posterior.train(show_progress_bar=False)

# Save one random sample for data average posterior (dap).
dap_samples[idx] = ths[0]

# rank for each posterior dimension as in Talts et al. section 4.1.
for i, reduce_fn in enumerate(reduce_fns):
ranks[idx, i] = (
(reduce_fn(ths, xo) < reduce_fn(thetas[idx].unsqueeze(0), xo))
# For each reduce_fn (e.g., per marginal for SBC)
for dim_idx, reduce_fn in enumerate(reduce_fns):
ranks[sbc_idx, dim_idx] = (
(reduce_fn(ths, x_i) < reduce_fn(theta_i.unsqueeze(0), x_i))
.sum()
.item()
)

return ranks, dap_samples
return ranks


def get_nltp(thetas: Tensor, xs: Tensor, posterior: NeuralPosterior) -> Tensor:
Expand Down
Loading

0 comments on commit 2fd89a8

Please sign in to comment.