Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NLE with multiple iid conditions #1331

Merged
merged 6 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 135 additions & 2 deletions sbi/inference/potentials/likelihood_based_potential.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Callable, Optional, Tuple
import warnings
from typing import Callable, List, Optional, Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -115,6 +116,54 @@
)
return log_likelihood_batches + self.prior.log_prob(theta) # type: ignore

def condition_on_theta(
janfb marked this conversation as resolved.
Show resolved Hide resolved
self, local_theta: Tensor, dims_global_theta: List[int]
) -> Callable:
r"""Returns a potential function conditioned on a subset of theta dimensions.

The goal of this function is to divide the original `theta` into a
`global_theta` we do inference over, and a `local_theta` we condition on (in
addition to conditioning on `x_o`). Thus, the returned potential function will
calculate $\prod_{i=1}^{N}p(x_i | local_theta_i, \global_theta)$, where `x_i`
and `local_theta_i` are fixed and `global_theta` varies at inference time.

Args:
local_theta: The condition values to be conditioned.
dims_global_theta: The indices of the columns in `theta` that will be
sampled, i.e., that *not* conditioned. For example, if original theta
has shape `(batch_dim, 3)`, and `dims_global_theta=[0, 1]`, then the
potential will set `theta[:, 3] = local_theta` at inference time.

Returns:
A potential function conditioned on the `local_theta`.
"""

assert self.x_is_iid, "Conditioning is only supported for iid data."

Check warning on line 141 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L141

Added line #L141 was not covered by tests

def conditioned_potential(

Check warning on line 143 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L143

Added line #L143 was not covered by tests
theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True
) -> Tensor:
assert (

Check warning on line 146 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L146

Added line #L146 was not covered by tests
len(dims_global_theta) == theta.shape[1]
), "dims_global_theta must match the number of parameters to sample."
global_theta = theta[:, dims_global_theta]
x_o = x_o if x_o is not None else self.x_o

Check warning on line 150 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L149-L150

Added lines #L149 - L150 were not covered by tests
# x needs shape (sample_dim (iid), batch_dim (xs), *event_shape)
if x_o.dim() < 3:
x_o = reshape_to_sample_batch_event(

Check warning on line 153 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L152-L153

Added lines #L152 - L153 were not covered by tests
x_o, event_shape=x_o.shape[1:], leading_is_sample=self.x_is_iid
)

return _log_likelihood_over_iid_trials_and_local_theta(

Check warning on line 157 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L157

Added line #L157 was not covered by tests
x=x_o,
global_theta=global_theta,
local_theta=local_theta,
estimator=self.likelihood_estimator,
track_gradients=track_gradients,
)

return conditioned_potential

Check warning on line 165 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L165

Added line #L165 was not covered by tests


def _log_likelihoods_over_trials(
x: Tensor,
Expand Down Expand Up @@ -172,6 +221,77 @@
return log_likelihood_trial_sum


def _log_likelihood_over_iid_trials_and_local_theta(
x: Tensor,
global_theta: Tensor,
local_theta: Tensor,
estimator: ConditionalDensityEstimator,
track_gradients: bool = False,
) -> Tensor:
"""Returns $\\prod_{i=1}^N \\log(p(x_i|\theta, local_theta_i)$.

`x` is a batch of iid data, and `local_theta` is a matching batch of condition
values that were part of `theta` but are treated as local iid variables at inference
time.

This function is different from `_log_likelihoods_over_trials` in that it moves the
iid batch dimension of `x` onto the batch dimension of `theta`. This is needed when
the likelihood estimator is conditioned on a batch of conditions that are iid with
the batch of `x`. It avoids the evaluation of the likelihood for every combination
of `x` and `local_theta`.

Args:
x: data with shape `(sample_dim, x_batch_dim, *x_event_shape)`, where sample_dim
holds the i.i.d. trials and batch_dim holds a batch of xs, e.g., non-iid
observations.
global_theta: Batch of parameters `(theta_batch_dim,
num_parameters)`.
local_theta: Batch of conditions of shape `(sample_dim, num_local_thetas)`, must
match x's `sample_dim`.
estimator: DensityEstimator.
track_gradients: Whether to track gradients.

Returns:
log_likelihood: log likelihood for each x in x_batch_dim, for each theta in
theta_batch_dim, summed over all iid trials. Shape `(x_batch_dim,
theta_batch_dim)`.
"""
assert x.dim() > 2, "x must have shape (sample_dim, batch_dim, *event_shape)."
assert (
local_theta.dim() == 2
), "condition must have shape (sample_dim, num_conditions)."
assert global_theta.dim() == 2, "theta must have shape (batch_dim, num_parameters)."
num_trials, num_xs = x.shape[:2]
num_thetas = global_theta.shape[0]
assert (
local_theta.shape[0] == num_trials
), "Condition batch size must match the number of iid trials in x."

# move the iid batch dimension onto the batch dimension of theta and repeat it there
x_repeated = torch.transpose(x, 0, 1).repeat_interleave(num_thetas, dim=1)

# construct theta and condition to cover all trial-theta combinations
theta_with_condition = torch.cat(
[
global_theta.repeat(num_trials, 1), # repeat ABAB
local_theta.repeat_interleave(num_thetas, dim=0), # repeat AABB
],
dim=-1,
)
janfb marked this conversation as resolved.
Show resolved Hide resolved

with torch.set_grad_enabled(track_gradients):
# Calculate likelihood in one batch. Returns (1, num_trials * num_theta)
log_likelihood_trial_batch = estimator.log_prob(
x_repeated, condition=theta_with_condition
)
# Reshape to (x-trials x parameters), sum over trial-log likelihoods.
log_likelihood_trial_sum = log_likelihood_trial_batch.reshape(
num_xs, num_trials, num_thetas
).sum(1)

return log_likelihood_trial_sum


def mixed_likelihood_estimator_based_potential(
likelihood_estimator: MixedDensityEstimator,
prior: Distribution,
Expand All @@ -192,6 +312,13 @@
to unconstrained space.
"""

warnings.warn(

Check warning on line 315 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L315

Added line #L315 was not covered by tests
"This function is deprecated and will be removed in a future release. Use "
"`likelihood_estimator_based_potential` instead.",
DeprecationWarning,
stacklevel=2,
)

device = str(next(likelihood_estimator.discrete_net.parameters()).device)

potential_fn = MixedLikelihoodBasedPotential(
Expand All @@ -212,6 +339,13 @@
):
super().__init__(likelihood_estimator, prior, x_o, device)

warnings.warn(

Check warning on line 342 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L342

Added line #L342 was not covered by tests
"This function is deprecated and will be removed in a future release. Use "
"`LikelihoodBasedPotential` instead.",
DeprecationWarning,
stacklevel=2,
)

def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
prior_log_prob = self.prior.log_prob(theta) # type: ignore

Expand All @@ -231,7 +365,6 @@
with torch.set_grad_enabled(track_gradients):
# Call the specific log prob method of the mixed likelihood estimator as
# this optimizes the evaluation of the discrete data part.
# TODO log_prob_iid
log_likelihood_trial_batch = self.likelihood_estimator.log_prob(
input=x,
condition=theta.to(self.device),
Expand Down
6 changes: 2 additions & 4 deletions sbi/inference/trainers/nle/mnle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.distributions import Distribution

from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior
from sbi.inference.potentials import mixed_likelihood_estimator_based_potential
from sbi.inference.potentials import likelihood_estimator_based_potential
from sbi.inference.trainers.nle.nle_base import LikelihoodEstimator
from sbi.neural_nets.estimators import MixedDensityEstimator
from sbi.sbi_types import TensorboardSummaryWriter, TorchModule
Expand Down Expand Up @@ -155,9 +155,7 @@ def build_posterior(
(
potential_fn,
theta_transform,
) = mixed_likelihood_estimator_based_potential(
likelihood_estimator=likelihood_estimator, prior=prior, x_o=None
)
) = likelihood_estimator_based_potential(likelihood_estimator, prior, x_o=None)

if sample_with == "mcmc":
self._posterior = MCMCPosterior(
Expand Down
2 changes: 1 addition & 1 deletion sbi/utils/conditional_density_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def __init__(
masked outside of prior.
"""
condition = torch.atleast_2d(condition)
if condition.shape[0] != 1:
if condition.shape[0] > 1:
raise ValueError("Condition with batch size > 1 not supported.")

self.potential_fn = potential_fn
Expand Down
4 changes: 2 additions & 2 deletions sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def warn_if_zscoring_changes_data(x: Tensor, duplicate_tolerance: float = 0.1) -

if num_unique_z < num_unique * (1 - duplicate_tolerance):
warnings.warn(
"Z-scoring these simulation outputs resulted in {num_unique_z} unique "
"datapoints. Before z-scoring, it had been {num_unique}. This can "
f"Z-scoring these simulation outputs resulted in {num_unique_z} unique "
f"datapoints. Before z-scoring, it had been {num_unique}. This can "
"occur due to numerical inaccuracies when the data covers a large "
"range of values. Consider either setting `z_score_x=False` (but "
"beware that this can be problematic for training the NN) or exclude "
Expand Down
Loading
Loading