From 81fffcffbbb7649b3d3ac7cfbf99efe0709de22a Mon Sep 17 00:00:00 2001 From: manuelgloeckler <38903899+manuelgloeckler@users.noreply.github.com> Date: Tue, 30 Jul 2024 11:24:40 +0200 Subject: [PATCH] feat: batched sampling for vectorized MCMC (#1176) * Base estimator class * intermediate commit * make autoreload work * `amortized_sample` works for MCMCPosterior * fixes current bug! * Added tests * batched_rejection_sampling * intermediate commit * make autoreload work * `amortized_sample` works for MCMCPosterior * Revert "Merge branch '990-add-sample_batched-and-log_prob_batched-to-posteriors' into amortizedsample" This reverts commit 07084e28fb586d43605dba6786d60c3e48ed96e5, reversing changes made to f16622d552e0dd69b17855bea9d672594e11d8ce. * sample works, try log_prob_batched * log_prob_batched works * abstract method implement for other methods * temp fix mcmcposterior * meh for general use i.e. in the restriction prior we have to add some reshapes in rejection * ... test class * Revert "Base estimator class" This reverts commit 17c534303343bd6306ea8e45fd4085a929ba42c2. * removing previous change * removing some artifacts * revert wierd change * docs and tests * MCMC sample_batched works but not log_prob batched * adding some docs * batch_log_prob for MCMC requires at best changes for potential -> removed * intermediate commit * make autoreload work * `amortized_sample` works for MCMCPosterior * intermediate commit * make autoreload work * `amortized_sample` works for MCMCPosterior * Base estimator class * Revert "Merge branch '990-add-sample_batched-and-log_prob_batched-to-posteriors' into amortizedsample" This reverts commit 07084e28fb586d43605dba6786d60c3e48ed96e5, reversing changes made to f16622d552e0dd69b17855bea9d672594e11d8ce. * fixes current bug! * Added tests * batched_rejection_sampling * sample works, try log_prob_batched * log_prob_batched works * abstract method implement for other methods * temp fix mcmcposterior * meh for general use i.e. in the restriction prior we have to add some reshapes in rejection * ... test class * Revert "Base estimator class" This reverts commit 17c534303343bd6306ea8e45fd4085a929ba42c2. * removing previous change * removing some artifacts * revert wierd change * docs and tests * MCMC sample_batched works but not log_prob batched * adding some docs * batch_log_prob for MCMC requires at best changes for potential -> removed * Fixing bug from rebase... * tracking all acceptance rates * Comment on NFlows * Also testing SNRE batched sampling, Need to test ensemble implementation * fig bug * Ensemble sample_batched is working (with tests) * GPU compatibility * restriction priopr requires float as output of accept_reject * Adding a few comments * 2d sample_shape tests * Apply suggestions from code review Co-authored-by: Jan * Adding comment about squeeze * Formating new mcmc branch * mcmc sample batched for likelihood estimator * batch sampling for snpe,snre * ruff fixes after merge * pytest not catching xfail * mcmc_posterior sample_batched disappeared in merge * move mcmc chain shape handling to mcmcposterior away from potentials * batched init strategies for mcmc * update raio_based_potential for new RatioEstimator class * mcmc sample shape out fix and process_x utils * suggestions from jan * warning on batched x --------- Co-authored-by: michaeldeistler Co-authored-by: Jan Boelts Co-authored-by: Jan Co-authored-by: Guy Moss Co-authored-by: Guy Moss <91739128+gmoss13@users.noreply.github.com> --- sbi/inference/abc/mcabc.py | 2 +- sbi/inference/abc/smcabc.py | 4 +- sbi/inference/posteriors/base_posterior.py | 8 +- .../posteriors/ensemble_posterior.py | 6 +- sbi/inference/posteriors/mcmc_posterior.py | 197 ++++++++++++++++-- sbi/inference/potentials/base_potential.py | 17 +- .../potentials/likelihood_based_potential.py | 36 ++-- .../potentials/posterior_based_potential.py | 61 ++++-- .../potentials/ratio_based_potential.py | 38 ++-- sbi/utils/conditional_density_utils.py | 20 +- sbi/utils/sbiutils.py | 15 +- sbi/utils/user_input_checks.py | 12 +- tests/embedding_net_test.py | 27 ++- tests/posterior_nn_test.py | 58 ++++-- tests/user_input_checks_test.py | 18 +- 15 files changed, 388 insertions(+), 131 deletions(-) diff --git a/sbi/inference/abc/mcabc.py b/sbi/inference/abc/mcabc.py index 3381e31d7..5fd153fe9 100644 --- a/sbi/inference/abc/mcabc.py +++ b/sbi/inference/abc/mcabc.py @@ -176,7 +176,7 @@ def simulator(theta): self.x_o = process_x(x_o, self.x_shape) else: self.x_shape = x[0, 0].shape - self.x_o = process_x(x_o, self.x_shape, allow_iid_x=True) + self.x_o = process_x(x_o, self.x_shape) distances = self.distance(self.x_o, x) diff --git a/sbi/inference/abc/smcabc.py b/sbi/inference/abc/smcabc.py index 37cc38ada..a4766321a 100644 --- a/sbi/inference/abc/smcabc.py +++ b/sbi/inference/abc/smcabc.py @@ -389,9 +389,7 @@ def _set_xo_and_sample_initial_population( self.x_shape = x[0].shape else: self.x_shape = x[0, 0].shape - self.x_o = process_x( - x_o, self.x_shape, allow_iid_x=self.distance.requires_iid_data - ) + self.x_o = process_x(x_o, self.x_shape) distances = self.distance(self.x_o, x) sortidx = torch.argsort(distances) diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index 3f109e9c0..a4b9d49fa 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -163,9 +163,7 @@ def set_default_x(self, x: Tensor) -> "NeuralPosterior": Returns: `NeuralPosterior` that will use a default `x` when not explicitly passed. """ - self._x = process_x( - x, x_event_shape=None, allow_iid_x=self.potential_fn.allow_iid_x - ).to(self._device) + self._x = process_x(x, x_event_shape=None).to(self._device) self._map = None return self @@ -173,9 +171,7 @@ def _x_else_default_x(self, x: Optional[Array]) -> Tensor: if x is not None: # New x, reset posterior sampler. self._posterior_sampler = None - return process_x( - x, x_event_shape=None, allow_iid_x=self.potential_fn.allow_iid_x - ) + return process_x(x, x_event_shape=None) elif self.default_x is None: raise ValueError( "Context `x` needed when a default has not been set." diff --git a/sbi/inference/posteriors/ensemble_posterior.py b/sbi/inference/posteriors/ensemble_posterior.py index 53b2bdfaf..8270ebf59 100644 --- a/sbi/inference/posteriors/ensemble_posterior.py +++ b/sbi/inference/posteriors/ensemble_posterior.py @@ -265,9 +265,7 @@ def set_default_x(self, x: Tensor) -> "NeuralPosterior": `EnsemblePosterior` that will use a default `x` when not explicitly passed. """ - self._x = process_x( - x, x_event_shape=None, allow_iid_x=self.potential_fn.allow_iid_x - ).to(self._device) + self._x = process_x(x, x_event_shape=None).to(self._device) for posterior in self.posteriors: posterior.set_default_x(x) @@ -433,7 +431,7 @@ def allow_iid_x(self) -> bool: def set_x(self, x_o: Optional[Tensor]): """Check the shape of the observed data and, if valid, set it.""" if x_o is not None: - x_o = process_x(x_o, allow_iid_x=self.allow_iid_x).to( # type: ignore + x_o = process_x(x_o).to( # type: ignore self.device ) self._x_o = x_o diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index c7c829c5e..bfb7f6280 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -1,6 +1,7 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see +from copy import deepcopy from functools import partial from math import ceil from typing import Any, Callable, Dict, Optional, Union @@ -20,6 +21,7 @@ from sbi.inference.posteriors.base_posterior import NeuralPosterior from sbi.inference.potentials.base_potential import BasePotential +from sbi.neural_nets.density_estimators.shape_handling import reshape_to_batch_event from sbi.samplers.mcmc import ( IterateParameters, PyMCSampler, @@ -30,7 +32,6 @@ sir_init, ) from sbi.sbi_types import Shape, TorchTransform -from sbi.simulators.simutils import tqdm_joblib from sbi.utils.potentialutils import pyro_potential_wrapper, transformed_potential from sbi.utils.torchutils import ensure_theta_batched, tensor2numpy @@ -245,6 +246,7 @@ def sample( Returns: Samples from posterior. """ + self.potential_fn.set_x(self._x_else_default_x(x)) # Replace arguments that were not passed with their default. @@ -321,6 +323,7 @@ def sample( thin=thin, # type: ignore warmup_steps=warmup_steps, # type: ignore vectorized=(method == "slice_np_vectorized"), + interchangeable_chains=True, num_workers=num_workers, show_progress_bars=show_progress_bars, ) @@ -391,11 +394,82 @@ def sample_batched( Samples from the posteriors of shape (*sample_shape, B, *input_shape) """ - # See #1176 for a discussion on the implementation of batched sampling. - raise NotImplementedError( - "Batched sampling is not implemented for MCMC posterior. \ - Alternatively you can use `sample` in a loop \ - [posterior.sample(theta, x_o) for x_o in x]." + # Replace arguments that were not passed with their default. + method = self.method if method is None else method + thin = self.thin if thin is None else thin + warmup_steps = self.warmup_steps if warmup_steps is None else warmup_steps + num_chains = self.num_chains if num_chains is None else num_chains + init_strategy = self.init_strategy if init_strategy is None else init_strategy + num_workers = self.num_workers if num_workers is None else num_workers + mp_context = self.mp_context if mp_context is None else mp_context + init_strategy_parameters = ( + self.init_strategy_parameters + if init_strategy_parameters is None + else init_strategy_parameters + ) + + assert ( + method == "slice_np_vectorized" + ), "Batched sampling only supported for vectorized samplers!" + + # custom shape handling to make sure to match the batch size of x and theta + # without unnecessary combinations. + if len(x.shape) == 1: + x = x.unsqueeze(0) + batch_size = x.shape[0] + + x = reshape_to_batch_event(x, event_shape=x.shape[1:]) + + # For batched sampling, we want `num_chains` for each observation in the batch. + # Here we repeat the observations ABC -> AAABBBCCC, so that the chains are + # in the order of the observations. + x_ = x.repeat_interleave(num_chains, dim=0) + + self.potential_fn.set_x(x_, x_is_iid=False) + self.potential_ = self._prepare_potential(method) # type: ignore + + # For each observation in the batch, we have num_chains independent chains. + num_chains_extended = batch_size * num_chains + init_strategy_parameters["num_return_samples"] = num_chains_extended + initial_params = self._get_initial_params_batched( + x, + init_strategy, # type: ignore + num_chains, # type: ignore + num_workers, + show_progress_bars, + **init_strategy_parameters, + ) + # We need num_samples from each posterior in the batch + num_samples = torch.Size(sample_shape).numel() * batch_size + + with torch.set_grad_enabled(False): + transformed_samples = self._slice_np_mcmc( + num_samples=num_samples, + potential_function=self.potential_, + initial_params=initial_params, + thin=thin, # type: ignore + warmup_steps=warmup_steps, # type: ignore + vectorized=(method == "slice_np_vectorized"), + interchangeable_chains=False, + num_workers=num_workers, + show_progress_bars=show_progress_bars, + ) + + samples = self.theta_transform.inv(transformed_samples) + sample_shape_len = len(sample_shape) + # The MCMC sampler returns the samples per chain, of shape + # (num_samples, num_chains_extended, *input_shape). We return the samples as ` + # (*sample_shape, x_batch_size, *input_shape). This means we want to combine + # all the chains that belong to the same x. However, using + # samples.reshape(*sample_shape,batch_size,-1) does not combine the samples in + # the right order, since this mixes samples that belong to different `x`. + # This is a workaround to reshape the samples in the right order. + return samples.reshape((batch_size, *sample_shape, -1)).permute( # type: ignore + tuple(range(1, sample_shape_len + 1)) + + ( + 0, + -1, + ) ) def _build_mcmc_init_fn( @@ -459,7 +533,7 @@ def _get_initial_params( ) -> Tensor: """Return initial parameters for MCMC obtained with given init strategy. - Parallelizes across CPU cores only for SIR. + Parallelizes across CPU cores only for resample and SIR. Args: init_strategy: Specifies the initialization method. Either of @@ -491,25 +565,95 @@ def seeded_init_fn(seed): seeds = torch.randint(high=2**31, size=(num_chains,)) # Generate initial params parallelized over num_workers. - with tqdm_joblib( + initial_params = list( tqdm( - range(num_chains), # type: ignore - disable=not show_progress_bars, - desc=f"""Generating {num_chains} MCMC inits with {num_workers} - workers.""", - total=num_chains, - ) - ): - initial_params = torch.cat( - Parallel(n_jobs=num_workers)( # pyright: ignore[reportArgumentType] + Parallel(return_as="generator", n_jobs=num_workers)( delayed(seeded_init_fn)(seed) for seed in seeds - ) + ), + total=len(seeds), + desc=f"""Generating {num_chains} MCMC inits with + {num_workers} workers.""", + disable=not show_progress_bars, ) + ) + initial_params = torch.cat(initial_params) # type: ignore else: initial_params = torch.cat( [init_fn() for _ in range(num_chains)] # type: ignore ) + return initial_params + + def _get_initial_params_batched( + self, + x: torch.Tensor, + init_strategy: str, + num_chains_per_x: int, + num_workers: int, + show_progress_bars: bool, + **kwargs, + ) -> Tensor: + """Return initial parameters for MCMC for a batch of `x`, obtained with given + init strategy. + + Parallelizes across CPU cores only for resample and SIR. + + Args: + x: Batch of observations to create different initial parameters for. + init_strategy: Specifies the initialization method. Either of + [`proposal`|`sir`|`resample`|`latest_sample`]. + num_chains_per_x: number of MCMC chains for each x, generates initial params + for each x + num_workers: number of CPU cores for parallization + show_progress_bars: whether to show progress bars for SIR init + kwargs: Passed on to `_build_mcmc_init_fn`. + + Returns: + Tensor: initial parameters, one for each chain + """ + + potential_ = deepcopy(self.potential_fn) + initial_params = [] + init_fn = self._build_mcmc_init_fn( + self.proposal, + potential_fn=potential_, + transform=self.theta_transform, + init_strategy=init_strategy, # type: ignore + **kwargs, + ) + for xi in x: + # Build init function + potential_.set_x(xi) + + # Parallelize inits for resampling or sir. + if num_workers > 1 and ( + init_strategy == "resample" or init_strategy == "sir" + ): + def seeded_init_fn(seed): + torch.manual_seed(seed) + return init_fn() + + seeds = torch.randint(high=2**31, size=(num_chains_per_x,)) + + # Generate initial params parallelized over num_workers. + initial_params = initial_params + list( + tqdm( + Parallel(return_as="generator", n_jobs=num_workers)( + delayed(seeded_init_fn)(seed) for seed in seeds + ), + total=len(seeds), + desc=f"""Generating {num_chains_per_x} MCMC inits with + {num_workers} workers.""", + disable=not show_progress_bars, + ) + ) + + else: + initial_params = initial_params + [ + init_fn() for _ in range(num_chains_per_x) + ] # type: ignore + + initial_params = torch.cat(initial_params) return initial_params def _slice_np_mcmc( @@ -520,6 +664,7 @@ def _slice_np_mcmc( thin: int, warmup_steps: int, vectorized: bool = False, + interchangeable_chains=True, num_workers: int = 1, init_width: Union[float, ndarray] = 0.01, show_progress_bars: bool = True, @@ -534,6 +679,8 @@ def _slice_np_mcmc( warmup_steps: Initial number of samples to discard. vectorized: Whether to use a vectorized implementation of the `SliceSampler`. + interchangeable_chains: Whether chains are interchangeable, i.e., whether + we can mix samples between chains. num_workers: Number of CPU cores to use. init_width: Inital width of brackets. show_progress_bars: Whether to show a progressbar during sampling; @@ -550,9 +697,14 @@ def _slice_np_mcmc( else: SliceSamplerMultiChain = SliceSamplerVectorized + def multi_obs_potential(params): + # Params are of shape (num_chains * num_obs, event). + all_potentials = potential_function(params) # Shape: (num_chains, num_obs) + return all_potentials.flatten() + posterior_sampler = SliceSamplerMultiChain( init_params=tensor2numpy(initial_params), - log_prob_fn=potential_function, + log_prob_fn=multi_obs_potential, num_chains=num_chains, thin=thin, verbose=show_progress_bars, @@ -572,8 +724,11 @@ def _slice_np_mcmc( # Save sample as potential next init (if init_strategy == 'latest_sample'). self._mcmc_init_params = samples[:, -1, :].reshape(num_chains, dim_samples) - # Collect samples from all chains. - samples = samples.reshape(-1, dim_samples)[:num_samples] + # Update: If chains are interchangeable, return concatenated samples. Otherwise + # return samples per chain. + if interchangeable_chains: + # Collect samples from all chains. + samples = samples.reshape(-1, dim_samples)[:num_samples] return samples.type(torch.float32).to(self._device) diff --git a/sbi/inference/potentials/base_potential.py b/sbi/inference/potentials/base_potential.py index 7bf914eb6..769031321 100644 --- a/sbi/inference/potentials/base_potential.py +++ b/sbi/inference/potentials/base_potential.py @@ -36,15 +36,22 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: raise NotImplementedError @property - @abstractmethod - def allow_iid_x(self) -> bool: - raise NotImplementedError + def x_is_iid(self) -> bool: + """If x has batch dimension greater than 1, whether to intepret the batch as iid + samples or batch of data points.""" + if self._x_is_iid is not None: + return self._x_is_iid + else: + raise ValueError( + "No observed data is available. Use `potential_fn.set_x(x_o)`." + ) - def set_x(self, x_o: Optional[Tensor]): + def set_x(self, x_o: Optional[Tensor], x_is_iid: Optional[bool] = True): """Check the shape of the observed data and, if valid, set it.""" if x_o is not None: - x_o = process_x(x_o, allow_iid_x=self.allow_iid_x).to(self.device) + x_o = process_x(x_o).to(self.device) self._x_o = x_o + self._x_is_iid = x_is_iid @property def x_o(self) -> Tensor: diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index a6efd99bc..c824a5dc5 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -54,8 +54,6 @@ def likelihood_estimator_based_potential( class LikelihoodBasedPotential(BasePotential): - allow_iid_x = True # type: ignore - def __init__( self, likelihood_estimator: ConditionalDensityEstimator, @@ -90,16 +88,30 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: Returns: The potential $\log(p(x_o|\theta)p(\theta))$. """ - - # Calculate likelihood over trials and in one batch. - log_likelihood_trial_sum = _log_likelihoods_over_trials( - x=self.x_o, - theta=theta.to(self.device), - estimator=self.likelihood_estimator, - track_gradients=track_gradients, - ) - - return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore + if self.x_is_iid: + # For each theta, calculate the likelihood sum over all x in batch. + log_likelihood_trial_sum = _log_likelihoods_over_trials( + x=self.x_o, + theta=theta.to(self.device), + estimator=self.likelihood_estimator, + track_gradients=track_gradients, + ) + return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore + else: + # Calculate likelihood for each (theta,x) pair separately + theta_batch_size = theta.shape[0] + x_batch_size = self.x_o.shape[0] + assert ( + theta_batch_size == x_batch_size + ), f"Batch size mismatch: {theta_batch_size} and {x_batch_size}.\ + When performing batched sampling for multiple `x`, the batch size of\ + `theta` must match the batch size of `x`." + x = self.x_o.unsqueeze(0) + with torch.set_grad_enabled(track_gradients): + log_likelihood_batches = self.likelihood_estimator.log_prob( + x, condition=theta + ) + return log_likelihood_batches + self.prior.log_prob(theta) # type: ignore def _log_likelihoods_over_trials( diff --git a/sbi/inference/potentials/posterior_based_potential.py b/sbi/inference/potentials/posterior_based_potential.py index f4cb64ab7..3dead6ade 100644 --- a/sbi/inference/potentials/posterior_based_potential.py +++ b/sbi/inference/potentials/posterior_based_potential.py @@ -58,8 +58,6 @@ def posterior_estimator_based_potential( class PosteriorBasedPotential(BasePotential): - allow_iid_x = False # type: ignore - def __init__( self, posterior_estimator: ConditionalDensityEstimator, @@ -84,6 +82,21 @@ def __init__( self.posterior_estimator = posterior_estimator self.posterior_estimator.eval() + def set_x(self, x_o: Optional[Tensor], x_is_iid: Optional[bool] = False): + """ + Check the shape of the observed data and, if valid, set it. + For posterior-based methods, `x_o` is not allowed to be iid, as we assume that + iid `x` is handled by a Permutation Invariant embedding net. + """ + if x_is_iid: + raise NotImplementedError( + "For NPE, iid `x` must be handled by a Permutation Invariant embedding \ + net. Therefore, the iid dimension of `x` is added to the event\ + dimension of `x`. Please set `x_is_iid=False`." + ) + else: + super().set_x(x_o, x_is_iid=False) + def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: r"""Returns the potential for posterior-based methods. @@ -101,28 +114,44 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: the potential or manually set self._x_o." ) - theta = ensure_theta_batched(torch.as_tensor(theta)).to(self.device) - with torch.set_grad_enabled(track_gradients): # Force probability to be zero outside prior support. in_prior_support = within_support(self.prior, theta) - - x = reshape_to_batch_event(self.x_o, event_shape=self.x_o.shape[1:]) - assert ( - x.shape[0] == 1 - ), f"`x` has batchsize {x.shape[0]}. Only `batchsize == 1` is supported." - theta = reshape_to_sample_batch_event( - theta, event_shape=theta.shape[1:], leading_is_sample=True + x = reshape_to_batch_event( + self.x_o, event_shape=self.posterior_estimator.condition_shape ) - # We assume that a single `x` is passed (i.e. batchsize==1), so we squeeze - # the batch dimension of the log-prob with `.squeeze(dim=1)`. - posterior_log_prob = self.posterior_estimator.log_prob( - theta, condition=x - ).squeeze(dim=1) + theta = ensure_theta_batched(torch.as_tensor(theta)).to(self.device) + theta_batch_size = theta.shape[0] + x_batch_size = x.shape[0] + assert ( + theta_batch_size == x_batch_size or x_batch_size == 1 + ), f"Batch size mismatch: {theta_batch_size} and {x_batch_size}.\ + When performing batched sampling for multiple `x`, the batch size of\ + `theta` must match the batch size of `x`." + + if x_batch_size == 1: + # If a single `x` is passed (i.e. batchsize==1), we squeeze + # the batch dimension of the log-prob with `.squeeze(dim=1)`. + theta = reshape_to_sample_batch_event( + theta, event_shape=theta.shape[1:], leading_is_sample=True + ) + + posterior_log_prob = self.posterior_estimator.log_prob( + theta, condition=x + ) + posterior_log_prob = posterior_log_prob.squeeze(1) + else: + # If multiple `x` are passed, we return the log-probs for each (x,theta) + # pair, and do not squeeze the batch dimension. + theta = theta.unsqueeze(0) + posterior_log_prob = self.posterior_estimator.log_prob( + theta, condition=x + ) posterior_log_prob = torch.where( in_prior_support, posterior_log_prob, torch.tensor(float("-inf"), dtype=torch.float32, device=self.device), ) + return posterior_log_prob diff --git a/sbi/inference/potentials/ratio_based_potential.py b/sbi/inference/potentials/ratio_based_potential.py index a246b225d..9a8e85ddb 100644 --- a/sbi/inference/potentials/ratio_based_potential.py +++ b/sbi/inference/potentials/ratio_based_potential.py @@ -47,8 +47,6 @@ def ratio_estimator_based_potential( class RatioBasedPotential(BasePotential): - allow_iid_x = True # type: ignore - def __init__( self, ratio_estimator: nn.Module, @@ -81,17 +79,31 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: Returns: The potential. """ - - # Calculate likelihood over trials and in one batch. - log_likelihood_trial_sum = _log_ratios_over_trials( - x=self.x_o, - theta=theta.to(self.device), - net=self.ratio_estimator, - track_gradients=track_gradients, - ) - - # Move to cpu for comparison with prior. - return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore + if self.x_is_iid: + # For each theta, calculate likelihood ratio sum over all x in batch. + log_ratio_trial_sum = _log_ratios_over_trials( + x=self.x_o, + theta=theta.to(self.device), + net=self.ratio_estimator, + track_gradients=track_gradients, + ) + + # Move to cpu for comparison with prior. + return log_ratio_trial_sum + self.prior.log_prob(theta) # type: ignore + else: + # Calculate likelihood ratio for each (theta,x) pair separately + + theta_batch_size = theta.shape[0] + x_batch_size = self.x_o.shape[0] + assert ( + theta_batch_size == x_batch_size + ), f"Batch size mismatch: {theta_batch_size} and {x_batch_size}.\ + When performing batched sampling for multiple `x`, the batch size of\ + `theta` must match the batch size of `x`." + with torch.set_grad_enabled(track_gradients): + log_ratio_batches = self.ratio_estimator(theta, self.x_o) + log_ratio_batches = log_ratio_batches.reshape(-1) + return log_ratio_batches + self.prior.log_prob(theta) # type: ignore def _log_ratios_over_trials( diff --git a/sbi/utils/conditional_density_utils.py b/sbi/utils/conditional_density_utils.py index 8f53889d3..2c875247e 100644 --- a/sbi/utils/conditional_density_utils.py +++ b/sbi/utils/conditional_density_utils.py @@ -276,7 +276,6 @@ def __init__( potential_fn: Callable, condition: Tensor, dims_to_sample: List[int], - allow_iid_x: bool = False, ): r""" Return conditional posterior log-probability or $-\infty$ if outside prior. @@ -292,7 +291,6 @@ def __init__( self.condition = condition self.dims_to_sample = dims_to_sample self.device = self.potential_fn.device - self.allow_iid_x = allow_iid_x def __call__( self, theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True @@ -323,11 +321,23 @@ def __call__( return self.potential_fn(theta_condition, track_gradients=track_gradients) - def set_x(self, x_o: Optional[Tensor]): + @property + def x_is_iid(self) -> bool: + """If x has batch dimension greater than 1, whether to intepret the batch as iid + samples or batch of data points.""" + if self._x_is_iid is not None: + return self._x_is_iid + else: + raise ValueError( + "No observed data is available. Use `potential_fn.set_x(x_o)`." + ) + + def set_x(self, x_o: Optional[Tensor], x_is_iid: Optional[bool] = True): """Check the shape of the observed data and, if valid, set it.""" if x_o is not None: - x_o = process_x(x_o, allow_iid_x=self.allow_iid_x).to(self.device) - self.potential_fn.set_x(x_o) + x_o = process_x(x_o).to(self.device) + self._x_is_iid = x_is_iid + self.potential_fn.set_x(x_o, x_is_iid=x_is_iid) @property def x_o(self) -> Tensor: diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index 7f3195cfa..bdc57e068 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -400,16 +400,17 @@ def nle_nre_apt_msg_on_invalid_x( ) -def warn_on_iid_x(num_trials): +def warn_on_batched_x(batch_size): """Warn if more than one x was passed.""" - if num_trials > 1: + if batch_size > 1: warnings.warn( - f"An x with a batch size of {num_trials} was passed. " - + """It will be interpreted as a batch of independent and identically - distributed data X={x_1, ..., x_n}, i.e., data generated based on the - same underlying (unknown) parameter. The resulting posterior will be with - respect to entire batch, i.e,. p(theta | X).""", + f"An x with a batch size of {batch_size} was passed. " + + """Unless you are using `sample_batched` or `log_prob_batched`, this will + be interpreted as a batch of independent and identically distributed data + X={x_1, ..., x_n}, i.e., data generated based on the same underlying + (unknown) parameter. The resulting posterior will be with respect to entire + batch, i.e,. p(theta | X).""", stacklevel=2, ) diff --git a/sbi/utils/user_input_checks.py b/sbi/utils/user_input_checks.py index 370a9376e..5b4da0923 100644 --- a/sbi/utils/user_input_checks.py +++ b/sbi/utils/user_input_checks.py @@ -12,7 +12,7 @@ from torch.distributions import Distribution, Uniform from sbi.sbi_types import Array -from sbi.utils.sbiutils import warn_on_iid_x, within_support +from sbi.utils.sbiutils import warn_on_batched_x, within_support from sbi.utils.torchutils import BoxUniform, atleast_2d from sbi.utils.user_input_checks_utils import ( CustomPriorWrapper, @@ -551,9 +551,7 @@ def batch_loop_simulator(theta: Tensor) -> Tensor: return batch_loop_simulator -def process_x( - x: Array, x_event_shape: Optional[torch.Size] = None, allow_iid_x: bool = False -) -> Tensor: +def process_x(x: Array, x_event_shape: Optional[torch.Size] = None) -> Tensor: """Return observed data adapted to match sbi's shape and type requirements. This means that `x` is returned with a `batch_dim`. @@ -565,7 +563,6 @@ def process_x( x_event_shape: Prescribed shape - either directly provided by the user at init or inferred by sbi by running a simulation and checking the output. Does not contain a batch dimension. - allow_iid_x: Whether multiple trials in x are allowed. Returns: x: Observed data with shape ready for usage in sbi. @@ -585,10 +582,7 @@ def process_x( x = x.unsqueeze(0) input_x_shape = x.shape - if not allow_iid_x: - check_for_possibly_batched_x_shape(input_x_shape) - else: - warn_on_iid_x(num_trials=input_x_shape[0]) + warn_on_batched_x(batch_size=input_x_shape[0]) if x_event_shape is not None: # Number of trials can change for every new x, but single trial x shape must diff --git a/tests/embedding_net_test.py b/tests/embedding_net_test.py index 402211161..6b3d1e271 100644 --- a/tests/embedding_net_test.py +++ b/tests/embedding_net_test.py @@ -75,9 +75,13 @@ def test_embedding_net_api( _ = posterior.potential(s) +@pytest.mark.parametrize("num_xo_batch", [1, 2]) @pytest.mark.parametrize("num_trials", [1, 2]) @pytest.mark.parametrize("num_dim", [1, 2]) -def test_embedding_api_with_multiple_trials(num_trials, num_dim): +@pytest.mark.parametrize("posterior_method", ["direct", "mcmc"]) +def test_embedding_api_with_multiple_trials( + num_xo_batch, num_trials, num_dim, posterior_method +): """Tests the API when using iid trial-based data.""" prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim)) @@ -87,7 +91,7 @@ def test_embedding_api_with_multiple_trials(num_trials, num_dim): # simulate iid x. iid_theta = theta.reshape(num_thetas, 1, num_dim).repeat(1, num_trials, 1) x = torch.randn_like(iid_theta) + iid_theta - x_o = zeros(1, num_trials, num_dim) + x_o = zeros(num_xo_batch, num_trials, num_dim) output_dim = 5 single_trial_net = FCEmbedding(input_dim=num_dim, output_dim=output_dim) @@ -101,10 +105,21 @@ def test_embedding_api_with_multiple_trials(num_trials, num_dim): _ = inference.append_simulations(theta, x).train(max_num_epochs=5) - posterior = inference.build_posterior().set_default_x(x_o) - - s = posterior.sample((1,)) - _ = posterior.potential(s) + if posterior_method == "direct": + posterior = inference.build_posterior().set_default_x(x_o) + elif posterior_method == "mcmc": + posterior = inference.build_posterior( + sample_with=posterior_method, + mcmc_method="slice_np_vectorized", + ).set_default_x(x_o) + if num_xo_batch == 1: + s = posterior.sample((1,), x=x_o) + _ = posterior.potential(s) + else: + s = posterior.sample_batched((1,), x=x_o).squeeze(0) + # potentials take `theta` as (batch_shape, event_shape), so squeeze sample_dim + s = s.squeeze(0) + _ = posterior.potential(s) @pytest.mark.parametrize("input_shape", [(32,), (32, 32), (32, 64)]) diff --git a/tests/posterior_nn_test.py b/tests/posterior_nn_test.py index 98694bc3f..7a98c1020 100644 --- a/tests/posterior_nn_test.py +++ b/tests/posterior_nn_test.py @@ -4,6 +4,7 @@ from __future__ import annotations import pytest +import torch from torch import eye, ones, zeros from torch.distributions import MultivariateNormal @@ -25,7 +26,7 @@ ( 0, 1, - pytest.param(2, marks=pytest.mark.xfail(raises=ValueError)), + pytest.param(2, marks=pytest.mark.xfail(raises=AssertionError)), ), ) def test_log_prob_with_different_x(snpe_method: type, x_o_batch_dim: bool): @@ -117,18 +118,11 @@ def test_batched_sample_log_prob_with_different_x( @pytest.mark.mcmc -@pytest.mark.parametrize( - "snlre_method", - [ - pytest.param(SNLE_A, marks=pytest.mark.xfail(raises=NotImplementedError)), - pytest.param(SNRE_A, marks=pytest.mark.xfail(raises=NotImplementedError)), - pytest.param(SNRE_B, marks=pytest.mark.xfail(raises=NotImplementedError)), - pytest.param(SNRE_C, marks=pytest.mark.xfail(raises=NotImplementedError)), - ], -) +@pytest.mark.parametrize("snlre_method", [SNLE_A, SNRE_A, SNRE_B, SNRE_C, SNPE_C]) @pytest.mark.parametrize("x_o_batch_dim", (0, 1, 2)) +@pytest.mark.parametrize("init_strategy", ["proposal", "resample"]) def test_batched_mcmc_sample_log_prob_with_different_x( - snlre_method: type, x_o_batch_dim: bool, mcmc_params_fast: dict + snlre_method: type, x_o_batch_dim: bool, mcmc_params_fast: dict, init_strategy: str ): num_dim = 2 num_simulations = 1000 @@ -144,13 +138,51 @@ def test_batched_mcmc_sample_log_prob_with_different_x( x_o = ones(num_dim) if x_o_batch_dim == 0 else ones(x_o_batch_dim, num_dim) posterior = inference.build_posterior( - mcmc_method="slice_np_vectorized", mcmc_parameters=mcmc_params_fast + sample_with="mcmc", + mcmc_method="slice_np_vectorized", + mcmc_parameters=mcmc_params_fast, ) - samples = posterior.sample_batched((10,), x_o) + samples = posterior.sample_batched( + (10,), + x_o, + init_strategy=init_strategy, + num_chains=2, + ) assert ( samples.shape == (10, x_o_batch_dim, num_dim) if x_o_batch_dim > 0 else (10, num_dim) ), "Sample shape wrong" + + if x_o_batch_dim > 1: + assert samples.shape[1] == x_o_batch_dim, "Batch dimension wrong" + inference = snlre_method(prior=prior) + _ = inference.append_simulations(theta, x).train() + posterior = inference.build_posterior( + sample_with="mcmc", + mcmc_method="slice_np_vectorized", + mcmc_parameters=mcmc_params_fast, + ) + + x_o = torch.stack([0.5 * ones(num_dim), -0.5 * ones(num_dim)], dim=0) + # test with multiple chains to test whether correct chains are concatenated. + samples = posterior.sample_batched((1000,), x_o, num_chains=2, warmup_steps=500) + + samples_separate1 = posterior.sample( + (1000,), x_o[0], num_chains=2, warmup_steps=500 + ) + samples_separate2 = posterior.sample( + (1000,), x_o[1], num_chains=2, warmup_steps=500 + ) + + # Check if means are approx. same + samples_m = torch.mean(samples, dim=0, dtype=torch.float32) + samples_separate1_m = torch.mean(samples_separate1, dim=0, dtype=torch.float32) + samples_separate2_m = torch.mean(samples_separate2, dim=0, dtype=torch.float32) + samples_sep_m = torch.stack([samples_separate1_m, samples_separate2_m], dim=0) + + assert torch.allclose( + samples_m, samples_sep_m, atol=0.2, rtol=0.2 + ), "Batched sampling is not consistent with separate sampling." diff --git a/tests/user_input_checks_test.py b/tests/user_input_checks_test.py index 7d0bfcdea..5d1135104 100644 --- a/tests/user_input_checks_test.py +++ b/tests/user_input_checks_test.py @@ -181,19 +181,17 @@ def test_process_prior(prior): @pytest.mark.parametrize( - "x, x_shape, allow_iid", + "x, x_shape", ( - (ones(3), torch.Size([3]), False), - (ones(1, 3), torch.Size([3]), False), - (ones(10, 3), torch.Size([10, 3]), False), # 2D data / iid SNPE - pytest.param( - ones(10, 3), None, False, marks=pytest.mark.xfail - ), # 2D data / iid SNPE without x_shape - (ones(10, 10), torch.Size([10]), True), # iid likelihood based + (ones(3), torch.Size([3])), + (ones(1, 3), torch.Size([3])), + (ones(10, 3), torch.Size([10, 3])), # 2D data / iid SNPE + pytest.param(ones(10, 3), None), # 2D data / iid SNPE without x_shape + (ones(10, 10), torch.Size([10])), # iid likelihood based ), ) -def test_process_x(x, x_shape, allow_iid): - process_x(x, x_shape, allow_iid_x=allow_iid) +def test_process_x(x, x_shape): + process_x(x, x_shape) @pytest.mark.parametrize(