Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

additional features for NPSE #1370

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions sbi/inference/posteriors/direct_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def sample(
)

samples = rejection.accept_reject_sample(
proposal=self.posterior_estimator,
proposal=self.posterior_estimator.sample,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
show_progress_bars=show_progress_bars,
Expand Down Expand Up @@ -176,7 +176,7 @@ def sample_batched(
)

samples = rejection.accept_reject_sample(
proposal=self.posterior_estimator,
proposal=self.posterior_estimator.sample,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
show_progress_bars=show_progress_bars,
Expand Down Expand Up @@ -373,7 +373,7 @@ def leakage_correction(
def acceptance_at(x: Tensor) -> Tensor:
# [1:] to remove batch-dimension for `reshape_to_batch_event`.
return rejection.accept_reject_sample(
proposal=self.posterior_estimator,
proposal=self.posterior_estimator.sample,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_rejection_samples,
show_progress_bars=show_progress_bars,
Expand Down
131 changes: 113 additions & 18 deletions sbi/inference/posteriors/score_posterior.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from functools import partial
from typing import Dict, Optional, Union

import torch
Expand All @@ -16,9 +17,11 @@
from sbi.neural_nets.estimators.shape_handling import (
reshape_to_batch_event,
)
from sbi.samplers.rejection import rejection
from sbi.samplers.score import Corrector, Diffuser, Predictor
from sbi.sbi_types import Shape
from sbi.utils import check_prior
from sbi.utils.sbiutils import gradient_ascent, within_support
from sbi.utils.torchutils import ensure_theta_batched


Expand All @@ -44,7 +47,7 @@ def __init__(
prior: Distribution,
max_sampling_batch_size: int = 10_000,
device: Optional[str] = None,
enable_transform: bool = False,
enable_transform: bool = True,
sample_with: str = "sde",
):
"""
Expand Down Expand Up @@ -136,8 +139,18 @@ def sample(
x = reshape_to_batch_event(x, self.score_estimator.condition_shape)
self.potential_fn.set_x(x)

num_samples = torch.Size(sample_shape).numel()

if self.sample_with == "ode":
samples = self.sample_via_zuko(sample_shape=sample_shape, x=x)
samples = rejection.accept_reject_sample(
proposal=self.sample_via_zuko,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
show_progress_bars=show_progress_bars,
max_sampling_batch_size=max_sampling_batch_size,
proposal_sampling_kwargs={"x": x},
)[0]
samples = samples.reshape(sample_shape + self.score_estimator.input_shape)
elif self.sample_with == "sde":
samples = self._sample_via_diffusion(
sample_shape=sample_shape,
Expand All @@ -150,6 +163,25 @@ def sample(
max_sampling_batch_size=max_sampling_batch_size,
show_progress_bars=show_progress_bars,
)
proposal_sampling_kwargs = {
"predictor": predictor,
"corrector": corrector,
"predictor_params": predictor_params,
"corrector_params": corrector_params,
"steps": steps,
"ts": ts,
"max_sampling_batch_size": max_sampling_batch_size,
"show_progress_bars": show_progress_bars,
}
samples = rejection.accept_reject_sample(
proposal=self._sample_via_diffusion,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
show_progress_bars=show_progress_bars,
max_sampling_batch_size=max_sampling_batch_size,
proposal_sampling_kwargs=proposal_sampling_kwargs,
)[0]
samples = samples.reshape(sample_shape + self.score_estimator.input_shape)

return samples

Expand Down Expand Up @@ -220,12 +252,12 @@ def _sample_via_diffusion(
)
samples = torch.cat(samples, dim=0)[:num_samples]

return samples.reshape(sample_shape + self.score_estimator.input_shape)
return samples

def sample_via_zuko(
self,
x: Tensor,
sample_shape: Shape = torch.Size(),
x: Optional[Tensor] = None,
) -> Tensor:
r"""Return samples from posterior distribution with probability flow ODE.

Expand All @@ -241,10 +273,13 @@ def sample_via_zuko(
"""
num_samples = torch.Size(sample_shape).numel()

x = self._x_else_default_x(x)
x = reshape_to_batch_event(x, self.score_estimator.condition_shape)

flow = self.potential_fn.get_continuous_normalizing_flow(condition=x)
samples = flow.sample(torch.Size((num_samples,)))

return samples.reshape(sample_shape + self.score_estimator.input_shape)
return samples

def log_prob(
self,
Expand Down Expand Up @@ -301,7 +336,7 @@ def map(
x: Optional[Tensor] = None,
num_iter: int = 1000,
num_to_optimize: int = 1000,
learning_rate: float = 1e-5,
learning_rate: float = 0.01,
init_method: Union[str, Tensor] = "posterior",
num_init_samples: int = 1000,
save_best_every: int = 1000,
Expand Down Expand Up @@ -349,17 +384,77 @@ def map(
Returns:
The MAP estimate.
"""
raise NotImplementedError(
"MAP estimation is currently not working accurately for ScorePosterior."
if x is not None:
raise ValueError(
"Passing `x` directly to `.map()` has been deprecated."
"Use `.self_default_x()` to set `x`, and then run `.map()` "
)

if self.default_x is None:
raise ValueError(
"Default `x` has not been set."
"To set the default, use the `.set_default_x()` method."
)

if self._map is None or force_update:
self.potential_fn.set_x(self.default_x)
callable_potential_fn = CallableDifferentiablePotentialFunction(
self.potential_fn
)
if init_method == "posterior":
inits = self.sample((num_init_samples,))
elif init_method == "proposal":
inits = self.proposal.sample((num_init_samples,)) # type: ignore
elif isinstance(init_method, Tensor):
inits = init_method
else:
raise ValueError

self._map = gradient_ascent(
potential_fn=callable_potential_fn,
inits=inits,
theta_transform=self.theta_transform,
num_iter=num_iter,
num_to_optimize=num_to_optimize,
learning_rate=learning_rate,
save_best_every=save_best_every,
show_progress_bars=show_progress_bars,
)[0]

return self._map


class DifferentiablePotentialFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, call_function, gradient_function):
# Save the methods as callables
ctx.call_function = call_function
ctx.gradient_function = gradient_function
ctx.save_for_backward(input)

# Perform the forward computation
output = call_function(input)
return output

@staticmethod
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
grad = ctx.gradient_function(input)
while len(grad_output.shape) < len(grad.shape):
grad_output = grad_output.unsqueeze(-1)
grad_input = grad_output * grad
return grad_input, None, None


# Wrapper class to manage state
class CallableDifferentiablePotentialFunction:
def __init__(self, posterior_score_based_potential):
self.posterior_score_based_potential = posterior_score_based_potential

def __call__(self, input):
prepared_potential = partial(
self.posterior_score_based_potential.__call__, rebuild_flow=False
)
return super().map(
x=x,
num_iter=num_iter,
num_to_optimize=num_to_optimize,
learning_rate=learning_rate,
init_method=init_method,
num_init_samples=num_init_samples,
save_best_every=save_best_every,
show_progress_bars=show_progress_bars,
force_update=force_update,
return DifferentiablePotentialFunction.apply(
input, prepared_potential, self.posterior_score_based_potential.gradient
)
57 changes: 39 additions & 18 deletions sbi/inference/potentials/score_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def score_estimator_based_potential(
score_estimator: ConditionalScoreEstimator,
prior: Optional[Distribution],
x_o: Optional[Tensor],
enable_transform: bool = False,
enable_transform: bool = True,
) -> Tuple["PosteriorScoreBasedPotential", TorchTransform]:
r"""Returns the potential function gradient for score estimators.
Expand All @@ -41,10 +41,6 @@ def score_estimator_based_potential(
score_estimator, prior, x_o, device=device
)

assert enable_transform is False, (
"Transforms are not yet supported for score estimators."
)

if prior is not None:
theta_transform = mcmc_transform(
prior, device=device, enable_transform=enable_transform
Expand Down Expand Up @@ -74,16 +70,38 @@ def __init__(
`iid_bridge` as proposed in Geffner et al. is implemented.
device: The device on which to evaluate the potential.
"""

super().__init__(prior, x_o, device=device)
self.score_estimator = score_estimator
self.score_estimator.eval()
self.iid_method = iid_method
super().__init__(prior, x_o, device=device)

def set_x(
self,
x_o: Optional[Tensor],
x_is_iid: Optional[bool] = False,
rebuild_flow: Optional[bool] = True,
):
super().set_x(x_o, x_is_iid)
if rebuild_flow and self._x_o is not None:
x_density_estimator = reshape_to_batch_event(
self.x_o, event_shape=self.score_estimator.condition_shape
)
assert x_density_estimator.shape[0] == 1, (
"PosteriorScoreBasedPotential supports only x batchsize of 1`."
)
# For large number of evals, we want a high-tolerance flow.
# This flow will be used mainly for MAP calculations, hence we want to save
# it instead of rebuilding it every time.
flow = self.get_continuous_normalizing_flow(
condition=x_density_estimator, atol=1e-2, rtol=1e-3, exact=True
)
self.flow = flow

def __call__(
self,
theta: Tensor,
track_gradients: bool = True,
rebuild_flow: bool = True,
atol: float = 1e-5,
rtol: float = 1e-6,
exact: bool = True,
Expand All @@ -93,6 +111,7 @@ def __call__(
Args:
theta: The parameters at which to evaluate the potential.
track_gradients: Whether to track gradients.
rebuild_flow: Whether to rebuild the CNF for accurate log_prob evaluation.
atol: Absolute tolerance for the ODE solver.
rtol: Relative tolerance for the ODE solver.
exact: Whether to use the exact ODE solver.
Expand All @@ -104,18 +123,20 @@ def __call__(
theta_density_estimator = reshape_to_sample_batch_event(
theta, theta.shape[1:], leading_is_sample=True
)
x_density_estimator = reshape_to_batch_event(
self.x_o, event_shape=self.score_estimator.condition_shape
)
assert x_density_estimator.shape[0] == 1, (
"PosteriorScoreBasedPotential supports only x batchsize of 1`."
)

self.score_estimator.eval()
if rebuild_flow or self.flow is None:
x_density_estimator = reshape_to_batch_event(
self.x_o, event_shape=self.score_estimator.condition_shape
)
assert x_density_estimator.shape[0] == 1, (
"PosteriorScoreBasedPotential supports only x batchsize of 1`."
)

flow = self.get_continuous_normalizing_flow(
condition=x_density_estimator, atol=atol, rtol=rtol, exact=exact
)
flow = self.get_continuous_normalizing_flow(
condition=x_density_estimator, atol=atol, rtol=rtol, exact=exact
)
else:
flow = self.flow

with torch.set_grad_enabled(track_gradients):
log_probs = flow.log_prob(theta_density_estimator).squeeze(-1)
Expand All @@ -135,7 +156,7 @@ def gradient(
r"""Returns the potential function gradient for score-based methods.
Args:
theta: The parameters at which to evaluate the potential.
theta: The parameters at which to evaluate the potential gradient.
time: The diffusion time. If None, then `t_min` of the
self.score_estimator is used (i.e. we evaluate the gradient of the
actual data distribution).
Expand Down
9 changes: 4 additions & 5 deletions sbi/samplers/rejection/rejection.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import logging
import warnings
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple

import torch
import torch.distributions.transforms as torch_tf
from torch import Tensor, as_tensor, nn
from torch.distributions import Distribution
from torch import Tensor, as_tensor
from tqdm.auto import tqdm

from sbi.utils.sbiutils import gradient_ascent
Expand Down Expand Up @@ -188,7 +187,7 @@ def log_prob(self, theta: Tensor, **kwargs) -> Tensor:

@torch.no_grad()
def accept_reject_sample(
proposal: Union[nn.Module, Distribution],
proposal: Callable,
accept_reject_fn: Callable,
num_samples: int,
show_progress_bars: bool = False,
Expand Down Expand Up @@ -278,7 +277,7 @@ def accept_reject_sample(
num_samples_possible = 0
while num_remaining > 0:
# Sample and reject.
candidates = proposal.sample(
candidates = proposal(
(sampling_batch_size,), # type: ignore
**proposal_sampling_kwargs,
)
Expand Down
2 changes: 1 addition & 1 deletion sbi/utils/restriction_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def sample(

if sample_with == "rejection":
samples, acceptance_rate = accept_reject_sample(
proposal=self._prior,
proposal=self._prior.sample,
accept_reject_fn=self._accept_reject_fn,
num_samples=num_samples,
show_progress_bars=show_progress_bars,
Expand Down
Loading
Loading