Skip to content

Commit

Permalink
998 abc methods for trial based data using statistical distances (#1104)
Browse files Browse the repository at this point in the history
* Adding ABC with statistical distances and adding Wasserstein distance based on regularized optimal transport

* Wasserstein and MMD integration for ABC

* Fixing Documentation and pyright conflicts

* Moving distances in ABC to new class

* Documentation and additional test for MMD

* Adding types

* Adding missing distance file

* Adding documentation for Distance class and renaming
  • Loading branch information
theogruner authored Jun 3, 2024
1 parent 1b268b8 commit 7900af0
Show file tree
Hide file tree
Showing 7 changed files with 597 additions and 113 deletions.
77 changes: 18 additions & 59 deletions sbi/inference/abc/abc_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
"""Base class for Approximate Bayesian Computation methods."""

import logging
from typing import Callable, Union
from typing import Callable, Dict, Optional, Union

import numpy as np
import torch
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from torch import Tensor

from sbi.inference.abc.distances import Distance
from sbi.simulators.simutils import simulate_in_batches


Expand All @@ -23,8 +23,11 @@ def __init__(
simulator: Callable,
prior,
distance: Union[str, Callable] = "l2",
requires_iid_data: Optional[bool] = None,
distance_kwargs: Optional[Dict] = None,
num_workers: int = 1,
simulation_batch_size: int = 1,
distance_batch_size: int = -1,
show_progress_bars: bool = True,
) -> None:
r"""Base class for Approximate Bayesian Computation methods.
Expand All @@ -39,12 +42,21 @@ def __init__(
object with `.log_prob()`and `.sample()` (for example, a PyTorch
distribution) can be used.
distance: Distance function to compare observed and simulated data. Can be
a custom callable function or one of `l1`, `l2`, `mse`.
a custom callable function or one of `l1`, `l2`, `mse`,
`mmd`, `wasserstein`.
requires_iid_data: Whether to allow conditioning on iid sampled data or not.
Typically, this information is inferred by the choice of the distance,
but in case a custom distance is used, this information is pivotal.
distance_kwargs: Configurations parameters for the distances. In particular
useful for the MMD and Wasserstein distance.
num_workers: Number of parallel workers to use for simulations.
simulation_batch_size: Number of parameter sets that the simulator
maps to data x at once. If None, we simulate all parameter sets at the
same time. If >= 1, the simulator has to process data of shape
(simulation_batch_size, parameter_dimension).
distance_batch_size: Number of simulations that the distance function
evaluates against the reference observations at once. If -1, we evaluate
all simulations at the same time.
show_progress_bars: Whether to show a progressbar during simulation and
sampling.
"""
Expand All @@ -57,7 +69,9 @@ def __init__(
self.x_shape = None

# Select distance function.
self.distance = self.get_distance_function(distance)
self.distance = Distance(
distance, requires_iid_data, distance_kwargs, batch_size=distance_batch_size
)

self._batched_simulator = lambda theta: simulate_in_batches(
simulator=self._simulator,
Expand All @@ -69,61 +83,6 @@ def __init__(

self.logger = logging.getLogger(__name__)

@staticmethod
def get_distance_function(distance_type: Union[str, Callable] = "l2") -> Callable:
"""Return distance function for given distance type.
Args:
distance_type: string indicating the distance type, e.g., 'l2', 'l1',
'mse'. Note that the returned distance function averages over the last
dimension, e.g., over the summary statistics.
Returns:
distance_fun: distance functions built from passe string. Returns
distance_type is callable.
"""

if isinstance(distance_type, Callable):
return distance_type

# Select distance function.
implemented_distances = ["l1", "l2", "mse"]
assert (
distance_type in implemented_distances
), f"{distance_type} must be one of {implemented_distances}."

def mse_distance(xo, x):
return torch.mean((xo - x) ** 2, dim=-1)

def l2_distance(xo, x):
return torch.norm((xo - x), dim=-1)

def l1_distance(xo, x):
return torch.mean(abs(xo - x), dim=-1)

distance_functions = {"mse": mse_distance, "l2": l2_distance, "l1": l1_distance}

try:
distance = distance_functions[distance_type]
except KeyError as exc:
raise KeyError(f"Distance {distance_type} not supported.") from exc

def distance_fun(observed_data: Tensor, simulated_data: Tensor) -> Tensor:
"""Return distance over batch dimension.
Args:
observed_data: Observed data, could be 1D.
simulated_data: Batch of simulated data, has batch dimension.
Returns:
Torch tensor with batch of distances.
"""
assert simulated_data.ndim == 2, "simulated data needs batch dimension"

return distance(observed_data, simulated_data)

return distance_fun

@staticmethod
def get_sass_transform(
theta: torch.Tensor,
Expand Down
136 changes: 136 additions & 0 deletions sbi/inference/abc/distances.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from functools import partial
from logging import warning
from typing import Callable, Dict, Optional, Union

import torch
from tqdm import tqdm

from sbi.utils.metrics import unbiased_mmd_squared, wasserstein_2_squared


class Distance:
def __init__(
self,
distance: Union[str, Callable] = "l2",
requires_iid_data: Optional[bool] = None,
distance_kwargs: Optional[Dict] = None,
batch_size=-1,
):
"""Distance class for ABC
Args:
distance: A distance function comparing the simulations with 'x_o'.
Implemented distances are the 'mse', 'l2', and 'l1' norm as pairwise
distances, or the 'wasserstein' and 'mmd' as statistical distances.
requires_iid_data: 'True' if the distance is a statistical distance.
Only needs to be specified if 'distance' is a custom distance
distance_kwargs: Arguments for the specific distance.
"""
self.batch_size = batch_size
self.distance_kwargs = distance_kwargs or {}
if isinstance(distance, Callable):
if requires_iid_data is None:
# By default, we assume that data should not come in batches
warning(
"Please specify if your the custom distance requires "
"iid data or is evaluated between single datapoints. "
"By default, we assume that `requires_iid_data=False`"
)
requires_iid_data = False
self.distance_fn = distance
self._requires_iid_data = requires_iid_data
else:
implemented_pairwise_distances = ["l1", "l2", "mse"]
implemented_statistical_distances = ["mmd", "wasserstein"]

assert (
distance
in implemented_pairwise_distances + implemented_statistical_distances
), f"{distance} must be one of "
f"{implemented_pairwise_distances + implemented_statistical_distances}."

self._requires_iid_data = distance in implemented_statistical_distances

distance_functions = {
"mse": mse_distance,
"l2": l2_distance,
"l1": l1_distance,
"mmd": partial(mmd, **self.distance_kwargs),
"wasserstein": partial(wasserstein, **self.distance_kwargs),
}
try:
self.distance_fn = distance_functions[distance]
except KeyError as exc:
raise KeyError(f"Distance {distance} not supported.") from exc

def __call__(self, x_o, x) -> torch.Tensor:
"""Distance evaluation between the reference data and the simulated data.
Args:
x_o: Reference data
x: Simulated data
"""
if self.requires_iid_data:
assert x.ndim >= 3, "simulated data needs batch dimension"
assert x_o.ndim + 1 == x.ndim
else:
assert x.ndim >= 2, "simulated data needs batch dimension"
if self.batch_size == -1:
return self.distance_fn(x_o, x)
else:
return self._batched_distance(x_o, x)

def _batched_distance(self, x_o, x):
"""Evaluate the distance is mini-batches.
Especially for statistical distances, batching over two empirical
datasets can lead to memory overflow. Batching can help to resolve
the memory problems.
Args:
x_o: Reference data
x: Simulated data
"""
num_batches = x.shape[0] // self.batch_size - 1
remaining = x.shape[0] % self.batch_size
if remaining == 0:
remaining = self.batch_size

distances = torch.empty(x.shape[0])
for i in tqdm(range(num_batches)):
distances[self.batch_size * i : (i + 1) * self.batch_size] = (
self.distance_fn(
x_o, x[self.batch_size * i : (i + 1) * self.batch_size]
)
)
if remaining > 0:
distances[-remaining:] = self.distance_fn(x_o, x[-remaining:])

return distances

@property
def requires_iid_data(self):
return self._requires_iid_data


def mse_distance(x_o, x):
return torch.mean((x_o - x) ** 2, dim=-1)


def l2_distance(x_o, x):
return torch.norm((x_o - x), dim=-1)


def l1_distance(x_o, x):
return torch.mean(abs(x_o - x), dim=-1)


def mmd(x_o, x, scale=None):
dist_fn = partial(unbiased_mmd_squared, scale=scale)
return torch.vmap(dist_fn, in_dims=(None, 0))(x_o, x)


def wasserstein(x_o, x, epsilon=1e-3, max_iter=1000, tol=1e-9):
batched_x_o = x_o.repeat((x.shape[0], *[1] * len(x_o.shape)))
return wasserstein_2_squared(
batched_x_o, x, epsilon=epsilon, max_iter=max_iter, tol=tol
)
45 changes: 37 additions & 8 deletions sbi/inference/abc/mcabc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ def __init__(
simulator: Callable,
prior,
distance: Union[str, Callable] = "l2",
requires_iid_data: Optional[None] = None,
distance_kwargs: Optional[Dict] = None,
num_workers: int = 1,
simulation_batch_size: int = 1,
distance_batch_size: int = -1,
show_progress_bars: bool = True,
):
r"""Monte-Carlo Approximate Bayesian Computation (Rejection ABC) [1].
Expand All @@ -41,22 +44,32 @@ def __init__(
object with `.log_prob()`and `.sample()` (for example, a PyTorch
distribution) can be used.
distance: Distance function to compare observed and simulated data. Can be
a custom function or one of `l1`, `l2`, `mse`.
a custom callable function or one of `l1`, `l2`, `mse`,
`mmd`, `wasserstein`.
requires_iid_data: Whether to allow conditioning on iid sampled data or not.
Typically, this information is inferred by the choice of the distance,
but in case a custom distance is used, this information is pivotal.
distance_kwargs: Configurations parameters for the distances. In particular
useful for the MMD and Wasserstein distance.
num_workers: Number of parallel workers to use for simulations.
simulation_batch_size: Number of parameter sets that the simulator
maps to data x at once. If None, we simulate all parameter sets at the
same time. If >= 1, the simulator has to process data of shape
(simulation_batch_size, parameter_dimension).
show_progress_bars: Whether to show a progressbar during simulation and
sampling.
distance_batch_size: Number of simulations that the distance function
evaluates against the reference observations at once. If -1, we evaluate
all simulations at the same time.
"""

super().__init__(
simulator=simulator,
prior=prior,
distance=distance,
requires_iid_data=requires_iid_data,
distance_kwargs=distance_kwargs,
num_workers=num_workers,
simulation_batch_size=simulation_batch_size,
distance_batch_size=distance_batch_size,
show_progress_bars=show_progress_bars,
)

Expand All @@ -73,6 +86,7 @@ def __call__(
kde: bool = False,
kde_kwargs: Optional[Dict[str, Any]] = None,
return_summary: bool = False,
num_iid_samples: int = 1,
) -> Union[Tuple[Tensor, dict], Tuple[KDEWrapper, dict], Tensor, KDEWrapper]:
r"""Run MCABC and return accepted parameters or KDE object fitted on them.
Expand Down Expand Up @@ -101,6 +115,10 @@ def __call__(
more details
return_summary: Whether to return the distances and data corresponding to
the accepted parameters.
num_iid_samples: Number of simulations per parameter. Choose
`num_iid_samples>1`, if you have chosen a statistical distance that
evaluates sets of simulations against a set of reference observations
instead of a single data-point comparison.
Returns:
theta (if kde False): accepted parameters
Expand Down Expand Up @@ -142,11 +160,22 @@ def simulator(theta):

# Simulate and calculate distances.
theta = self.prior.sample((num_simulations,))
x = simulator(theta)

# Infer shape of x to test and set x_o.
self.x_shape = x[0].shape
self.x_o = process_x(x_o, self.x_shape)
theta_repeat = theta.repeat_interleave(num_iid_samples, dim=0)
x = simulator(theta_repeat)
x = x.reshape((
num_simulations,
num_iid_samples,
-1,
)) # Dim(num_initial_pop, num_iid_samples, -1)

# Infer x shape to test and set x_o.
if not self.distance.requires_iid_data:
x = x.squeeze(1)
self.x_shape = x[0].shape
self.x_o = process_x(x_o, self.x_shape)
else:
self.x_shape = x[0, 0].shape
self.x_o = process_x(x_o, self.x_shape, allow_iid_x=True)

distances = self.distance(self.x_o, x)

Expand Down
Loading

0 comments on commit 7900af0

Please sign in to comment.