Skip to content

Commit

Permalink
catpos
Browse files Browse the repository at this point in the history
  • Loading branch information
paddyroddy committed Oct 16, 2024
1 parent c9b4b63 commit 57f2820
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
6 changes: 5 additions & 1 deletion glass/points.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,11 @@ def positions_from_delta( # noqa: PLR0912, PLR0913, PLR0915
batch: int | None = 1_000_000,
rng: np.random.Generator | None = None,
) -> collections.abc.Generator[
tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.int_]]
tuple[
npt.NDArray[np.float64],
npt.NDArray[np.float64],
int | npt.NDArray[np.int_],
]
]:
"""
Generate positions tracing a density contrast.
Expand Down
34 changes: 25 additions & 9 deletions tests/test_points.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,30 @@
from __future__ import annotations

import typing

import numpy as np
import numpy.typing as npt

from glass.points import position_weights, positions_from_delta, uniform_positions

if typing.TYPE_CHECKING:
import collections.abc


def catpos(pos): # type: ignore[no-untyped-def]
def catpos( # type: ignore[no-untyped-def]
pos: collections.abc.Generator[
tuple[
npt.NDArray[np.float64],
npt.NDArray[np.float64],
int | npt.NDArray[np.int_],
]
],
):
lon, lat, cnt = [], [], 0 # type: ignore[var-annotated]
for lo, la, co in pos:
lon = np.concatenate([lon, lo]) # type: ignore[assignment]
lat = np.concatenate([lat, la]) # type: ignore[assignment]
cnt = cnt + co
cnt = cnt + co # type: ignore[assignment]
return lon, lat, cnt


Expand All @@ -20,7 +36,7 @@ def test_positions_from_delta() -> None:
bias = 0.8
vis = np.ones(12)

lon, lat, cnt = catpos(positions_from_delta(ngal, delta, bias, vis)) # type: ignore[no-untyped-call]
lon, lat, cnt = catpos(positions_from_delta(ngal, delta, bias, vis))

assert isinstance(cnt, int)
assert lon.shape == lat.shape == (cnt,)
Expand All @@ -32,7 +48,7 @@ def test_positions_from_delta() -> None:
bias = 0.8
vis = np.ones(12)

lon, lat, cnt = catpos(positions_from_delta(ngal, delta, bias, vis)) # type: ignore[no-untyped-call]
lon, lat, cnt = catpos(positions_from_delta(ngal, delta, bias, vis))

assert cnt.shape == (2,)
assert lon.shape == (cnt.sum(),)
Expand All @@ -45,7 +61,7 @@ def test_positions_from_delta() -> None:
bias = 0.8
vis = np.ones(12)

lon, lat, cnt = catpos(positions_from_delta(ngal, delta, bias, vis)) # type: ignore[no-untyped-call]
lon, lat, cnt = catpos(positions_from_delta(ngal, delta, bias, vis))

assert cnt.shape == (3, 2)
assert lon.shape == (cnt.sum(),)
Expand All @@ -58,7 +74,7 @@ def test_positions_from_delta() -> None:
bias = 0.8
vis = np.ones(12)

lon, lat, cnt = catpos(positions_from_delta(ngal, delta, bias, vis)) # type: ignore[no-untyped-call]
lon, lat, cnt = catpos(positions_from_delta(ngal, delta, bias, vis))

assert cnt.shape == (3, 2)
assert lon.shape == (cnt.sum(),)
Expand All @@ -70,7 +86,7 @@ def test_uniform_positions() -> None:

ngal = 1e-3

lon, lat, cnt = catpos(uniform_positions(ngal)) # type: ignore[no-untyped-call]
lon, lat, cnt = catpos(uniform_positions(ngal))

assert isinstance(cnt, int)
assert lon.shape == lat.shape == (cnt,)
Expand All @@ -79,7 +95,7 @@ def test_uniform_positions() -> None:

ngal = [1e-3, 2e-3, 3e-3] # type: ignore[assignment]

lon, lat, cnt = catpos(uniform_positions(ngal)) # type: ignore[no-untyped-call]
lon, lat, cnt = catpos(uniform_positions(ngal))

assert cnt.shape == (3,)
assert lon.shape == lat.shape == (cnt.sum(),)
Expand All @@ -88,7 +104,7 @@ def test_uniform_positions() -> None:

ngal = [[1e-3, 2e-3], [3e-3, 4e-3], [5e-3, 6e-3]] # type: ignore[assignment]

lon, lat, cnt = catpos(uniform_positions(ngal)) # type: ignore[no-untyped-call]
lon, lat, cnt = catpos(uniform_positions(ngal))

assert cnt.shape == (3, 2)
assert lon.shape == lat.shape == (cnt.sum(),)
Expand Down

0 comments on commit 57f2820

Please sign in to comment.