Skip to content

Commit 54f9158

Browse files
committed
refactor: add batched option to utils, refactor sbc.
1 parent ca2db69 commit 54f9158

File tree

4 files changed

+58
-43
lines changed

4 files changed

+58
-43
lines changed

sbi/diagnostics/sbc.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def run_sbc(
2525
reduce_fns: Union[str, Callable, List[Callable]] = "marginals",
2626
num_workers: int = 1,
2727
show_progress_bar: bool = True,
28+
use_batched_sampling: bool = True,
2829
**kwargs,
2930
):
3031
"""Run simulation-based calibration (SBC) (parallelized across sbc runs).
@@ -47,6 +48,8 @@ def run_sbc(
4748
num_workers: number of CPU cores to use in parallel for running
4849
`num_sbc_samples` inferences.
4950
show_progress_bar: whether to display a progress over sbc runs.
51+
use_batched_sampling: whether to use batched sampling for posterior
52+
samples.
5053
5154
Returns:
5255
ranks: ranks of the ground truth parameters under the inferred
@@ -81,13 +84,16 @@ def run_sbc(
8184

8285
# Get posterior samples, batched or parallelized.
8386
posterior_samples = get_posterior_samples_on_batch(
84-
xs, posterior, num_posterior_samples, num_workers, show_progress_bar
87+
xs,
88+
posterior,
89+
(num_posterior_samples,),
90+
num_workers,
91+
show_progress_bar,
92+
use_batched_sampling=use_batched_sampling,
8593
)
86-
# for calibration methods its handy to have len(xs) in first dim.
87-
posterior_samples = posterior_samples.transpose(0, 1)
8894

8995
# take a random draw from each posterior to get data averaged posterior samples.
90-
dap_samples = posterior_samples[:, 0, :]
96+
dap_samples = posterior_samples[0, :, :]
9197
assert dap_samples.shape == (num_sbc_samples, thetas.shape[1]), "Wrong dap shape."
9298

9399
ranks = _run_sbc(
@@ -126,8 +132,8 @@ def _run_sbc(
126132

127133
ranks = torch.zeros((num_sbc_samples, len(reduce_fns)))
128134
# Iterate over all sbc samples and calculate ranks.
129-
for sbc_idx, (ths, theta_i, x_i) in tqdm(
130-
enumerate(zip(posterior_samples, thetas, xs)),
135+
for sbc_idx, (true_theta, x_i) in tqdm(
136+
enumerate(zip(thetas, xs)),
131137
total=num_sbc_samples,
132138
disable=not show_progress_bar,
133139
desc=f"Calculating ranks for {num_sbc_samples} sbc samples.",
@@ -139,8 +145,12 @@ def _run_sbc(
139145

140146
# For each reduce_fn (e.g., per marginal for SBC)
141147
for dim_idx, reduce_fn in enumerate(reduce_fns):
148+
# rank posterior samples against true parameter, reduced to 1D.
142149
ranks[sbc_idx, dim_idx] = (
143-
(reduce_fn(ths, x_i) < reduce_fn(theta_i.unsqueeze(0), x_i))
150+
(
151+
reduce_fn(posterior_samples[:, sbc_idx, :], x_i)
152+
< reduce_fn(true_theta.unsqueeze(0), x_i)
153+
)
144154
.sum()
145155
.item()
146156
)

sbi/diagnostics/tarp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def run_tarp(
6464
posterior_samples = get_posterior_samples_on_batch(
6565
xs,
6666
posterior,
67-
num_posterior_samples,
67+
(num_posterior_samples,),
6868
num_workers,
6969
show_progress_bar=show_progress_bar,
7070
)

sbi/utils/diagnostics_utils.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,26 @@
55

66
from sbi.inference.posteriors.base_posterior import NeuralPosterior
77
from sbi.inference.posteriors.vi_posterior import VIPosterior
8+
from sbi.sbi_types import Shape
89

910

1011
def get_posterior_samples_on_batch(
1112
xs: Tensor,
1213
posterior: NeuralPosterior,
13-
num_samples: int,
14+
sample_shape: Shape,
1415
num_workers: int = 1,
1516
show_progress_bar: bool = False,
17+
use_batched_sampling: bool = True,
1618
) -> Tensor:
1719
"""Get posterior samples for a batch of xs.
1820
1921
Args:
2022
xs: batch of observations.
2123
posterior: sbi posterior.
22-
num_posterior_samples: number of samples to draw from the posterior in each sbc
23-
run.
24+
num_samples: number of samples to draw from the posterior for each x.
2425
num_workers: number of workers to use for parallelization.
2526
show_progress_bars: whether to show progress bars.
27+
use_batched_sampling: whether to use batched sampling if possible.
2628
2729
Returns:
2830
posterior_samples: of shape (num_samples, batch_size, dim_parameters).
@@ -32,35 +34,44 @@ def get_posterior_samples_on_batch(
3234
# Try using batched sampling when implemented.
3335
try:
3436
# has shape (num_samples, batch_size, dim_parameters)
35-
posterior_samples = posterior.sample_batched(
36-
(num_samples,), xs, show_progress_bars=show_progress_bar
37-
)
37+
if use_batched_sampling:
38+
posterior_samples = posterior.sample_batched(
39+
sample_shape, x=xs, show_progress_bars=show_progress_bar
40+
)
41+
else:
42+
raise NotImplementedError
3843
except NotImplementedError:
3944
# We need a function with extra training step for new x for VIPosterior.
40-
def sample_fun(posterior: NeuralPosterior, sample_shape, x: Tensor) -> Tensor:
45+
def sample_fun(
46+
posterior: NeuralPosterior, sample_shape: Shape, x: Tensor, seed: int = 0
47+
) -> Tensor:
4148
if isinstance(posterior, VIPosterior):
4249
posterior.set_default_x(x)
4350
posterior.train()
51+
torch.manual_seed(seed)
4452
return posterior.sample(sample_shape, x=x, show_progress_bars=False)
4553

4654
# Run in parallel with progress bar.
55+
seeds = torch.randint(0, 2**32, (batch_size,))
4756
outputs = list(
4857
tqdm(
4958
Parallel(return_as="generator", n_jobs=num_workers)(
50-
delayed(sample_fun)(posterior, (num_samples,), x=x) for x in xs
59+
delayed(sample_fun)(posterior, sample_shape, x=x, seed=s)
60+
for x, s in zip(xs, seeds)
5161
),
5262
disable=not show_progress_bar,
5363
total=len(xs),
54-
desc=f"Sampling {batch_size} times {num_samples} posterior samples.",
64+
desc=f"Sampling {batch_size} times {sample_shape} posterior samples.",
5565
)
56-
)
57-
# Transpose to sample_batched shape convention:
58-
posterior_samples = torch.stack(outputs).transpose(0, 1) # type: ignore
66+
) # (batch_size, num_samples, dim_parameters)
67+
# Transpose to shape convention: (sample_shape, batch_size, dim_parameters)
68+
posterior_samples = torch.stack(
69+
outputs # type: ignore
70+
).permute(1, 0, 2)
5971

60-
assert posterior_samples.shape[:2] == (
61-
num_samples,
72+
assert posterior_samples.shape[:2] == sample_shape + (
6273
batch_size,
63-
), f"""Expected batched posterior samples of shape {(num_samples, batch_size)} got {
64-
posterior_samples.shape[:2]
65-
}."""
74+
), f"""Expected batched posterior samples of shape {
75+
sample_shape + (batch_size,)
76+
} got {posterior_samples.shape[:2]}."""
6677
return posterior_samples

tests/sbc_test.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
from sbi.analysis import sbc_rank_plot
1414
from sbi.diagnostics import check_sbc, get_nltp, run_sbc
1515
from sbi.inference import SNLE, SNPE, simulate_for_sbi
16-
from sbi.simulators import linear_gaussian
16+
from sbi.simulators.linear_gaussian import (
17+
linear_gaussian,
18+
)
1719
from sbi.utils import BoxUniform, MultipleIndependent
1820
from sbi.utils.user_input_checks import process_prior, process_simulator
1921
from tests.test_utils import PosteriorPotential, TractablePosterior
@@ -91,29 +93,21 @@ def simulator(theta):
9193

9294

9395
@pytest.mark.slow
94-
@pytest.mark.parametrize(
95-
"density_estimator",
96-
[
97-
pytest.param(
98-
"mdn",
99-
marks=pytest.mark.xfail(
100-
reason="MDN batched sampling results in miscalibrated posteriors",
101-
strict=True,
102-
),
103-
),
104-
"maf",
105-
],
106-
)
96+
@pytest.mark.parametrize("density_estimator", ["mdn", "maf"])
10797
@pytest.mark.parametrize("cov_method", ("sbc", "coverage"))
10898
def test_consistent_sbc_results(density_estimator, cov_method):
10999
"""Test consistent SBC results on well-trained NPE."""
110100

111-
num_dim = 3
112-
prior = BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim))
101+
num_dim = 2
102+
103+
likelihood_shift = -1.0 * ones(num_dim)
104+
likelihood_cov = 0.3 * eye(num_dim)
105+
prior_mean = zeros(num_dim)
106+
prior_cov = eye(num_dim)
107+
prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)
113108

114109
def simulator(theta):
115-
# linear gaussian
116-
return theta + 1.0 + torch.randn_like(theta) * 0.1
110+
return linear_gaussian(theta, likelihood_shift, likelihood_cov)
117111

118112
num_simulations = 2000
119113
num_posterior_samples = 1000

0 commit comments

Comments
 (0)