From b2fe6363b9820db0e7d17dfb2216110fda3c2d94 Mon Sep 17 00:00:00 2001 From: Jan Date: Thu, 12 Dec 2024 15:14:20 +0100 Subject: [PATCH] add method for iid-batched conditioning. - deprecate MNLE-based potential (can be nle-based) - adapt tests for conditioned mnle. --- .../potentials/likelihood_based_potential.py | 111 +++++++++++++++++- sbi/inference/trainers/nle/mnle.py | 6 +- sbi/utils/conditional_density_utils.py | 2 +- sbi/utils/sbiutils.py | 4 +- tests/mnle_test.py | 78 ++++++++---- 5 files changed, 166 insertions(+), 35 deletions(-) diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index f382968cd..20a831899 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -1,7 +1,8 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see -from typing import Callable, Optional, Tuple +import warnings +from typing import Callable, List, Optional, Tuple import torch from torch import Tensor @@ -115,6 +116,38 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: ) return log_likelihood_batches + self.prior.log_prob(theta) # type: ignore + def condition_on(self, condition: Tensor, dims_to_sample: List[int]) -> Callable: + """Returns a potential conditioned on a subset of theta dimensions. + + The condition is a part of theta, but is assumed to correspond to a batch of iid + x_o. + + Args: + condition: The condition to fix. + dims_to_sample: The indices of the parameters to sample. + + Returns: + A potential function conditioned on the condition. + """ + + def conditioned_potential( + theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True + ) -> Tensor: + assert ( + len(dims_to_sample) == theta.shape[1] - condition.shape[1] + ), "dims_to_sample must match the number of parameters to sample." + theta_without_condition = theta[:, dims_to_sample] + + return _log_likelihood_with_iid_condition( + x=x_o or self.x_o, + theta_without_condition=theta_without_condition, + condition=condition, + estimator=self.likelihood_estimator, + track_gradients=track_gradients, + ) + + return conditioned_potential + def _log_likelihoods_over_trials( x: Tensor, @@ -172,6 +205,67 @@ def _log_likelihoods_over_trials( return log_likelihood_trial_sum +def _log_likelihood_with_iid_condition( + x: Tensor, + theta_without_condition: Tensor, + condition: Tensor, + estimator: ConditionalDensityEstimator, + track_gradients: bool = False, +) -> Tensor: + """Return log likelihoods summed over iid trials of `x` with a matching batch of + conditions. + + This function is different from `_log_likelihoods_over_trials` in that it moves the + iid batch dimension of `x` onto the batch dimension of `theta`. This is useful when + the likelihood estimator is conditioned on a batch of conditions that are iid with + the batch of `x`. It avoid the evaluation of the likelihood for every combination of + `x` and `condition`. Instead, it manually constructs a batch covering all + combination of iid trial and theta batch and reshapes to sum over the iid + likelihoods. + + Args: + x: Batch of iid data of shape `(iid_dim, *event_shape)`. + theta_without_condition: Batch of parameters `(batch_dim, *event_shape)` + condition: Batch of conditions of shape `(iid_dim, *condition_shape)`. + estimator: DensityEstimator. + track_gradients: Whether to track gradients. + + Returns: + log_likelihood_trial_sum: log likelihood for each parameter, summed over all + batch entries (iid trials) in `x`. + """ + assert ( + condition.shape[0] == x.shape[0] + ), "Condition and iid x must have the same batch size." + num_trials = x.shape[0] + num_theta = theta_without_condition.shape[0] + x = reshape_to_sample_batch_event( + x, event_shape=x.shape[1:], leading_is_sample=True + ) + + # move the iid batch dimension onto the batch dimension of theta and repeat it there + x_expanded = x.reshape(1, num_trials, -1).repeat_interleave(num_theta, dim=1) + # for this to work we construct theta and condition to cover all combinations in the + # trial batch and the theta batch. + theta = torch.cat( + [ + theta_without_condition.repeat(num_trials, 1), # repeat ABAB + condition.repeat_interleave(num_theta, dim=0), # repeat AABB + ], + dim=-1, + ) + + with torch.set_grad_enabled(track_gradients): + # Calculate likelihood in one batch. Returns (1, num_trials * theta_batch_size) + log_likelihood_trial_batch = estimator.log_prob(x_expanded, condition=theta) + # Reshape to (x-trials x parameters), sum over trial-log likelihoods. + log_likelihood_trial_sum = log_likelihood_trial_batch.reshape( + num_trials, num_theta + ).sum(0) + + return log_likelihood_trial_sum + + def mixed_likelihood_estimator_based_potential( likelihood_estimator: MixedDensityEstimator, prior: Distribution, @@ -192,6 +286,13 @@ def mixed_likelihood_estimator_based_potential( to unconstrained space. """ + warnings.warn( + "This function is deprecated and will be removed in a future release. Use " + "`likelihood_estimator_based_potential` instead.", + DeprecationWarning, + stacklevel=2, + ) + device = str(next(likelihood_estimator.discrete_net.parameters()).device) potential_fn = MixedLikelihoodBasedPotential( @@ -212,6 +313,13 @@ def __init__( ): super().__init__(likelihood_estimator, prior, x_o, device) + warnings.warn( + "This function is deprecated and will be removed in a future release. Use " + "`LikelihoodBasedPotential` instead.", + DeprecationWarning, + stacklevel=2, + ) + def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: prior_log_prob = self.prior.log_prob(theta) # type: ignore @@ -231,7 +339,6 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: with torch.set_grad_enabled(track_gradients): # Call the specific log prob method of the mixed likelihood estimator as # this optimizes the evaluation of the discrete data part. - # TODO log_prob_iid log_likelihood_trial_batch = self.likelihood_estimator.log_prob( input=x, condition=theta.to(self.device), diff --git a/sbi/inference/trainers/nle/mnle.py b/sbi/inference/trainers/nle/mnle.py index d01ce1e91..83622eaea 100644 --- a/sbi/inference/trainers/nle/mnle.py +++ b/sbi/inference/trainers/nle/mnle.py @@ -7,7 +7,7 @@ from torch.distributions import Distribution from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior -from sbi.inference.potentials import mixed_likelihood_estimator_based_potential +from sbi.inference.potentials import likelihood_estimator_based_potential from sbi.inference.trainers.nle.nle_base import LikelihoodEstimator from sbi.neural_nets.estimators import MixedDensityEstimator from sbi.sbi_types import TensorboardSummaryWriter, TorchModule @@ -155,9 +155,7 @@ def build_posterior( ( potential_fn, theta_transform, - ) = mixed_likelihood_estimator_based_potential( - likelihood_estimator=likelihood_estimator, prior=prior, x_o=None - ) + ) = likelihood_estimator_based_potential(likelihood_estimator, prior, x_o=None) if sample_with == "mcmc": self._posterior = MCMCPosterior( diff --git a/sbi/utils/conditional_density_utils.py b/sbi/utils/conditional_density_utils.py index d6c73b7c9..829f5e1df 100644 --- a/sbi/utils/conditional_density_utils.py +++ b/sbi/utils/conditional_density_utils.py @@ -293,7 +293,7 @@ def __init__( masked outside of prior. """ condition = torch.atleast_2d(condition) - if condition.shape[0] != 1: + if condition.shape[0] > 1: raise ValueError("Condition with batch size > 1 not supported.") self.potential_fn = potential_fn diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index fcb5953d9..fc01d4dbd 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -60,8 +60,8 @@ def warn_if_zscoring_changes_data(x: Tensor, duplicate_tolerance: float = 0.1) - if num_unique_z < num_unique * (1 - duplicate_tolerance): warnings.warn( - "Z-scoring these simulation outputs resulted in {num_unique_z} unique " - "datapoints. Before z-scoring, it had been {num_unique}. This can " + f"Z-scoring these simulation outputs resulted in {num_unique_z} unique " + f"datapoints. Before z-scoring, it had been {num_unique}. This can " "occur due to numerical inaccuracies when the data covers a large " "range of values. Consider either setting `z_score_x=False` (but " "beware that this can be problematic for training the NN) or exclude " diff --git a/tests/mnle_test.py b/tests/mnle_test.py index a95a2a6ac..b242f477f 100644 --- a/tests/mnle_test.py +++ b/tests/mnle_test.py @@ -1,29 +1,31 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see +from typing import Union + import pytest import torch from pyro.distributions import InverseGamma -from torch.distributions import Beta, Binomial, Categorical, Gamma +from torch import Tensor +from torch.distributions import Beta, Binomial, Distribution, Gamma from sbi.inference import MNLE, MCMCPosterior from sbi.inference.posteriors.rejection_posterior import RejectionPosterior from sbi.inference.posteriors.vi_posterior import VIPosterior from sbi.inference.potentials.base_potential import BasePotential from sbi.inference.potentials.likelihood_based_potential import ( - MixedLikelihoodBasedPotential, + likelihood_estimator_based_potential, ) from sbi.neural_nets import likelihood_nn from sbi.neural_nets.embedding_nets import FCEmbedding from sbi.utils import BoxUniform, mcmc_transform -from sbi.utils.conditional_density_utils import ConditionedPotential from sbi.utils.torchutils import atleast_2d, process_device from sbi.utils.user_input_checks_utils import MultipleIndependent from tests.test_utils import check_c2st # toy simulator for mixed data -def mixed_simulator(theta, stimulus_condition=2.0): +def mixed_simulator(theta: Tensor, stimulus_condition: Union[Tensor, float] = 2.0): """Simulator for mixed data.""" # Extract parameters beta, ps = theta[:, :1], theta[:, 1:] @@ -190,11 +192,28 @@ def test_mnle_accuracy_with_different_samplers_and_trials( class BinomialGammaPotential(BasePotential): - def __init__(self, prior, x_o, concentration_scaling=1.0, device="cpu"): + """Binomial-Gamma potential for mixed data.""" + + def __init__( + self, + prior: Distribution, + x_o: Tensor, + concentration_scaling: Union[Tensor, float] = 1.0, + device="cpu", + ): super().__init__(prior, x_o, device) + + # concentration_scaling needs to be a float or match the batch size + if isinstance(concentration_scaling, Tensor): + num_trials = x_o.shape[0] + assert concentration_scaling.shape[0] == num_trials + + # Reshape to match convention (batch_size, num_trials, *event_shape) + concentration_scaling = concentration_scaling.reshape(1, num_trials, -1) + self.concentration_scaling = concentration_scaling - def __call__(self, theta, track_gradients: bool = True): + def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: theta = atleast_2d(theta) with torch.set_grad_enabled(track_gradients): @@ -202,11 +221,12 @@ def __call__(self, theta, track_gradients: bool = True): return iid_ll + self.prior.log_prob(theta) - def iid_likelihood(self, theta): + def iid_likelihood(self, theta: Tensor) -> Tensor: batch_size = theta.shape[0] num_trials = self.x_o.shape[0] theta = theta.reshape(batch_size, 1, -1) beta, rho = theta[:, :, :1], theta[:, :, 1:] + # vectorized logprob_choices = Binomial(probs=rho).log_prob( self.x_o[:, 1:].reshape(1, num_trials, -1) @@ -233,18 +253,22 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict): categorical parameter is set to a fixed value (conditioned posterior), and the accuracy of the conditioned posterior is tested against the true posterior. """ - num_simulations = 6000 - num_samples = 500 + num_simulations = 10000 + num_samples = 1000 - def sim_wrapper(theta): + def sim_wrapper( + theta_and_condition: Tensor, last_idx_parameters: int = 2 + ) -> Tensor: # simulate with experiment conditions - return mixed_simulator(theta[:, :2], theta[:, 2:] + 1) + theta = theta_and_condition[:, :last_idx_parameters] + condition = theta_and_condition[:, last_idx_parameters:] + return mixed_simulator(theta, condition) proposal = MultipleIndependent( [ Gamma(torch.tensor([1.0]), torch.tensor([0.5])), Beta(torch.tensor([2.0]), torch.tensor([2.0])), - Categorical(probs=torch.ones(1, 3)), + BoxUniform(torch.tensor([0.0]), torch.tensor([1.0])), ], validate_args=False, ) @@ -254,22 +278,27 @@ def sim_wrapper(theta): assert x.shape == (num_simulations, 2) num_trials = 10 - theta_o = proposal.sample((1,)) - theta_o[0, 2] = 2.0 # set condition to 2 as in original simulator. - x_o = sim_wrapper(theta_o.repeat(num_trials, 1)) + theta_and_condition = proposal.sample((num_trials,)) + # use only a single parameter (iid trials) + theta_o = theta_and_condition[:1, :2].repeat(num_trials, 1) + # but different conditions + condition_o = theta_and_condition[:, 2:] + theta_and_conditions_o = torch.cat((theta_o, condition_o), dim=1) + + x_o = sim_wrapper(theta_and_conditions_o) mcmc_kwargs = dict( method="slice_np_vectorized", init_strategy="proposal", **mcmc_params_accurate ) # MNLE - trainer = MNLE(proposal) - estimator = trainer.append_simulations(theta, x).train(training_batch_size=1000) - - potential_fn = MixedLikelihoodBasedPotential(estimator, proposal, x_o) + estimator_fun = likelihood_nn(model="mnle", z_score_x=None) + trainer = MNLE(proposal, estimator_fun) + estimator = trainer.append_simulations(theta, x).train() - conditioned_potential_fn = ConditionedPotential( - potential_fn, condition=theta_o, dims_to_sample=[0, 1] + potential_fn, _ = likelihood_estimator_based_potential(estimator, proposal, x_o) + conditioned_potential_fn = potential_fn.condition_on( + condition_o, dims_to_sample=[0, 1] ) # True posterior samples @@ -283,10 +312,7 @@ def sim_wrapper(theta): prior_transform = mcmc_transform(prior) true_posterior_samples = MCMCPosterior( BinomialGammaPotential( - prior, - atleast_2d(x_o), - concentration_scaling=float(theta_o[0, 2]) - + 1.0, # add one because the sim_wrapper adds one (see above) + prior, atleast_2d(x_o), concentration_scaling=condition_o ), theta_transform=prior_transform, proposal=prior, @@ -303,5 +329,5 @@ def sim_wrapper(theta): check_c2st( cond_samples, true_posterior_samples, - alg=f"MNLE trained with {num_simulations}", + alg=f"MNLE trained with {num_simulations} simulations", )