Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NLE with multiple iid conditions #1331

Merged
merged 6 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 109 additions & 2 deletions sbi/inference/potentials/likelihood_based_potential.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Callable, Optional, Tuple
import warnings
from typing import Callable, List, Optional, Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -115,6 +116,38 @@
)
return log_likelihood_batches + self.prior.log_prob(theta) # type: ignore

def condition_on(self, condition: Tensor, dims_to_sample: List[int]) -> Callable:
"""Returns a potential conditioned on a subset of theta dimensions.

The condition is a part of theta, but is assumed to correspond to a batch of iid
janfb marked this conversation as resolved.
Show resolved Hide resolved
x_o.
janfb marked this conversation as resolved.
Show resolved Hide resolved

Args:
condition: The condition to fix.
janfb marked this conversation as resolved.
Show resolved Hide resolved
dims_to_sample: The indices of the parameters to sample.

Returns:
A potential function conditioned on the condition.
"""

def conditioned_potential(

Check warning on line 133 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L133

Added line #L133 was not covered by tests
theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True
) -> Tensor:
assert (

Check warning on line 136 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L136

Added line #L136 was not covered by tests
len(dims_to_sample) == theta.shape[1]
), "dims_to_sample must match the number of parameters to sample."
theta_without_condition = theta[:, dims_to_sample]

Check warning on line 139 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L139

Added line #L139 was not covered by tests

return _log_likelihood_with_iid_condition(

Check warning on line 141 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L141

Added line #L141 was not covered by tests
x=x_o if x_o is not None else self.x_o,
theta_without_condition=theta_without_condition,
condition=condition,
estimator=self.likelihood_estimator,
track_gradients=track_gradients,
)

return conditioned_potential

Check warning on line 149 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L149

Added line #L149 was not covered by tests


def _log_likelihoods_over_trials(
x: Tensor,
Expand Down Expand Up @@ -172,6 +205,67 @@
return log_likelihood_trial_sum


def _log_likelihood_with_iid_condition(
x: Tensor,
theta_without_condition: Tensor,
condition: Tensor,
estimator: ConditionalDensityEstimator,
track_gradients: bool = False,
) -> Tensor:
"""Return log likelihoods summed over iid trials of `x` with a matching batch of
conditions.

This function is different from `_log_likelihoods_over_trials` in that it moves the
iid batch dimension of `x` onto the batch dimension of `theta`. This is useful when
the likelihood estimator is conditioned on a batch of conditions that are iid with
the batch of `x`. It avoid the evaluation of the likelihood for every combination of
`x` and `condition`. Instead, it manually constructs a batch covering all
combination of iid trial and theta batch and reshapes to sum over the iid
likelihoods.

Args:
x: Batch of iid data of shape `(iid_dim, *event_shape)`.
janfb marked this conversation as resolved.
Show resolved Hide resolved
theta_without_condition: Batch of parameters `(batch_dim, *event_shape)`
janfb marked this conversation as resolved.
Show resolved Hide resolved
condition: Batch of conditions of shape `(iid_dim, *condition_shape)`.
janfb marked this conversation as resolved.
Show resolved Hide resolved
estimator: DensityEstimator.
track_gradients: Whether to track gradients.

Returns:
log_likelihood_trial_sum: log likelihood for each parameter, summed over all
batch entries (iid trials) in `x`.
janfb marked this conversation as resolved.
Show resolved Hide resolved
"""
assert (

Check warning on line 237 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L237

Added line #L237 was not covered by tests
condition.shape[0] == x.shape[0]
), "Condition and iid x must have the same batch size."
num_trials = x.shape[0]
num_theta = theta_without_condition.shape[0]
x = reshape_to_sample_batch_event(

Check warning on line 242 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L240-L242

Added lines #L240 - L242 were not covered by tests
x, event_shape=x.shape[1:], leading_is_sample=True
)

# move the iid batch dimension onto the batch dimension of theta and repeat it there
x_expanded = x.reshape(1, num_trials, -1).repeat_interleave(num_theta, dim=1)

Check warning on line 247 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L247

Added line #L247 was not covered by tests
# for this to work we construct theta and condition to cover all combinations in the
# trial batch and the theta batch.
theta = torch.cat(

Check warning on line 250 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L250

Added line #L250 was not covered by tests
[
theta_without_condition.repeat(num_trials, 1), # repeat ABAB
condition.repeat_interleave(num_theta, dim=0), # repeat AABB
],
dim=-1,
)
janfb marked this conversation as resolved.
Show resolved Hide resolved

with torch.set_grad_enabled(track_gradients):

Check warning on line 258 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L258

Added line #L258 was not covered by tests
# Calculate likelihood in one batch. Returns (1, num_trials * theta_batch_size)
log_likelihood_trial_batch = estimator.log_prob(x_expanded, condition=theta)

Check warning on line 260 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L260

Added line #L260 was not covered by tests
# Reshape to (x-trials x parameters), sum over trial-log likelihoods.
log_likelihood_trial_sum = log_likelihood_trial_batch.reshape(

Check warning on line 262 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L262

Added line #L262 was not covered by tests
num_trials, num_theta
).sum(0)

return log_likelihood_trial_sum

Check warning on line 266 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L266

Added line #L266 was not covered by tests


def mixed_likelihood_estimator_based_potential(
likelihood_estimator: MixedDensityEstimator,
prior: Distribution,
Expand All @@ -192,6 +286,13 @@
to unconstrained space.
"""

warnings.warn(

Check warning on line 289 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L289

Added line #L289 was not covered by tests
"This function is deprecated and will be removed in a future release. Use "
"`likelihood_estimator_based_potential` instead.",
DeprecationWarning,
stacklevel=2,
)

device = str(next(likelihood_estimator.discrete_net.parameters()).device)

potential_fn = MixedLikelihoodBasedPotential(
Expand All @@ -212,6 +313,13 @@
):
super().__init__(likelihood_estimator, prior, x_o, device)

warnings.warn(

Check warning on line 316 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L316

Added line #L316 was not covered by tests
"This function is deprecated and will be removed in a future release. Use "
"`LikelihoodBasedPotential` instead.",
DeprecationWarning,
stacklevel=2,
)

def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
prior_log_prob = self.prior.log_prob(theta) # type: ignore

Expand All @@ -231,7 +339,6 @@
with torch.set_grad_enabled(track_gradients):
# Call the specific log prob method of the mixed likelihood estimator as
# this optimizes the evaluation of the discrete data part.
# TODO log_prob_iid
log_likelihood_trial_batch = self.likelihood_estimator.log_prob(
input=x,
condition=theta.to(self.device),
Expand Down
6 changes: 2 additions & 4 deletions sbi/inference/trainers/nle/mnle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.distributions import Distribution

from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior
from sbi.inference.potentials import mixed_likelihood_estimator_based_potential
from sbi.inference.potentials import likelihood_estimator_based_potential
from sbi.inference.trainers.nle.nle_base import LikelihoodEstimator
from sbi.neural_nets.estimators import MixedDensityEstimator
from sbi.sbi_types import TensorboardSummaryWriter, TorchModule
Expand Down Expand Up @@ -155,9 +155,7 @@ def build_posterior(
(
potential_fn,
theta_transform,
) = mixed_likelihood_estimator_based_potential(
likelihood_estimator=likelihood_estimator, prior=prior, x_o=None
)
) = likelihood_estimator_based_potential(likelihood_estimator, prior, x_o=None)

if sample_with == "mcmc":
self._posterior = MCMCPosterior(
Expand Down
2 changes: 1 addition & 1 deletion sbi/utils/conditional_density_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def __init__(
masked outside of prior.
"""
condition = torch.atleast_2d(condition)
if condition.shape[0] != 1:
if condition.shape[0] > 1:
raise ValueError("Condition with batch size > 1 not supported.")

self.potential_fn = potential_fn
Expand Down
4 changes: 2 additions & 2 deletions sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def warn_if_zscoring_changes_data(x: Tensor, duplicate_tolerance: float = 0.1) -

if num_unique_z < num_unique * (1 - duplicate_tolerance):
warnings.warn(
"Z-scoring these simulation outputs resulted in {num_unique_z} unique "
"datapoints. Before z-scoring, it had been {num_unique}. This can "
f"Z-scoring these simulation outputs resulted in {num_unique_z} unique "
f"datapoints. Before z-scoring, it had been {num_unique}. This can "
"occur due to numerical inaccuracies when the data covers a large "
"range of values. Consider either setting `z_score_x=False` (but "
"beware that this can be problematic for training the NN) or exclude "
Expand Down
78 changes: 52 additions & 26 deletions tests/mnle_test.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Union

import pytest
import torch
from pyro.distributions import InverseGamma
from torch.distributions import Beta, Binomial, Categorical, Gamma
from torch import Tensor
from torch.distributions import Beta, Binomial, Distribution, Gamma

from sbi.inference import MNLE, MCMCPosterior
from sbi.inference.posteriors.rejection_posterior import RejectionPosterior
from sbi.inference.posteriors.vi_posterior import VIPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.inference.potentials.likelihood_based_potential import (
MixedLikelihoodBasedPotential,
likelihood_estimator_based_potential,
)
from sbi.neural_nets import likelihood_nn
from sbi.neural_nets.embedding_nets import FCEmbedding
from sbi.utils import BoxUniform, mcmc_transform
from sbi.utils.conditional_density_utils import ConditionedPotential
from sbi.utils.torchutils import atleast_2d, process_device
from sbi.utils.user_input_checks_utils import MultipleIndependent
from tests.test_utils import check_c2st


# toy simulator for mixed data
def mixed_simulator(theta, stimulus_condition=2.0):
def mixed_simulator(theta: Tensor, stimulus_condition: Union[Tensor, float] = 2.0):
"""Simulator for mixed data."""
# Extract parameters
beta, ps = theta[:, :1], theta[:, 1:]
Expand Down Expand Up @@ -190,23 +192,41 @@ def test_mnle_accuracy_with_different_samplers_and_trials(


class BinomialGammaPotential(BasePotential):
def __init__(self, prior, x_o, concentration_scaling=1.0, device="cpu"):
"""Binomial-Gamma potential for mixed data."""

def __init__(
self,
prior: Distribution,
x_o: Tensor,
concentration_scaling: Union[Tensor, float] = 1.0,
device="cpu",
):
super().__init__(prior, x_o, device)

# concentration_scaling needs to be a float or match the batch size
if isinstance(concentration_scaling, Tensor):
num_trials = x_o.shape[0]
assert concentration_scaling.shape[0] == num_trials

# Reshape to match convention (batch_size, num_trials, *event_shape)
concentration_scaling = concentration_scaling.reshape(1, num_trials, -1)

self.concentration_scaling = concentration_scaling

def __call__(self, theta, track_gradients: bool = True):
def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
theta = atleast_2d(theta)

with torch.set_grad_enabled(track_gradients):
iid_ll = self.iid_likelihood(theta)

return iid_ll + self.prior.log_prob(theta)

def iid_likelihood(self, theta):
def iid_likelihood(self, theta: Tensor) -> Tensor:
batch_size = theta.shape[0]
num_trials = self.x_o.shape[0]
theta = theta.reshape(batch_size, 1, -1)
beta, rho = theta[:, :, :1], theta[:, :, 1:]

# vectorized
logprob_choices = Binomial(probs=rho).log_prob(
self.x_o[:, 1:].reshape(1, num_trials, -1)
Expand All @@ -233,18 +253,22 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict):
categorical parameter is set to a fixed value (conditioned posterior), and the
accuracy of the conditioned posterior is tested against the true posterior.
"""
num_simulations = 6000
num_samples = 500
num_simulations = 10000
num_samples = 1000

def sim_wrapper(theta):
def sim_wrapper(
theta_and_condition: Tensor, last_idx_parameters: int = 2
) -> Tensor:
# simulate with experiment conditions
return mixed_simulator(theta[:, :2], theta[:, 2:] + 1)
theta = theta_and_condition[:, :last_idx_parameters]
condition = theta_and_condition[:, last_idx_parameters:]
return mixed_simulator(theta, condition)

proposal = MultipleIndependent(
[
Gamma(torch.tensor([1.0]), torch.tensor([0.5])),
Beta(torch.tensor([2.0]), torch.tensor([2.0])),
Categorical(probs=torch.ones(1, 3)),
BoxUniform(torch.tensor([0.0]), torch.tensor([1.0])),
],
validate_args=False,
)
Expand All @@ -254,22 +278,27 @@ def sim_wrapper(theta):
assert x.shape == (num_simulations, 2)

num_trials = 10
theta_o = proposal.sample((1,))
theta_o[0, 2] = 2.0 # set condition to 2 as in original simulator.
x_o = sim_wrapper(theta_o.repeat(num_trials, 1))
theta_and_condition = proposal.sample((num_trials,))
# use only a single parameter (iid trials)
theta_o = theta_and_condition[:1, :2].repeat(num_trials, 1)
# but different conditions
condition_o = theta_and_condition[:, 2:]
theta_and_conditions_o = torch.cat((theta_o, condition_o), dim=1)

x_o = sim_wrapper(theta_and_conditions_o)

mcmc_kwargs = dict(
method="slice_np_vectorized", init_strategy="proposal", **mcmc_params_accurate
)

# MNLE
trainer = MNLE(proposal)
estimator = trainer.append_simulations(theta, x).train(training_batch_size=1000)

potential_fn = MixedLikelihoodBasedPotential(estimator, proposal, x_o)
estimator_fun = likelihood_nn(model="mnle", z_score_x=None)
trainer = MNLE(proposal, estimator_fun)
estimator = trainer.append_simulations(theta, x).train()

conditioned_potential_fn = ConditionedPotential(
potential_fn, condition=theta_o, dims_to_sample=[0, 1]
potential_fn, _ = likelihood_estimator_based_potential(estimator, proposal, x_o)
conditioned_potential_fn = potential_fn.condition_on(
janfb marked this conversation as resolved.
Show resolved Hide resolved
condition_o, dims_to_sample=[0, 1]
)

# True posterior samples
Expand All @@ -283,10 +312,7 @@ def sim_wrapper(theta):
prior_transform = mcmc_transform(prior)
true_posterior_samples = MCMCPosterior(
BinomialGammaPotential(
prior,
atleast_2d(x_o),
concentration_scaling=float(theta_o[0, 2])
+ 1.0, # add one because the sim_wrapper adds one (see above)
prior, atleast_2d(x_o), concentration_scaling=condition_o
),
theta_transform=prior_transform,
proposal=prior,
Expand All @@ -303,5 +329,5 @@ def sim_wrapper(theta):
check_c2st(
cond_samples,
true_posterior_samples,
alg=f"MNLE trained with {num_simulations}",
alg=f"MNLE trained with {num_simulations} simulations",
)
Loading
Loading