Skip to content

Commit

Permalink
feat: batched sampling and log prob methods. (#1153)
Browse files Browse the repository at this point in the history
* Base estimator class

* intermediate commit

* make autoreload work

* `amortized_sample` works for MCMCPosterior

* fixes current bug!

* Added tests

* batched_rejection_sampling

* intermediate commit

* make autoreload work

* `amortized_sample` works for MCMCPosterior

* Revert "Merge branch '990-add-sample_batched-and-log_prob_batched-to-posteriors' into amortizedsample"

This reverts commit 07084e2, reversing
changes made to f16622d.

* sample works, try log_prob_batched

* log_prob_batched works

* abstract method implement for other methods

* temp fix mcmcposterior

* meh for general use i.e. in the restriction prior we have to add some reshapes in rejection

* ... test class

* Revert "Base estimator class"

This reverts commit 17c5343.

* removing previous change

* removing some artifacts

* revert wierd change

* docs and tests

* MCMC sample_batched works but not log_prob batched

* adding some docs

* batch_log_prob for MCMC requires at best changes for potential -> removed

* intermediate commit

* make autoreload work

* `amortized_sample` works for MCMCPosterior

* intermediate commit

* make autoreload work

* `amortized_sample` works for MCMCPosterior

* Base estimator class

* Revert "Merge branch '990-add-sample_batched-and-log_prob_batched-to-posteriors' into amortizedsample"

This reverts commit 07084e2, reversing
changes made to f16622d.

* fixes current bug!

* Added tests

* batched_rejection_sampling

* sample works, try log_prob_batched

* log_prob_batched works

* abstract method implement for other methods

* temp fix mcmcposterior

* meh for general use i.e. in the restriction prior we have to add some reshapes in rejection

* ... test class

* Revert "Base estimator class"

This reverts commit 17c5343.

* removing previous change

* removing some artifacts

* revert wierd change

* docs and tests

* MCMC sample_batched works but not log_prob batched

* adding some docs

* batch_log_prob for MCMC requires at best changes for potential -> removed

* Fixing bug from rebase...

* tracking all acceptance rates

* Comment on NFlows

* Also testing SNRE batched sampling, Need to test ensemble implementation

* fig bug

* Ensemble sample_batched is working (with tests)

* GPU compatibility

* restriction priopr requires float as output of accept_reject

* Adding a few comments

* 2d sample_shape tests

* Apply suggestions from code review

Co-authored-by: Jan <[email protected]>

* Adding comment about squeeze

* Update sbi/inference/posteriors/direct_posterior.py

Co-authored-by: Jan <[email protected]>

* fixing formating

* reverting MCM posterior changes

* xfail mcmc tests

* Exclude MCMC from ensamble batched_sample test

* SNPE_A Bug fix

* typo fix

* preamtive main fix

* Revert "preamtive main fix"

This reverts commit 2aac705.

---------

Co-authored-by: manuelgloeckler <[email protected]>
Co-authored-by: Jan Boelts <[email protected]>
Co-authored-by: manuelgloeckler <[email protected]>
Co-authored-by: Jan <[email protected]>
  • Loading branch information
5 people authored Jun 18, 2024
1 parent a7e65c5 commit 4951439
Show file tree
Hide file tree
Showing 14 changed files with 406 additions and 34 deletions.
11 changes: 11 additions & 0 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,17 @@ def sample(
"""See child classes for docstring."""
pass

@abstractmethod
def sample_batched(
self,
sample_shape: Shape,
x: Tensor,
max_sampling_batch_size: int = 10_000,
show_progress_bars: bool = True,
) -> Tensor:
"""See child classes for docstring."""
pass

@property
def default_x(self) -> Optional[Tensor]:
"""Return default x used by `.sample(), .log_prob` as conditioning context."""
Expand Down
125 changes: 122 additions & 3 deletions sbi/inference/posteriors/direct_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
reshape_to_batch_event,
reshape_to_sample_batch_event,
)
from sbi.samplers.rejection.rejection import accept_reject_sample
from sbi.samplers.rejection import rejection
from sbi.sbi_types import Shape
from sbi.utils import check_prior, within_support
from sbi.utils.torchutils import ensure_theta_batched
Expand Down Expand Up @@ -123,7 +123,51 @@ def sample(
f"`.build_posterior(sample_with={sample_with}).`"
)

samples = accept_reject_sample(
samples = rejection.accept_reject_sample(
proposal=self.posterior_estimator,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
show_progress_bars=show_progress_bars,
max_sampling_batch_size=max_sampling_batch_size,
proposal_sampling_kwargs={"condition": x},
alternative_method="build_posterior(..., sample_with='mcmc')",
)[0]

return samples[:, 0] # Remove batch dimension.

def sample_batched(
self,
sample_shape: Shape,
x: Tensor,
max_sampling_batch_size: int = 10_000,
show_progress_bars: bool = True,
) -> Tensor:
r"""Given a batch of observations [x_1, ..., x_B] this function samples from
posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized)
manner.
Args:
sample_shape: Desired shape of samples that are drawn from the posterior
given every observation.
x: A batch of observations, of shape `(batch_dim, event_shape_x)`.
`batch_dim` corresponds to the number of observations to be drawn.
max_sampling_batch_size: Maximum batch size for rejection sampling.
show_progress_bars: Whether to show sampling progress monitor.
Returns:
Samples from the posteriors of shape (*sample_shape, B, *input_shape)
"""
num_samples = torch.Size(sample_shape).numel()
condition_shape = self.posterior_estimator.condition_shape
x = reshape_to_batch_event(x, event_shape=condition_shape)

max_sampling_batch_size = (
self.max_sampling_batch_size
if max_sampling_batch_size is None
else max_sampling_batch_size
)

samples = rejection.accept_reject_sample(
proposal=self.posterior_estimator,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
Expand Down Expand Up @@ -210,6 +254,81 @@ def log_prob(

return masked_log_prob - log_factor

def log_prob_batched(
self,
theta: Tensor,
x: Tensor,
norm_posterior: bool = True,
track_gradients: bool = False,
leakage_correction_params: Optional[dict] = None,
) -> Tensor:
"""Given a batch of observations [x_1, ..., x_B] and a batch of parameters \
[$\theta_1$,..., $\theta_B$] this function evalautes the log-probabilities \
of the posteriors $p(\theta_1|x_1)$, ..., $p(\theta_B|x_B)$ in a batched \
(i.e. vectorized) manner.
Args:
theta: Batch of parameters $\theta$ of shape \
`(*sample_shape, batch_dim, *theta_shape)`.
x: Batch of observations $x$ of shape \
`(batch_dim, *condition_shape)`.
norm_posterior: Whether to enforce a normalized posterior density.
Renormalization of the posterior is useful when some
probability falls out or leaks out of the prescribed prior support.
The normalizing factor is calculated via rejection sampling, so if you
need speedier but unnormalized log posterior estimates set here
`norm_posterior=False`. The returned log posterior is set to
-∞ outside of the prior support regardless of this setting.
track_gradients: Whether the returned tensor supports tracking gradients.
This can be helpful for e.g. sensitivity analysis, but increases memory
consumption.
leakage_correction_params: A `dict` of keyword arguments to override the
default values of `leakage_correction()`. Possible options are:
`num_rejection_samples`, `force_update`, `show_progress_bars`, and
`rejection_sampling_batch_size`.
These parameters only have an effect if `norm_posterior=True`.
Returns:
`(len(θ), B)`-shaped log posterior probability $\\log p(\theta|x)$\\ for θ \
in the support of the prior, -∞ (corresponding to 0 probability) outside.
"""

theta = ensure_theta_batched(torch.as_tensor(theta))
event_shape = self.posterior_estimator.input_shape
theta_density_estimator = reshape_to_sample_batch_event(
theta, event_shape, leading_is_sample=True
)
x_density_estimator = reshape_to_batch_event(
x, event_shape=self.posterior_estimator.condition_shape
)

self.posterior_estimator.eval()

with torch.set_grad_enabled(track_gradients):
# Evaluate on device, move back to cpu for comparison with prior.
unnorm_log_prob = self.posterior_estimator.log_prob(
theta_density_estimator, condition=x_density_estimator
)

# Force probability to be zero outside prior support.
in_prior_support = within_support(self.prior, theta)

masked_log_prob = torch.where(
in_prior_support,
unnorm_log_prob,
torch.tensor(float("-inf"), dtype=torch.float32, device=self._device),
)

if leakage_correction_params is None:
leakage_correction_params = dict() # use defaults
log_factor = (
log(self.leakage_correction(x=x, **leakage_correction_params))
if norm_posterior
else 0
)

return masked_log_prob - log_factor

@torch.no_grad()
def leakage_correction(
self,
Expand Down Expand Up @@ -240,7 +359,7 @@ def leakage_correction(

def acceptance_at(x: Tensor) -> Tensor:
# [1:] to remove batch-dimension for `reshape_to_batch_event`.
return accept_reject_sample(
return rejection.accept_reject_sample(
proposal=self.posterior_estimator,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_rejection_samples,
Expand Down
23 changes: 23 additions & 0 deletions sbi/inference/posteriors/ensemble_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,29 @@ def sample(
)
return torch.vstack(samples).reshape(*sample_shape, -1)

def sample_batched(
self,
sample_shape: Shape,
x: Tensor,
**kwargs,
) -> Tensor:
num_samples = torch.Size(sample_shape).numel()
posterior_indices = torch.multinomial(
self._weights, num_samples, replacement=True
)
samples = []
for posterior_index, sample_size in torch.vstack(
posterior_indices.unique(return_counts=True)
).T:
sample_shape_c = torch.Size((int(sample_size),))
samples.append(
self.posteriors[posterior_index].sample_batched(
sample_shape_c, x=x, **kwargs
)
)
samples = torch.vstack(samples)
return samples.reshape(sample_shape + samples.shape[1:])

def log_prob(
self,
theta: Tensor,
Expand Down
13 changes: 13 additions & 0 deletions sbi/inference/posteriors/importance_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,19 @@ def sample(
else:
raise NameError

def sample_batched(
self,
sample_shape: Shape,
x: Tensor,
max_sampling_batch_size: int = 10000,
show_progress_bars: bool = True,
) -> Tensor:
raise NotImplementedError(
"Batched sampling is not implemented for ImportanceSamplingPosterior. \
Alternatively you can use `sample` in a loop \
[posterior.sample(theta, x_o) for x_o in x]."
)

def _importance_sample(
self,
sample_shape: Shape = torch.Size(),
Expand Down
45 changes: 44 additions & 1 deletion sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,51 @@ def sample(
raise NameError(f"The sampling method {method} is not implemented!")

samples = self.theta_transform.inv(transformed_samples)
# NOTE: Currently MCMCPosteriors will require a single dimension for the
# parameter dimension. With recent ConditionalDensity(Ratio) estimators, we
# can have multiple dimensions for the parameter dimension.
samples = samples.reshape((*sample_shape, -1)) # type: ignore

return samples.reshape((*sample_shape, -1)) # type: ignore
return samples

def sample_batched(
self,
sample_shape: Shape,
x: Tensor,
method: Optional[str] = None,
thin: Optional[int] = None,
warmup_steps: Optional[int] = None,
num_chains: Optional[int] = None,
init_strategy: Optional[str] = None,
init_strategy_parameters: Optional[Dict[str, Any]] = None,
num_workers: Optional[int] = None,
mp_context: Optional[str] = None,
show_progress_bars: bool = True,
) -> Tensor:
r"""Given a batch of observations [x_1, ..., x_B] this function samples from
posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized)
manner.
Check the `__init__()` method for a description of all arguments as well as
their default values.
Args:
sample_shape: Desired shape of samples that are drawn from the posterior
given every observation.
x: A batch of observations, of shape `(batch_dim, event_shape_x)`.
`batch_dim` corresponds to the number of observations to be drawn.
show_progress_bars: Whether to show sampling progress monitor.
Returns:
Samples from the posteriors of shape (*sample_shape, B, *input_shape)
"""

# See #1176 for a discussion on the implementation of batched sampling.
raise NotImplementedError(
"Batched sampling is not implemented for MCMC posterior. \
Alternatively you can use `sample` in a loop \
[posterior.sample(theta, x_o) for x_o in x]."
)

def _build_mcmc_init_fn(
self,
Expand Down
13 changes: 13 additions & 0 deletions sbi/inference/posteriors/rejection_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,19 @@ def sample(

return samples.reshape((*sample_shape, -1))

def sample_batched(
self,
sample_shape: Shape,
x: Tensor,
max_sampling_batch_size: int = 10000,
show_progress_bars: bool = True,
) -> Tensor:
raise NotImplementedError(
"Batched sampling is not implemented for RejectionPosterior. \
Alternatively you can use `sample` in a loop \
[posterior.sample(theta, x_o) for x_o in x]."
)

def map(
self,
x: Optional[Tensor] = None,
Expand Down
13 changes: 13 additions & 0 deletions sbi/inference/posteriors/vi_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,19 @@ def sample(
samples = self.q.sample(torch.Size(sample_shape))
return samples.reshape((*sample_shape, samples.shape[-1]))

def sample_batched(
self,
sample_shape: Shape,
x: Tensor,
max_sampling_batch_size: int = 10000,
show_progress_bars: bool = True,
) -> Tensor:
raise NotImplementedError(
"Batched sampling is not implemented for VIPosterior. \
Alternatively you can use `sample` in a loop \
[posterior.sample(theta, x_o) for x_o in x]."
)

def log_prob(
self,
theta: Tensor,
Expand Down
10 changes: 8 additions & 2 deletions sbi/inference/snpe/snpe_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,8 @@ def sample(self, sample_shape: torch.Size, condition: Tensor, **kwargs) -> Tenso
condition = condition.to(self._device)

if not self._apply_correction:
return self._neural_net.sample(sample_shape, condition=condition)
samples = self._neural_net.sample(sample_shape, condition=condition)
return samples
else:
# When we want to sample from the approx. posterior, a proposal prior
# \tilde{p} has already been observed. To analytically calculate the
Expand All @@ -486,7 +487,12 @@ def sample(self, sample_shape: torch.Size, condition: Tensor, **kwargs) -> Tenso
condition_ndim = len(self.condition_shape)
batch_size = condition.shape[:-condition_ndim]
batch_size = torch.Size(batch_size).numel()
return self._sample_approx_posterior_mog(num_samples, condition, batch_size)
samples = self._sample_approx_posterior_mog(
num_samples, condition, batch_size
)
# NOTE: New batching convention: (batch_dim, sample_dim, *event_shape)
samples = samples.transpose(0, 1)
return samples

def _sample_approx_posterior_mog(
self, num_samples, x: Tensor, batch_size: int
Expand Down
2 changes: 2 additions & 0 deletions sbi/neural_nets/density_estimators/nflows_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor:
num_samples = torch.Size(sample_shape).numel()

samples = self.net.sample(num_samples, context=condition)
# Change from Nflows' convention of (batch_dim, sample_dim, *event_shape) to
# (sample_dim, batch_dim, *event_shape) (PyTorch + SBI).
samples = samples.transpose(0, 1)
return samples.reshape((*sample_shape, condition_batch_dim, *self.input_shape))

Expand Down
Loading

0 comments on commit 4951439

Please sign in to comment.