forked from mfouesneau/NUTS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
emcee_nuts.py
192 lines (148 loc) · 5.37 KB
/
emcee_nuts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
""" Implements a NUTS Sampler for emcee
http://dan.iel.fm/emcee/
"""
import numpy as np
from .nuts import nuts6
from .helpers import NutsSampler_fn_wrapper
from emcee.sampler import Sampler
__all__ = ['NUTSSampler', 'test_sampler']
class NUTSSampler(Sampler):
""" A sampler object mirroring emcee.sampler object definition"""
def __init__(self, dim, lnprobfn, gradfn=None, *args, **kwargs):
self.dim = dim
self.f = NutsSampler_fn_wrapper(lnprobfn, gradfn, *args, **kwargs)
self.lnprobfn = self.f.lnp_func
self.gradfn = self.f.gradlnp_func
self.reset()
@property
def random_state(self):
"""
The state of the internal random number generator. In practice, it's
the result of calling ``get_state()`` on a
``numpy.random.mtrand.RandomState`` object. You can try to set this
property but be warned that if you do this and it fails, it will do
so silently.
"""
pass
@random_state.setter # NOQA
def random_state(self, state):
"""
Try to set the state of the random number generator but fail silently
if it doesn't work. Don't say I didn't warn you...
"""
pass
@property
def flatlnprobability(self):
"""
A shortcut to return the equivalent of ``lnprobability`` but aligned
to ``flatchain`` rather than ``chain``.
"""
return self.lnprobability.flatten()
def get_lnprob(self, p):
"""Return the log-probability at the given position."""
return self.lnprobfn(p)
def get_gradlnprob(self, p, dx=1e-3, order=1):
"""Return the log-probability at the given position."""
return self.gradfn(p)
def reset(self):
"""
Clear ``chain``, ``lnprobability`` and the bookkeeping parameters.
"""
self._lnprob = []
self._chain = []
self._epsilon = 0.
@property
def iterations(self):
return len(self._lnprob)
def clear_chain(self):
"""An alias for :func:`reset` kept for backwards compatibility."""
return self.reset()
def _sample_fn(self, p, dx=1e-3, order=1):
""" proxy function for nuts6 """
lnprob = self.lnprobfn(p)
gradlnp = self.gradfn(p)
return(lnprob, gradlnp)
def sample(self, pos0, M, Madapt, delta=0.6, **kwargs):
""" Runs NUTS6 """
samples, lnprob, epsilon = nuts6(self._sample_fn, M, Madapt, pos0, delta)
self._chain = samples
self._lnprob = lnprob
self._epsilon = epsilon
return samples
def run_mcmc(self, pos0, M, Madapt, delta=0.6, **kwargs):
"""
Iterate :func:`sample` for ``N`` iterations and return the result.
:param pos0:
The initial position vector.
:param M:
The number of steps to run.
:param Madapt:
The number of steps to run during the burning period.
:param delta: (optional, default=0.6)
Initial step size.
:param kwargs: (optional)
Other parameters that are directly passed to :func:`sample`.
"""
print('Running HMC with dual averaging and trajectory length %0.2f...' % delta)
return self.sample(pos0, M, Madapt, delta, **kwargs)
class _function_wrapper(object):
"""
This is a hack to make the likelihood function pickleable when ``args``
are also included.
"""
def __init__(self, f, args):
self.f = f
self.args = args
def __call__(self, x):
try:
return self.f(x, *self.args)
except:
import traceback
print("NUTS: Exception while calling your likelihood function:")
print(" params:", x)
print(" args:", self.args)
print(" exception:")
traceback.print_exc()
raise
def test_sampler():
""" Example usage of NUTS_sampler: sampling a 2d highly correlated Gaussian distribution """
def correlated_normal(theta):
"""
Example of a target distribution that could be sampled from using NUTS.
(Although of course you could sample from it more efficiently)
Doesn't include the normalizing constant.
"""
# Precision matrix with covariance [1, 1.98; 1.98, 4].
# A = np.linalg.inv( cov )
A = np.asarray([[50.251256, -24.874372],
[-24.874372, 12.562814]])
grad = -np.dot(theta, A)
logp = 0.5 * np.dot(grad, theta.T)
return logp, grad
def lnprobfn(theta):
return correlated_normal(theta)[0]
def gradfn(theta):
return correlated_normal(theta)[1]
D = 2
M = 5000
Madapt = 5000
theta0 = np.random.normal(0, 1, D)
delta = 0.6
mean = np.zeros(2)
cov = np.asarray([[1, 1.98],
[1.98, 4]])
sampler = NUTSSampler(D, lnprobfn, gradfn)
samples = sampler.run_mcmc(theta0, M, Madapt, delta)
print('Percentiles')
print (np.percentile(samples, [16, 50, 84], axis=0))
print('Mean')
print (np.mean(samples, axis=0))
print('Stddev')
print (np.std(samples, axis=0))
samples = samples[1::10, :]
import pylab as plt
temp = np.random.multivariate_normal(mean, cov, size=500)
plt.plot(temp[:, 0], temp[:, 1], '.')
plt.plot(samples[:, 0], samples[:, 1], 'r+')
plt.show()
return sampler