Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: reordered docs for solvers and kernels #759

Merged
merged 14 commits into from
Oct 2, 2024
Merged
28 changes: 25 additions & 3 deletions coreax/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import jax.numpy as jnp
import jax.tree_util as jtu
from jax import jit
from jaxtyping import Array, ArrayLike, Shaped
from jaxtyping import Array, Shaped
from typing_extensions import Self


Expand Down Expand Up @@ -134,8 +134,30 @@ def __getitem__(self, key) -> Self:
"""Support `Array` style indexing of `Data` objects."""
return jtu.tree_map(lambda x: x[key], self)

def __jax_array__(self) -> Shaped[ArrayLike, " n d"]:
"""Register `ArrayLike` behaviour - return for `jnp.asarray(Data(...))`."""
@overload
def __jax_array__(
self: "Data",
) -> Shaped[Array, " n d"]: ...

@overload
def __jax_array__( # pyright:ignore[reportOverlappingOverload]
self: "SupervisedData",
) -> Shaped[Array, " n d + p"]: ...
tp832944 marked this conversation as resolved.
Show resolved Hide resolved

def __jax_array__(
self: Union["Data", "SupervisedData"],
) -> Union[Shaped[Array, " n d"], Shaped[Array, " n d + p"]]:
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
"""
Return value of `jnp.asarray(Data(...))` and `jnp.asarray(SupervisedData(...))`.

.. note::

When ``self`` is a `SupervisedData` instance `jnp.asarray` will return
a single array where the ``supervision`` array has been
right-concatenated onto the``data`` array.
"""
if isinstance(self, SupervisedData):
return jnp.hstack((self.data, self.supervision))
return self.data

def __len__(self) -> int:
Expand Down
26 changes: 13 additions & 13 deletions coreax/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,23 @@
from coreax.kernels.util import median_heuristic

__all__ = [
"median_heuristic",
"ScalarValuedKernel",
"UniCompositeKernel",
"PowerKernel",
"DuoCompositeKernel",
"AdditiveKernel",
"ProductKernel",
"LinearKernel",
"PolynomialKernel",
"ExponentialKernel",
"LaplacianKernel",
"LinearKernel",
"LocallyPeriodicKernel",
"SquaredExponentialKernel",
"PCIMQKernel",
"PeriodicKernel",
"PolynomialKernel",
"RationalQuadraticKernel",
"SquaredExponentialKernel",
"AdditiveKernel",
"ProductKernel",
"SteinKernel",
"median_heuristic",
"DuoCompositeKernel",
"UniCompositeKernel",
"PowerKernel",
"PoissonKernel",
"MaternKernel",
"PeriodicKernel",
"LocallyPeriodicKernel",
"PoissonKernel",
"SteinKernel",
]
Loading