Diffusion meets (nested) sampling
A miniminal implementation of diffusion models in JAX (Flax). Tuned for usage in building emulators for scientific models, particularly where MCMC sampling is tractable and used.
Install fusions
and lsbi
from pypi
pip install lsbi fusions
create a 5D sampling problem then train a flow matched model to approximate the posterior
from fusions.cfm import CFM
from lsbi.model import MixtureModel
from anesthetic import MCMCSamples
import matplotlib.pyplot as plt
import numpy as np
dims = 5
Model = MixtureModel(
M=np.stack([np.eye(dims), -np.eye(dims)]),
C=np.eye(dims)*0.1,
)
data = Model.evidence().rvs()
diffusion = CFM(Model.prior())
# diffusion = CFM(dims)
diffusion.train(Model.posterior(data).rvs(1000))
a = MCMCSamples(Model.posterior(data).rvs(500)).plot_2d(np.arange(dims))
MCMCSamples(diffusion.rvs(500)).plot_2d(a)
plt.show()