Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Refactor MCMC to use BayesFlow primitives #779

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 87 additions & 43 deletions edward/inferences/metropolis_hastings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

try:
from edward.models import Uniform
from tensorflow.contrib.bayesflow.metropolis_hastings import evolve
except Exception as e:
raise ImportError("{0}. Your TensorFlow version is not supported.".format(e))

Expand Down Expand Up @@ -64,6 +65,19 @@ def __init__(self, latent_vars, proposal_vars, data=None):

def initialize(self, *args, **kwargs):
kwargs['auto_transform'] = False

# TODO In general, each latent variable has arbitrary shape and
# dtype. We cannot simply batch them into a single tf.Tensor with
# an extra dimension. How do we handle this with ``evolve``?
initial_sample = tf.stack([tf.gather(qz.params, 0)
for qz in six.itervalues(self.latent_vars)])
self._state = tf.Variable(initial_sample, trainable=False, name="state")
self._state_log_density = tf.Variable(
self._log_joint(initial_sample),
trainable=False, name="state_log_density")
self._log_accept_ratio = tf.Variable(
tf.zeros_like(self._state_log_density.initialized_value()),
trainable=False, name="log_accept_ratio")
return super(MetropolisHastings, self).initialize(*args, **kwargs)

def build_update(self):
Expand All @@ -80,9 +94,75 @@ def build_update(self):
The updates assume each Empirical random variable is directly
parameterized by `tf.Variable`s.
"""
old_sample = {z: tf.gather(qz.params, tf.maximum(self.t - 1, 0))
for z, qz in six.iteritems(self.latent_vars)}
old_sample = OrderedDict(old_sample)
old_state = self._state
forward_step = evolve(self._state,
self._state_log_density,
self._log_accept_ratio,
self._log_density,
self._proposal_fn,
n_steps=1)
assign_ops = [forward_step]

with tf.control_dependencies([forward_step]):
# Update Empirical random variables.
for state, qz in zip(tf.unstack(self._state),
six.itervalues(self.latent_vars)):
variable = qz.get_variables()[0]
assign_ops.append(tf.scatter_update(variable, self.t, state))

# Increment n_accept (if accepted).
# TODO old_state might always be same. It would be great if we
# could more naturally get the acceptance rate from ``evolve``.
is_proposal_accepted = tf.where(
tf.reduce_any(tf.not_equal(old_state, self._state)), 1, 0)
assign_ops.append(self.n_accept.assign_add(is_proposal_accepted))

return tf.group(*assign_ops)

def _log_joint(self, state):
"""Utility function to calculate model's log joint density,
log p(x, z), for inputs z (and fixed data x).
Args:
state: tf.Tensor.
"""
scope = self._scope + tf.get_default_graph().unique_name("sample")
# Form dictionary in order to replace conditioning on prior or
# observed variable with conditioning on a specific value.
# TODO verify ordering is preserved
dict_swap = {z: sample for z, sample in
zip(six.iterkeys(self.latent_vars), state)}
for x, qx in six.iteritems(self.data):
if isinstance(x, RandomVariable):
if isinstance(qx, RandomVariable):
qx_copy = copy(qx, scope=scope)
dict_swap[x] = qx_copy.value()
else:
dict_swap[x] = qx

log_joint = 0.0
for z in six.iterkeys(self.latent_vars):
z_copy = copy(z, dict_swap, scope=scope)
log_joint += tf.reduce_sum(z_copy.log_prob(dict_swap[z]))

for x in six.iterkeys(self.data):
if isinstance(x, RandomVariable):
x_copy = copy(x, dict_swap, scope=scope)
log_joint += tf.reduce_sum(x_copy.log_prob(dict_swap[x]))

return log_joint

def proposal_fn(state):
"""Utility function to propose new state,
znew ~ g(znew | zold) for inputs zold, and return the log density
ratio of log g(znew | zold) - log g(zold | znew).
Args:
state: tf.Tensor.
"""
# TODO verify ordering is preserved
old_sample = {z: sample for z, sample in
zip(six.iterkeys(self.latent_vars), state)}

# Form dictionary in order to replace conditioning on prior or
# observed variable with conditioning on a specific value.
Expand All @@ -99,7 +179,6 @@ def build_update(self):
dict_swap_old.update(old_sample)
base_scope = tf.get_default_graph().unique_name("inference") + '/'
scope_old = base_scope + 'old'
scope_new = base_scope + 'new'

# Draw proposed sample and calculate acceptance ratio.
new_sample = old_sample.copy() # copy to ensure same order
Expand All @@ -114,49 +193,14 @@ def build_update(self):

dict_swap_new = dict_swap.copy()
dict_swap_new.update(new_sample)
scope_new = base_scope + 'new'

for z, proposal_z in six.iteritems(self.proposal_vars):
# Build proposal g(zold | znew).
proposal_zold = copy(proposal_z, dict_swap_new, scope=scope_new)
# Increment ratio.
ratio -= tf.reduce_sum(proposal_zold.log_prob(dict_swap_old[z]))

for z in six.iterkeys(self.latent_vars):
# Build priors p(znew) and p(zold).
znew = copy(z, dict_swap_new, scope=scope_new)
zold = copy(z, dict_swap_old, scope=scope_old)
# Increment ratio.
ratio += tf.reduce_sum(znew.log_prob(dict_swap_new[z]))
ratio -= tf.reduce_sum(zold.log_prob(dict_swap_old[z]))

for x in six.iterkeys(self.data):
if isinstance(x, RandomVariable):
# Build likelihoods p(x | znew) and p(x | zold).
x_znew = copy(x, dict_swap_new, scope=scope_new)
x_zold = copy(x, dict_swap_old, scope=scope_old)
# Increment ratio.
ratio += tf.reduce_sum(x_znew.log_prob(dict_swap[x]))
ratio -= tf.reduce_sum(x_zold.log_prob(dict_swap[x]))

# Accept or reject sample.
u = Uniform(low=tf.constant(0.0, dtype=ratio.dtype),
high=tf.constant(1.0, dtype=ratio.dtype)).sample()
accept = tf.log(u) < ratio
sample_values = tf.cond(accept, lambda: list(six.itervalues(new_sample)),
lambda: list(six.itervalues(old_sample)))
if not isinstance(sample_values, list):
# `tf.cond` returns tf.Tensor if output is a list of size 1.
sample_values = [sample_values]

sample = {z: sample_value for z, sample_value in
zip(six.iterkeys(new_sample), sample_values)}

# Update Empirical random variables.
assign_ops = []
for z, qz in six.iteritems(self.latent_vars):
variable = qz.get_variables()[0]
assign_ops.append(tf.scatter_update(variable, self.t, sample[z]))

# Increment n_accept (if accepted).
assign_ops.append(self.n_accept.assign_add(tf.where(accept, 1, 0)))
return tf.group(*assign_ops)
# TODO verify ordering is preserved
new_sample = tf.stack(list(six.itervalues(new_sample)))
return (new_sample, ratio)