Skip to content

Commit bb6129b

Browse files
authored
MAINT: array_api_compat tweaks (#285)
1 parent 4425d14 commit bb6129b

File tree

3 files changed

+29
-20
lines changed

3 files changed

+29
-20
lines changed

src/array_api_extra/_lib/_utils/_compat.py

+3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
is_torch_namespace,
2424
is_writeable_array,
2525
size,
26+
to_device,
2627
)
2728
except ImportError:
2829
from array_api_compat import (
@@ -45,6 +46,7 @@
4546
is_torch_namespace,
4647
is_writeable_array,
4748
size,
49+
to_device,
4850
)
4951

5052
__all__ = [
@@ -67,4 +69,5 @@
6769
"is_torch_namespace",
6870
"is_writeable_array",
6971
"size",
72+
"to_device",
7073
]

src/array_api_extra/_lib/_utils/_compat.pyi

+24-19
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from __future__ import annotations
55

66
from types import ModuleType
7+
from typing import Any, TypeGuard
78

89
# TODO import from typing (requires Python >=3.13)
910
from typing_extensions import TypeIs
@@ -12,29 +13,33 @@ from ._typing import Array, Device
1213

1314
# pylint: disable=missing-class-docstring,unused-argument
1415

15-
class Namespace(ModuleType):
16-
def device(self, x: Array, /) -> Device: ...
17-
1816
def array_namespace(
1917
*xs: Array | complex | None,
2018
api_version: str | None = None,
2119
use_compat: bool | None = None,
22-
) -> Namespace: ...
20+
) -> ModuleType: ...
2321
def device(x: Array, /) -> Device: ...
2422
def is_array_api_obj(x: object, /) -> TypeIs[Array]: ...
25-
def is_array_api_strict_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
26-
def is_cupy_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
27-
def is_dask_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
28-
def is_jax_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
29-
def is_numpy_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
30-
def is_pydata_sparse_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
31-
def is_torch_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
32-
def is_cupy_array(x: object, /) -> TypeIs[Array]: ...
33-
def is_dask_array(x: object, /) -> TypeIs[Array]: ...
34-
def is_jax_array(x: object, /) -> TypeIs[Array]: ...
35-
def is_numpy_array(x: object, /) -> TypeIs[Array]: ...
36-
def is_pydata_sparse_array(x: object, /) -> TypeIs[Array]: ...
37-
def is_torch_array(x: object, /) -> TypeIs[Array]: ...
38-
def is_lazy_array(x: object, /) -> TypeIs[Array]: ...
39-
def is_writeable_array(x: object, /) -> TypeIs[Array]: ...
23+
def is_array_api_strict_namespace(xp: ModuleType, /) -> bool: ...
24+
def is_cupy_namespace(xp: ModuleType, /) -> bool: ...
25+
def is_dask_namespace(xp: ModuleType, /) -> bool: ...
26+
def is_jax_namespace(xp: ModuleType, /) -> bool: ...
27+
def is_numpy_namespace(xp: ModuleType, /) -> bool: ...
28+
def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ...
29+
def is_torch_namespace(xp: ModuleType, /) -> bool: ...
30+
def is_cupy_array(x: object, /) -> TypeGuard[Array]: ...
31+
def is_dask_array(x: object, /) -> TypeGuard[Array]: ...
32+
def is_jax_array(x: object, /) -> TypeGuard[Array]: ...
33+
def is_numpy_array(x: object, /) -> TypeGuard[Array]: ...
34+
def is_pydata_sparse_array(x: object, /) -> TypeGuard[Array]: ...
35+
def is_torch_array(x: object, /) -> TypeGuard[Array]: ...
36+
def is_lazy_array(x: object, /) -> TypeGuard[Array]: ...
37+
def is_writeable_array(x: object, /) -> TypeGuard[Array]: ...
4038
def size(x: Array, /) -> int | None: ...
39+
def to_device( # type: ignore[explicit-any]
40+
x: Array,
41+
device: Device, # pylint: disable=redefined-outer-name
42+
/,
43+
*,
44+
stream: int | Any | None = None,
45+
) -> Array: ...

vendor_tests/test_vendor.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@ def test_vendor_compat():
2323
is_torch_namespace,
2424
is_writeable_array,
2525
size,
26+
to_device,
2627
)
2728

2829
x = xp.asarray([1, 2, 3])
2930
assert array_namespace(x) is xp
30-
device(x)
31+
to_device(x, device(x))
3132
assert is_array_api_obj(x)
3233
assert is_array_api_strict_namespace(xp)
3334
assert not is_cupy_array(x)

0 commit comments

Comments
 (0)