Skip to content

JuliaLinhart/benchmark_sbi

Repository files navigation

Simulation-based Inference Benchmark

Build Status Python 3.8+

Benchopt is a package to simplify and make more transparent and reproducible the comparisons of optimization algorithms. This benchmark is dedicated to simulation-based inference (SBI) algorithms. The goal of SBI is to approximate the posterior distribution of a stochastic model (or simulator):

q_{\phi}(\theta \mid x) \approx p(\theta \mid x) = \frac{p(x \mid \theta) p(\theta)}{p(x)}

where \theta denotes the model parameters and x is an observation. In SBI the likelihood p(x \mid \theta) is implicitly modeled by the stochastic simulator. Placing a prior p(\theta) over the simulator parameters, allows us to generate samples from the joint distribution p(\theta, x) = p(x \mid \theta) p(\theta) which can then be used to approximate the posterior distribution p(\theta \mid x), e.g. via the training of a deep generative model q_{\phi}(\theta \mid x).

In this benchmark we only consider amortized SBI algorithms that allow for inference for any new observation x, without simulating new data after the initial training phase.

Environment

CPU, Python 3.8 - 3.11

If a MacOS device with a M1 (ARM) processor is used, run the following before proceeding to the below installation instructions of benchopt:

conda install pyarrow

Installation

This benchmark can be run using the following commands:

pip install -U benchopt
git clone https://github.com/JuliaLinhart/benchmark_sbi
cd benchmark_sbi
benchopt install .
benchopt run .

Alternatively, options can be passed to benchopt <install/run> to restrict the installations/runs to some solvers or datasets:

benchopt <install/run> -s 'npe_sbi[flow=nsf]' -d 'slcp[train_size=4096'] --n-repetitions 3

Use benchopt run -h for more details about these options, or visit https://benchopt.github.io/api.html.

Results

Results are saved in the outputs/ folder, with a .html file that offers a visual interface showing convergence plots for the different datasets, solvers and metrics. They were obtained by running

benchopt run --n-repetitions 10 --max-runs 1000 --timeout 1000000000000000000

where the parameters max-runs and timeout are given high values to avoid premature stopping of the algorithms without convergence.

Contributing

Everyone is welcome to contribute by adding datasets, solvers (algorithms) or metrics.

  • Datasets represent different prior-simulator pairs that define a joint distribution p(\theta, x) = p(\theta) p(x \mid \theta). The data they are expected to return (Dataset.get_data) consist in a set of training parameters-observation pairs (\theta, x), a set of testing parameters-observation pairs (\theta, x) and and a set of reference posterior-observation pairs (p(\theta \mid x), x).

    To add a dataset, add a file in the datasets folder.

    It is possible to directly use a simulator from the sbibm package as showcased with the provided two_moons example dataset, or to implement a custom simulator from scratch as we did for the slcp example dataset.

  • Solvers represent different amortized SBI algorithms (NRE, NPE, FMPE, ...) or different implementations (sbi, lampe, ...) of such algorithms. They are initialized (Solver.set_objective) with the training pairs and the prior p(\theta). After training (Solver.run), they are expected to return (Solver.get_result) a pair of functions log_prob and sample that evaluate the posterior log-density \log q_{\phi}(\theta \mid x) and generate parameters \theta \sim q_{\phi}(\theta \mid x), respectively.

    To add a solver, add a file in the solvers folder.

  • Metrics evaluate the quality of the estimated posterior obtained from the solver. The main objective is the expected negative log-likelihood \mathbb{E}_{p(\theta, x)} [ - \log q_{\phi}(\theta \mid x) ] over the test set. Other metrics such as the C2ST and EMD scores are computed (Objective.compute) using the reference posteriors (if available).

    To add a metric, implement it in the benchmark_utils.metrics.py file.

About

Using BenchOpt to benchmark sbi-algorithms.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •