Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Cumulative Distribution Function, Inverse CDF methods to Distributions #122

Closed
wants to merge 9 commits into from
22 changes: 22 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2309,6 +2309,28 @@ def test_variance_stddev(self):
self.assertEqual(pytorch_dist.variance, scipy_dist.var(), allow_inf=True, message=pytorch_dist)
self.assertEqual(pytorch_dist.stddev, scipy_dist.var() ** 0.5, message=pytorch_dist)

def test_cdf(self):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have an additional test that did not rely on scipy, e.g.

class TestDistributions(TestCase):
    def test_cdf_icdf(self):
        for Dist, params in EXAMPLES:
            for i, param in enumerate(params):
                dist = Dist(**param)
                samples = dist.sample(sample_shape=(20,))
                try:
                    cdf = dist.cdf(samples)
                    actual = dist.icdf(cdf)
                except NotImplementedError:
                    continue
                self.assertEqual(actual, samples, message='{} example {}/{}, icdf(cdf(x)) != x')

or you could get even fancier by using grad() like

x = dist.sample(sample_shape=(20,))
expected_pdf = dist.log_prob(x).exp()
actual_pdf = grad(dist.cdf(x).sum(), [x])[0]
self.assertEqual(actual_pdf, expected_pdf)

set_rng_seed(0) # see Note [Randomized statistical tests]
for pytorch_dist, scipy_dist in self.distribution_pairs:
samples = pytorch_dist.sample((5,))
try:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's safer to enclose as little as needed in a try-except. Could you refactor to

try:
    cdf = pytorch_dist.cdf(samples)
except NotImplementedError:
    continue
self.assertEqual(cdf, scipy_dist.cdf(samples), message=pytorch_dist)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes. I saw the discussion in TruncatedNormal. I will modify it accordingly.

self.assertEqual(pytorch_dist.cdf(samples),
scipy_dist.cdf(samples),
message=pytorch_dist)
except NotImplementedError:
pass

def test_icdf(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
for pytorch_dist, scipy_dist in self.distribution_pairs:
samples = Variable(torch.rand((5,) + pytorch_dist.batch_shape))
try:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, enclose as little as possible in try-except

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

self.assertEqual(pytorch_dist.icdf(samples),
scipy_dist.ppf(samples),
message=pytorch_dist)
except NotImplementedError:
pass


class TestTransforms(TestCase):
def setUp(self):
Expand Down
8 changes: 8 additions & 0 deletions torch/distributions/cauchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,13 @@ def log_prob(self, value):
self._validate_log_prob_arg(value)
return -math.log(math.pi) - self.scale.log() - (1 + ((value - self.loc) / self.scale)**2).log()

def cdf(self, value):
self._validate_log_prob_arg(value)
return torch.atan((value - self.loc) / self.scale) / math.pi + 0.5

def icdf(self, value):
self._validate_log_prob_arg(value)
return torch.tan(math.pi * (value - 0.5)) * self.scale + self.loc

def entropy(self):
return math.log(4 * math.pi) + self.scale.log()
20 changes: 20 additions & 0 deletions torch/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,26 @@ def log_prob(self, value):
"""
raise NotImplementedError

def cdf(self, value):
"""
Returns the cumulative density/mass function evaluated at
`value`.

Args:
value (Tensor or Variable):
"""
raise NotImplementedError

def icdf(self, value):
"""
Returns the inverse cumulative density/mass function evaluated at
`value`.

Args:
value (Tensor or Variable):
"""
raise NotImplementedError

def enumerate_support(self):
"""
Returns tensor containing all values supported by a discrete
Expand Down
8 changes: 8 additions & 0 deletions torch/distributions/exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,13 @@ def log_prob(self, value):
self._validate_log_prob_arg(value)
return self.rate.log() - self.rate * value

def cdf(self, value):
self._validate_log_prob_arg(value)
return 1 - torch.exp(-self.rate * value)

def icdf(self, value):
self._validate_log_prob_arg(value)
return -torch.log(1 - value) / self.rate

def entropy(self):
return 1.0 - torch.log(self.rate)
8 changes: 8 additions & 0 deletions torch/distributions/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,13 @@ def log_prob(self, value):
self._validate_log_prob_arg(value)
return -torch.log(2 * self.scale) - torch.abs(value - self.loc) / self.scale

def cdf(self, value):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No .icdf()?

Copy link
Author

@vishwakftw vishwakftw Feb 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Laplace's .cdf is a piecewise function. I was doubtful about adding an inverse, and later realized that the inverse could be piecewise as well. Will update this too..

self._validate_log_prob_arg(value)
term = torch.exp((value - self.loc) / self.scale)
result = value.new()
result[value < self.loc] = 0.5 * term
result[value >= self.loc] = 1 - 0.5 / term
return result

def entropy(self):
return 1 + torch.log(2 * self.scale)
8 changes: 8 additions & 0 deletions torch/distributions/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,13 @@ def log_prob(self, value):
log_scale = math.log(self.scale) if isinstance(self.scale, Number) else self.scale.log()
return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi))

def cdf(self, value):
self._validate_log_prob_arg(value)
return 0.5 * (1 + torch.erf((value - self.loc) * self.scale.reciprocal() / math.sqrt(2)))

def icdf(self, value):
self._validate_log_prob_arg(value)
return self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2)

def entropy(self):
return 0.5 + 0.5 * math.log(2 * math.pi) + torch.log(self.scale)
8 changes: 8 additions & 0 deletions torch/distributions/pareto.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,13 @@ def log_prob(self, value):
self._validate_log_prob_arg(value)
return torch.log(self.alpha / value) + self.alpha * (self.scale / value).log()

def cdf(self, value):
self._validate_log_prob_arg(value)
return 1 - (self.scale / value).pow(self.alpha)

def icdf(self, value):
self._validate_log_prob_arg(value)
return self.scale / (1 - value).pow(self.alpha.reciprocal())

def entropy(self):
return ((self.scale / self.alpha).log() + (1 + self.alpha.reciprocal()))