-
Notifications
You must be signed in to change notification settings - Fork 757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rejection sampling variational inference #819
base: master
Are you sure you want to change the base?
Changes from all commits
4efb780
7e43d1b
d673763
7a5f90e
94a1bc3
a4c87cc
163414c
f162135
2f96076
2c1162b
ad25f6d
7e4a9ce
8dc4f4f
0aae8ed
70172fb
929e25c
95d9774
c212858
81637fb
7aec66c
dda7f26
2a4ccc8
8f69548
26f8ed8
c7f3ea1
435ec01
45b17b8
ed6e266
80cee16
ef45bc3
b94ef73
a136f9d
680894b
47ba81c
26f0c32
7b997e1
6108125
77e9a6c
3846fa6
23c33af
4c481a0
00c9325
40d3808
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -123,7 +123,6 @@ def run(self, variables=None, use_coordinator=True, *args, **kwargs): | |
Passed into `initialize`. | ||
""" | ||
self.initialize(*args, **kwargs) | ||
|
||
if variables is None: | ||
init = tf.global_variables_initializer() | ||
else: | ||
|
@@ -144,6 +143,7 @@ def run(self, variables=None, use_coordinator=True, *args, **kwargs): | |
|
||
for _ in range(self.n_iter): | ||
info_dict = self.update() | ||
print(info_dict) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rm? |
||
self.print_progress(info_dict) | ||
|
||
self.finalize() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,7 +32,7 @@ class KLpq(VariationalInference): | |
|
||
with respect to $\\theta$. | ||
|
||
In conditional inference, we infer $z` in $p(z, \\beta | ||
In conditional inference, we infer $z$ in $p(z, \\beta | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is unrelated to this PR. Can you make a new PR to fix this? |
||
\mid x)$ while fixing inference over $\\beta$ using another | ||
distribution $q(\\beta)$. During gradient calculation, instead | ||
of using the model's density | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,8 @@ | |
import tensorflow as tf | ||
|
||
from edward.inferences.variational_inference import VariationalInference | ||
from edward.models import RandomVariable | ||
from edward.models import RandomVariable, Gamma | ||
from edward.samplers import GammaRejectionSampler | ||
from edward.util import copy, get_descendants | ||
|
||
try: | ||
|
@@ -616,6 +617,62 @@ def build_loss_and_gradients(self, var_list): | |
return build_score_rb_loss_and_gradients(self, var_list) | ||
|
||
|
||
class RejectionSamplingKLqp(VariationalInference): | ||
|
||
""" | ||
""" | ||
|
||
def __init__(self, latent_vars=None, data=None, rejection_sampler_vars=None): | ||
"""Create an inference algorithm. | ||
|
||
# TODO: update me | ||
|
||
Args: | ||
latent_vars: list of RandomVariable or | ||
dict of RandomVariable to RandomVariable. | ||
Collection of random variables to perform inference on. If | ||
list, each random variable will be implictly optimized using a | ||
`Normal` random variable that is defined internally with a | ||
free parameter per location and scale and is initialized using | ||
standard normal draws. The random variables to approximate | ||
must be continuous. | ||
""" | ||
if isinstance(latent_vars, list): | ||
with tf.variable_scope(None, default_name="posterior"): | ||
latent_vars_dict = {} | ||
continuous = \ | ||
('01', 'nonnegative', 'simplex', 'real', 'multivariate_real') | ||
for z in latent_vars: | ||
if not hasattr(z, 'support') or z.support not in continuous: | ||
raise AttributeError( | ||
"Random variable {} is not continuous or a random " | ||
"variable with supported continuous support.".format(z)) | ||
batch_event_shape = z.batch_shape.concatenate(z.event_shape) | ||
loc = tf.Variable(tf.random_normal(batch_event_shape)) | ||
scale = tf.nn.softplus( | ||
tf.Variable(tf.random_normal(batch_event_shape))) | ||
latent_vars_dict[z] = Normal(loc=loc, scale=scale) | ||
latent_vars = latent_vars_dict | ||
del latent_vars_dict | ||
super(RejectionSamplingKLqp, self).__init__(latent_vars, data) | ||
self.rejection_sampler_vars = rejection_sampler_vars | ||
|
||
def initialize(self, n_samples=1, *args, **kwargs): | ||
"""Initialize inference algorithm. It initializes hyperparameters | ||
and builds ops for the algorithm's computation graph. | ||
|
||
Args: | ||
n_samples: int, optional. | ||
Number of samples from variational model for calculating | ||
stochastic gradients. | ||
""" | ||
self.n_samples = n_samples | ||
return super(RejectionSamplingKLqp, self).initialize(*args, **kwargs) | ||
|
||
def build_loss_and_gradients(self, var_list): | ||
return build_rejection_sampling_loss_and_gradients(self, var_list) | ||
|
||
|
||
def build_reparam_loss_and_gradients(inference, var_list): | ||
"""Build loss function. Its automatic differentiation | ||
is a stochastic gradient of | ||
|
@@ -1127,3 +1184,90 @@ def build_score_rb_loss_and_gradients(inference, var_list): | |
grads_vars.extend(model_vars) | ||
grads_and_vars = list(zip(grads, grads_vars)) | ||
return loss, grads_and_vars | ||
|
||
|
||
def build_rejection_sampling_loss_and_gradients(inference, var_list, epsilon=None): | ||
""" | ||
""" | ||
rej_samplers = { | ||
Gamma: GammaRejectionSampler | ||
} | ||
|
||
rep = [0.0] * inference.n_samples | ||
cor = [0.0] * inference.n_samples | ||
base_scope = tf.get_default_graph().unique_name("inference") + '/' | ||
for s in range(inference.n_samples): | ||
# Form dictionary in order to replace conditioning on prior or | ||
# observed variable with conditioning on a specific value. | ||
scope = base_scope + tf.get_default_graph().unique_name("sample") | ||
dict_swap = {} | ||
for x, qx in six.iteritems(inference.data): | ||
if isinstance(x, RandomVariable): | ||
if isinstance(qx, RandomVariable): | ||
qx_copy = copy(qx, scope=scope) | ||
dict_swap[x] = qx_copy.value() | ||
else: | ||
dict_swap[x] = qx | ||
|
||
p_log_prob = 0. | ||
q_log_prob = 0. | ||
r_log_prob = 0. | ||
|
||
for z, qz in six.iteritems(inference.latent_vars): | ||
# Copy q(z) to obtain new set of posterior samples. | ||
qz_copy = copy(qz, scope=scope) | ||
sampler = rej_samplers[qz_copy.__class__](density=qz) | ||
|
||
if epsilon is not None: # temporary | ||
pass | ||
else: | ||
dict_swap[z] = qz_copy.value() | ||
print('sample:', dict_swap[z]) | ||
epsilon = sampler.h_inverse(dict_swap[z]) | ||
|
||
dict_swap[z] = sampler.h(epsilon) | ||
q_log_prob += tf.reduce_sum( | ||
inference.scale.get(z, 1.0) * qz_copy.log_prob(dict_swap[z])) | ||
r_log_prob += -tf.log(tf.gradients(dict_swap[z], epsilon)) | ||
|
||
for z in six.iterkeys(inference.latent_vars): | ||
z_copy = copy(z, dict_swap, scope=scope) | ||
p_log_prob += tf.reduce_sum( | ||
inference.scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z])) | ||
|
||
for x in six.iterkeys(inference.data): | ||
if isinstance(x, RandomVariable): | ||
x_copy = copy(x, dict_swap, scope=scope) | ||
p_log_prob += tf.reduce_sum( | ||
inference.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) | ||
|
||
rep[s] = p_log_prob | ||
cor[s] = tf.stop_gradient(p_log_prob) * (q_log_prob - r_log_prob) | ||
|
||
rep = tf.reduce_mean(rep) | ||
cor = tf.reduce_mean(cor) | ||
q_entropy = tf.reduce_sum([ | ||
tf.reduce_sum(qz.entropy()) | ||
for z, qz in six.iteritems(inference.latent_vars)]) | ||
reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses()) | ||
|
||
loss = -(rep + q_entropy - reg_penalty) | ||
|
||
if inference.logging: | ||
tf.summary.scalar("loss/reparam_objective", rep, | ||
collections=[inference._summary_key]) | ||
tf.summary.scalar("loss/correction_term", cor, | ||
collections=[inference._summary_key]) | ||
tf.summary.scalar("loss/q_entropy", q_entropy, | ||
collections=[inference._summary_key]) | ||
tf.summary.scalar("loss/reg_penalty", reg_penalty, | ||
collections=[inference._summary_key]) | ||
|
||
g_rep = tf.gradients(rep, var_list) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain why you need the multiple gradient calls and not just one? This seems inefficient. |
||
g_cor = tf.gradients(cor, var_list) | ||
g_entropy = tf.gradients(q_entropy, var_list) | ||
|
||
grad_summands = zip(*[g_rep, g_cor, g_entropy]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we try dropping Expected behavior: pass at a higher tolerance, but not blow up. This is a possible culprit re: why gradients are exploding in running With a reasonably small step size, maybe 100 epochs. Worth keeping an eye on
Print all the gradient terms from notebook as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "With a reasonably small step size, maybe 100 epochs." --> i.e. it should pass "with a reasonably small step size, and run for maybe 100 epochs." |
||
grads = [tf.reduce_sum(summand) for summand in grad_summands] | ||
grads_and_vars = list(zip(grads, var_list)) | ||
return loss, grads_and_vars |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
""" | ||
""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from edward.samplers.rejection import * | ||
|
||
from tensorflow.python.util.all_util import remove_undocumented | ||
|
||
_allowed_symbols = [ | ||
'GammaRejectionSampler', | ||
] | ||
|
||
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import math | ||
|
||
import tensorflow as tf | ||
|
||
|
||
class GammaRejectionSampler: | ||
|
||
# As implemented in https://github.com/blei-lab/ars-reparameterization/blob/master/gamma/demo.ipynb | ||
|
||
def __init__(self, density): | ||
self.alpha = density.parameters['concentration'] | ||
self.beta = density.parameters['rate'] | ||
|
||
def h(self, epsilon): | ||
a = self.alpha - (1. / 3) | ||
b = tf.sqrt(9 * self.alpha - 3) | ||
c = 1 + (epsilon / b) | ||
d = a * c**3 | ||
return d / self.beta | ||
|
||
def h_inverse(self, z): | ||
a = self.alpha - (1. / 3) | ||
b = tf.sqrt(9 * self.alpha - 3) | ||
c = self.beta * z / a | ||
d = c**(1 / 3) | ||
return b * (d - 1) | ||
|
||
@staticmethod | ||
def log_prob_s(epsilon): | ||
return -0.5 * (tf.log(2 * math.pi) + epsilon**2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add back newline? unrelated to PR