Skip to content

Commit

Permalink
Mixture multivariate normal distribution (#8)
Browse files Browse the repository at this point in the history
* Added mixture normal distribution

* bump version to 0.3.0

* Added explicit tests for logpdf
  • Loading branch information
williamjameshandley authored Oct 20, 2023
1 parent 11bcc40 commit c165340
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ lsbi: Linear Simulation Based Inference
=======================================
:lsbi: Linear Simulation Based Inference
:Author: Will Handley
:Version: 0.2.0
:Version: 0.3.0
:Homepage: https://github.com/handley-lab/lsbi
:Documentation: http://lsbi.readthedocs.io/

Expand Down
2 changes: 1 addition & 1 deletion lsbi/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.2.0'
__version__ = '0.3.0'
48 changes: 48 additions & 0 deletions lsbi/stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Extensions to scipy.stats functions."""
import numpy as np
import scipy.stats


class mixture_multivariate_normal(object):
"""Mixture of multivariate normal distributions.
Implemented with the same style as scipy.stats.multivariate_normal
Parameters
----------
means : array_like, shape (n_components, n_features)
Mean of each component.
covs: array_like, shape (n_components, n_features, n_features)
Covariance matrix of each component.
logA: array_like, shape (n_components,)
Log of the mixing weights.
"""

def __init__(self, means, covs, logA):
self.means = np.array([np.atleast_1d(m) for m in means])
self.covs = np.array([np.atleast_2d(c) for c in covs])
self.logA = np.atleast_1d(logA)
self.choleskys = np.linalg.cholesky(self.covs)
self.invcovs = np.linalg.inv(self.covs)

def logpdf(self, x):
"""Log of the probability density function."""
process_quantiles = scipy.stats.multivariate_normal._process_quantiles
x = process_quantiles(x, self.means.shape[-1])
dx = self.means - x[..., None, :]
chi2 = np.einsum('...ij,ijk,...ik->...i', dx, self.invcovs, dx)
norm = -np.linalg.slogdet(2*np.pi*self.covs)[1]/2
logA = self.logA - scipy.special.logsumexp(self.logA)
return np.squeeze(scipy.special.logsumexp(norm-chi2/2+logA, axis=-1))

def rvs(self, size=1):
"""Random variates."""
size = np.atleast_1d(size)
p = np.exp(self.logA-self.logA.max())
p /= p.sum()
i = np.random.choice(len(p), size, p=p)
x = np.random.randn(*size, self.means.shape[-1])
return np.squeeze(self.means[i, ..., None]
+ self.choleskys[i] @ x[..., None])
49 changes: 49 additions & 0 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
from lsbi.stats import mixture_multivariate_normal
from numpy.testing import assert_allclose
import numpy as np
import scipy.stats
import scipy.special


@pytest.mark.parametrize("k", [1, 2, 5])
@pytest.mark.parametrize("d", [1, 2, 5])
def test_mixture_multivariate_normal(k, d):
N = 1000
means = np.random.randn(k, d)
covs = scipy.stats.wishart(scale=np.eye(d)).rvs(k)
if k == 1:
covs = np.array([covs])
logA = np.log(scipy.stats.dirichlet(np.ones(k)).rvs())[0] + 10
mixture = mixture_multivariate_normal(means, covs, logA)
logA -= scipy.special.logsumexp(logA)

samples_1, logpdfs_1 = [], []
mvns = [scipy.stats.multivariate_normal(means[i], covs[i])
for i in range(k)]
for _ in range(N):
i = np.random.choice(k, p=np.exp(logA))
x = mvns[i].rvs()
samples_1.append(x)
logpdf = scipy.special.logsumexp([mvns[j].logpdf(x) + logA[j]
for j in range(k)])
assert_allclose(logpdf, mixture.logpdf(x))
logpdfs_1.append(logpdf)
samples_1, logpdfs_1 = np.array(samples_1), np.array(logpdfs_1)

samples_2 = mixture.rvs(N)
logpdfs_2 = mixture.logpdf(samples_2)

for i in range(d):
if d == 1:
p = scipy.stats.kstest(samples_1, samples_2).pvalue
else:
p = scipy.stats.kstest(samples_1[:, i], samples_2[:, i]).pvalue
assert p > 1e-5

p = scipy.stats.kstest(logpdfs_1, logpdfs_2).pvalue
assert p > 1e-5

for shape in [(d,), (3, d), (3, 4, d)]:
x = np.random.rand(*shape)
assert mvns[0].logpdf(x).shape == mixture.logpdf(x).shape

0 comments on commit c165340

Please sign in to comment.