diff --git a/README.rst b/README.rst index 7639005..7125e1f 100644 --- a/README.rst +++ b/README.rst @@ -3,7 +3,7 @@ lsbi: Linear Simulation Based Inference ======================================= :lsbi: Linear Simulation Based Inference :Author: Will Handley -:Version: 0.2.0 +:Version: 0.3.0 :Homepage: https://github.com/handley-lab/lsbi :Documentation: http://lsbi.readthedocs.io/ diff --git a/lsbi/_version.py b/lsbi/_version.py index 7fd229a..0404d81 100644 --- a/lsbi/_version.py +++ b/lsbi/_version.py @@ -1 +1 @@ -__version__ = '0.2.0' +__version__ = '0.3.0' diff --git a/lsbi/stats.py b/lsbi/stats.py new file mode 100644 index 0000000..7f027b6 --- /dev/null +++ b/lsbi/stats.py @@ -0,0 +1,48 @@ +"""Extensions to scipy.stats functions.""" +import numpy as np +import scipy.stats + + +class mixture_multivariate_normal(object): + """Mixture of multivariate normal distributions. + + Implemented with the same style as scipy.stats.multivariate_normal + + Parameters + ---------- + means : array_like, shape (n_components, n_features) + Mean of each component. + + covs: array_like, shape (n_components, n_features, n_features) + Covariance matrix of each component. + + logA: array_like, shape (n_components,) + Log of the mixing weights. + """ + + def __init__(self, means, covs, logA): + self.means = np.array([np.atleast_1d(m) for m in means]) + self.covs = np.array([np.atleast_2d(c) for c in covs]) + self.logA = np.atleast_1d(logA) + self.choleskys = np.linalg.cholesky(self.covs) + self.invcovs = np.linalg.inv(self.covs) + + def logpdf(self, x): + """Log of the probability density function.""" + process_quantiles = scipy.stats.multivariate_normal._process_quantiles + x = process_quantiles(x, self.means.shape[-1]) + dx = self.means - x[..., None, :] + chi2 = np.einsum('...ij,ijk,...ik->...i', dx, self.invcovs, dx) + norm = -np.linalg.slogdet(2*np.pi*self.covs)[1]/2 + logA = self.logA - scipy.special.logsumexp(self.logA) + return np.squeeze(scipy.special.logsumexp(norm-chi2/2+logA, axis=-1)) + + def rvs(self, size=1): + """Random variates.""" + size = np.atleast_1d(size) + p = np.exp(self.logA-self.logA.max()) + p /= p.sum() + i = np.random.choice(len(p), size, p=p) + x = np.random.randn(*size, self.means.shape[-1]) + return np.squeeze(self.means[i, ..., None] + + self.choleskys[i] @ x[..., None]) diff --git a/tests/test_stats.py b/tests/test_stats.py new file mode 100644 index 0000000..c77ae51 --- /dev/null +++ b/tests/test_stats.py @@ -0,0 +1,49 @@ +import pytest +from lsbi.stats import mixture_multivariate_normal +from numpy.testing import assert_allclose +import numpy as np +import scipy.stats +import scipy.special + + +@pytest.mark.parametrize("k", [1, 2, 5]) +@pytest.mark.parametrize("d", [1, 2, 5]) +def test_mixture_multivariate_normal(k, d): + N = 1000 + means = np.random.randn(k, d) + covs = scipy.stats.wishart(scale=np.eye(d)).rvs(k) + if k == 1: + covs = np.array([covs]) + logA = np.log(scipy.stats.dirichlet(np.ones(k)).rvs())[0] + 10 + mixture = mixture_multivariate_normal(means, covs, logA) + logA -= scipy.special.logsumexp(logA) + + samples_1, logpdfs_1 = [], [] + mvns = [scipy.stats.multivariate_normal(means[i], covs[i]) + for i in range(k)] + for _ in range(N): + i = np.random.choice(k, p=np.exp(logA)) + x = mvns[i].rvs() + samples_1.append(x) + logpdf = scipy.special.logsumexp([mvns[j].logpdf(x) + logA[j] + for j in range(k)]) + assert_allclose(logpdf, mixture.logpdf(x)) + logpdfs_1.append(logpdf) + samples_1, logpdfs_1 = np.array(samples_1), np.array(logpdfs_1) + + samples_2 = mixture.rvs(N) + logpdfs_2 = mixture.logpdf(samples_2) + + for i in range(d): + if d == 1: + p = scipy.stats.kstest(samples_1, samples_2).pvalue + else: + p = scipy.stats.kstest(samples_1[:, i], samples_2[:, i]).pvalue + assert p > 1e-5 + + p = scipy.stats.kstest(logpdfs_1, logpdfs_2).pvalue + assert p > 1e-5 + + for shape in [(d,), (3, d), (3, 4, d)]: + x = np.random.rand(*shape) + assert mvns[0].logpdf(x).shape == mixture.logpdf(x).shape