-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Cumulative Distribution Function, Inverse CDF methods to Distributions #122
Conversation
1. Cauchy 2. Exponential 3. Laplace (Only CDF) 4. Pareto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! I only have minor comments about testing.
set_rng_seed(0) # see Note [Randomized statistical tests] | ||
for pytorch_dist, scipy_dist in self.distribution_pairs: | ||
samples = pytorch_dist.sample((5,)) | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's safer to enclose as little as needed in a try-except. Could you refactor to
try:
cdf = pytorch_dist.cdf(samples)
except NotImplementedError:
continue
self.assertEqual(cdf, scipy_dist.cdf(samples), message=pytorch_dist)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, yes. I saw the discussion in TruncatedNormal. I will modify it accordingly.
set_rng_seed(0) # see Note [Randomized statistical tests] | ||
for pytorch_dist, scipy_dist in self.distribution_pairs: | ||
samples = Variable(torch.rand((5,) + pytorch_dist.batch_shape)) | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto, enclose as little as possible in try-except
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure.
torch/distributions/laplace.py
Outdated
@@ -55,5 +55,13 @@ def log_prob(self, value): | |||
self._validate_log_prob_arg(value) | |||
return -torch.log(2 * self.scale) - torch.abs(value - self.loc) / self.scale | |||
|
|||
def cdf(self, value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No .icdf()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Laplace's .cdf is a piecewise function. I was doubtful about adding an inverse, and later realized that the inverse could be piecewise as well. Will update this too..
@@ -2309,6 +2309,28 @@ def test_variance_stddev(self): | |||
self.assertEqual(pytorch_dist.variance, scipy_dist.var(), allow_inf=True, message=pytorch_dist) | |||
self.assertEqual(pytorch_dist.stddev, scipy_dist.var() ** 0.5, message=pytorch_dist) | |||
|
|||
def test_cdf(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to have an additional test that did not rely on scipy, e.g.
class TestDistributions(TestCase):
def test_cdf_icdf(self):
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
dist = Dist(**param)
samples = dist.sample(sample_shape=(20,))
try:
cdf = dist.cdf(samples)
actual = dist.icdf(cdf)
except NotImplementedError:
continue
self.assertEqual(actual, samples, message='{} example {}/{}, icdf(cdf(x)) != x')
or you could get even fancier by using grad()
like
x = dist.sample(sample_shape=(20,))
expected_pdf = dist.log_prob(x).exp()
actual_pdf = grad(dist.cdf(x).sum(), [x])[0]
self.assertEqual(actual_pdf, expected_pdf)
|
Minor: 1. Convert Pareto and Gumbel to TransformedDistribution 2. Add .cdf and .icdf for Uniform 3. Temporarily remove .cdf from Laplace
Three tests fail:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Just one minor comment about eps
and tiny
, then it's ready to send upstream.
torch/distributions/gumbel.py
Outdated
z = (value - self.loc) / self.scale | ||
return -(self.scale.log() + z + torch.exp(-z)) | ||
base_dist = Uniform(torch.zeros_like(self.loc), 1) | ||
transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
torch/distributions/gumbel.py
Outdated
self._validate_log_prob_arg(value) | ||
z = (value - self.loc) / self.scale | ||
return -(self.scale.log() + z + torch.exp(-z)) | ||
base_dist = Uniform(torch.zeros_like(self.loc), 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should avoid infinity like
finfo = _finfo(self.loc)
base_dist = Uniform(self.loc.new([finfo.tiny]).expand_as(self.loc), 1 - finfo.eps)
Computes the inverse cumulative distribution function using transform(s) and computing | ||
the score of the base distribution | ||
""" | ||
self.base_dist._validate_log_prob_arg(value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the base_dist.icdf()
should call _validate_log_prob_arg(value)
internally on the following line. Do you think it's worth having the extra check here? I'd be happy either way.
@vishwakftw Let me know if you want any help with the failing tests. I might have time today or tomorrow to help debug. |
@fritzo I have fixed the shaping failures with the Gumbel distribution. There is one issue however. Some how the |
I also tried implementing def __init__(self, loc, scale):
self.loc, self.scale = broadcast_all(loc, scale)
finfo = _finfo(self.loc)
if isinstance(loc, Number) and isinstance(scale, Number):
base_dist = Uniform(finfo.eps - 1, 1)
else:
base_dist = Uniform(self.loc.new([finfo.eps]).expand_as(self.loc) - 1, 1)
transforms = [AbsTransform(), AffineTransform(loc=1, scale=-1), ExpTransform().inv,
AffineTransform(loc=self.loc, scale=self.scale)]
super(Laplace, self).__init__(base_dist, transforms) I believe the sampling requires a SignTransform = AbsTransform / identity_transform |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM Feel free to send upstream.
Great. I am sending this upstream now!! |
Work in parallel with PR #121.
cc: @fritzo @alicanb