Skip to content

Commit

Permalink
add bounds to gaussian_phz()
Browse files Browse the repository at this point in the history
  • Loading branch information
ntessore committed Jul 20, 2023
1 parent d2dc1e5 commit 260cc74
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 5 deletions.
31 changes: 26 additions & 5 deletions glass/galaxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def galaxy_shear(lon: np.ndarray, lat: np.ndarray, eps: np.ndarray,
return g


def gaussian_phz(z: ArrayLike, sigma_0: float | ArrayLike,
def gaussian_phz(z: ArrayLike, sigma_0: float | ArrayLike, *,
lower: ArrayLike | None = None,
upper: ArrayLike | None = None,
rng: np.random.Generator | None = None) -> np.ndarray:
r'''Photometric redshifts assuming a Gaussian error.
Expand All @@ -159,6 +161,8 @@ def gaussian_phz(z: ArrayLike, sigma_0: float | ArrayLike,
True redshifts.
sigma_0 : float or array_like
Redshift error in the tomographic binning at zero redshift.
lower, upper : float or array_like, optional
Bounds for the returned photometric redshifts.
rng : :class:`~numpy.random.Generator`, optional
Random number generator. If not given, a default RNG is used.
Expand All @@ -168,6 +172,12 @@ def gaussian_phz(z: ArrayLike, sigma_0: float | ArrayLike,
Photometric redshifts assuming Gaussian errors, of the same
shape as *z*.
Warnings
--------
The *lower* and *upper* bounds are implemented using plain rejection
sampling from the non-truncated normal distribution. If bounds are
used, they should always contain significant probability mass.
See Also
--------
glass.observations.tomo_nz_gausserr :
Expand All @@ -194,14 +204,25 @@ def gaussian_phz(z: ArrayLike, sigma_0: float | ArrayLike,

zphot = rng.normal(z, sigma)

print(zphot)

if lower is None:
lower = 0.
if upper is None:
upper = np.inf

if not np.all(lower < upper):
raise ValueError("requires lower < upper")

if not dims:
while zphot < 0:
while zphot < lower or zphot > upper:
zphot = rng.normal(z, sigma)
else:
z = np.broadcast_to(z, dims)
trunc = np.where(zphot < 0)[0]
trunc = np.where((zphot < lower) | (zphot > upper))[0]
while trunc.size:
zphot[trunc] = rng.normal(z[trunc], sigma[trunc])
trunc = trunc[zphot[trunc] < 0]
znew = rng.normal(z[trunc], sigma[trunc])
zphot[trunc] = znew
trunc = trunc[(znew < lower) | (znew > upper)]

return zphot
11 changes: 11 additions & 0 deletions glass/test/test_galaxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,17 @@ def test_gaussian_phz():
assert phz.shape == (100,)
assert np.all(phz >= 0)

# case: upper and lower bound

z = 1.
sigma_0 = np.ones(100)

phz = gaussian_phz(z, sigma_0, lower=0.5, upper=1.5)

assert phz.shape == (100,)
assert np.all(phz >= 0.5)
assert np.all(phz <= 1.5)

# test interface

# case: scalar redshift, scalar sigma_0
Expand Down

0 comments on commit 260cc74

Please sign in to comment.