-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add (iid) truncated normal and also add grad to normal #542
base: main
Are you sure you want to change the base?
Changes from 18 commits
00caf45
999d57d
34d5cf2
c03dbe6
0eeec67
683ae55
9dadda4
e378924
27a0982
92a65d9
6a484b3
edec0d7
adac072
28461cf
3b911d7
50c2560
6cc6e6f
86f7ad9
364db84
5f8277f
b87545d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,89 @@ | ||
import numpy as np | ||
from scipy.special import erf | ||
from cuqi.distribution import Distribution | ||
from cuqi.distribution import Normal | ||
|
||
class TruncatedNormal(Distribution): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You made this really nice showcases in the PR @chaozg. I suggest you add one or two of them as "Examples" in the docstring! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a example in the docstring as suggested |
||
""" | ||
Truncated Normal probability distribution. Generates instance of cuqi.distribution.TruncatedNormal. | ||
Truncated Normal probability distribution. | ||
|
||
Generates instance of cuqi.distribution.TruncatedNormal. | ||
It allows the user to specify upper and lower bounds on random variables | ||
represented by a Normal distribution. This distribution is suitable for a | ||
small dimension setup (e.g. `dim`=3 or 4). Using TruncatedNormal | ||
Distribution with a larger dimension can lead to a high rejection rate when | ||
used within MCMC samplers. | ||
|
||
The variables of this distribution are iid. | ||
|
||
|
||
Parameters | ||
amal-ghamdi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
------------ | ||
mean: mean of distribution | ||
std: standard deviation | ||
a: lower bound of the distribution | ||
b: upper bound of the distribution | ||
mean : float or array_like of floats | ||
mean of distribution | ||
std : float or array_like of floats | ||
standard deviation | ||
low : float or array_like of floats | ||
lower bound of the distribution | ||
high : float or array_like of floats | ||
upper bound of the distribution | ||
|
||
Example | ||
----------- | ||
.. code-block:: python | ||
|
||
#Generate Normal with mean 0, standard deviation 1 and bounds [-2,2] | ||
p = cuqi.distribution.TruncatedNormal(mean=0, std=1, low=-2, high=2) | ||
samples = p.sample(5000) | ||
""" | ||
def __init__(self, mean=None, std=None, a=-np.Inf, b=np.Inf, is_symmetric=False, **kwargs): | ||
def __init__(self, mean=None, std=None, low=-np.Inf, high=np.Inf, is_symmetric=False, **kwargs): | ||
# Init from abstract distribution class | ||
super().__init__(is_symmetric=is_symmetric, **kwargs) | ||
|
||
# Init specific to this distribution | ||
self.mean = mean | ||
self.std = std | ||
self.a = a | ||
self.b = b | ||
self.low = low | ||
self.high = high | ||
|
||
# Init underlying normal distribution | ||
self._normal = Normal(self.mean, self.std) | ||
|
||
def logpdf(self, x): | ||
""" | ||
Computes the unnormalized logpdf at the given values of x. | ||
""" | ||
# the unnormalized logpdf | ||
# check if x falls in the range between np.array a and b | ||
if np.any(x < self.a) or np.any(x > self.b): | ||
if np.any(x < self.low) or np.any(x > self.high): | ||
return -np.Inf | ||
else: | ||
return np.sum(-np.log(self.std*np.sqrt(2*np.pi))-0.5*((x-self.mean)/self.std)**2) | ||
return self._normal.logpdf(x) | ||
|
||
def gradient(self, x): | ||
""" | ||
Computes the gradient of the unnormalized logpdf at the given values of x. | ||
""" | ||
# check if x falls in the range between np.array a and b | ||
if np.any(x < self.a) or np.any(x > self.b): | ||
if np.any(x < self.low) or np.any(x > self.high): | ||
return np.NaN*np.ones_like(x) | ||
else: | ||
return -(x-self.mean)/(self.std**2) | ||
return self._normal.gradient(x) | ||
|
||
def _sample(self,N=1, rng=None): | ||
def _sample(self, N=1, rng=None): | ||
""" | ||
Generates random samples from the distribution. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps for the direct sampling of the distribution we could use a simple "Rejection" sampling? If asked for "N" samples, we sample a Gaussian, say "N" times, remove out of bounds, sample another "N" times, remote out of bounds. If we have now "N" or more sample we return the first N, else we repeat until we get "N" samples. Could also be used to compare with the samplers you showed in the code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just added There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just a comment here, would it be possible just to pass the rng to |
||
""" | ||
raise NotImplementedError(f"sample is not implemented for {self.__class__.__name__}.") | ||
# FIXME: this implementation does not honor the rng | ||
max_iter = 1e9 # maximum number of trials to avoid infinite loop | ||
samples = [] | ||
for i in range(int(max_iter)): | ||
if len(samples) == N: | ||
break | ||
sample = self._normal.sample() | ||
if np.all(sample >= self.low) and np.all(sample <= self.high): | ||
samples.append(sample) | ||
# raise a error if the number of iterations exceeds max_iter | ||
if i == max_iter-1: | ||
raise RuntimeError("Failed to generate {} samples within {} iterations".format(N, max_iter)) | ||
return np.array(samples).T.reshape(-1,N) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,73 @@ def test_Normal_sample_regression(mean,var,expected): | |
target = np.array(expected).T | ||
assert np.allclose( samples.samples, target) | ||
|
||
@pytest.mark.parametrize("mean,std,points,expected",[ | ||
(0,1,[-1,0,1],[1,0,-1]), | ||
(np.array([0,0]),np.array([1,1]),[[-1,0],[0,0],[0,-1]], [[1,0],[0,0],[0,1]])]) | ||
def test_Normal_gradient(mean,std,points,expected): | ||
p = cuqi.distribution.Normal(mean,std) | ||
for point, grad in zip(points, expected): | ||
assert np.allclose(p.gradient(point), grad) | ||
|
||
@pytest.mark.parametrize("mean,std,low,high,points",[(0.0, | ||
1.0, | ||
-1.0, | ||
1.0, | ||
[-1.5, -0.5, 0.5, 1.5]), | ||
(np.array([0.0, 0.0]), | ||
np.array([1.0, 1.0]), | ||
np.array([-1.0, -1.0]), | ||
np.array([1.0, 1.0]), | ||
[np.array([-0.5, 0.0]), | ||
np.array([0.5, 0.0]), | ||
np.array([-2.0, 0.0]), | ||
np.array([2.0, 0.0])])]) | ||
def test_TruncatedNormal_logpdf(mean,std,low,high,points): | ||
x_trun = cuqi.distribution.TruncatedNormal(mean,std,low=low,high=high) | ||
x = cuqi.distribution.Normal(mean,std) | ||
for point in points: | ||
if np.all(point >= low) and np.all(point <= high): | ||
assert x_trun.logpdf(point) == approx(x.logpdf(point)) | ||
else: | ||
assert np.isinf(x_trun.logpdf(point)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a suggestion to use |
||
|
||
@pytest.mark.parametrize("mean,std,low,high,points",[(0.0, | ||
1.0, | ||
-1.0, | ||
1.0, | ||
[-1.5, -0.5, 0.5, 1.5]), | ||
(np.array([0.0, 0.0]), | ||
np.array([1.0, 1.0]), | ||
np.array([-1.0, -1.0]), | ||
np.array([1.0, 1.0]), | ||
[np.array([-0.5, 0.0]), | ||
np.array([0.5, 0.0]), | ||
np.array([-2.0, 0.0]), | ||
np.array([2.0, 0.0])])]) | ||
def test_TruncatedNormal_gradient(mean,std,low,high,points): | ||
x_trun = cuqi.distribution.TruncatedNormal(mean,std,low=low,high=high) | ||
x = cuqi.distribution.Normal(mean,std) | ||
for point in points: | ||
if np.all(point >= low) and np.all(point <= high): | ||
assert np.all(x_trun.gradient(point) == approx(x.gradient(point))) | ||
else: | ||
assert np.all(np.isnan(x_trun.gradient(point))) | ||
|
||
@pytest.mark.parametrize("mean,std,low,high",[(0.0, | ||
1.0, | ||
-1.0, | ||
1.0), | ||
(np.array([0.0, 0.0]), | ||
np.array([1.0, 1.0]), | ||
np.array([-1.0, -1.0]), | ||
np.array([1.0, 1.0]))]) | ||
def test_TruncatedNormal_sampling(mean,std,low,high): | ||
x = cuqi.distribution.TruncatedNormal(mean,std,low=low,high=high) | ||
samples = x.sample(10000).samples | ||
for i in range(samples.shape[1]): | ||
sample = samples[:,i] | ||
assert np.all(sample >= low) and np.all(sample <= high) | ||
|
||
def test_Gaussian_mean(): | ||
mean = np.array([0, 0]) | ||
std = np.array([1, 1]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a nice bonus that we have a gradient for normal now :). Just one suggestion to add two checks:
1- the geometry is not a geometry that needs a chain rule.
2- The distribution is used as prior, not likelihood because we do not account for the chain rule here.
For example, the gradient implementation of cmrf does these two checks