-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement TruncatedDistribution #121
base: master
Are you sure you want to change the base?
Conversation
c5ef53e
to
952bf5d
Compare
Will the |
I just added minimal working example (+tests) to get TruncatedNormal working. With cdfs for all current distributions I think this PR would be too large. What do you think? |
I thought the same too. Maybe after this is merged, I can start working on the populating PR. Hope that is fine. |
You can cherry-pick the first commit to your branch and start working in parallel if you want? |
I was actually just thinking about this today and was exploring how disasterous it would be to try to implement a "generic" TruncatedDistribution, where we use inverse transform sampling to generate from it. Some plots in this gist: https://gist.github.com/tbrx/18e7579d9b7ff7c2a84c17c300555fc1 Basically, it's pretty bad numerically once you are more than four standard deviations away from the mean, on a Gaussian, and falls apart entirely a little past five. This doesn't give me high hopes for e.g. Gamma… I looked at the Scipy code this morning, and it actually appears to use inverse transform sampling for truncated normals. Higher-precision floating point though means that they can get quite far away from the mean before this is an issue. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks clean but we need a safer way to do .new()
test/test_distributions.py
Outdated
set_rng_seed(0) # see Note [Randomized statistical tests] | ||
for pytorch_dist, scipy_dist in self.distribution_pairs: | ||
samples = pytorch_dist.sample((5,)) | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's safest to enclose as little as possible in a try-except:
try:
pytorch_cdf = pytorch_dist.cdf(samples)
except NotImplementedError:
pass
self.assertEqual(pytorch_cdf, scipy_dist.cdf(samples), message=pytorch_dist)
test/test_distributions.py
Outdated
set_rng_seed(0) # see Note [Randomized statistical tests] | ||
for pytorch_dist, scipy_dist in self.distribution_pairs: | ||
samples = Variable(torch.rand((5,) + pytorch_dist.batch_shape)) | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto, enclose as little as possible
super(TruncatedDistribution, self).__init__(*args, **kwargs) | ||
self.base_dist = base_distribution | ||
self.lower_bound, self.upper_bound, _ = broadcast_all(lower_bound, upper_bound, | ||
getattr(self.base_dist, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really dangerous. Why do we need to broadcast? Can we simply set
self.lower_bound = lower_bound
self.upper_bound = upper_bound
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking about supporting batched bounds while writing that part, but I gave up that idea later & forgot to change it.
is a generic sampler which is not the most efficient or accurate around tails of base distribution. | ||
""" | ||
shape = shape = self._extended_shape(sample_shape) | ||
u = getattr(self.base_dist, list(self.base_dist.params.keys())[0]).new(shape).uniform_() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks dangerous. I wish we had a .new()
method to create a correctly-placed tensor from given distribution.
@apaszke Is there an established pattern to do this? Can we define a .new_tensor()
method or something? This has been coming up often. Some of our distributions define a private ._new()
but we haven't exposed this as a general interface.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems simple and safe to define a method-as-property like
class Distribution(object):
@property
def new_tensor(self):
raise NotImplementedError
class Normal(Distribution):
@property
def new_tensor(self):
return self.loc.new
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have a common pattern except for new
on tensors, but we never needed anything else
Re: numerical stability, one option would be to use rejection sampling to draw samples and merely use the cdf derivative to compute reparameterized gradients: def sample(self):
...use rejection sampling...
def rsample(self):
x = self.sample() # detached
cdf = self.cdf(x)
pdf = self.log_prob(x).exp()
return x + (cdf.detach() - cdf) / pdf.detach() # or something like this... |
Do you think rejection sampling would be fast enough? Should we write low
level functions for it?
…On Fri, Feb 2, 2018, 2:07 PM Fritz Obermeyer ***@***.***> wrote:
Re: numerical stability, one option would be to use rejection sampling to
draw samples and merely use the cdf derivative to compute reparameterized
gradients:
def sample(self):
...use rejection sampling...
def rsample(self):
x = self.sample() # detached
cdf = self.cdf(x)
pdf = self.log_prob(x).exp()
return x + (cdf.detach() - cdf) / pdf.detach()
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<#121 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/ABCw1tQZo_ATLVbsWMUVQUPmhz8BtYjGks5tQ1slgaJpZM4R3fzz>
.
|
It would be cheap if we rejection sampled when |
@fritzo thanks, I'll take a look and try to come up with something. |
I think |
Oh, I completely agree that adding For the truncated normal, example it seems like there are only 58 distinct floating point values between 4.5 and infinity. It seems like if your bounds are within ±4 standard deviations though this would work pretty much fine! Maybe that is the more common case than sampling or evaluating tail probabilities anyway. |
0f104ef
to
a4cee87
Compare
Here are 2 gists for |
Do I understand correctly that the difficult case is when you're truncating e.g. a |
Here is an updated gist. Sampling from 4.5 sigma looks problematic, but sampling from 4 sigma looks okish |
Those plots look good! But I agree, I don't think inverse CDF sampling will work very well for a Normal(0,1) outside of the region [-4, 4] or maybe [-4.5, 4.5] in a pinch… There are algorithms for sampling from the tail of a gaussian (e.g. on [4, \infty) ) in chapter 9 of http://www.nrbook.com/devroye/. This doesn't help, though, with computing the |
I implemented sampling from tail algorithm, it's fast and looks good! Here's a gist. Precision problem with |
This is cool!! That would work really well for I was wondering if there was a way of maybe directly approximating EDIT: looking at the these numeric approximations, in particular the fourth, maybe it's possible get approximations for |
I'm very interested in this, I'm working on a similar thing. I'm calling it ConditionalExcessDistribution, but I'm really only looking at right censored (truncated) things at the moment. |
New gist time 😄 This time I also have sampling times. https://gist.github.com/alicanb/c9e6567b7c512140ed43916b4dd30106 . At this point I'm inclined towards having |
That sounds reasonable, implementing one new generic distribution and one specific special-case distribution. It even makes sense to send them in the same PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I let this drop, it would be nice to get it in before 0.4 release.
self.base_dist = base_distribution | ||
cdf_low, cdf_high = self.base_dist.cdf(self.lower_bound), self.base_dist.cdf(self.upper_bound) | ||
if sample_method in ['rejection', 'inversion']: | ||
self.sample = {'rejection': self._rejection_sample, 'inversion': self._inversion_sample}[sample_method] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This creates a circular reference and leaks memory.
def event_shape(self): | ||
return self.base_dist.event_shape | ||
|
||
def _inversion_sample(self, sample_shape=torch.Size()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm inclined to implement Inversion
and Rejection
as different classes because the interfaces differ: Inversion
allows reparametrization hence allows an .rsample()
whereas Rejection
is not reparametrizable and hence only implements .sample()
(It can be partially reparametrized via RSVI but that requires yet a different interface). Also, Pyro defines a different Rejector
class to do rejection sampling given a more general rejection criterion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm inclined to omit rejection sampling from pytorch actually. It's hard to make it work efficiently oob for a range of distributions. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, let's omit rejection sampling.
def __init__(self, loc, scale, lower_bound=-float('inf'), upper_bound=float('inf'), sample_method='robert', *args, **kwargs): | ||
super(TruncatedNormal, self).__init__(Normal(loc, scale), lower_bound, upper_bound, *args, **kwargs) | ||
if sample_method in {'exp', 'robert'}: | ||
self.sample = {'exp':self._exp_proposal, 'robert': self._robert_sample}[sample_method] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This creates a circular reference. It's better to simply define an if
statement in an .rsample()
method.
.jenkins/pytorch/test.sh
Outdated
@@ -23,6 +23,7 @@ if [[ "$BUILD_ENVIRONMENT" == *asan* ]]; then | |||
export ASAN_OPTIONS=detect_leaks=0:symbolize=1 | |||
export PYTORCH_TEST_WITH_ASAN=1 | |||
# TODO: Figure out how to avoid hard-coding these paths | |||
export ASAN_SYMBOLIZER_PATH=/usr/lib/llvm-5.0/bin/llvm-symbolizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like diff was tainted.
this PR adds:
.cdf()
andicdf()
methods forDistribution
and tests. (Populated only forNormal
for now)TruncatedDistribution
classTruncatedNormal
classcloses #78, touches #120
@tbrx I forgot you volunteered for this, want to work together? This is a very rough sketch at the moment. @fritzo it's not at the stage where I request a review, but comments welcome as always