Skip to content

Commit

Permalink
feat: batched sampling for vectorized MCMC (#1176)
Browse files Browse the repository at this point in the history
* Base estimator class

* intermediate commit

* make autoreload work

* `amortized_sample` works for MCMCPosterior

* fixes current bug!

* Added tests

* batched_rejection_sampling

* intermediate commit

* make autoreload work

* `amortized_sample` works for MCMCPosterior

* Revert "Merge branch '990-add-sample_batched-and-log_prob_batched-to-posteriors' into amortizedsample"

This reverts commit 07084e2, reversing
changes made to f16622d.

* sample works, try log_prob_batched

* log_prob_batched works

* abstract method implement for other methods

* temp fix mcmcposterior

* meh for general use i.e. in the restriction prior we have to add some reshapes in rejection

* ... test class

* Revert "Base estimator class"

This reverts commit 17c5343.

* removing previous change

* removing some artifacts

* revert wierd change

* docs and tests

* MCMC sample_batched works but not log_prob batched

* adding some docs

* batch_log_prob for MCMC requires at best changes for potential -> removed

* intermediate commit

* make autoreload work

* `amortized_sample` works for MCMCPosterior

* intermediate commit

* make autoreload work

* `amortized_sample` works for MCMCPosterior

* Base estimator class

* Revert "Merge branch '990-add-sample_batched-and-log_prob_batched-to-posteriors' into amortizedsample"

This reverts commit 07084e2, reversing
changes made to f16622d.

* fixes current bug!

* Added tests

* batched_rejection_sampling

* sample works, try log_prob_batched

* log_prob_batched works

* abstract method implement for other methods

* temp fix mcmcposterior

* meh for general use i.e. in the restriction prior we have to add some reshapes in rejection

* ... test class

* Revert "Base estimator class"

This reverts commit 17c5343.

* removing previous change

* removing some artifacts

* revert wierd change

* docs and tests

* MCMC sample_batched works but not log_prob batched

* adding some docs

* batch_log_prob for MCMC requires at best changes for potential -> removed

* Fixing bug from rebase...

* tracking all acceptance rates

* Comment on NFlows

* Also testing SNRE batched sampling, Need to test ensemble implementation

* fig bug

* Ensemble sample_batched is working (with tests)

* GPU compatibility

* restriction priopr requires float as output of accept_reject

* Adding a few comments

* 2d sample_shape tests

* Apply suggestions from code review

Co-authored-by: Jan <[email protected]>

* Adding comment about squeeze

* Formating new mcmc branch

* mcmc sample batched for likelihood estimator

* batch sampling for snpe,snre

* ruff fixes after merge

* pytest not catching xfail

* mcmc_posterior sample_batched disappeared in merge

* move mcmc chain shape handling to mcmcposterior away from potentials

* batched init strategies for mcmc

* update raio_based_potential for new RatioEstimator class

* mcmc sample shape out fix and process_x utils

* suggestions from jan

* warning on batched x

---------

Co-authored-by: michaeldeistler <[email protected]>
Co-authored-by: Jan Boelts <[email protected]>
Co-authored-by: Jan <[email protected]>
Co-authored-by: Guy Moss <[email protected]>
Co-authored-by: Guy Moss <[email protected]>
  • Loading branch information
6 people authored Jul 30, 2024
1 parent e86c761 commit 81fffcf
Show file tree
Hide file tree
Showing 15 changed files with 388 additions and 131 deletions.
2 changes: 1 addition & 1 deletion sbi/inference/abc/mcabc.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def simulator(theta):
self.x_o = process_x(x_o, self.x_shape)
else:
self.x_shape = x[0, 0].shape
self.x_o = process_x(x_o, self.x_shape, allow_iid_x=True)
self.x_o = process_x(x_o, self.x_shape)

distances = self.distance(self.x_o, x)

Expand Down
4 changes: 1 addition & 3 deletions sbi/inference/abc/smcabc.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,7 @@ def _set_xo_and_sample_initial_population(
self.x_shape = x[0].shape
else:
self.x_shape = x[0, 0].shape
self.x_o = process_x(
x_o, self.x_shape, allow_iid_x=self.distance.requires_iid_data
)
self.x_o = process_x(x_o, self.x_shape)

distances = self.distance(self.x_o, x)
sortidx = torch.argsort(distances)
Expand Down
8 changes: 2 additions & 6 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,15 @@ def set_default_x(self, x: Tensor) -> "NeuralPosterior":
Returns:
`NeuralPosterior` that will use a default `x` when not explicitly passed.
"""
self._x = process_x(
x, x_event_shape=None, allow_iid_x=self.potential_fn.allow_iid_x
).to(self._device)
self._x = process_x(x, x_event_shape=None).to(self._device)
self._map = None
return self

def _x_else_default_x(self, x: Optional[Array]) -> Tensor:
if x is not None:
# New x, reset posterior sampler.
self._posterior_sampler = None
return process_x(
x, x_event_shape=None, allow_iid_x=self.potential_fn.allow_iid_x
)
return process_x(x, x_event_shape=None)
elif self.default_x is None:
raise ValueError(
"Context `x` needed when a default has not been set."
Expand Down
6 changes: 2 additions & 4 deletions sbi/inference/posteriors/ensemble_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,7 @@ def set_default_x(self, x: Tensor) -> "NeuralPosterior":
`EnsemblePosterior` that will use a default `x` when not explicitly
passed.
"""
self._x = process_x(
x, x_event_shape=None, allow_iid_x=self.potential_fn.allow_iid_x
).to(self._device)
self._x = process_x(x, x_event_shape=None).to(self._device)

for posterior in self.posteriors:
posterior.set_default_x(x)
Expand Down Expand Up @@ -433,7 +431,7 @@ def allow_iid_x(self) -> bool:
def set_x(self, x_o: Optional[Tensor]):
"""Check the shape of the observed data and, if valid, set it."""
if x_o is not None:
x_o = process_x(x_o, allow_iid_x=self.allow_iid_x).to( # type: ignore
x_o = process_x(x_o).to( # type: ignore
self.device
)
self._x_o = x_o
Expand Down
197 changes: 176 additions & 21 deletions sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from copy import deepcopy
from functools import partial
from math import ceil
from typing import Any, Callable, Dict, Optional, Union
Expand All @@ -20,6 +21,7 @@

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.neural_nets.density_estimators.shape_handling import reshape_to_batch_event
from sbi.samplers.mcmc import (
IterateParameters,
PyMCSampler,
Expand All @@ -30,7 +32,6 @@
sir_init,
)
from sbi.sbi_types import Shape, TorchTransform
from sbi.simulators.simutils import tqdm_joblib
from sbi.utils.potentialutils import pyro_potential_wrapper, transformed_potential
from sbi.utils.torchutils import ensure_theta_batched, tensor2numpy

Expand Down Expand Up @@ -245,6 +246,7 @@ def sample(
Returns:
Samples from posterior.
"""

self.potential_fn.set_x(self._x_else_default_x(x))

# Replace arguments that were not passed with their default.
Expand Down Expand Up @@ -321,6 +323,7 @@ def sample(
thin=thin, # type: ignore
warmup_steps=warmup_steps, # type: ignore
vectorized=(method == "slice_np_vectorized"),
interchangeable_chains=True,
num_workers=num_workers,
show_progress_bars=show_progress_bars,
)
Expand Down Expand Up @@ -391,11 +394,82 @@ def sample_batched(
Samples from the posteriors of shape (*sample_shape, B, *input_shape)
"""

# See #1176 for a discussion on the implementation of batched sampling.
raise NotImplementedError(
"Batched sampling is not implemented for MCMC posterior. \
Alternatively you can use `sample` in a loop \
[posterior.sample(theta, x_o) for x_o in x]."
# Replace arguments that were not passed with their default.
method = self.method if method is None else method
thin = self.thin if thin is None else thin
warmup_steps = self.warmup_steps if warmup_steps is None else warmup_steps
num_chains = self.num_chains if num_chains is None else num_chains
init_strategy = self.init_strategy if init_strategy is None else init_strategy
num_workers = self.num_workers if num_workers is None else num_workers
mp_context = self.mp_context if mp_context is None else mp_context
init_strategy_parameters = (
self.init_strategy_parameters
if init_strategy_parameters is None
else init_strategy_parameters
)

assert (
method == "slice_np_vectorized"
), "Batched sampling only supported for vectorized samplers!"

# custom shape handling to make sure to match the batch size of x and theta
# without unnecessary combinations.
if len(x.shape) == 1:
x = x.unsqueeze(0)
batch_size = x.shape[0]

x = reshape_to_batch_event(x, event_shape=x.shape[1:])

# For batched sampling, we want `num_chains` for each observation in the batch.
# Here we repeat the observations ABC -> AAABBBCCC, so that the chains are
# in the order of the observations.
x_ = x.repeat_interleave(num_chains, dim=0)

self.potential_fn.set_x(x_, x_is_iid=False)
self.potential_ = self._prepare_potential(method) # type: ignore

# For each observation in the batch, we have num_chains independent chains.
num_chains_extended = batch_size * num_chains
init_strategy_parameters["num_return_samples"] = num_chains_extended
initial_params = self._get_initial_params_batched(
x,
init_strategy, # type: ignore
num_chains, # type: ignore
num_workers,
show_progress_bars,
**init_strategy_parameters,
)
# We need num_samples from each posterior in the batch
num_samples = torch.Size(sample_shape).numel() * batch_size

with torch.set_grad_enabled(False):
transformed_samples = self._slice_np_mcmc(
num_samples=num_samples,
potential_function=self.potential_,
initial_params=initial_params,
thin=thin, # type: ignore
warmup_steps=warmup_steps, # type: ignore
vectorized=(method == "slice_np_vectorized"),
interchangeable_chains=False,
num_workers=num_workers,
show_progress_bars=show_progress_bars,
)

samples = self.theta_transform.inv(transformed_samples)
sample_shape_len = len(sample_shape)
# The MCMC sampler returns the samples per chain, of shape
# (num_samples, num_chains_extended, *input_shape). We return the samples as `
# (*sample_shape, x_batch_size, *input_shape). This means we want to combine
# all the chains that belong to the same x. However, using
# samples.reshape(*sample_shape,batch_size,-1) does not combine the samples in
# the right order, since this mixes samples that belong to different `x`.
# This is a workaround to reshape the samples in the right order.
return samples.reshape((batch_size, *sample_shape, -1)).permute( # type: ignore
tuple(range(1, sample_shape_len + 1))
+ (
0,
-1,
)
)

def _build_mcmc_init_fn(
Expand Down Expand Up @@ -459,7 +533,7 @@ def _get_initial_params(
) -> Tensor:
"""Return initial parameters for MCMC obtained with given init strategy.
Parallelizes across CPU cores only for SIR.
Parallelizes across CPU cores only for resample and SIR.
Args:
init_strategy: Specifies the initialization method. Either of
Expand Down Expand Up @@ -491,25 +565,95 @@ def seeded_init_fn(seed):
seeds = torch.randint(high=2**31, size=(num_chains,))

# Generate initial params parallelized over num_workers.
with tqdm_joblib(
initial_params = list(
tqdm(
range(num_chains), # type: ignore
disable=not show_progress_bars,
desc=f"""Generating {num_chains} MCMC inits with {num_workers}
workers.""",
total=num_chains,
)
):
initial_params = torch.cat(
Parallel(n_jobs=num_workers)( # pyright: ignore[reportArgumentType]
Parallel(return_as="generator", n_jobs=num_workers)(
delayed(seeded_init_fn)(seed) for seed in seeds
)
),
total=len(seeds),
desc=f"""Generating {num_chains} MCMC inits with
{num_workers} workers.""",
disable=not show_progress_bars,
)
)
initial_params = torch.cat(initial_params) # type: ignore
else:
initial_params = torch.cat(
[init_fn() for _ in range(num_chains)] # type: ignore
)
return initial_params

def _get_initial_params_batched(
self,
x: torch.Tensor,
init_strategy: str,
num_chains_per_x: int,
num_workers: int,
show_progress_bars: bool,
**kwargs,
) -> Tensor:
"""Return initial parameters for MCMC for a batch of `x`, obtained with given
init strategy.
Parallelizes across CPU cores only for resample and SIR.
Args:
x: Batch of observations to create different initial parameters for.
init_strategy: Specifies the initialization method. Either of
[`proposal`|`sir`|`resample`|`latest_sample`].
num_chains_per_x: number of MCMC chains for each x, generates initial params
for each x
num_workers: number of CPU cores for parallization
show_progress_bars: whether to show progress bars for SIR init
kwargs: Passed on to `_build_mcmc_init_fn`.
Returns:
Tensor: initial parameters, one for each chain
"""

potential_ = deepcopy(self.potential_fn)
initial_params = []
init_fn = self._build_mcmc_init_fn(
self.proposal,
potential_fn=potential_,
transform=self.theta_transform,
init_strategy=init_strategy, # type: ignore
**kwargs,
)
for xi in x:
# Build init function
potential_.set_x(xi)

# Parallelize inits for resampling or sir.
if num_workers > 1 and (
init_strategy == "resample" or init_strategy == "sir"
):

def seeded_init_fn(seed):
torch.manual_seed(seed)
return init_fn()

seeds = torch.randint(high=2**31, size=(num_chains_per_x,))

# Generate initial params parallelized over num_workers.
initial_params = initial_params + list(
tqdm(
Parallel(return_as="generator", n_jobs=num_workers)(
delayed(seeded_init_fn)(seed) for seed in seeds
),
total=len(seeds),
desc=f"""Generating {num_chains_per_x} MCMC inits with
{num_workers} workers.""",
disable=not show_progress_bars,
)
)

else:
initial_params = initial_params + [
init_fn() for _ in range(num_chains_per_x)
] # type: ignore

initial_params = torch.cat(initial_params)
return initial_params

def _slice_np_mcmc(
Expand All @@ -520,6 +664,7 @@ def _slice_np_mcmc(
thin: int,
warmup_steps: int,
vectorized: bool = False,
interchangeable_chains=True,
num_workers: int = 1,
init_width: Union[float, ndarray] = 0.01,
show_progress_bars: bool = True,
Expand All @@ -534,6 +679,8 @@ def _slice_np_mcmc(
warmup_steps: Initial number of samples to discard.
vectorized: Whether to use a vectorized implementation of the
`SliceSampler`.
interchangeable_chains: Whether chains are interchangeable, i.e., whether
we can mix samples between chains.
num_workers: Number of CPU cores to use.
init_width: Inital width of brackets.
show_progress_bars: Whether to show a progressbar during sampling;
Expand All @@ -550,9 +697,14 @@ def _slice_np_mcmc(
else:
SliceSamplerMultiChain = SliceSamplerVectorized

def multi_obs_potential(params):
# Params are of shape (num_chains * num_obs, event).
all_potentials = potential_function(params) # Shape: (num_chains, num_obs)
return all_potentials.flatten()

posterior_sampler = SliceSamplerMultiChain(
init_params=tensor2numpy(initial_params),
log_prob_fn=potential_function,
log_prob_fn=multi_obs_potential,
num_chains=num_chains,
thin=thin,
verbose=show_progress_bars,
Expand All @@ -572,8 +724,11 @@ def _slice_np_mcmc(
# Save sample as potential next init (if init_strategy == 'latest_sample').
self._mcmc_init_params = samples[:, -1, :].reshape(num_chains, dim_samples)

# Collect samples from all chains.
samples = samples.reshape(-1, dim_samples)[:num_samples]
# Update: If chains are interchangeable, return concatenated samples. Otherwise
# return samples per chain.
if interchangeable_chains:
# Collect samples from all chains.
samples = samples.reshape(-1, dim_samples)[:num_samples]

return samples.type(torch.float32).to(self._device)

Expand Down
17 changes: 12 additions & 5 deletions sbi/inference/potentials/base_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,22 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
raise NotImplementedError

@property
@abstractmethod
def allow_iid_x(self) -> bool:
raise NotImplementedError
def x_is_iid(self) -> bool:
"""If x has batch dimension greater than 1, whether to intepret the batch as iid
samples or batch of data points."""
if self._x_is_iid is not None:
return self._x_is_iid
else:
raise ValueError(
"No observed data is available. Use `potential_fn.set_x(x_o)`."
)

def set_x(self, x_o: Optional[Tensor]):
def set_x(self, x_o: Optional[Tensor], x_is_iid: Optional[bool] = True):
"""Check the shape of the observed data and, if valid, set it."""
if x_o is not None:
x_o = process_x(x_o, allow_iid_x=self.allow_iid_x).to(self.device)
x_o = process_x(x_o).to(self.device)
self._x_o = x_o
self._x_is_iid = x_is_iid

@property
def x_o(self) -> Tensor:
Expand Down
Loading

0 comments on commit 81fffcf

Please sign in to comment.