diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py
index f382968cd..20a831899 100644
--- a/sbi/inference/potentials/likelihood_based_potential.py
+++ b/sbi/inference/potentials/likelihood_based_potential.py
@@ -1,7 +1,8 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see
-from typing import Callable, Optional, Tuple
+import warnings
+from typing import Callable, List, Optional, Tuple
import torch
from torch import Tensor
@@ -115,6 +116,38 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
)
return log_likelihood_batches + self.prior.log_prob(theta) # type: ignore
+ def condition_on(self, condition: Tensor, dims_to_sample: List[int]) -> Callable:
+ """Returns a potential conditioned on a subset of theta dimensions.
+
+ The condition is a part of theta, but is assumed to correspond to a batch of iid
+ x_o.
+
+ Args:
+ condition: The condition to fix.
+ dims_to_sample: The indices of the parameters to sample.
+
+ Returns:
+ A potential function conditioned on the condition.
+ """
+
+ def conditioned_potential(
+ theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True
+ ) -> Tensor:
+ assert (
+ len(dims_to_sample) == theta.shape[1] - condition.shape[1]
+ ), "dims_to_sample must match the number of parameters to sample."
+ theta_without_condition = theta[:, dims_to_sample]
+
+ return _log_likelihood_with_iid_condition(
+ x=x_o or self.x_o,
+ theta_without_condition=theta_without_condition,
+ condition=condition,
+ estimator=self.likelihood_estimator,
+ track_gradients=track_gradients,
+ )
+
+ return conditioned_potential
+
def _log_likelihoods_over_trials(
x: Tensor,
@@ -172,6 +205,67 @@ def _log_likelihoods_over_trials(
return log_likelihood_trial_sum
+def _log_likelihood_with_iid_condition(
+ x: Tensor,
+ theta_without_condition: Tensor,
+ condition: Tensor,
+ estimator: ConditionalDensityEstimator,
+ track_gradients: bool = False,
+) -> Tensor:
+ """Return log likelihoods summed over iid trials of `x` with a matching batch of
+ conditions.
+
+ This function is different from `_log_likelihoods_over_trials` in that it moves the
+ iid batch dimension of `x` onto the batch dimension of `theta`. This is useful when
+ the likelihood estimator is conditioned on a batch of conditions that are iid with
+ the batch of `x`. It avoid the evaluation of the likelihood for every combination of
+ `x` and `condition`. Instead, it manually constructs a batch covering all
+ combination of iid trial and theta batch and reshapes to sum over the iid
+ likelihoods.
+
+ Args:
+ x: Batch of iid data of shape `(iid_dim, *event_shape)`.
+ theta_without_condition: Batch of parameters `(batch_dim, *event_shape)`
+ condition: Batch of conditions of shape `(iid_dim, *condition_shape)`.
+ estimator: DensityEstimator.
+ track_gradients: Whether to track gradients.
+
+ Returns:
+ log_likelihood_trial_sum: log likelihood for each parameter, summed over all
+ batch entries (iid trials) in `x`.
+ """
+ assert (
+ condition.shape[0] == x.shape[0]
+ ), "Condition and iid x must have the same batch size."
+ num_trials = x.shape[0]
+ num_theta = theta_without_condition.shape[0]
+ x = reshape_to_sample_batch_event(
+ x, event_shape=x.shape[1:], leading_is_sample=True
+ )
+
+ # move the iid batch dimension onto the batch dimension of theta and repeat it there
+ x_expanded = x.reshape(1, num_trials, -1).repeat_interleave(num_theta, dim=1)
+ # for this to work we construct theta and condition to cover all combinations in the
+ # trial batch and the theta batch.
+ theta = torch.cat(
+ [
+ theta_without_condition.repeat(num_trials, 1), # repeat ABAB
+ condition.repeat_interleave(num_theta, dim=0), # repeat AABB
+ ],
+ dim=-1,
+ )
+
+ with torch.set_grad_enabled(track_gradients):
+ # Calculate likelihood in one batch. Returns (1, num_trials * theta_batch_size)
+ log_likelihood_trial_batch = estimator.log_prob(x_expanded, condition=theta)
+ # Reshape to (x-trials x parameters), sum over trial-log likelihoods.
+ log_likelihood_trial_sum = log_likelihood_trial_batch.reshape(
+ num_trials, num_theta
+ ).sum(0)
+
+ return log_likelihood_trial_sum
+
+
def mixed_likelihood_estimator_based_potential(
likelihood_estimator: MixedDensityEstimator,
prior: Distribution,
@@ -192,6 +286,13 @@ def mixed_likelihood_estimator_based_potential(
to unconstrained space.
"""
+ warnings.warn(
+ "This function is deprecated and will be removed in a future release. Use "
+ "`likelihood_estimator_based_potential` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
device = str(next(likelihood_estimator.discrete_net.parameters()).device)
potential_fn = MixedLikelihoodBasedPotential(
@@ -212,6 +313,13 @@ def __init__(
):
super().__init__(likelihood_estimator, prior, x_o, device)
+ warnings.warn(
+ "This function is deprecated and will be removed in a future release. Use "
+ "`LikelihoodBasedPotential` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
prior_log_prob = self.prior.log_prob(theta) # type: ignore
@@ -231,7 +339,6 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
with torch.set_grad_enabled(track_gradients):
# Call the specific log prob method of the mixed likelihood estimator as
# this optimizes the evaluation of the discrete data part.
- # TODO log_prob_iid
log_likelihood_trial_batch = self.likelihood_estimator.log_prob(
input=x,
condition=theta.to(self.device),
diff --git a/sbi/inference/trainers/nle/mnle.py b/sbi/inference/trainers/nle/mnle.py
index d01ce1e91..83622eaea 100644
--- a/sbi/inference/trainers/nle/mnle.py
+++ b/sbi/inference/trainers/nle/mnle.py
@@ -7,7 +7,7 @@
from torch.distributions import Distribution
from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior
-from sbi.inference.potentials import mixed_likelihood_estimator_based_potential
+from sbi.inference.potentials import likelihood_estimator_based_potential
from sbi.inference.trainers.nle.nle_base import LikelihoodEstimator
from sbi.neural_nets.estimators import MixedDensityEstimator
from sbi.sbi_types import TensorboardSummaryWriter, TorchModule
@@ -155,9 +155,7 @@ def build_posterior(
(
potential_fn,
theta_transform,
- ) = mixed_likelihood_estimator_based_potential(
- likelihood_estimator=likelihood_estimator, prior=prior, x_o=None
- )
+ ) = likelihood_estimator_based_potential(likelihood_estimator, prior, x_o=None)
if sample_with == "mcmc":
self._posterior = MCMCPosterior(
diff --git a/sbi/utils/conditional_density_utils.py b/sbi/utils/conditional_density_utils.py
index d6c73b7c9..829f5e1df 100644
--- a/sbi/utils/conditional_density_utils.py
+++ b/sbi/utils/conditional_density_utils.py
@@ -293,7 +293,7 @@ def __init__(
masked outside of prior.
"""
condition = torch.atleast_2d(condition)
- if condition.shape[0] != 1:
+ if condition.shape[0] > 1:
raise ValueError("Condition with batch size > 1 not supported.")
self.potential_fn = potential_fn
diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py
index fcb5953d9..fc01d4dbd 100644
--- a/sbi/utils/sbiutils.py
+++ b/sbi/utils/sbiutils.py
@@ -60,8 +60,8 @@ def warn_if_zscoring_changes_data(x: Tensor, duplicate_tolerance: float = 0.1) -
if num_unique_z < num_unique * (1 - duplicate_tolerance):
warnings.warn(
- "Z-scoring these simulation outputs resulted in {num_unique_z} unique "
- "datapoints. Before z-scoring, it had been {num_unique}. This can "
+ f"Z-scoring these simulation outputs resulted in {num_unique_z} unique "
+ f"datapoints. Before z-scoring, it had been {num_unique}. This can "
"occur due to numerical inaccuracies when the data covers a large "
"range of values. Consider either setting `z_score_x=False` (but "
"beware that this can be problematic for training the NN) or exclude "
diff --git a/tests/mnle_test.py b/tests/mnle_test.py
index a95a2a6ac..b242f477f 100644
--- a/tests/mnle_test.py
+++ b/tests/mnle_test.py
@@ -1,29 +1,31 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see
+from typing import Union
+
import pytest
import torch
from pyro.distributions import InverseGamma
-from torch.distributions import Beta, Binomial, Categorical, Gamma
+from torch import Tensor
+from torch.distributions import Beta, Binomial, Distribution, Gamma
from sbi.inference import MNLE, MCMCPosterior
from sbi.inference.posteriors.rejection_posterior import RejectionPosterior
from sbi.inference.posteriors.vi_posterior import VIPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.inference.potentials.likelihood_based_potential import (
- MixedLikelihoodBasedPotential,
+ likelihood_estimator_based_potential,
)
from sbi.neural_nets import likelihood_nn
from sbi.neural_nets.embedding_nets import FCEmbedding
from sbi.utils import BoxUniform, mcmc_transform
-from sbi.utils.conditional_density_utils import ConditionedPotential
from sbi.utils.torchutils import atleast_2d, process_device
from sbi.utils.user_input_checks_utils import MultipleIndependent
from tests.test_utils import check_c2st
# toy simulator for mixed data
-def mixed_simulator(theta, stimulus_condition=2.0):
+def mixed_simulator(theta: Tensor, stimulus_condition: Union[Tensor, float] = 2.0):
"""Simulator for mixed data."""
# Extract parameters
beta, ps = theta[:, :1], theta[:, 1:]
@@ -190,11 +192,28 @@ def test_mnle_accuracy_with_different_samplers_and_trials(
class BinomialGammaPotential(BasePotential):
- def __init__(self, prior, x_o, concentration_scaling=1.0, device="cpu"):
+ """Binomial-Gamma potential for mixed data."""
+
+ def __init__(
+ self,
+ prior: Distribution,
+ x_o: Tensor,
+ concentration_scaling: Union[Tensor, float] = 1.0,
+ device="cpu",
+ ):
super().__init__(prior, x_o, device)
+
+ # concentration_scaling needs to be a float or match the batch size
+ if isinstance(concentration_scaling, Tensor):
+ num_trials = x_o.shape[0]
+ assert concentration_scaling.shape[0] == num_trials
+
+ # Reshape to match convention (batch_size, num_trials, *event_shape)
+ concentration_scaling = concentration_scaling.reshape(1, num_trials, -1)
+
self.concentration_scaling = concentration_scaling
- def __call__(self, theta, track_gradients: bool = True):
+ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
theta = atleast_2d(theta)
with torch.set_grad_enabled(track_gradients):
@@ -202,11 +221,12 @@ def __call__(self, theta, track_gradients: bool = True):
return iid_ll + self.prior.log_prob(theta)
- def iid_likelihood(self, theta):
+ def iid_likelihood(self, theta: Tensor) -> Tensor:
batch_size = theta.shape[0]
num_trials = self.x_o.shape[0]
theta = theta.reshape(batch_size, 1, -1)
beta, rho = theta[:, :, :1], theta[:, :, 1:]
+
# vectorized
logprob_choices = Binomial(probs=rho).log_prob(
self.x_o[:, 1:].reshape(1, num_trials, -1)
@@ -233,18 +253,22 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict):
categorical parameter is set to a fixed value (conditioned posterior), and the
accuracy of the conditioned posterior is tested against the true posterior.
"""
- num_simulations = 6000
- num_samples = 500
+ num_simulations = 10000
+ num_samples = 1000
- def sim_wrapper(theta):
+ def sim_wrapper(
+ theta_and_condition: Tensor, last_idx_parameters: int = 2
+ ) -> Tensor:
# simulate with experiment conditions
- return mixed_simulator(theta[:, :2], theta[:, 2:] + 1)
+ theta = theta_and_condition[:, :last_idx_parameters]
+ condition = theta_and_condition[:, last_idx_parameters:]
+ return mixed_simulator(theta, condition)
proposal = MultipleIndependent(
[
Gamma(torch.tensor([1.0]), torch.tensor([0.5])),
Beta(torch.tensor([2.0]), torch.tensor([2.0])),
- Categorical(probs=torch.ones(1, 3)),
+ BoxUniform(torch.tensor([0.0]), torch.tensor([1.0])),
],
validate_args=False,
)
@@ -254,22 +278,27 @@ def sim_wrapper(theta):
assert x.shape == (num_simulations, 2)
num_trials = 10
- theta_o = proposal.sample((1,))
- theta_o[0, 2] = 2.0 # set condition to 2 as in original simulator.
- x_o = sim_wrapper(theta_o.repeat(num_trials, 1))
+ theta_and_condition = proposal.sample((num_trials,))
+ # use only a single parameter (iid trials)
+ theta_o = theta_and_condition[:1, :2].repeat(num_trials, 1)
+ # but different conditions
+ condition_o = theta_and_condition[:, 2:]
+ theta_and_conditions_o = torch.cat((theta_o, condition_o), dim=1)
+
+ x_o = sim_wrapper(theta_and_conditions_o)
mcmc_kwargs = dict(
method="slice_np_vectorized", init_strategy="proposal", **mcmc_params_accurate
)
# MNLE
- trainer = MNLE(proposal)
- estimator = trainer.append_simulations(theta, x).train(training_batch_size=1000)
-
- potential_fn = MixedLikelihoodBasedPotential(estimator, proposal, x_o)
+ estimator_fun = likelihood_nn(model="mnle", z_score_x=None)
+ trainer = MNLE(proposal, estimator_fun)
+ estimator = trainer.append_simulations(theta, x).train()
- conditioned_potential_fn = ConditionedPotential(
- potential_fn, condition=theta_o, dims_to_sample=[0, 1]
+ potential_fn, _ = likelihood_estimator_based_potential(estimator, proposal, x_o)
+ conditioned_potential_fn = potential_fn.condition_on(
+ condition_o, dims_to_sample=[0, 1]
)
# True posterior samples
@@ -283,10 +312,7 @@ def sim_wrapper(theta):
prior_transform = mcmc_transform(prior)
true_posterior_samples = MCMCPosterior(
BinomialGammaPotential(
- prior,
- atleast_2d(x_o),
- concentration_scaling=float(theta_o[0, 2])
- + 1.0, # add one because the sim_wrapper adds one (see above)
+ prior, atleast_2d(x_o), concentration_scaling=condition_o
),
theta_transform=prior_transform,
proposal=prior,
@@ -303,5 +329,5 @@ def sim_wrapper(theta):
check_c2st(
cond_samples,
true_posterior_samples,
- alg=f"MNLE trained with {num_simulations}",
+ alg=f"MNLE trained with {num_simulations} simulations",
)