diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 8ea9162a..00490e2c 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -7,7 +7,7 @@ import inspect from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast -from ._helpers import _check_device, array_namespace +from ._helpers import _device_ctx, array_namespace from ._helpers import device as _get_device from ._helpers import is_cupy_namespace as _is_cupy_namespace from ._typing import Array, Device, DType, Namespace @@ -32,8 +32,8 @@ def arange( device: Device | None = None, **kwargs: object, ) -> Array: - _check_device(xp, device) - return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) + with _device_ctx(xp, device): + return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) def empty( @@ -44,8 +44,8 @@ def empty( device: Device | None = None, **kwargs: object, ) -> Array: - _check_device(xp, device) - return xp.empty(shape, dtype=dtype, **kwargs) + with _device_ctx(xp, device): + return xp.empty(shape, dtype=dtype, **kwargs) def empty_like( @@ -57,8 +57,8 @@ def empty_like( device: Device | None = None, **kwargs: object, ) -> Array: - _check_device(xp, device) - return xp.empty_like(x, dtype=dtype, **kwargs) + with _device_ctx(xp, device, like=x): + return xp.empty_like(x, dtype=dtype, **kwargs) def eye( @@ -72,8 +72,8 @@ def eye( device: Device | None = None, **kwargs: object, ) -> Array: - _check_device(xp, device) - return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) + with _device_ctx(xp, device): + return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) def full( @@ -85,8 +85,8 @@ def full( device: Device | None = None, **kwargs: object, ) -> Array: - _check_device(xp, device) - return xp.full(shape, fill_value, dtype=dtype, **kwargs) + with _device_ctx(xp, device): + return xp.full(shape, fill_value, dtype=dtype, **kwargs) def full_like( @@ -99,8 +99,8 @@ def full_like( device: Device | None = None, **kwargs: object, ) -> Array: - _check_device(xp, device) - return xp.full_like(x, fill_value, dtype=dtype, **kwargs) + with _device_ctx(xp, device, like=x): + return xp.full_like(x, fill_value, dtype=dtype, **kwargs) def linspace( @@ -115,8 +115,8 @@ def linspace( endpoint: bool = True, **kwargs: object, ) -> Array: - _check_device(xp, device) - return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) + with _device_ctx(xp, device): + return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) def ones( @@ -127,8 +127,8 @@ def ones( device: Device | None = None, **kwargs: object, ) -> Array: - _check_device(xp, device) - return xp.ones(shape, dtype=dtype, **kwargs) + with _device_ctx(xp, device): + return xp.ones(shape, dtype=dtype, **kwargs) def ones_like( @@ -140,8 +140,8 @@ def ones_like( device: Device | None = None, **kwargs: object, ) -> Array: - _check_device(xp, device) - return xp.ones_like(x, dtype=dtype, **kwargs) + with _device_ctx(xp, device, like=x): + return xp.ones_like(x, dtype=dtype, **kwargs) def zeros( @@ -152,8 +152,8 @@ def zeros( device: Device | None = None, **kwargs: object, ) -> Array: - _check_device(xp, device) - return xp.zeros(shape, dtype=dtype, **kwargs) + with _device_ctx(xp, device): + return xp.zeros(shape, dtype=dtype, **kwargs) def zeros_like( @@ -165,8 +165,8 @@ def zeros_like( device: Device | None = None, **kwargs: object, ) -> Array: - _check_device(xp, device) - return xp.zeros_like(x, dtype=dtype, **kwargs) + with _device_ctx(xp, device, like=x): + return xp.zeros_like(x, dtype=dtype, **kwargs) # np.unique() is split into four functions in the array API: diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 77175d0d..ca4d16d0 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -8,6 +8,7 @@ from __future__ import annotations +import contextlib import inspect import math import sys @@ -657,26 +658,42 @@ def your_function(x, y): get_namespace = array_namespace -def _check_device(bare_xp: Namespace, device: Device) -> None: # pyright: ignore[reportUnusedFunction] - """ - Validate dummy device on device-less array backends. +def _device_ctx( + bare_xp: Namespace, device: Device, like: Array | None = None +) -> Generator[None]: + """Context manager which changes the current device in CuPy. - Notes - ----- - This function is also invoked by CuPy, which does have multiple devices - if there are multiple GPUs available. - However, CuPy multi-device support is currently impossible - without using the global device or a context manager: - - https://github.com/data-apis/array-api-compat/pull/293 + Used internally by array creation functions in common._aliases. """ - if bare_xp is sys.modules.get("numpy"): - if device not in ("cpu", None): + if device is None: + if like is None: + return contextlib.nullcontext() + device = _device(like) + + if bare_xp is sys.modules.get('numpy'): + if device != "cpu": raise ValueError(f"Unsupported device for NumPy: {device!r}") + return contextlib.nullcontext() - elif bare_xp is sys.modules.get("dask.array"): - if device not in ("cpu", _DASK_DEVICE, None): + if bare_xp is sys.modules.get('dask.array'): + if device not in ("cpu", _DASK_DEVICE): raise ValueError(f"Unsupported device for Dask: {device!r}") + return contextlib.nullcontext() + + if bare_xp is sys.modules.get('cupy'): + if not isinstance(device, bare_xp.cuda.Device): + raise TypeError(f"device is not a cupy.cuda.Device: {device!r}") + return device + + # PyTorch doesn't have a "current device" context manager and you + # can't use array creation functions from common._aliases. + raise AssertionError("unreachable") # pragma: nocover + + +def _check_device(bare_xp: Namespace, device: Device) -> None: + """Validate dummy device on device-less array backends.""" + with _device_ctx(bare_xp, device): + pass # Placeholder object to represent the dask device diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index adb74bff..c1ea532c 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -86,7 +86,8 @@ def asarray( See the corresponding documentation in the array library and/or the array API specification for more details. """ - with cp.cuda.Device(device): + like = obj if isinstance(obj, cp.ndarray) else None + with _helpers._device_ctx(cp, device, like=like): if copy is None: return cp.asarray(obj, dtype=dtype, **kwargs) else: