Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds Gelmen-Rubin diagnostic routine #22

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions bipymc/demc.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,19 @@ def gather_all_chains(self, collection_rank=0):
"""
return list(self.iter_all_chains(collection_rank))

def get_all_chains(self, collection_rank=0):
"""!
@brief Get samples of all chains
@return all_chains in n x m x k vector
m is number of chains, n is number of samples,
and k is the dimension of sample space.
"""
all_chain = []
for chain in self.iter_all_chains(collection_rank):
all_chain.append(chain.chain)
all_chain = np.asarray(all_chain)
return all_chain

def iter_local_chains(self):
"""!
@brief Local chain generator
Expand Down
78 changes: 78 additions & 0 deletions bipymc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,81 @@ def var_box(varepsilon, dim):
eps = np.random.uniform(low=-np.asarray(varepsilon) * np.ones(dim),
high=np.asarray(varepsilon) * np.ones(dim))
return eps


def gelman_rubin_sweep(x, n_samples_req=100):
"""
@brief Sweeps over the chain history and returns GR diagnostic as a function
of chain length for each parameter.
@param x An array of dimension m x n x k, where m is the number of chains,
n the number of samples, and k is the dimensionality of the param space.
@returns rhat_n array with shape (k, n - n_samples_req)
"""
gr_all = []
n_samples = x.shape[1]
if n_samples < n_samples_req:
raise ValueError(
'Gelman-Rubin diagnostic sweep requires atleast n_samples_req')
for n in range(1, n_samples):
if n >= n_samples_req:
x_partial = x[:, :n, :]
n_burn = int(x_partial.shape[1] / 2)
gr = gelman_rubin_partial(x_partial, n_burn)
gr_all.append(gr)
return np.asarray(gr_all).T


def gelman_rubin_partial(x, n_burn=0, return_var=False):
"""!
@brief Helper function to compute GR diagnostic, discarding
the first n_burn samples from each chain.
@param x An array of dimension m x n x k, where m is the number of chains,
n the number of samples, and k is the dimensionality of the param space.
@param n_burn number of samples to discard, only the final (n - nburn) samples
will be used in the computation of the GR diagnostic.
"""
n_samples = x.shape[1]
assert n_samples > n_burn
return gelman_rubin(x[:, n_burn:, :], return_var)


def gelman_rubin(x, return_var=False):
"""!
@brief Computes estimate of Gelman-Rubin diagnostic.
@param x An array of dimension m x n x k, where m is the number of chains,
n the number of samples, and k is the dimensionality of the param space.
@returns Rhat array float for dim with len == k

References
P. Brooks and A. Gelman. General Methods for Monitoring Convergence of Iterative
Simulations. Journal of Computational and Graphical Statistics. v7. n4. 1998.
"""
try:
# For single parameter chain
m, n = np.shape(x)
except ValueError:
# For iterate over each parameter
return [gelman_rubin(np.transpose(y)) for y in np.transpose(x)]

# Calculate between-chain variance
B_over_n = np.sum((np.mean(x, 1) - np.mean(x)) ** 2) / (m - 1)

# Calculate within-chain variances
W = np.sum(
[(x[i] - xbar) ** 2 for i,
xbar in enumerate(np.mean(x,
1))]) / (m * (n - 1))

# (over) estimate of variance
s2 = W * (n - 1) / n + B_over_n

if return_var:
return s2

# Pooled posterior variance estimate
V = s2 + B_over_n / m

# Calculate PSRF
R = V / W

return R
16 changes: 16 additions & 0 deletions tests/test_dblgauss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import time
#
from bipymc.utils import banana_rv, dblgauss_rv
from bipymc.util import gelman_rubin_sweep, gelman_rubin_partial
from bipymc.mc_plot import mc_plot
from bipymc.demc import DeMcMpi
from bipymc.dream import DreamMpi
Expand Down Expand Up @@ -73,6 +74,21 @@ def test_samplers(self):
else:
pass

if sampler_name == 'demc' or sampler_name == 'dream':
# plot Gelman-Rubin chain convergence plot
plt.figure()
all_chains = my_mcmc.get_all_chains(self.comm.rank)
gr_diagnostics = gelman_rubin_sweep(all_chains)
for d, grd in enumerate(gr_diagnostics):
plt.scatter(np.arange(len(grd)) + 100, grd, label='x'+str(d), s=2)
plt.legend()
plt.grid(ls='--', alpha=0.5)
plt.axhline(1.1, xmax=len(grd) + 100, ls='--', c='r')
plt.ylabel('Gelman-Rubin diagnostic')
plt.xlabel('N samples')
plt.savefig(str(sampler_name) + "_bimodal_gauss_gelman_rubin.png")
plt.close()

# plot mcmc samples
plt.figure()
plt.scatter(y1, y2, s=2, alpha=0.10)
Expand Down