Skip to content

Commit

Permalink
Type count
Browse files Browse the repository at this point in the history
  • Loading branch information
paddyroddy committed Oct 17, 2024
1 parent f165dcb commit 3d86e19
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
10 changes: 5 additions & 5 deletions glass/galaxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@


def redshifts(
n: int | npt.NDArray[np.float64],
n: int | list[int] | list[list[int]],
w: RadialWindow,
*,
rng: np.random.Generator | None = None,
Expand Down Expand Up @@ -67,7 +67,7 @@ def redshifts(


def redshifts_from_nz(
count: int | npt.NDArray[np.float64],
count: int | list[int] | list[list[int]],
z: list[float] | npt.NDArray[np.float64],
nz: list[float] | npt.NDArray[np.float64],
*,
Expand Down Expand Up @@ -132,12 +132,12 @@ def redshifts_from_nz(
cdf /= cdf[-1]

# sample redshifts and store result
redshifts[total : total + count[k]] = np.interp( # type: ignore[index]
rng.uniform(0, 1, size=count[k]), # type: ignore[index]
redshifts[total : total + count[k]] = np.interp( # type: ignore[call-overload, index]
rng.uniform(0, 1, size=count[k]), # type: ignore[call-overload, index]
cdf,
z[k], # type: ignore[call-overload]
)
total += count[k] # type: ignore[index]
total += count[k] # type: ignore[call-overload, index]

assert total == redshifts.size # noqa: S101

Expand Down
10 changes: 5 additions & 5 deletions tests/test_galaxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_redshifts(mocker) -> None: # type: ignore[no-untyped-def]
assert z.max() <= 1.0

# sample redshifts (array)
z = redshifts([[1, 2], [3, 4]], w) # type: ignore[arg-type]
z = redshifts([[1, 2], [3, 4]], w)
assert z.shape == (10,)


Expand Down Expand Up @@ -51,7 +51,7 @@ def test_redshifts_from_nz(rng: np.random.Generator) -> None:

# case: no extra dimensions

count = 10
count: int | list[int] | list[list[int]] = 10
z = np.linspace(0, 1, 100)
nz = z * (1 - z)

Expand All @@ -62,7 +62,7 @@ def test_redshifts_from_nz(rng: np.random.Generator) -> None:

# case: extra dimensions from count

count = [10, 20, 30] # type: ignore[assignment]
count = [10, 20, 30]
z = np.linspace(0, 1, 100)
nz = z * (1 - z)

Expand All @@ -82,7 +82,7 @@ def test_redshifts_from_nz(rng: np.random.Generator) -> None:

# case: extra dimensions from count and nz

count = [[10], [20], [30]] # type: ignore[assignment]
count = [[10], [20], [30]]
z = np.linspace(0, 1, 100)
nz = [z * (1 - z), (z - 0.5) ** 2] # type: ignore[assignment]

Expand All @@ -92,7 +92,7 @@ def test_redshifts_from_nz(rng: np.random.Generator) -> None:

# case: incompatible input shapes

count = [10, 20, 30] # type: ignore[assignment]
count = [10, 20, 30]
z = np.linspace(0, 1, 100)
nz = [z * (1 - z), (z - 0.5) ** 2] # type: ignore[assignment]

Expand Down

0 comments on commit 3d86e19

Please sign in to comment.