Skip to content

Commit

Permalink
gh-358: add static types support (#368)
Browse files Browse the repository at this point in the history
  • Loading branch information
paddyroddy authored Nov 7, 2024
1 parent 4137616 commit 74c5ac8
Show file tree
Hide file tree
Showing 26 changed files with 787 additions and 449 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ dist
.env
.coverage*
coverage*
.ipynb_checkpoints
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,6 @@ repos:
rev: v1.13.0
hooks:
- id: mypy
files: ^glass/
additional_dependencies:
- numpy
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

html_logo = "_static/logo.png"
html_favicon = "_static/favicon.ico"
html_css_files = [] # type: ignore[var-annotated]
html_css_files: list[str] = []


# -- Intersphinx -------------------------------------------------------------
Expand Down
6 changes: 3 additions & 3 deletions glass/core/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@


def nnls(
a: npt.ArrayLike,
b: npt.ArrayLike,
a: npt.NDArray[np.float64],
b: npt.NDArray[np.float64],
*,
tol: float = 0.0,
maxiter: int | None = None,
) -> npt.ArrayLike:
) -> npt.NDArray[np.float64]:
"""
Compute a non-negative least squares solution.
Expand Down
54 changes: 48 additions & 6 deletions glass/core/array.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,35 @@
"""Module for array utilities."""

from __future__ import annotations

import typing
from functools import partial

import numpy as np
import numpy.typing as npt

if typing.TYPE_CHECKING:
import collections.abc


def broadcast_first(*arrays): # type: ignore[no-untyped-def]
def broadcast_first(
*arrays: npt.NDArray[np.float64],
) -> tuple[npt.NDArray[np.float64], ...]:
"""Broadcast arrays, treating the first axis as common."""
arrays = tuple(np.moveaxis(a, 0, -1) if np.ndim(a) else a for a in arrays)
arrays = np.broadcast_arrays(*arrays)
return tuple(np.moveaxis(a, -1, 0) if np.ndim(a) else a for a in arrays)


def broadcast_leading_axes(*args): # type: ignore[no-untyped-def]
def broadcast_leading_axes(
*args: tuple[
float | npt.NDArray[np.float64],
int,
],
) -> tuple[
tuple[int, ...],
typing.Unpack[tuple[npt.NDArray[np.float64], ...]],
]:
"""
Broadcast all but the last N axes.
Expand Down Expand Up @@ -49,7 +66,15 @@ def broadcast_leading_axes(*args): # type: ignore[no-untyped-def]
return (dims, *arrs)


def ndinterp(x, xp, fp, axis=-1, left=None, right=None, period=None): # type: ignore[no-untyped-def] # noqa: PLR0913
def ndinterp( # noqa: PLR0913
x: float | npt.NDArray[np.float64],
xp: collections.abc.Sequence[float] | npt.NDArray[np.float64],
fp: collections.abc.Sequence[float] | npt.NDArray[np.float64],
axis: int = -1,
left: float | None = None,
right: float | None = None,
period: float | None = None,
) -> npt.NDArray[np.float64]:
"""Interpolate multi-dimensional array over axis."""
return np.apply_along_axis(
partial(np.interp, x, xp),
Expand All @@ -61,8 +86,16 @@ def ndinterp(x, xp, fp, axis=-1, left=None, right=None, period=None): # type: i
)


def trapz_product(f, *ff, axis=-1): # type: ignore[no-untyped-def]
def trapz_product(
f: tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]],
*ff: tuple[
npt.NDArray[np.float64],
npt.NDArray[np.float64],
],
axis: int = -1,
) -> npt.NDArray[np.float64]:
"""Trapezoidal rule for a product of functions."""
x: npt.NDArray[np.float64]
x, _ = f
for x_, _ in ff:
x = np.union1d(
Expand All @@ -72,10 +105,19 @@ def trapz_product(f, *ff, axis=-1): # type: ignore[no-untyped-def]
y = np.interp(x, *f)
for f_ in ff:
y *= np.interp(x, *f_)
return np.trapz(y, x, axis=axis) # type: ignore[attr-defined]
return np.trapz( # type: ignore[attr-defined, no-any-return]
y,
x,
axis=axis,
)


def cumtrapz(f, x, dtype=None, out=None): # type: ignore[no-untyped-def]
def cumtrapz(
f: npt.NDArray[np.int_] | npt.NDArray[np.float64],
x: npt.NDArray[np.int_] | npt.NDArray[np.float64],
dtype: npt.DTypeLike | None = None,
out: npt.NDArray[np.float64] | None = None,
) -> npt.NDArray[np.float64]:
"""Cumulative trapezoidal rule along last axis."""
if out is None:
out = np.empty_like(f, dtype=dtype)
Expand Down
2 changes: 1 addition & 1 deletion glass/ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""


def _extend_path(path, name) -> list: # type: ignore[no-untyped-def, type-arg]
def _extend_path(path: list[str], name: str) -> list[str]:
import os.path
from pkgutil import extend_path

Expand Down
98 changes: 51 additions & 47 deletions glass/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,32 +35,20 @@
import numpy.typing as npt
from gaussiancl import gaussiancl

# types
Size = typing.Optional[typing.Union[int, tuple[int, ...]]]
Iternorm = tuple[typing.Optional[int], npt.NDArray[typing.Any], npt.NDArray[typing.Any]]
ClTransform = typing.Union[
str,
typing.Callable[[npt.NDArray[typing.Any]], npt.NDArray[typing.Any]],
]
Cls = collections.abc.Sequence[
typing.Union[npt.NDArray[typing.Any], collections.abc.Sequence[float]]
typing.Union[npt.NDArray[np.float64], collections.abc.Sequence[float]]
]
Alms = npt.NDArray[typing.Any]


def iternorm(
k: int,
cov: collections.abc.Iterable[npt.NDArray[typing.Any]],
size: Size = None,
) -> collections.abc.Generator[Iternorm, None, None]:
cov: collections.abc.Iterable[npt.NDArray[np.float64]],
size: int | tuple[int, ...] = (),
) -> collections.abc.Generator[
tuple[int | None, npt.NDArray[np.float64], npt.NDArray[np.float64]]
]:
"""Return the vector a and variance sigma^2 for iterative normal sampling."""
n: tuple[int, ...]
if size is None:
n = ()
elif isinstance(size, int):
n = (size,)
else:
n = size
n = (size,) if isinstance(size, int) else size

m = np.zeros((*n, k, k))
a = np.zeros((*n, k))
Expand Down Expand Up @@ -115,27 +103,29 @@ def cls2cov(
nl: int,
nf: int,
nc: int,
) -> collections.abc.Generator[npt.NDArray[typing.Any], None, None]:
) -> collections.abc.Generator[npt.NDArray[np.float64]]:
"""Return array of cls as a covariance matrix for iterative sampling."""
cov = np.zeros((nl, nc + 1))
end = 0
for j in range(nf):
begin, end = end, end + j + 1
for i, cl in enumerate(cls[begin:end][: nc + 1]):
if cl is None:
cov[:, i] = 0 # type: ignore[unreachable]
else:
if i == 0 and np.any(np.less(cl, 0)):
msg = "negative values in cl"
raise ValueError(msg)
n = len(cl)
cov[:n, i] = cl
cov[n:, i] = 0
if i == 0 and np.any(np.less(cl, 0)):
msg = "negative values in cl"
raise ValueError(msg)
n = len(cl)
cov[:n, i] = cl
cov[n:, i] = 0
cov /= 2
yield cov


def multalm(alm: Alms, bl: npt.NDArray[typing.Any], *, inplace: bool = False) -> Alms:
def multalm(
alm: npt.NDArray[np.complex128],
bl: npt.NDArray[np.float64],
*,
inplace: bool = False,
) -> npt.NDArray[np.complex128]:
"""Multiply alm by bl."""
n = len(bl)
out = np.asanyarray(alm) if inplace else np.copy(alm)
Expand All @@ -144,11 +134,15 @@ def multalm(alm: Alms, bl: npt.NDArray[typing.Any], *, inplace: bool = False) ->
return out


def transform_cls(cls: Cls, tfm: ClTransform, pars: tuple[typing.Any, ...] = ()) -> Cls:
def transform_cls(
cls: Cls,
tfm: str | typing.Callable[[npt.NDArray[np.float64]], npt.NDArray[np.float64]],
pars: tuple[typing.Any, ...] = (),
) -> Cls:
"""Transform Cls to Gaussian Cls."""
gls = []
for cl in cls:
if cl is not None and len(cl) > 0: # type: ignore[redundant-expr]
if len(cl) > 0:
monopole = 0.0 if cl[0] == 0 else None
gl, info, _, _ = gaussiancl(cl, tfm, pars, monopole=monopole)
if info == 0:
Expand Down Expand Up @@ -194,7 +188,7 @@ def discretized_cls(

gls = []
for cl in cls:
if cl is not None and len(cl) > 0: # type: ignore[redundant-expr]
if len(cl) > 0:
if lmax is not None:
cl = cl[: lmax + 1] # noqa: PLW2901
if nside is not None:
Expand All @@ -218,7 +212,7 @@ def generate_gaussian(
*,
ncorr: int | None = None,
rng: np.random.Generator | None = None,
) -> collections.abc.Generator[npt.NDArray[typing.Any], None, None]:
) -> collections.abc.Generator[npt.NDArray[np.float64]]:
"""
Sample Gaussian random fields from Cls iteratively.
Expand Down Expand Up @@ -255,7 +249,7 @@ def generate_gaussian(
ncorr = ngrf - 1

# number of modes
n = max((len(gl) for gl in gls if gl is not None), default=0) # type: ignore[redundant-expr]
n = max((len(gl) for gl in gls), default=0)
if n == 0:
msg = "all gls are empty"
raise ValueError(msg)
Expand Down Expand Up @@ -304,7 +298,7 @@ def generate_lognormal(
*,
ncorr: int | None = None,
rng: np.random.Generator | None = None,
) -> collections.abc.Generator[npt.NDArray[typing.Any], None, None]:
) -> collections.abc.Generator[npt.NDArray[np.float64]]:
"""Sample lognormal random fields from Gaussian Cls iteratively."""
for i, m in enumerate(generate_gaussian(gls, nside, ncorr=ncorr, rng=rng)):
# compute the variance of the auto-correlation
Expand All @@ -326,7 +320,14 @@ def generate_lognormal(
yield m


def getcl(cls, i, j, lmax=None): # type: ignore[no-untyped-def]
def getcl(
cls: collections.abc.Sequence[
npt.NDArray[np.float64] | collections.abc.Sequence[float]
],
i: int,
j: int,
lmax: int | None = None,
) -> npt.NDArray[np.float64] | collections.abc.Sequence[float]:
"""
Return a specific angular power spectrum from an array.
Expand Down Expand Up @@ -356,12 +357,14 @@ def getcl(cls, i, j, lmax=None): # type: ignore[no-untyped-def]
return cl


def effective_cls( # type: ignore[no-untyped-def]
cls,
weights1,
weights2=None,
def effective_cls(
cls: collections.abc.Sequence[
npt.NDArray[np.float64] | collections.abc.Sequence[float]
],
weights1: npt.NDArray[np.float64],
weights2: npt.NDArray[np.float64] | None = None,
*,
lmax=None,
lmax: int | None = None,
) -> npt.NDArray[np.float64]:
r"""
Compute effective angular power spectra from weights.
Expand Down Expand Up @@ -411,10 +414,11 @@ def effective_cls( # type: ignore[no-untyped-def]

# get the iterator over leading weight axes
# auto-spectra do not repeat identical computations
if weights2 is weights1:
pairs = combinations_with_replacement(np.ndindex(shape1[1:]), 2)
else:
pairs = product(np.ndindex(shape1[1:]), np.ndindex(shape2[1:])) # type: ignore[assignment]
pairs = (
combinations_with_replacement(np.ndindex(shape1[1:]), 2)
if weights2 is weights1
else product(np.ndindex(shape1[1:]), np.ndindex(shape2[1:]))
)

# create the output array: axes for all input axes plus lmax+1
out = np.empty(shape1[1:] + shape2[1:] + (lmax + 1,))
Expand All @@ -427,7 +431,7 @@ def effective_cls( # type: ignore[no-untyped-def]
for j1, j2 in pairs:
w1, w2 = weights1[c + j1], weights2[c + j2]
cl = sum(
w1[i1] * w2[i2] * getcl(cls, i1, i2, lmax=lmax) # type: ignore[no-untyped-call]
w1[i1] * w2[i2] * getcl(cls, i1, i2, lmax=lmax)
for i1, i2 in np.ndindex(n, n)
)
out[j1 + j2] = cl
Expand Down
Loading

0 comments on commit 74c5ac8

Please sign in to comment.