Skip to content

Commit

Permalink
importance sampling posterior (#1183)
Browse files Browse the repository at this point in the history
* importance sampling posterior

* Tests and improved importance_posterior

* pyright fix as suggested by Guy
  • Loading branch information
manuelgloeckler authored Jun 21, 2024
1 parent 07e3995 commit 6f61662
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 7 deletions.
17 changes: 13 additions & 4 deletions sbi/inference/posteriors/importance_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def sample(
self,
sample_shape: Shape = torch.Size(),
x: Optional[Tensor] = None,
method: Optional[str] = None,
oversampling_factor: int = 32,
max_sampling_batch_size: int = 10_000,
sample_with: Optional[str] = None,
Expand All @@ -164,14 +165,22 @@ def sample(
"""Return samples from the approximate posterior distribution.
Args:
sample_shape: _description_
x: _description_
sample_shape: Shape of samples that are drawn from posterior.
x: Observed data.
method: Either of [`sir`|`importance`]. This sets the behavior of the
`.sample()` method. With `sir`, approximate posterior samples are
generated with sampling importance resampling (SIR). With
`importance`, the `.sample()` method returns a tuple of samples and
corresponding importance weights.
oversampling_factor: Number of proposed samples from which only one is
selected based on its importance weight.
max_sampling_batch_size: The batch size of samples being drawn from the
proposal at every iteration.
show_progress_bars: Whether to show a progressbar during sampling.
"""

method = self.method if method is None else method

if sample_with is not None:
raise ValueError(
f"You set `sample_with={sample_with}`. As of sbi v0.18.0, setting "
Expand All @@ -181,14 +190,14 @@ def sample(

self.potential_fn.set_x(self._x_else_default_x(x))

if self.method == "sir":
if method == "sir":
return self._sir_sample(
sample_shape,
oversampling_factor=oversampling_factor,
max_sampling_batch_size=max_sampling_batch_size,
show_progress_bars=show_progress_bars,
)
elif self.method == "importance":
elif method == "importance":
return self._importance_sample(sample_shape)
else:
raise NameError
Expand Down
13 changes: 12 additions & 1 deletion sbi/inference/snle/snle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from sbi.inference.base import NeuralInference
from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior
from sbi.inference.posteriors.importance_posterior import ImportanceSamplingPosterior
from sbi.inference.potentials import likelihood_estimator_based_potential
from sbi.neural_nets import ConditionalDensityEstimator, likelihood_nn
from sbi.neural_nets.density_estimators.shape_handling import (
Expand Down Expand Up @@ -270,7 +271,10 @@ def build_posterior(
mcmc_parameters: Optional[Dict[str, Any]] = None,
vi_parameters: Optional[Dict[str, Any]] = None,
rejection_sampling_parameters: Optional[Dict[str, Any]] = None,
) -> Union[MCMCPosterior, RejectionPosterior, VIPosterior]:
importance_sampling_parameters: Optional[Dict[str, Any]] = None,
) -> Union[
MCMCPosterior, RejectionPosterior, VIPosterior, ImportanceSamplingPosterior
]:
r"""Build posterior from the neural density estimator.
SNLE trains a neural network to approximate the likelihood $p(x|\theta)$. The
Expand Down Expand Up @@ -350,6 +354,13 @@ def build_posterior(
device=device,
**vi_parameters or {},
)
elif sample_with == "importance":
self._posterior = ImportanceSamplingPosterior(
potential_fn=potential_fn,
proposal=prior,
device=device,
**importance_sampling_parameters or {},
)
else:
raise NotImplementedError

Expand Down
17 changes: 16 additions & 1 deletion sbi/inference/snpe/snpe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
VIPosterior,
)
from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.posteriors.importance_posterior import ImportanceSamplingPosterior
from sbi.inference.potentials import posterior_estimator_based_potential
from sbi.neural_nets import ConditionalDensityEstimator, posterior_nn
from sbi.neural_nets.density_estimators.shape_handling import (
Expand Down Expand Up @@ -441,7 +442,14 @@ def build_posterior(
mcmc_parameters: Optional[Dict[str, Any]] = None,
vi_parameters: Optional[Dict[str, Any]] = None,
rejection_sampling_parameters: Optional[Dict[str, Any]] = None,
) -> Union[MCMCPosterior, RejectionPosterior, VIPosterior, DirectPosterior]:
importance_sampling_parameters: Optional[Dict[str, Any]] = None,
) -> Union[
MCMCPosterior,
RejectionPosterior,
VIPosterior,
DirectPosterior,
ImportanceSamplingPosterior,
]:
r"""Build posterior from the neural density estimator.
For SNPE, the posterior distribution that is returned here implements the
Expand Down Expand Up @@ -540,6 +548,13 @@ def build_posterior(
device=device,
**vi_parameters or {},
)
elif sample_with == "importance":
self._posterior = ImportanceSamplingPosterior(
potential_fn=potential_fn,
proposal=prior,
device=device,
**importance_sampling_parameters or {},
)
else:
raise NotImplementedError

Expand Down
13 changes: 12 additions & 1 deletion sbi/inference/snre/snre_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from sbi.inference.base import NeuralInference
from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior
from sbi.inference.posteriors.importance_posterior import ImportanceSamplingPosterior
from sbi.inference.potentials import ratio_estimator_based_potential
from sbi.neural_nets import classifier_nn
from sbi.utils import (
Expand Down Expand Up @@ -322,7 +323,10 @@ def build_posterior(
mcmc_parameters: Optional[Dict[str, Any]] = None,
vi_parameters: Optional[Dict[str, Any]] = None,
rejection_sampling_parameters: Optional[Dict[str, Any]] = None,
) -> Union[MCMCPosterior, RejectionPosterior, VIPosterior]:
importance_sampling_parameters: Optional[Dict[str, Any]] = None,
) -> Union[
MCMCPosterior, RejectionPosterior, VIPosterior, ImportanceSamplingPosterior
]:
r"""Build posterior from the neural density estimator.
SNRE trains a neural network to approximate likelihood ratios. The
Expand Down Expand Up @@ -405,6 +409,13 @@ def build_posterior(
device=device,
**vi_parameters or {},
)
elif sample_with == "importance":
self._posterior = ImportanceSamplingPosterior(
potential_fn=potential_fn,
proposal=prior,
device=device,
**importance_sampling_parameters or {},
)
else:
raise NotImplementedError

Expand Down
27 changes: 27 additions & 0 deletions tests/posterior_nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,33 @@ def test_log_prob_with_different_x(snpe_method: type, x_o_batch_dim: bool):
_ = posterior.log_prob(samples)


@pytest.mark.parametrize(
"snplre_method", [SNPE_A, SNPE_C, SNLE_A, SNRE_A, SNRE_B, SNRE_C]
)
def test_importance_posterior_sample_log_prob(snplre_method: type):
num_dim = 2

prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))
simulator = diagonal_linear_gaussian

inference = snplre_method(prior=prior)
theta, x = simulate_for_sbi(simulator, prior, 1000)
_ = inference.append_simulations(theta, x).train(max_num_epochs=3)

posterior = inference.build_posterior(sample_with="importance")

x_o = ones(num_dim)
samples = posterior.sample((10,), x=x_o)
samples2, weights = posterior.sample((10,), x=x_o, method="importance")
assert samples.shape == (10, num_dim), "Sample shape of sample is wrong"
assert samples2.shape == (10, num_dim), "Sample of sample_with_weights shape wrong"
assert weights.shape == (10,), "Weights shape wrong"

log_prob = posterior.log_prob(samples, x=x_o)

assert log_prob.shape == (10,), "logprob shape wrong"


@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C])
@pytest.mark.parametrize("x_o_batch_dim", (0, 1, 2))
def test_batched_sample_log_prob_with_different_x(
Expand Down

0 comments on commit 6f61662

Please sign in to comment.