From 8e5cc941152d3f4019f70071c2bbfe12c46669f2 Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Tue, 26 Nov 2024 12:44:01 +0800
Subject: [PATCH 01/19] add paddle support in array-api-compat
---
.github/workflows/array-api-tests-paddle.yml | 11 +
array_api_compat/common/_helpers.py | 79 ++
array_api_compat/paddle/__init__.py | 28 +
array_api_compat/paddle/_aliases.py | 1153 ++++++++++++++++++
array_api_compat/paddle/_info.py | 373 ++++++
array_api_compat/paddle/fft.py | 92 ++
array_api_compat/paddle/linalg.py | 136 +++
array_api_compat/torch/fft.py | 26 +-
array_api_compat/torch/linalg.py | 76 +-
docs/index.md | 4 +
docs/supported-array-libraries.md | 23 +
requirements-dev.txt | 1 +
tests/_helpers.py | 13 +-
tests/test_array_namespace.py | 76 +-
tests/test_common.py | 28 +-
tests/test_isdtype.py | 2 +-
tests/test_no_dependencies.py | 8 +-
tests/test_vendoring.py | 26 +-
vendor_test/uses_paddle.py | 30 +
19 files changed, 2088 insertions(+), 97 deletions(-)
create mode 100644 .github/workflows/array-api-tests-paddle.yml
create mode 100644 array_api_compat/paddle/__init__.py
create mode 100644 array_api_compat/paddle/_aliases.py
create mode 100644 array_api_compat/paddle/_info.py
create mode 100644 array_api_compat/paddle/fft.py
create mode 100644 array_api_compat/paddle/linalg.py
create mode 100644 vendor_test/uses_paddle.py
diff --git a/.github/workflows/array-api-tests-paddle.yml b/.github/workflows/array-api-tests-paddle.yml
new file mode 100644
index 00000000..d4f88b00
--- /dev/null
+++ b/.github/workflows/array-api-tests-paddle.yml
@@ -0,0 +1,11 @@
+name: Array API Tests (Paddle Latest)
+
+on: [push, pull_request]
+
+jobs:
+ array-api-tests-paddle:
+ uses: ./.github/workflows/array-api-tests.yml
+ with:
+ package-name: paddle
+ extra-env-vars: |
+ ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64
diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py
index b011f08d..ff2c213f 100644
--- a/array_api_compat/common/_helpers.py
+++ b/array_api_compat/common/_helpers.py
@@ -120,6 +120,33 @@ def is_torch_array(x):
# TODO: Should we reject ndarray subclasses?
return isinstance(x, torch.Tensor)
+def is_paddle_array(x):
+ """
+ Return True if `x` is a Paddle tensor.
+
+ This function does not import Paddle if it has not already been imported
+ and is therefore cheap to use.
+
+ See Also
+ --------
+
+ array_namespace
+ is_array_api_obj
+ is_numpy_array
+ is_cupy_array
+ is_dask_array
+ is_jax_array
+ is_pydata_sparse_array
+ """
+ # Avoid importing paddle if it isn't already
+ if 'paddle' not in sys.modules:
+ return False
+
+ import paddle
+
+ # TODO: Should we reject ndarray subclasses?
+ return paddle.is_tensor(x)
+
def is_ndonnx_array(x):
"""
Return True if `x` is a ndonnx Array.
@@ -252,6 +279,7 @@ def is_array_api_obj(x):
or is_dask_array(x) \
or is_jax_array(x) \
or is_pydata_sparse_array(x) \
+ or is_paddle_array(x) \
or hasattr(x, '__array_namespace__')
def _compat_module_name():
@@ -319,6 +347,27 @@ def is_torch_namespace(xp) -> bool:
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
+def is_paddle_namespace(xp) -> bool:
+ """
+ Returns True if `xp` is a Paddle namespace.
+
+ This includes both Paddle itself and the version wrapped by array-api-compat.
+
+ See Also
+ --------
+
+ array_namespace
+ is_numpy_namespace
+ is_cupy_namespace
+ is_ndonnx_namespace
+ is_dask_namespace
+ is_jax_namespace
+ is_pydata_sparse_namespace
+ is_array_api_strict_namespace
+ """
+ return xp.__name__ in {'paddle', _compat_module_name() + '.paddle'}
+
+
def is_ndonnx_namespace(xp):
"""
Returns True if `xp` is an NDONNX namespace.
@@ -543,6 +592,14 @@ def your_function(x, y):
else:
import jax.experimental.array_api as jnp
namespaces.add(jnp)
+ elif is_paddle_array(x):
+ if _use_compat:
+ _check_api_version(api_version)
+ from .. import paddle as paddle_namespace
+ namespaces.add(paddle_namespace)
+ else:
+ import paddle
+ namespaces.add(paddle)
elif is_pydata_sparse_array(x):
if use_compat is True:
_check_api_version(api_version)
@@ -660,6 +717,16 @@ def device(x: Array, /) -> Device:
return "cpu"
# Return the device of the constituent array
return device(inner)
+ elif is_paddle_array(x):
+ raw_place_str = str(x.place)
+ if "gpu_pinned" in raw_place_str:
+ return "cpu"
+ elif "cpu" in raw_place_str:
+ return "cpu"
+ elif "gpu" in raw_place_str:
+ return "gpu"
+ raise NotImplementedError(f"Unsupported device {raw_place_str}")
+
return x.device
# Prevent shadowing, used below
@@ -709,6 +776,14 @@ def _torch_to_device(x, device, /, stream=None):
raise NotImplementedError
return x.to(device)
+def _paddle_to_device(x, device, /, stream=None):
+ if stream is not None:
+ raise NotImplementedError(
+ "paddle.Tensor.to() do not support stream argument yet"
+ )
+ return x.to(device)
+
+
def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array:
"""
Copy the array from the device on which it currently resides to the specified ``device``.
@@ -781,6 +856,8 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
# In JAX v0.4.31 and older, this import adds to_device method to x.
import jax.experimental.array_api # noqa: F401
return x.to_device(device, stream=stream)
+ elif is_paddle_array(x):
+ return _paddle_to_device(x, device, stream=stream)
elif is_pydata_sparse_array(x) and device == _device(x):
# Perform trivial check to return the same array if
# device is same instead of err-ing.
@@ -819,6 +896,8 @@ def size(x):
"is_torch_namespace",
"is_ndonnx_array",
"is_ndonnx_namespace",
+ "is_paddle_array",
+ "is_paddle_namespace",
"is_pydata_sparse_array",
"is_pydata_sparse_namespace",
"size",
diff --git a/array_api_compat/paddle/__init__.py b/array_api_compat/paddle/__init__.py
new file mode 100644
index 00000000..9f96fa9f
--- /dev/null
+++ b/array_api_compat/paddle/__init__.py
@@ -0,0 +1,28 @@
+from paddle import * # noqa: F403
+
+# Several names are not included in the above import *
+import paddle
+
+for n in dir(paddle):
+ if (
+ n.startswith("_")
+ or n.endswith("_")
+ or "gpu" in n
+ or "cpu" in n
+ or "backward" in n
+ ):
+ continue
+ exec(n + " = paddle." + n)
+ exec("asarray = paddle.to_tensor")
+
+# These imports may overwrite names from the import * above.
+from ._aliases import * # noqa: F403
+
+# See the comment in the numpy __init__.py
+__import__(__package__ + ".linalg")
+
+__import__(__package__ + ".fft")
+
+from ..common._helpers import * # noqa: F403
+
+__array_api_version__ = "2023.12"
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
new file mode 100644
index 00000000..dabe2928
--- /dev/null
+++ b/array_api_compat/paddle/_aliases.py
@@ -0,0 +1,1153 @@
+from __future__ import annotations
+
+from functools import wraps as _wraps
+from builtins import all as _builtin_all, any as _builtin_any
+
+from ..common._aliases import (
+ matrix_transpose as _aliases_matrix_transpose,
+ vecdot as _aliases_vecdot,
+ clip as _aliases_clip,
+ unstack as _aliases_unstack,
+ cumulative_sum as _aliases_cumulative_sum,
+)
+from .._internal import get_xp
+
+from ._info import __array_namespace_info__
+
+import paddle
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from typing import List, Optional, Sequence, Tuple, Union
+ from ..common._typing import Device
+ from paddle import dtype as Dtype
+
+ array = paddle.Tensor
+
+_int_dtypes = {
+ paddle.uint8,
+ paddle.int8,
+ paddle.int16,
+ paddle.int32,
+ paddle.int64,
+}
+
+_array_api_dtypes = {
+ paddle.bool,
+ *_int_dtypes,
+ paddle.float32,
+ paddle.float64,
+ paddle.complex64,
+ paddle.complex128,
+}
+
+_promotion_table = {
+ # bool
+ (paddle.bool, paddle.bool): paddle.bool,
+ # ints
+ (paddle.int8, paddle.int8): paddle.int8,
+ (paddle.int8, paddle.int16): paddle.int16,
+ (paddle.int8, paddle.int32): paddle.int32,
+ (paddle.int8, paddle.int64): paddle.int64,
+ (paddle.int16, paddle.int8): paddle.int16,
+ (paddle.int16, paddle.int16): paddle.int16,
+ (paddle.int16, paddle.int32): paddle.int32,
+ (paddle.int16, paddle.int64): paddle.int64,
+ (paddle.int32, paddle.int8): paddle.int32,
+ (paddle.int32, paddle.int16): paddle.int32,
+ (paddle.int32, paddle.int32): paddle.int32,
+ (paddle.int32, paddle.int64): paddle.int64,
+ (paddle.int64, paddle.int8): paddle.int64,
+ (paddle.int64, paddle.int16): paddle.int64,
+ (paddle.int64, paddle.int32): paddle.int64,
+ (paddle.int64, paddle.int64): paddle.int64,
+ # uints
+ (paddle.uint8, paddle.uint8): paddle.uint8,
+ # ints and uints (mixed sign)
+ (paddle.int8, paddle.uint8): paddle.int16,
+ (paddle.int16, paddle.uint8): paddle.int16,
+ (paddle.int32, paddle.uint8): paddle.int32,
+ (paddle.int64, paddle.uint8): paddle.int64,
+ (paddle.uint8, paddle.int8): paddle.int16,
+ (paddle.uint8, paddle.int16): paddle.int16,
+ (paddle.uint8, paddle.int32): paddle.int32,
+ (paddle.uint8, paddle.int64): paddle.int64,
+ # floats
+ (paddle.float32, paddle.float32): paddle.float32,
+ (paddle.float32, paddle.float64): paddle.float64,
+ (paddle.float64, paddle.float32): paddle.float64,
+ (paddle.float64, paddle.float64): paddle.float64,
+ # complexes
+ (paddle.complex64, paddle.complex64): paddle.complex64,
+ (paddle.complex64, paddle.complex128): paddle.complex128,
+ (paddle.complex128, paddle.complex64): paddle.complex128,
+ (paddle.complex128, paddle.complex128): paddle.complex128,
+ # Mixed float and complex
+ (paddle.float32, paddle.complex64): paddle.complex64,
+ (paddle.float32, paddle.complex128): paddle.complex128,
+ (paddle.float64, paddle.complex64): paddle.complex128,
+ (paddle.float64, paddle.complex128): paddle.complex128,
+}
+
+
+def _two_arg(f):
+ @_wraps(f)
+ def _f(x1, x2, /, **kwargs):
+ x1, x2 = _fix_promotion(x1, x2)
+ return f(x1, x2, **kwargs)
+
+ if _f.__doc__ is None:
+ _f.__doc__ = f"""\
+Array API compatibility wrapper for paddle.{f.__name__}.
+
+See the corresponding Paddle documentation and/or the array API specification
+for more details.
+
+"""
+ return _f
+
+
+def _fix_promotion(x1, x2, only_scalar=True):
+ if not isinstance(x1, paddle.Tensor) or not isinstance(x2, paddle.Tensor):
+ return x1, x2
+ if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes:
+ return x1, x2
+ # If an argument is 0-D pytorch downcasts the other argument
+ if not only_scalar or x1.shape == ():
+ dtype = result_type(x1, x2)
+ x2 = x2.to(dtype)
+ if not only_scalar or x2.shape == ():
+ dtype = result_type(x1, x2)
+ x1 = x1.to(dtype)
+ return x1, x2
+
+
+def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
+ if len(arrays_and_dtypes) == 0:
+ raise TypeError("At least one array or dtype must be provided")
+ if len(arrays_and_dtypes) == 1:
+ x = arrays_and_dtypes[0]
+ if isinstance(x, paddle.dtype):
+ return x
+ return x.dtype
+ if len(arrays_and_dtypes) > 2:
+ return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:]))
+
+ x, y = arrays_and_dtypes
+ xdt = x.dtype if not isinstance(x, paddle.dtype) else x
+ ydt = y.dtype if not isinstance(y, paddle.dtype) else y
+
+ if (xdt, ydt) in _promotion_table:
+ return _promotion_table[xdt, ydt]
+
+ # This doesn't result_type(dtype, dtype) for non-array API dtypes
+ # because paddle.result_type only accepts tensors. This does however, allow
+ # cross-kind promotion.
+ x = paddle.to_tensor([], dtype=x) if isinstance(x, paddle.dtype) else x
+ y = paddle.to_tensor([], dtype=y) if isinstance(y, paddle.dtype) else y
+ return paddle.result_type(x, y)
+
+
+def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
+ can_cast_dict = {
+ paddle.bfloat16: {
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: False,
+ paddle.int8: False,
+ paddle.int16: False,
+ paddle.int32: False,
+ paddle.int64: False,
+ paddle.bool: False,
+ },
+ paddle.float16: {
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: False,
+ paddle.int8: False,
+ paddle.int16: False,
+ paddle.int32: False,
+ paddle.int64: False,
+ paddle.bool: False,
+ },
+ paddle.float32: {
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: False,
+ paddle.int8: False,
+ paddle.int16: False,
+ paddle.int32: False,
+ paddle.int64: False,
+ paddle.bool: False,
+ },
+ paddle.float64: {
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: False,
+ paddle.int8: False,
+ paddle.int16: False,
+ paddle.int32: False,
+ paddle.int64: False,
+ paddle.bool: False,
+ },
+ paddle.complex64: {
+ paddle.bfloat16: False,
+ paddle.float16: False,
+ paddle.float32: False,
+ paddle.float64: False,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: False,
+ paddle.int8: False,
+ paddle.int16: False,
+ paddle.int32: False,
+ paddle.int64: False,
+ paddle.bool: False,
+ },
+ paddle.complex128: {
+ paddle.bfloat16: False,
+ paddle.float16: False,
+ paddle.float32: False,
+ paddle.float64: False,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: False,
+ paddle.int8: False,
+ paddle.int16: False,
+ paddle.int32: False,
+ paddle.int64: False,
+ paddle.bool: False,
+ },
+ paddle.uint8: {
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: False,
+ },
+ paddle.int8: {
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: False,
+ },
+ paddle.int16: {
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: False,
+ },
+ paddle.int32: {
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: False,
+ },
+ paddle.int64: {
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: False,
+ },
+ paddle.bool: {
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: True,
+ },
+ }
+ return can_cast_dict[from_][to]
+
+
+# Basic renames
+bitwise_invert = paddle.bitwise_not
+newaxis = None
+# paddle.conj sets the conjugation bit, which breaks conversion to other
+# libraries. See https://github.com/data-apis/array-api-compat/issues/173
+conj = paddle.conj
+
+# Two-arg elementwise functions
+# These require a wrapper to do the correct type promotion on 0-D tensors
+add = _two_arg(paddle.add)
+atan2 = _two_arg(paddle.atan2)
+bitwise_and = _two_arg(paddle.bitwise_and)
+bitwise_left_shift = _two_arg(paddle.bitwise_left_shift)
+bitwise_or = _two_arg(paddle.bitwise_or)
+bitwise_right_shift = _two_arg(paddle.bitwise_right_shift)
+bitwise_xor = _two_arg(paddle.bitwise_xor)
+copysign = _two_arg(paddle.copysign)
+divide = _two_arg(paddle.divide)
+# Also a rename. paddle.equal does not broadcast
+equal = _two_arg(paddle.equal)
+floor_divide = _two_arg(paddle.floor_divide)
+greater = _two_arg(paddle.greater_than)
+greater_equal = _two_arg(paddle.greater_equal)
+hypot = _two_arg(paddle.hypot)
+less = _two_arg(paddle.less)
+less_equal = _two_arg(paddle.less_equal)
+logaddexp = _two_arg(paddle.logaddexp)
+# logical functions are not included here because they only accept bool in the
+# spec, so type promotion is irrelevant.
+maximum = _two_arg(paddle.maximum)
+minimum = _two_arg(paddle.minimum)
+multiply = _two_arg(paddle.multiply)
+not_equal = _two_arg(paddle.not_equal)
+pow = _two_arg(paddle.pow)
+remainder = _two_arg(paddle.remainder)
+subtract = _two_arg(paddle.subtract)
+
+# These wrappers are mostly based on the fact that pytorch uses 'dim' instead
+# of 'axis'.
+
+
+def max(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+) -> array:
+ # https://github.com/pytorch/pytorch/issues/29137
+ if axis == ():
+ return paddle.clone(x)
+ return paddle.amax(x, axis, keepdim=keepdims)
+
+
+def min(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+) -> array:
+ # https://github.com/pytorch/pytorch/issues/29137
+ if axis == ():
+ return paddle.clone(x)
+ return paddle.min(x, axis, keepdim=keepdims)
+
+
+clip = get_xp(paddle)(_aliases_clip)
+unstack = get_xp(paddle)(_aliases_unstack)
+cumulative_sum = get_xp(paddle)(_aliases_cumulative_sum)
+
+
+# paddle.sort also returns a tuple
+# https://github.com/pytorch/pytorch/issues/70921
+def sort(
+ x: array,
+ /,
+ *,
+ axis: int = -1,
+ descending: bool = False,
+ stable: bool = True,
+ **kwargs,
+) -> array:
+ return paddle.sort(
+ x, axis=axis, descending=descending, stable=stable, **kwargs
+ ).values
+
+
+def _normalize_axes(axis, ndim):
+ axes = []
+ if ndim == 0 and axis:
+ # Better error message in this case
+ raise IndexError(f"Dimension out of range: {axis[0]}")
+ lower, upper = -ndim, ndim - 1
+ for a in axis:
+ if a < lower or a > upper:
+ # Match paddle error message (e.g., from sum())
+ raise IndexError(
+ f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}"
+ )
+ if a < 0:
+ a = a + ndim
+ if a in axes:
+ # Use IndexError instead of RuntimeError, and "axis" instead of "dim"
+ raise IndexError(f"Axis {a} appears multiple times in the list of axes")
+ axes.append(a)
+ return sorted(axes)
+
+
+def _axis_none_keepdims(x, ndim, keepdims):
+ # Apply keepdims when axis=None
+ # (https://github.com/pytorch/pytorch/issues/71209)
+ # Note that this is only valid for the axis=None case.
+ if keepdims:
+ for i in range(ndim):
+ x = paddle.unsqueeze(x, 0)
+ return x
+
+
+def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
+ # Some reductions don't support multiple axes
+ # (https://github.com/pytorch/pytorch/issues/56586).
+ axes = _normalize_axes(axis, x.ndim)
+ for a in reversed(axes):
+ x = paddle.movedim(x, a, -1)
+ x = paddle.flatten(x, -len(axes))
+
+ out = f(x, -1, **kwargs)
+
+ if keepdims:
+ for a in axes:
+ out = paddle.unsqueeze(out, a)
+ return out
+
+
+def prod(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ dtype: Optional[Dtype] = None,
+ keepdims: bool = False,
+ **kwargs,
+) -> array:
+ x = paddle.asarray(x)
+ ndim = x.ndim
+
+ # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
+ # below because it still needs to upcast.
+ if axis == ():
+ if dtype is None:
+ # We can't upcast uint8 according to the spec because there is no
+ # paddle.uint64, so at least upcast to int64 which is what sum does
+ # when axis=None.
+ if x.dtype in [paddle.int8, paddle.int16, paddle.int32, paddle.uint8]:
+ return x.to(paddle.int64)
+ return x.clone()
+ return x.to(dtype)
+
+ # paddle.prod doesn't support multiple axes
+ # (https://github.com/pytorch/pytorch/issues/56586).
+ if isinstance(axis, tuple):
+ return _reduce_multiple_axes(
+ paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs
+ )
+ if axis is None:
+ # paddle doesn't support keepdims with axis=None
+ # (https://github.com/pytorch/pytorch/issues/71209)
+ res = paddle.prod(x, dtype=dtype, **kwargs)
+ res = _axis_none_keepdims(res, ndim, keepdims)
+ return res
+
+ return paddle.prod(x, axis, dtype=dtype, keepdim=keepdims, **kwargs)
+
+
+def sum(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ dtype: Optional[Dtype] = None,
+ keepdims: bool = False,
+ **kwargs,
+) -> array:
+ x = paddle.asarray(x)
+ ndim = x.ndim
+
+ # https://github.com/pytorch/pytorch/issues/29137.
+ # Make sure it upcasts.
+ if axis == ():
+ if dtype is None:
+ # We can't upcast uint8 according to the spec because there is no
+ # paddle.uint64, so at least upcast to int64 which is what sum does
+ # when axis=None.
+ if x.dtype in [paddle.int8, paddle.int16, paddle.int32, paddle.uint8]:
+ return x.to(paddle.int64)
+ return x.clone()
+ return x.to(dtype)
+
+ if axis is None:
+ # paddle doesn't support keepdims with axis=None
+ # (https://github.com/pytorch/pytorch/issues/71209)
+ res = paddle.sum(x, dtype=dtype, **kwargs)
+ res = _axis_none_keepdims(res, ndim, keepdims)
+ return res
+
+ return paddle.sum(x, axis, dtype=dtype, keepdim=keepdims, **kwargs)
+
+
+def any(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+ **kwargs,
+) -> array:
+ x = paddle.asarray(x)
+ ndim = x.ndim
+ if axis == ():
+ return x.to(paddle.bool)
+ # paddle.any doesn't support multiple axes
+ # (https://github.com/pytorch/pytorch/issues/56586).
+ if isinstance(axis, tuple):
+ res = _reduce_multiple_axes(paddle.any, x, axis, keepdim=keepdims, **kwargs)
+ return res.to(paddle.bool)
+ if axis is None:
+ # paddle doesn't support keepdims with axis=None
+ # (https://github.com/pytorch/pytorch/issues/71209)
+ res = paddle.any(x, **kwargs)
+ res = _axis_none_keepdims(res, ndim, keepdims)
+ return res.to(paddle.bool)
+
+ # paddle.any doesn't return bool for uint8
+ return paddle.any(x, axis, keepdim=keepdims).to(paddle.bool)
+
+
+def all(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+ **kwargs,
+) -> array:
+ x = paddle.asarray(x)
+ ndim = x.ndim
+ if axis == ():
+ return x.to(paddle.bool)
+ # paddle.all doesn't support multiple axes
+ # (https://github.com/pytorch/pytorch/issues/56586).
+ if isinstance(axis, tuple):
+ res = _reduce_multiple_axes(paddle.all, x, axis, keepdim=keepdims, **kwargs)
+ return res.to(paddle.bool)
+ if axis is None:
+ # paddle doesn't support keepdims with axis=None
+ # (https://github.com/pytorch/pytorch/issues/71209)
+ res = paddle.all(x, **kwargs)
+ res = _axis_none_keepdims(res, ndim, keepdims)
+ return res.to(paddle.bool)
+
+ # paddle.all doesn't return bool for uint8
+ return paddle.all(x, axis, keepdim=keepdims).to(paddle.bool)
+
+
+def mean(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+ **kwargs,
+) -> array:
+ # https://github.com/pytorch/pytorch/issues/29137
+ if axis == ():
+ return paddle.clone(x)
+ if axis is None:
+ # paddle doesn't support keepdims with axis=None
+ # (https://github.com/pytorch/pytorch/issues/71209)
+ res = paddle.mean(x, **kwargs)
+ res = _axis_none_keepdims(res, x.ndim, keepdims)
+ return res
+ return paddle.mean(x, axis, keepdim=keepdims, **kwargs)
+
+
+def std(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ correction: Union[int, float] = 0.0,
+ keepdims: bool = False,
+ **kwargs,
+) -> array:
+ # Note, float correction is not supported
+ # https://github.com/pytorch/pytorch/issues/61492. We don't try to
+ # implement it here for now.
+
+ if isinstance(correction, float):
+ _correction = int(correction)
+ if correction != _correction:
+ raise NotImplementedError(
+ "float correction in paddle std() is not yet supported"
+ )
+ elif isinstance(correction, int):
+ if correction not in [0, 1]:
+ raise NotImplementedError("correction only can be 0 or 1")
+ elif not isinstance(correction, bool):
+ raise NotImplementedError("Only support bool correction and 0, 1")
+
+ _correction = bool(_correction)
+
+ # https://github.com/pytorch/pytorch/issues/29137
+ if axis == ():
+ return paddle.zeros_like(x)
+ if isinstance(axis, int):
+ axis = (axis,)
+ if axis is None:
+ # paddle doesn't support keepdims with axis=None
+ # (https://github.com/pytorch/pytorch/issues/71209)
+ res = paddle.std(x, tuple(range(x.ndim)), unbiased=_correction, **kwargs)
+ res = _axis_none_keepdims(res, x.ndim, keepdims)
+ return res
+ return paddle.std(x, axis, unbiased=_correction, keepdim=keepdims, **kwargs)
+
+
+def var(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ correction: Union[int, float] = 0.0,
+ keepdims: bool = False,
+ **kwargs,
+) -> array:
+ # Note, float correction is not supported
+ # https://github.com/pytorch/pytorch/issues/61492. We don't try to
+ # implement it here for now.
+
+ # if isinstance(correction, float):
+ # correction = int(correction)
+ if isinstance(correction, float):
+ _correction = int(correction)
+ if correction != _correction:
+ raise NotImplementedError(
+ "float correction in paddle std() is not yet supported"
+ )
+ elif isinstance(correction, int):
+ if correction not in [0, 1]:
+ raise NotImplementedError("correction only can be 0 or 1")
+ elif not isinstance(correction, bool):
+ raise NotImplementedError("Only support bool correction and 0, 1")
+
+ _correction = bool(_correction)
+
+ # https://github.com/pytorch/pytorch/issues/29137
+ if axis == ():
+ return paddle.zeros_like(x)
+ if isinstance(axis, int):
+ axis = (axis,)
+ if axis is None:
+ # paddle doesn't support keepdims with axis=None
+ # (https://github.com/pytorch/pytorch/issues/71209)
+ res = paddle.var(x, tuple(range(x.ndim)), unbiased=_correction, **kwargs)
+ res = _axis_none_keepdims(res, x.ndim, keepdims)
+ return res
+ return paddle.var(x, axis, unbiased=_correction, keepdim=keepdims, **kwargs)
+
+
+# paddle.concat doesn't support dim=None
+# https://github.com/pytorch/pytorch/issues/70925
+def concat(
+ arrays: Union[Tuple[array, ...], List[array]],
+ /,
+ *,
+ axis: Optional[int] = 0,
+ **kwargs,
+) -> array:
+ if axis is None:
+ arrays = tuple(ar.flatten() for ar in arrays)
+ axis = 0
+ return paddle.concat(arrays, axis, **kwargs)
+
+
+# paddle.squeeze only accepts int dim and doesn't require it
+# https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was
+# added at https://github.com/pytorch/pytorch/pull/89017.
+def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array:
+ if isinstance(axis, int):
+ axis = (axis,)
+ for a in axis:
+ if x.shape[a] != 1:
+ raise ValueError("squeezed dimensions must be equal to 1")
+ axes = _normalize_axes(axis, x.ndim)
+ # Remove this once pytorch 1.14 is released with the above PR #89017.
+ sequence = [a - i for i, a in enumerate(axes)]
+ for a in sequence:
+ x = paddle.squeeze(x, a)
+ return x
+
+
+# paddle.broadcast_to uses size instead of shape
+def broadcast_to(x: array, /, shape: Tuple[int, ...], **kwargs) -> array:
+ return paddle.broadcast_to(x, shape, **kwargs)
+
+
+# paddle.permute uses dims instead of axes
+def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array:
+ if len(axes) == 2:
+ perm = list(range(x.ndim))
+ perm[axes[0]], perm[axes[1]] = perm[axes[1]], perm[axes[0]]
+ axes = perm
+ return paddle.transpose(x, axes)
+
+
+# The axis parameter doesn't work for flip() and roll()
+# https://github.com/pytorch/pytorch/issues/71210. Also paddle.flip() doesn't
+# accept axis=None
+def flip(
+ x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs
+) -> array:
+ if axis is None:
+ axis = tuple(range(x.ndim))
+ # paddle.flip doesn't accept dim as an int but the method does
+ # https://github.com/pytorch/pytorch/issues/18095
+ return x.flip(axis, **kwargs)
+
+
+def roll(
+ x: array,
+ /,
+ shift: Union[int, Tuple[int, ...]],
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ **kwargs,
+) -> array:
+ return paddle.roll(x, shift, axis, **kwargs)
+
+
+def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
+ if x.ndim == 0:
+ raise ValueError("nonzero() does not support zero-dimensional arrays")
+ return paddle.nonzero(x, as_tuple=True, **kwargs)
+
+
+def where(condition: array, x1: array, x2: array, /) -> array:
+ x1, x2 = _fix_promotion(x1, x2)
+ return paddle.where(condition, x1, x2)
+
+
+# paddle.reshape doesn't have the copy keyword
+def reshape(
+ x: array, /, shape: Tuple[int, ...], copy: Optional[bool] = None, **kwargs
+) -> array:
+ if copy is not None:
+ raise NotImplementedError("paddle.reshape doesn't yet support the copy keyword")
+ return paddle.reshape(x, shape, **kwargs)
+
+
+# paddle.arange doesn't support returning empty arrays
+# (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some
+# keyword argument combinations
+# (https://github.com/pytorch/pytorch/issues/70914)
+def arange(
+ start: Union[int, float],
+ /,
+ stop: Optional[Union[int, float]] = None,
+ step: Union[int, float] = 1,
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+ **kwargs,
+) -> array:
+ if stop is None:
+ start, stop = 0, start
+ if step > 0 and stop <= start or step < 0 and stop >= start:
+ if dtype is None:
+ if _builtin_all(isinstance(i, int) for i in [start, stop, step]):
+ dtype = paddle.int64
+ else:
+ dtype = paddle.float32
+ return paddle.empty([0], dtype=dtype, **kwargs).to(device)
+ return paddle.arange(start, stop, step, dtype=dtype, **kwargs).to(device)
+
+
+# paddle.eye does not accept None as a default for the second argument and
+# doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910)
+def eye(
+ n_rows: int,
+ n_cols: Optional[int] = None,
+ /,
+ *,
+ k: int = 0,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+ **kwargs,
+) -> array:
+ if n_cols is None:
+ n_cols = n_rows
+ z = paddle.zeros([n_rows, n_cols], dtype=dtype, **kwargs).to(device)
+ if abs(k) <= n_rows + n_cols:
+ z.diagonal(k).fill_(1)
+ return z
+
+
+# paddle.linspace doesn't have the endpoint parameter
+def linspace(
+ start: Union[int, float],
+ stop: Union[int, float],
+ /,
+ num: int,
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+ endpoint: bool = True,
+ **kwargs,
+) -> array:
+ if not endpoint:
+ return paddle.linspace(start, stop, num + 1, dtype=dtype, **kwargs).to(device)[
+ :-1
+ ]
+ return paddle.linspace(start, stop, num, dtype=dtype, **kwargs).to(device)
+
+
+# paddle.full does not accept an int size
+# https://github.com/pytorch/pytorch/issues/70906
+def full(
+ shape: Union[int, Tuple[int, ...]],
+ fill_value: Union[bool, int, float, complex],
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+ **kwargs,
+) -> array:
+ if isinstance(shape, int):
+ shape = (shape,)
+
+ return paddle.full(shape, fill_value, dtype=dtype, **kwargs).to(device)
+
+
+# ones, zeros, and empty do not accept shape as a keyword argument
+def ones(
+ shape: Union[int, Tuple[int, ...]],
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+ **kwargs,
+) -> array:
+ return paddle.ones(shape, dtype=dtype, **kwargs).to(device)
+
+
+def zeros(
+ shape: Union[int, Tuple[int, ...]],
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+ **kwargs,
+) -> array:
+ return paddle.zeros(shape, dtype=dtype, **kwargs).to(device)
+
+
+def empty(
+ shape: Union[int, Tuple[int, ...]],
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+ **kwargs,
+) -> array:
+ return paddle.empty(shape, dtype=dtype, **kwargs).to(device)
+
+
+# tril and triu do not call the keyword argument k
+
+
+def tril(x: array, /, *, k: int = 0) -> array:
+ return paddle.tril(x, k)
+
+
+def triu(x: array, /, *, k: int = 0) -> array:
+ return paddle.triu(x, k)
+
+
+# Functions that aren't in paddle https://github.com/pytorch/pytorch/issues/58742
+def expand_dims(x: array, /, *, axis: int = 0) -> array:
+ return paddle.unsqueeze(x, axis)
+
+
+def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array:
+ return x.to(dtype, copy=copy)
+
+
+def broadcast_arrays(*arrays: array) -> List[array]:
+ shape = paddle.broadcast_shapes(*[a.shape for a in arrays])
+ return [paddle.broadcast_to(a, shape) for a in arrays]
+
+
+# Note that these named tuples aren't actually part of the standard namespace,
+# but I don't see any issue with exporting the names here regardless.
+from ..common._aliases import UniqueAllResult, UniqueCountsResult, UniqueInverseResult
+
+
+# https://github.com/pytorch/pytorch/issues/70920
+def unique_all(x: array) -> UniqueAllResult:
+ # paddle.unique doesn't support returning indices.
+ # https://github.com/pytorch/pytorch/issues/36748. The workaround
+ # suggested in that issue doesn't actually function correctly (it relies
+ # on non-deterministic behavior of scatter()).
+ raise NotImplementedError(
+ "unique_all() not yet implemented for paddle (see https://github.com/pytorch/pytorch/issues/36748)"
+ )
+
+ # values, inverse_indices, counts = paddle.unique(x, return_counts=True, return_inverse=True)
+ # # paddle.unique incorrectly gives a 0 count for nan values.
+ # # https://github.com/pytorch/pytorch/issues/94106
+ # counts[paddle.isnan(values)] = 1
+ # return UniqueAllResult(values, indices, inverse_indices, counts)
+
+
+def unique_counts(x: array) -> UniqueCountsResult:
+ values, counts = paddle.unique(x, return_counts=True)
+
+ # paddle.unique incorrectly gives a 0 count for nan values.
+ # https://github.com/pytorch/pytorch/issues/94106
+ counts[paddle.isnan(values)] = 1
+ return UniqueCountsResult(values, counts)
+
+
+def unique_inverse(x: array) -> UniqueInverseResult:
+ values, inverse = paddle.unique(x, return_inverse=True)
+ return UniqueInverseResult(values, inverse)
+
+
+def unique_values(x: array) -> array:
+ return paddle.unique(x)
+
+
+def matmul(x1: array, x2: array, /, **kwargs) -> array:
+ # paddle.matmul doesn't type promote (but differently from _fix_promotion)
+ x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
+ return paddle.matmul(x1, x2, **kwargs)
+
+
+matrix_transpose = get_xp(paddle)(_aliases_matrix_transpose)
+_vecdot = get_xp(paddle)(_aliases_vecdot)
+
+
+def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
+ x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
+ return _vecdot(x1, x2, axis=axis)
+
+
+# paddle.tensordot uses dims instead of axes
+def tensordot(
+ x1: array,
+ x2: array,
+ /,
+ *,
+ axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
+ **kwargs,
+) -> array:
+ # Note: paddle.tensordot fails with integer dtypes when there is only 1
+ # element in the axis (https://github.com/pytorch/pytorch/issues/84530).
+ x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
+ return paddle.tensordot(x1, x2, axes=axes, **kwargs)
+
+
+def isdtype(
+ dtype: Dtype,
+ kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]],
+ *,
+ _tuple=True, # Disallow nested tuples
+) -> bool:
+ """
+ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
+
+ Note that outside of this function, this compat library does not yet fully
+ support complex numbers.
+
+ See
+ https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
+ for more details
+ """
+
+ def is_signed(dtype):
+ return dtype in [paddle.int8, paddle.int16, paddle.int32, paddle.int64]
+
+ def is_floating_point(dtype):
+ return dtype in [
+ paddle.float32,
+ paddle.float64,
+ paddle.float16,
+ paddle.bfloat16,
+ paddle.float8_e4m3fn,
+ paddle.float8_e5m2,
+ ]
+
+ def is_complex(dtype):
+ return dtype in [paddle.complex64, paddle.complex128]
+
+ if isinstance(kind, tuple) and _tuple:
+ return _builtin_any(isdtype(dtype, k, _tuple=False) for k in kind)
+
+ elif isinstance(kind, str):
+ if kind == "bool":
+ return dtype == paddle.bool
+ elif kind == "signed integer":
+ return dtype in _int_dtypes and is_signed(dtype)
+ elif kind == "unsigned integer":
+ return dtype in _int_dtypes and not is_signed(dtype)
+ elif kind == "integral":
+ return dtype in _int_dtypes
+ elif kind == "real floating":
+ return is_floating_point(dtype)
+ elif kind == "complex floating":
+ return is_complex(dtype)
+ elif kind == "numeric":
+ return isdtype(dtype, ("integral", "real floating", "complex floating"))
+ else:
+ raise ValueError(f"Unrecognized data type kind: {kind!r}")
+ else:
+ return dtype == kind
+
+
+def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array:
+ if axis is None:
+ if x.ndim != 1:
+ raise ValueError("axis must be specified when ndim > 1")
+ axis = 0
+ return paddle.index_select(x, axis, indices, **kwargs)
+
+
+def sign(x: array, /) -> array:
+ # paddle sign() does not support complex numbers and does not propagate
+ # nans. See https://github.com/data-apis/array-api-compat/issues/136
+ if x.dtype.is_complex:
+ out = x / paddle.abs(x)
+ # sign(0) = 0 but the above formula would give nan
+ out[x == 0 + 0j] = 0 + 0j
+ return out
+ else:
+ out = paddle.sign(x)
+ if x.dtype.is_floating_point:
+ out[paddle.isnan(x)] = paddle.nan
+ return out
+
+
+__all__ = [
+ "__array_namespace_info__",
+ "result_type",
+ "can_cast",
+ "permute_dims",
+ "bitwise_invert",
+ "newaxis",
+ "conj",
+ "add",
+ "atan2",
+ "bitwise_and",
+ "bitwise_left_shift",
+ "bitwise_or",
+ "bitwise_right_shift",
+ "bitwise_xor",
+ "copysign",
+ "divide",
+ "equal",
+ "floor_divide",
+ "greater",
+ "greater_equal",
+ "hypot",
+ "less",
+ "less_equal",
+ "logaddexp",
+ "maximum",
+ "minimum",
+ "multiply",
+ "not_equal",
+ "pow",
+ "remainder",
+ "subtract",
+ "max",
+ "min",
+ "clip",
+ "unstack",
+ "cumulative_sum",
+ "sort",
+ "prod",
+ "sum",
+ "any",
+ "all",
+ "mean",
+ "std",
+ "var",
+ "concat",
+ "squeeze",
+ "broadcast_to",
+ "flip",
+ "roll",
+ "nonzero",
+ "where",
+ "reshape",
+ "arange",
+ "eye",
+ "linspace",
+ "full",
+ "ones",
+ "zeros",
+ "empty",
+ "tril",
+ "triu",
+ "expand_dims",
+ "astype",
+ "broadcast_arrays",
+ "UniqueAllResult",
+ "UniqueCountsResult",
+ "UniqueInverseResult",
+ "unique_all",
+ "unique_counts",
+ "unique_inverse",
+ "unique_values",
+ "matmul",
+ "matrix_transpose",
+ "vecdot",
+ "tensordot",
+ "isdtype",
+ "take",
+ "sign",
+]
+
+_all_ignore = ["paddle", "get_xp"]
diff --git a/array_api_compat/paddle/_info.py b/array_api_compat/paddle/_info.py
new file mode 100644
index 00000000..1fe48356
--- /dev/null
+++ b/array_api_compat/paddle/_info.py
@@ -0,0 +1,373 @@
+"""
+Array API Inspection namespace
+
+This is the namespace for inspection functions as defined by the array API
+standard. See
+https://data-apis.org/array-api/latest/API_specification/inspection.html for
+more details.
+
+"""
+
+import paddle
+
+from functools import cache
+
+
+class __array_namespace_info__:
+ """
+ Get the array API inspection namespace for PyTorch.
+
+ The array API inspection namespace defines the following functions:
+
+ - capabilities()
+ - default_device()
+ - default_dtypes()
+ - dtypes()
+ - devices()
+
+ See
+ https://data-apis.org/array-api/latest/API_specification/inspection.html
+ for more details.
+
+ Returns
+ -------
+ info : ModuleType
+ The array API inspection namespace for PyTorch.
+
+ Examples
+ --------
+ >>> info = np.__array_namespace_info__()
+ >>> info.default_dtypes()
+ {'real floating': numpy.float64,
+ 'complex floating': numpy.complex128,
+ 'integral': numpy.int64,
+ 'indexing': numpy.int64}
+
+ """
+
+ __module__ = "paddle"
+
+ def capabilities(self):
+ """
+ Return a dictionary of array API library capabilities.
+
+ The resulting dictionary has the following keys:
+
+ - **"boolean indexing"**: boolean indicating whether an array library
+ supports boolean indexing. Always ``True`` for PyTorch.
+
+ - **"data-dependent shapes"**: boolean indicating whether an array
+ library supports data-dependent output shapes. Always ``True`` for
+ PyTorch.
+
+ See
+ https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
+ for more details.
+
+ See Also
+ --------
+ __array_namespace_info__.default_device,
+ __array_namespace_info__.default_dtypes,
+ __array_namespace_info__.dtypes,
+ __array_namespace_info__.devices
+
+ Returns
+ -------
+ capabilities : dict
+ A dictionary of array API library capabilities.
+
+ Examples
+ --------
+ >>> info = np.__array_namespace_info__()
+ >>> info.capabilities()
+ {'boolean indexing': True,
+ 'data-dependent shapes': True}
+
+ """
+ return {
+ "boolean indexing": True,
+ "data-dependent shapes": True,
+ # 'max rank' will be part of the 2024.12 standard
+ # "max rank": 64,
+ }
+
+ def default_device(self):
+ """
+ The default device used for new PyTorch arrays.
+
+ See Also
+ --------
+ __array_namespace_info__.capabilities,
+ __array_namespace_info__.default_dtypes,
+ __array_namespace_info__.dtypes,
+ __array_namespace_info__.devices
+
+ Returns
+ -------
+ device : str
+ The default device used for new PyTorch arrays.
+
+ Examples
+ --------
+ >>> info = np.__array_namespace_info__()
+ >>> info.default_device()
+ 'cpu'
+
+ """
+ return paddle.device.get_device()
+
+ def default_dtypes(self, *, device=None):
+ """
+ The default data types used for new PyTorch arrays.
+
+ Parameters
+ ----------
+ device : str, optional
+ The device to get the default data types for. For PyTorch, only
+ ``'cpu'`` is allowed.
+
+ Returns
+ -------
+ dtypes : dict
+ A dictionary describing the default data types used for new PyTorch
+ arrays.
+
+ See Also
+ --------
+ __array_namespace_info__.capabilities,
+ __array_namespace_info__.default_device,
+ __array_namespace_info__.dtypes,
+ __array_namespace_info__.devices
+
+ Examples
+ --------
+ >>> info = np.__array_namespace_info__()
+ >>> info.default_dtypes()
+ {'real floating': paddle.float32,
+ 'complex floating': paddle.complex64,
+ 'integral': paddle.int64,
+ 'indexing': paddle.int64}
+
+ """
+ # Note: if the default is set to float64, the devices like MPS that
+ # don't support float64 will error. We still return the default_dtype
+ # value here because this error doesn't represent a different default
+ # per-device.
+ default_floating = paddle.get_default_dtype()
+ default_complex = "complex64" if default_floating == "float32" else "complex128"
+ default_integral = "int64"
+ return {
+ "real floating": default_floating,
+ "complex floating": default_complex,
+ "integral": default_integral,
+ "indexing": default_integral,
+ }
+
+ def _dtypes(self, kind):
+ bool = paddle.bool
+ int8 = paddle.int8
+ int16 = paddle.int16
+ int32 = paddle.int32
+ int64 = paddle.int64
+ uint8 = paddle.uint8
+ # uint16, uint32, and uint64 are present in newer versions of pytorch,
+ # but they aren't generally supported by the array API functions, so
+ # we omit them from this function.
+ float32 = paddle.float32
+ float64 = paddle.float64
+ complex64 = paddle.complex64
+ complex128 = paddle.complex128
+
+ if kind is None:
+ return {
+ "bool": bool,
+ "int8": int8,
+ "int16": int16,
+ "int32": int32,
+ "int64": int64,
+ "uint8": uint8,
+ "float32": float32,
+ "float64": float64,
+ "complex64": complex64,
+ "complex128": complex128,
+ }
+ if kind == "bool":
+ return {"bool": bool}
+ if kind == "signed integer":
+ return {
+ "int8": int8,
+ "int16": int16,
+ "int32": int32,
+ "int64": int64,
+ }
+ if kind == "unsigned integer":
+ return {
+ "uint8": uint8,
+ }
+ if kind == "integral":
+ return {
+ "int8": int8,
+ "int16": int16,
+ "int32": int32,
+ "int64": int64,
+ "uint8": uint8,
+ }
+ if kind == "real floating":
+ return {
+ "float32": float32,
+ "float64": float64,
+ }
+ if kind == "complex floating":
+ return {
+ "complex64": complex64,
+ "complex128": complex128,
+ }
+ if kind == "numeric":
+ return {
+ "int8": int8,
+ "int16": int16,
+ "int32": int32,
+ "int64": int64,
+ "uint8": uint8,
+ "float32": float32,
+ "float64": float64,
+ "complex64": complex64,
+ "complex128": complex128,
+ }
+ if isinstance(kind, tuple):
+ res = {}
+ for k in kind:
+ res.update(self.dtypes(kind=k))
+ return res
+ raise ValueError(f"unsupported kind: {kind!r}")
+
+ @cache
+ def dtypes(self, *, device=None, kind=None):
+ """
+ The array API data types supported by PyTorch.
+
+ Note that this function only returns data types that are defined by
+ the array API.
+
+ Parameters
+ ----------
+ device : str, optional
+ The device to get the data types for.
+ kind : str or tuple of str, optional
+ The kind of data types to return. If ``None``, all data types are
+ returned. If a string, only data types of that kind are returned.
+ If a tuple, a dictionary containing the union of the given kinds
+ is returned. The following kinds are supported:
+
+ - ``'bool'``: boolean data types (i.e., ``bool``).
+ - ``'signed integer'``: signed integer data types (i.e., ``int8``,
+ ``int16``, ``int32``, ``int64``).
+ - ``'unsigned integer'``: unsigned integer data types (i.e.,
+ ``uint8``, ``uint16``, ``uint32``, ``uint64``).
+ - ``'integral'``: integer data types. Shorthand for ``('signed
+ integer', 'unsigned integer')``.
+ - ``'real floating'``: real-valued floating-point data types
+ (i.e., ``float32``, ``float64``).
+ - ``'complex floating'``: complex floating-point data types (i.e.,
+ ``complex64``, ``complex128``).
+ - ``'numeric'``: numeric data types. Shorthand for ``('integral',
+ 'real floating', 'complex floating')``.
+
+ Returns
+ -------
+ dtypes : dict
+ A dictionary mapping the names of data types to the corresponding
+ PyTorch data types.
+
+ See Also
+ --------
+ __array_namespace_info__.capabilities,
+ __array_namespace_info__.default_device,
+ __array_namespace_info__.default_dtypes,
+ __array_namespace_info__.devices
+
+ Examples
+ --------
+ >>> info = np.__array_namespace_info__()
+ >>> info.dtypes(kind='signed integer')
+ {'int8': numpy.int8,
+ 'int16': numpy.int16,
+ 'int32': numpy.int32,
+ 'int64': numpy.int64}
+
+ """
+ res = self._dtypes(kind)
+ for k, v in res.copy().items():
+ try:
+ paddle.empty((0,), dtype=v, device=device)
+ except:
+ del res[k]
+ return res
+
+ @cache
+ def devices(self):
+ """
+ The devices supported by PyTorch.
+
+ Returns
+ -------
+ devices : list of str
+ The devices supported by PyTorch.
+
+ See Also
+ --------
+ __array_namespace_info__.capabilities,
+ __array_namespace_info__.default_device,
+ __array_namespace_info__.default_dtypes,
+ __array_namespace_info__.dtypes
+
+ Examples
+ --------
+ >>> info = np.__array_namespace_info__()
+ >>> info.devices()
+ [device(type='cpu'), device(type='mps', index=0), device(type='meta')]
+
+ """
+ # Paddle doesn't have a straightforward way to get the list of all
+ # currently supported devices. To do this, we first parse the error
+ # message of paddle.device to get the list of all possible types of
+ # device:
+ try:
+ paddle.device("notadevice")
+ except RuntimeError as e:
+ # The error message is something like:
+ # ValueError: The device must be a string which is like 'cpu', 'gpu', 'gpu:x', 'xpu', 'xpu:x', 'npu', 'npu:x
+ devices_names = (
+ e.args[0]
+ .split("ValueError: The device must be a string which is like ")[1]
+ .split(", ")
+ )
+ devices_names = [
+ name.strip("'") for name in devices_names if ":" not in name
+ ]
+
+ # Next we need to check for different indices for different devices.
+ # device(device_name, index=index) doesn't actually check if the
+ # device name or index is valid. We have to try to create a tensor
+ # with it (which is why this function is cached).
+ devices = []
+ for device_name in devices_names:
+ i = 0
+ while True:
+ try:
+ if device_name == "cpu":
+ a = paddle.empty((0,), place=paddle.CPUPlace())
+ elif device_name == "gpu":
+ a = paddle.empty((0,), place=paddle.CUDAPlace(i))
+ elif device_name == "xpu":
+ a = paddle.empty((0,), place=paddle.XPUPlace())
+ else:
+ raise
+ if a.place in devices:
+ break
+ devices.append(a.device)
+ except:
+ break
+ i += 1
+
+ return devices
diff --git a/array_api_compat/paddle/fft.py b/array_api_compat/paddle/fft.py
new file mode 100644
index 00000000..15519b5a
--- /dev/null
+++ b/array_api_compat/paddle/fft.py
@@ -0,0 +1,92 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ import paddle
+
+ array = paddle.Tensor
+ from typing import Union, Sequence, Literal
+
+from paddle.fft import * # noqa: F403
+import paddle.fft
+
+
+def fftn(
+ x: array,
+ /,
+ *,
+ s: Sequence[int] = None,
+ axes: Sequence[int] = None,
+ norm: Literal["backward", "ortho", "forward"] = "backward",
+ **kwargs,
+) -> array:
+ return paddle.fft.fftn(x, s=s, axes=axes, norm=norm, **kwargs)
+
+
+def ifftn(
+ x: array,
+ /,
+ *,
+ s: Sequence[int] = None,
+ axes: Sequence[int] = None,
+ norm: Literal["backward", "ortho", "forward"] = "backward",
+ **kwargs,
+) -> array:
+ return paddle.fft.ifftn(x, s=s, axes=axes, norm=norm, **kwargs)
+
+
+def rfftn(
+ x: array,
+ /,
+ *,
+ s: Sequence[int] = None,
+ axes: Sequence[int] = None,
+ norm: Literal["backward", "ortho", "forward"] = "backward",
+ **kwargs,
+) -> array:
+ return paddle.fft.rfftn(x, s=s, axes=axes, norm=norm, **kwargs)
+
+
+def irfftn(
+ x: array,
+ /,
+ *,
+ s: Sequence[int] = None,
+ axes: Sequence[int] = None,
+ norm: Literal["backward", "ortho", "forward"] = "backward",
+ **kwargs,
+) -> array:
+ return paddle.fft.irfftn(x, s=s, axes=axes, norm=norm, **kwargs)
+
+
+def fftshift(
+ x: array,
+ /,
+ *,
+ axes: Union[int, Sequence[int]] = None,
+ **kwargs,
+) -> array:
+ return paddle.fft.fftshift(x, axes=axes, **kwargs)
+
+
+def ifftshift(
+ x: array,
+ /,
+ *,
+ axes: Union[int, Sequence[int]] = None,
+ **kwargs,
+) -> array:
+ return paddle.fft.ifftshift(x, axes=axes, **kwargs)
+
+
+__all__ = paddle.fft.__all__ + [
+ "fftn",
+ "ifftn",
+ "rfftn",
+ "irfftn",
+ "fftshift",
+ "ifftshift",
+]
+
+_all_ignore = ["paddle"]
diff --git a/array_api_compat/paddle/linalg.py b/array_api_compat/paddle/linalg.py
new file mode 100644
index 00000000..6ee57fcf
--- /dev/null
+++ b/array_api_compat/paddle/linalg.py
@@ -0,0 +1,136 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ import paddle
+
+ array = paddle.Tensor
+ from paddle import dtype as Dtype
+ from typing import Optional, Union, Tuple, Literal
+
+ inf = float("inf")
+
+from ._aliases import _fix_promotion, sum
+
+from paddle.linalg import * # noqa: F403
+
+# paddle.linalg doesn't define __all__
+# from paddle.linalg import __all__ as linalg_all
+from paddle import linalg as paddle_linalg
+
+linalg_all = [i for i in dir(paddle_linalg) if not i.startswith("_")]
+
+# outer is implemented in paddle but aren't in the linalg namespace
+from paddle import outer
+
+# These functions are in both the main and linalg namespaces
+from ._aliases import matmul, matrix_transpose, tensordot
+
+# Note: paddle.linalg.cross does not default to axis=-1 (it defaults to the
+# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
+
+
+# paddle.cross also does not support broadcasting when it would add new
+# dimensions https://github.com/pytorch/pytorch/issues/39656
+def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
+ x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
+ if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
+ raise ValueError(
+ f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}"
+ )
+
+ if not (x1.shape[axis] == x2.shape[axis] == 3):
+ raise ValueError(
+ f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}"
+ )
+
+ x1, x2 = paddle.broadcast_tensors(x1, x2)
+ return paddle_linalg.cross(x1, x2, axis=axis)
+
+
+def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
+ from ._aliases import isdtype
+
+ x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
+
+ # paddle.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
+ if x1.shape[axis] != x2.shape[axis]:
+ raise ValueError("x1 and x2 must have the same size along the given axis")
+
+ # paddle.linalg.vecdot doesn't support integer dtypes
+ if isdtype(x1.dtype, "integral") or isdtype(x2.dtype, "integral"):
+ if kwargs:
+ raise RuntimeError("vecdot kwargs not supported for integral dtypes")
+
+ x1_ = paddle.moveaxis(x1, axis, -1)
+ x2_ = paddle.moveaxis(x2, axis, -1)
+ x1_, x2_ = paddle.broadcast_tensors(x1_, x2_)
+
+ res = x1_[..., None, :] @ x2_[..., None]
+ return res[..., 0, 0]
+ return paddle.linalg.vecdot(x1, x2, axis=axis, **kwargs)
+
+
+def solve(x1: array, x2: array, /, **kwargs) -> array:
+ x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
+
+ if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape:
+ x2 = x2[None]
+ return paddle.linalg.solve(x1, x2, **kwargs)
+
+
+# paddle.trace doesn't support the offset argument and doesn't support stacking
+def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
+ # Use our wrapped sum to make sure it does upcasting correctly
+ return sum(
+ paddle.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype
+ )
+
+
+def vector_norm(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+ ord: Union[int, float, Literal[inf, -inf]] = 2,
+ **kwargs,
+) -> array:
+ # paddle.vector_norm incorrectly treats axis=() the same as axis=None
+ if axis == ():
+ out = kwargs.get("out")
+ if out is None:
+ dtype = None
+ if x.dtype == paddle.complex64:
+ dtype = paddle.float32
+ elif x.dtype == paddle.complex128:
+ dtype = paddle.float64
+
+ out = paddle.zeros_like(x, dtype=dtype)
+
+ # The norm of a single scalar works out to abs(x) in every case except
+ # for ord=0, which is x != 0.
+ if ord == 0:
+ out[:] = x != 0
+ else:
+ out[:] = paddle.abs(x)
+ return out
+ return paddle.linalg.vector_norm(x, p=ord, axis=axis, keepdim=keepdims, **kwargs)
+
+
+__all__ = linalg_all + [
+ "outer",
+ "matmul",
+ "matrix_transpose",
+ "tensordot",
+ "cross",
+ "vecdot",
+ "solve",
+ "trace",
+ "vector_norm",
+]
+
+_all_ignore = ["paddle_linalg", "sum"]
+
+del linalg_all
diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py
index 3c9117ee..59c306af 100644
--- a/array_api_compat/torch/fft.py
+++ b/array_api_compat/torch/fft.py
@@ -2,14 +2,14 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
- import torch
- array = torch.Tensor
+ import paddle
+ array = paddle.Tensor
from typing import Union, Sequence, Literal
-from torch.fft import * # noqa: F403
-import torch.fft
+from paddle.fft import * # noqa: F403
+import paddle.fft
-# Several torch fft functions do not map axes to dim
+# Several paddle fft functions do not map axes to dim
def fftn(
x: array,
@@ -20,7 +20,7 @@ def fftn(
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
- return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs)
+ return paddle.fft.fftn(x, s=s, axes=axes, norm=norm, **kwargs)
def ifftn(
x: array,
@@ -31,7 +31,7 @@ def ifftn(
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
- return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs)
+ return paddle.fft.ifftn(x, s=s, axes=axes, norm=norm, **kwargs)
def rfftn(
x: array,
@@ -42,7 +42,7 @@ def rfftn(
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
- return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs)
+ return paddle.fft.rfftn(x, s=s, axes=axes, norm=norm, **kwargs)
def irfftn(
x: array,
@@ -53,7 +53,7 @@ def irfftn(
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
- return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs)
+ return paddle.fft.irfftn(x, s=s, axes=axes, norm=norm, **kwargs)
def fftshift(
x: array,
@@ -62,7 +62,7 @@ def fftshift(
axes: Union[int, Sequence[int]] = None,
**kwargs,
) -> array:
- return torch.fft.fftshift(x, dim=axes, **kwargs)
+ return paddle.fft.fftshift(x, axes=axes, **kwargs)
def ifftshift(
x: array,
@@ -71,10 +71,10 @@ def ifftshift(
axes: Union[int, Sequence[int]] = None,
**kwargs,
) -> array:
- return torch.fft.ifftshift(x, dim=axes, **kwargs)
+ return paddle.fft.ifftshift(x, axes=axes, **kwargs)
-__all__ = torch.fft.__all__ + [
+__all__ = paddle.fft.__all__ + [
"fftn",
"ifftn",
"rfftn",
@@ -83,4 +83,4 @@ def ifftshift(
"ifftshift",
]
-_all_ignore = ['torch']
+_all_ignore = ['paddle']
diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py
index e26198b9..5e4ee47b 100644
--- a/array_api_compat/torch/linalg.py
+++ b/array_api_compat/torch/linalg.py
@@ -2,86 +2,84 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
- import torch
- array = torch.Tensor
- from torch import dtype as Dtype
+ import paddle
+ array = paddle.Tensor
+ from paddle import dtype as Dtype
from typing import Optional, Union, Tuple, Literal
inf = float('inf')
from ._aliases import _fix_promotion, sum
-from torch.linalg import * # noqa: F403
+from paddle.linalg import * # noqa: F403
-# torch.linalg doesn't define __all__
-# from torch.linalg import __all__ as linalg_all
-from torch import linalg as torch_linalg
-linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]
+# paddle.linalg doesn't define __all__
+# from paddle.linalg import __all__ as linalg_all
+from paddle import linalg as paddle_linalg
+linalg_all = [i for i in dir(paddle_linalg) if not i.startswith('_')]
-# outer is implemented in torch but aren't in the linalg namespace
-from torch import outer
+# outer is implemented in paddle but aren't in the linalg namespace
+from paddle import outer
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot
-# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
-# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
+# Note: paddle.linalg.cross does not default to axis=-1 (it defaults to the
-# torch.cross also does not support broadcasting when it would add new
-# dimensions https://github.com/pytorch/pytorch/issues/39656
+# paddle.cross also does not support broadcasting when it would add new
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}")
if not (x1.shape[axis] == x2.shape[axis] == 3):
raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}")
- x1, x2 = torch.broadcast_tensors(x1, x2)
- return torch_linalg.cross(x1, x2, dim=axis)
+ x1, x2 = paddle.broadcast_tensors(x1, x2)
+ return paddle_linalg.cross(x1, x2, axis=axis)
def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
from ._aliases import isdtype
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
- # torch.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
+ # paddle.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
if x1.shape[axis] != x2.shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")
- # torch.linalg.vecdot doesn't support integer dtypes
+ # paddle.linalg.vecdot doesn't support integer dtypes
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
if kwargs:
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
- x1_ = torch.moveaxis(x1, axis, -1)
- x2_ = torch.moveaxis(x2, axis, -1)
- x1_, x2_ = torch.broadcast_tensors(x1_, x2_)
+ x1_ = paddle.moveaxis(x1, axis, -1)
+ x2_ = paddle.moveaxis(x2, axis, -1)
+ x1_, x2_ = paddle.broadcast_tensors(x1_, x2_)
res = x1_[..., None, :] @ x2_[..., None]
return res[..., 0, 0]
- return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)
+ return paddle.linalg.vecdot(x1, x2, axis=axis, **kwargs)
def solve(x1: array, x2: array, /, **kwargs) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
- # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve
+ # paddle tries to emulate NumPy 1 solve behavior by using batched 1-D solve
# whenever
# 1. x1.ndim - 1 == x2.ndim
# 2. x1.shape[:-1] == x2.shape
#
# See linalg_solve_is_vector_rhs in
# aten/src/ATen/native/LinearAlgebraUtils.h and
- # TORCH_META_FUNC(_linalg_solve_ex) in
- # aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code.
+ # paddle_META_FUNC(_linalg_solve_ex) in
+ # aten/src/ATen/native/BatchLinearAlgebra.cpp in the Pypaddle source code.
#
# The easiest way to work around this is to prepend a size 1 dimension to
# x2, since x2 is already one dimension less than x1.
#
- # See https://github.com/pytorch/pytorch/issues/52915
+ # See https://github.com/pypaddle/pypaddle/issues/52915
if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape:
x2 = x2[None]
- return torch.linalg.solve(x1, x2, **kwargs)
+ return paddle.linalg.solve(x1, x2, **kwargs)
-# torch.trace doesn't support the offset argument and doesn't support stacking
+# paddle.trace doesn't support the offset argument and doesn't support stacking
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
# Use our wrapped sum to make sure it does upcasting correctly
- return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
+ return sum(paddle.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
def vector_norm(
x: array,
@@ -92,30 +90,30 @@ def vector_norm(
ord: Union[int, float, Literal[inf, -inf]] = 2,
**kwargs,
) -> array:
- # torch.vector_norm incorrectly treats axis=() the same as axis=None
+ # paddle.vector_norm incorrectly treats axis=() the same as axis=None
if axis == ():
out = kwargs.get('out')
if out is None:
dtype = None
- if x.dtype == torch.complex64:
- dtype = torch.float32
- elif x.dtype == torch.complex128:
- dtype = torch.float64
+ if x.dtype == paddle.complex64:
+ dtype = paddle.float32
+ elif x.dtype == paddle.complex128:
+ dtype = paddle.float64
- out = torch.zeros_like(x, dtype=dtype)
+ out = paddle.zeros_like(x, dtype=dtype)
# The norm of a single scalar works out to abs(x) in every case except
- # for ord=0, which is x != 0.
+ # for p=0, which is x != 0.
if ord == 0:
out[:] = (x != 0)
else:
- out[:] = torch.abs(x)
+ out[:] = paddle.abs(x)
return out
- return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs)
+ return paddle.linalg.vector_norm(x, p=ord, axis=axis, keepdim=keepdims, **kwargs)
__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',
'cross', 'vecdot', 'solve', 'trace', 'vector_norm']
-_all_ignore = ['torch_linalg', 'sum']
+_all_ignore = ['paddle_linalg', 'sum']
del linalg_all
diff --git a/docs/index.md b/docs/index.md
index ef18265e..874c3866 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -60,6 +60,10 @@ import array_api_compat.torch as torch
import array_api_compat.dask as da
```
+```py
+import array_api_compat.paddle as paddle
+```
+
```{note}
There are no `array_api_compat` submodules for JAX, sparse, or ndonnx. These
support for these libraries is contained in the libraries themselves (JAX
diff --git a/docs/supported-array-libraries.md b/docs/supported-array-libraries.md
index a016a636..fa30ccd2 100644
--- a/docs/supported-array-libraries.md
+++ b/docs/supported-array-libraries.md
@@ -137,3 +137,26 @@ The minimum supported Dask version is 2023.12.0.
## [Sparse](https://sparse.pydata.org/en/stable/)
Similar to JAX, `sparse` Array API support is contained directly in `sparse`.
+
+## [Paddle](https://www.paddlepaddle.org.cn/)
+
+- Like NumPy/CuPy, we do not wrap the `paddle.Tensor` object. It is missing the
+ `__array_namespace__` and `to_device` methods, so the corresponding helper
+ functions {func}`~.array_namespace()` and {func}`~.to_device()` in this
+ library should be used instead.
+
+- Paddle does not have unsigned integer types other than `uint8`, and no
+ attempt is made to implement them here.
+
+- [`std()`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.std.html#array_api.std)
+ and
+ [`var()`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.var.html#array_api.var)
+ do not support floating-point `correction` except for `0.0` and `1.0`.
+
+- The `stream` argument of the {func}`~.to_device()` helper is not supported.
+
+- As with NumPy, type annotations and positional-only arguments may not
+ exactly match the spec for functions that are not wrapped at all.
+
+The minimum supported PyTorch version is 1.13.
+
diff --git a/requirements-dev.txt b/requirements-dev.txt
index c9d10f71..ae41a25e 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -4,5 +4,6 @@ jax[cpu]
numpy
pytest
torch
+paddlepaddle
sparse >=0.15.1
ndonnx
diff --git a/tests/_helpers.py b/tests/_helpers.py
index e2a7e1d1..0321bcb4 100644
--- a/tests/_helpers.py
+++ b/tests/_helpers.py
@@ -3,12 +3,12 @@
import pytest
-wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"]
-all_libraries = wrapped_libraries + ["jax.numpy"]
+wrapped_libraries = ["numpy", "paddle"]
+all_libraries = wrapped_libraries + []
# `sparse` added array API support as of Python 3.10.
-if sys.version_info >= (3, 10):
- all_libraries.append('sparse')
+# if sys.version_info >= (3, 10):
+# all_libraries.append('sparse')
def import_(library, wrapper=False):
if library == 'cupy':
@@ -25,4 +25,9 @@ def import_(library, wrapper=False):
else:
library = 'array_api_compat.' + library
+ if library == 'paddle':
+ xp = import_module(library)
+ xp.asarray = xp.to_tensor
+ return xp
+
return import_module(library)
diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py
index 9c26371c..4b494ec3 100644
--- a/tests/test_array_namespace.py
+++ b/tests/test_array_namespace.py
@@ -2,10 +2,11 @@
import sys
import warnings
-import jax
+# import jax
import numpy as np
import pytest
-import torch
+# import torch
+import paddle
import array_api_compat
from array_api_compat import array_namespace
@@ -72,11 +73,11 @@ def test_array_namespace(library, api_version, use_compat):
"""
subprocess.run([sys.executable, "-c", code], check=True)
-def test_jax_zero_gradient():
- jx = jax.numpy.arange(4)
- jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
- assert (array_api_compat.get_namespace(jax_zero) is
- array_api_compat.get_namespace(jx))
+# def test_jax_zero_gradient():
+# jx = jax.numpy.arange(4)
+# jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
+# assert (array_api_compat.get_namespace(jax_zero) is
+# array_api_compat.get_namespace(jx))
def test_array_namespace_errors():
pytest.raises(TypeError, lambda: array_namespace([1]))
@@ -86,26 +87,53 @@ def test_array_namespace_errors():
pytest.raises(TypeError, lambda: array_namespace((x, x)))
pytest.raises(TypeError, lambda: array_namespace(x, (x, x)))
-def test_array_namespace_errors_torch():
- y = torch.asarray([1, 2])
+# def test_array_namespace_errors_torch():
+# y = torch.asarray([1, 2])
+# x = np.asarray([1, 2])
+# pytest.raises(TypeError, lambda: array_namespace(x, y))
+
+
+def test_array_namespace_errors_paddle():
+ y = paddle.to_tensor([1, 2])
x = np.asarray([1, 2])
pytest.raises(TypeError, lambda: array_namespace(x, y))
+
+# def test_api_version():
+# x = torch.asarray([1, 2])
+# torch_ = import_("torch", wrapper=True)
+# assert array_namespace(x, api_version="2023.12") == torch_
+# assert array_namespace(x, api_version=None) == torch_
+# assert array_namespace(x) == torch_
+# # Should issue a warning
+# with warnings.catch_warnings(record=True) as w:
+# assert array_namespace(x, api_version="2021.12") == torch_
+# assert len(w) == 1
+# assert "2021.12" in str(w[0].message)
+
+# # Should issue a warning
+# with warnings.catch_warnings(record=True) as w:
+# assert array_namespace(x, api_version="2022.12") == torch_
+# assert len(w) == 1
+# assert "2022.12" in str(w[0].message)
+
+# pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12"))
+
def test_api_version():
- x = torch.asarray([1, 2])
- torch_ = import_("torch", wrapper=True)
- assert array_namespace(x, api_version="2023.12") == torch_
- assert array_namespace(x, api_version=None) == torch_
- assert array_namespace(x) == torch_
+ x = paddle.asarray([1, 2])
+ paddle_ = import_("paddle", wrapper=True)
+ assert array_namespace(x, api_version="2023.12") == paddle_
+ assert array_namespace(x, api_version=None) == paddle_
+ assert array_namespace(x) == paddle_
# Should issue a warning
with warnings.catch_warnings(record=True) as w:
- assert array_namespace(x, api_version="2021.12") == torch_
+ assert array_namespace(x, api_version="2021.12") == paddle_
assert len(w) == 1
assert "2021.12" in str(w[0].message)
# Should issue a warning
with warnings.catch_warnings(record=True) as w:
- assert array_namespace(x, api_version="2022.12") == torch_
+ assert array_namespace(x, api_version="2022.12") == paddle_
assert len(w) == 1
assert "2022.12" in str(w[0].message)
@@ -130,3 +158,19 @@ def test_python_scalars():
assert array_namespace(a, 1j) == xp
assert array_namespace(a, True) == xp
assert array_namespace(a, None) == xp
+
+def test_python_scalars():
+ a = paddle.to_tensor([1, 2])
+ xp = import_("paddle", wrapper=True)
+
+ pytest.raises(TypeError, lambda: array_namespace(1))
+ pytest.raises(TypeError, lambda: array_namespace(1.0))
+ pytest.raises(TypeError, lambda: array_namespace(1j))
+ pytest.raises(TypeError, lambda: array_namespace(True))
+ pytest.raises(TypeError, lambda: array_namespace(None))
+
+ assert array_namespace(a, 1) == xp
+ assert array_namespace(a, 1.0) == xp
+ assert array_namespace(a, 1j) == xp
+ assert array_namespace(a, True) == xp
+ assert array_namespace(a, None) == xp
diff --git a/tests/test_common.py b/tests/test_common.py
index e1cfa9eb..5c0b5826 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -1,8 +1,8 @@
from array_api_compat import ( # noqa: F401
- is_numpy_array, is_cupy_array, is_torch_array,
+ is_numpy_array, is_cupy_array, is_torch_array, is_paddle_array,
is_dask_array, is_jax_array, is_pydata_sparse_array,
is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
- is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
+ is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace, is_paddle_namespace,
)
from array_api_compat import is_array_api_obj, device, to_device
@@ -16,20 +16,22 @@
is_array_functions = {
'numpy': 'is_numpy_array',
- 'cupy': 'is_cupy_array',
- 'torch': 'is_torch_array',
- 'dask.array': 'is_dask_array',
- 'jax.numpy': 'is_jax_array',
- 'sparse': 'is_pydata_sparse_array',
+ # 'cupy': 'is_cupy_array',
+ # 'torch': 'is_torch_array',
+ # 'dask.array': 'is_dask_array',
+ # 'jax.numpy': 'is_jax_array',
+ # 'sparse': 'is_pydata_sparse_array',
+ 'paddle': 'is_paddle_array',
}
is_namespace_functions = {
'numpy': 'is_numpy_namespace',
- 'cupy': 'is_cupy_namespace',
- 'torch': 'is_torch_namespace',
- 'dask.array': 'is_dask_namespace',
- 'jax.numpy': 'is_jax_namespace',
- 'sparse': 'is_pydata_sparse_namespace',
+ # 'cupy': 'is_cupy_namespace',
+ # 'torch': 'is_torch_namespace',
+ # 'dask.array': 'is_dask_namespace',
+ # 'jax.numpy': 'is_jax_namespace',
+ # 'sparse': 'is_pydata_sparse_namespace',
+ 'paddle': 'is_paddle_namespace',
}
@@ -114,6 +116,8 @@ def test_asarray_cross_library(source_library, target_library, request):
@pytest.mark.parametrize("library", wrapped_libraries)
def test_asarray_copy(library):
+ if library == 'paddle':
+ pytest.skip("Paddle does not support explicit copies")
# Note, we have this test here because the test suite currently doesn't
# test the copy flag to asarray() very rigorously. Once
# https://github.com/data-apis/array-api-tests/issues/241 is fixed we
diff --git a/tests/test_isdtype.py b/tests/test_isdtype.py
index 6ad45d4c..e7b7d9c1 100644
--- a/tests/test_isdtype.py
+++ b/tests/test_isdtype.py
@@ -10,7 +10,7 @@
# Check the known dtypes by their string names
def _spec_dtypes(library):
- if library == 'torch':
+ if library in ['torch', 'paddle']:
# torch does not have unsigned integer dtypes
return {
'bool',
diff --git a/tests/test_no_dependencies.py b/tests/test_no_dependencies.py
index a1fdf731..201f98ea 100644
--- a/tests/test_no_dependencies.py
+++ b/tests/test_no_dependencies.py
@@ -49,8 +49,12 @@ def _test_dependency(mod):
# TODO: Test that wrapper for library X doesn't depend on wrappers for library
# Y (except most array libraries actually do themselves depend on numpy).
-@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array",
- "jax.numpy", "sparse", "array_api_strict"])
+@pytest.mark.parametrize("library",
+ [
+ "numpy",
+ "paddle", "array_api_strict",
+ ]
+)
def test_numpy_dependency(library):
# This import is here because it imports numpy
from ._helpers import import_
diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py
index 70083b49..91ca1709 100644
--- a/tests/test_vendoring.py
+++ b/tests/test_vendoring.py
@@ -7,20 +7,26 @@ def test_vendoring_numpy():
uses_numpy._test_numpy()
-def test_vendoring_cupy():
- pytest.importorskip("cupy")
+# def test_vendoring_cupy():
+# pytest.importorskip("cupy")
- from vendor_test import uses_cupy
+# from vendor_test import uses_cupy
- uses_cupy._test_cupy()
+# uses_cupy._test_cupy()
-def test_vendoring_torch():
- from vendor_test import uses_torch
+# def test_vendoring_torch():
+# from vendor_test import uses_torch
- uses_torch._test_torch()
+# uses_torch._test_torch()
-def test_vendoring_dask():
- from vendor_test import uses_dask
- uses_dask._test_dask()
+# def test_vendoring_dask():
+# from vendor_test import uses_dask
+# uses_dask._test_dask()
+
+
+def test_vendoring_paddle():
+ from vendor_test import uses_paddle
+
+ uses_paddle._test_paddle()
diff --git a/vendor_test/uses_paddle.py b/vendor_test/uses_paddle.py
new file mode 100644
index 00000000..e92257a4
--- /dev/null
+++ b/vendor_test/uses_paddle.py
@@ -0,0 +1,30 @@
+# Basic test that vendoring works
+
+from .vendored._compat import (
+ is_paddle_array,
+ is_paddle_namespace,
+ paddle as paddle_compat,
+)
+
+import paddle
+
+def _test_paddle():
+ a = paddle_compat.to_tensor([1., 2., 3.])
+ b = paddle_compat.arange(3, dtype=paddle_compat.float64)
+ assert a.dtype == paddle_compat.float32 == paddle.float32
+ assert b.dtype == paddle_compat.float64 == paddle.float64
+
+ # paddle.expand_dims does not exist. Update this to use something else if it is added
+ res = paddle_compat.expand_dims(a, axis=0)
+ assert res.dtype == paddle_compat.float32 == paddle.float32
+ assert res.shape == [1, 3]
+ assert isinstance(res.shape, list)
+ assert isinstance(a, paddle.Tensor)
+ assert isinstance(b, paddle.Tensor)
+ assert isinstance(res, paddle.Tensor)
+
+ assert paddle.allclose(res, paddle.to_tensor([[1., 2., 3.]]))
+
+ assert is_paddle_array(res)
+ assert is_paddle_namespace(paddle) and is_paddle_namespace(paddle_compat)
+
From 7118894fae4c1d101a8262a61248c4208fe83d56 Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Tue, 26 Nov 2024 12:45:05 +0800
Subject: [PATCH 02/19] update README
---
README.md | 4 +-
array_api_compat/torch/fft.py | 26 +++++------
array_api_compat/torch/linalg.py | 76 ++++++++++++++++----------------
3 files changed, 54 insertions(+), 52 deletions(-)
diff --git a/README.md b/README.md
index 4b0b0c9c..5c30919d 100644
--- a/README.md
+++ b/README.md
@@ -2,8 +2,8 @@
This is a small wrapper around common array libraries that is compatible with
the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
-NumPy, CuPy, PyTorch, Dask, JAX, ndonnx and `sparse` are supported. If you want
+NumPy, CuPy, PyTorch, Dask, JAX, ndonnx, `sparse` and Paddle are supported. If you want
support for other array libraries, or if you encounter any issues, please [open
an issue](https://github.com/data-apis/array-api-compat/issues).
-See the documentation for more details https://data-apis.org/array-api-compat/
+See the documentation for more details
diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py
index 59c306af..3c9117ee 100644
--- a/array_api_compat/torch/fft.py
+++ b/array_api_compat/torch/fft.py
@@ -2,14 +2,14 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
- import paddle
- array = paddle.Tensor
+ import torch
+ array = torch.Tensor
from typing import Union, Sequence, Literal
-from paddle.fft import * # noqa: F403
-import paddle.fft
+from torch.fft import * # noqa: F403
+import torch.fft
-# Several paddle fft functions do not map axes to dim
+# Several torch fft functions do not map axes to dim
def fftn(
x: array,
@@ -20,7 +20,7 @@ def fftn(
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
- return paddle.fft.fftn(x, s=s, axes=axes, norm=norm, **kwargs)
+ return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs)
def ifftn(
x: array,
@@ -31,7 +31,7 @@ def ifftn(
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
- return paddle.fft.ifftn(x, s=s, axes=axes, norm=norm, **kwargs)
+ return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs)
def rfftn(
x: array,
@@ -42,7 +42,7 @@ def rfftn(
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
- return paddle.fft.rfftn(x, s=s, axes=axes, norm=norm, **kwargs)
+ return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs)
def irfftn(
x: array,
@@ -53,7 +53,7 @@ def irfftn(
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
- return paddle.fft.irfftn(x, s=s, axes=axes, norm=norm, **kwargs)
+ return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs)
def fftshift(
x: array,
@@ -62,7 +62,7 @@ def fftshift(
axes: Union[int, Sequence[int]] = None,
**kwargs,
) -> array:
- return paddle.fft.fftshift(x, axes=axes, **kwargs)
+ return torch.fft.fftshift(x, dim=axes, **kwargs)
def ifftshift(
x: array,
@@ -71,10 +71,10 @@ def ifftshift(
axes: Union[int, Sequence[int]] = None,
**kwargs,
) -> array:
- return paddle.fft.ifftshift(x, axes=axes, **kwargs)
+ return torch.fft.ifftshift(x, dim=axes, **kwargs)
-__all__ = paddle.fft.__all__ + [
+__all__ = torch.fft.__all__ + [
"fftn",
"ifftn",
"rfftn",
@@ -83,4 +83,4 @@ def ifftshift(
"ifftshift",
]
-_all_ignore = ['paddle']
+_all_ignore = ['torch']
diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py
index 5e4ee47b..e26198b9 100644
--- a/array_api_compat/torch/linalg.py
+++ b/array_api_compat/torch/linalg.py
@@ -2,84 +2,86 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
- import paddle
- array = paddle.Tensor
- from paddle import dtype as Dtype
+ import torch
+ array = torch.Tensor
+ from torch import dtype as Dtype
from typing import Optional, Union, Tuple, Literal
inf = float('inf')
from ._aliases import _fix_promotion, sum
-from paddle.linalg import * # noqa: F403
+from torch.linalg import * # noqa: F403
-# paddle.linalg doesn't define __all__
-# from paddle.linalg import __all__ as linalg_all
-from paddle import linalg as paddle_linalg
-linalg_all = [i for i in dir(paddle_linalg) if not i.startswith('_')]
+# torch.linalg doesn't define __all__
+# from torch.linalg import __all__ as linalg_all
+from torch import linalg as torch_linalg
+linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]
-# outer is implemented in paddle but aren't in the linalg namespace
-from paddle import outer
+# outer is implemented in torch but aren't in the linalg namespace
+from torch import outer
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot
-# Note: paddle.linalg.cross does not default to axis=-1 (it defaults to the
+# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
+# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
-# paddle.cross also does not support broadcasting when it would add new
+# torch.cross also does not support broadcasting when it would add new
+# dimensions https://github.com/pytorch/pytorch/issues/39656
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}")
if not (x1.shape[axis] == x2.shape[axis] == 3):
raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}")
- x1, x2 = paddle.broadcast_tensors(x1, x2)
- return paddle_linalg.cross(x1, x2, axis=axis)
+ x1, x2 = torch.broadcast_tensors(x1, x2)
+ return torch_linalg.cross(x1, x2, dim=axis)
def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
from ._aliases import isdtype
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
- # paddle.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
+ # torch.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
if x1.shape[axis] != x2.shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")
- # paddle.linalg.vecdot doesn't support integer dtypes
+ # torch.linalg.vecdot doesn't support integer dtypes
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
if kwargs:
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
- x1_ = paddle.moveaxis(x1, axis, -1)
- x2_ = paddle.moveaxis(x2, axis, -1)
- x1_, x2_ = paddle.broadcast_tensors(x1_, x2_)
+ x1_ = torch.moveaxis(x1, axis, -1)
+ x2_ = torch.moveaxis(x2, axis, -1)
+ x1_, x2_ = torch.broadcast_tensors(x1_, x2_)
res = x1_[..., None, :] @ x2_[..., None]
return res[..., 0, 0]
- return paddle.linalg.vecdot(x1, x2, axis=axis, **kwargs)
+ return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)
def solve(x1: array, x2: array, /, **kwargs) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
- # paddle tries to emulate NumPy 1 solve behavior by using batched 1-D solve
+ # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve
# whenever
# 1. x1.ndim - 1 == x2.ndim
# 2. x1.shape[:-1] == x2.shape
#
# See linalg_solve_is_vector_rhs in
# aten/src/ATen/native/LinearAlgebraUtils.h and
- # paddle_META_FUNC(_linalg_solve_ex) in
- # aten/src/ATen/native/BatchLinearAlgebra.cpp in the Pypaddle source code.
+ # TORCH_META_FUNC(_linalg_solve_ex) in
+ # aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code.
#
# The easiest way to work around this is to prepend a size 1 dimension to
# x2, since x2 is already one dimension less than x1.
#
- # See https://github.com/pypaddle/pypaddle/issues/52915
+ # See https://github.com/pytorch/pytorch/issues/52915
if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape:
x2 = x2[None]
- return paddle.linalg.solve(x1, x2, **kwargs)
+ return torch.linalg.solve(x1, x2, **kwargs)
-# paddle.trace doesn't support the offset argument and doesn't support stacking
+# torch.trace doesn't support the offset argument and doesn't support stacking
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
# Use our wrapped sum to make sure it does upcasting correctly
- return sum(paddle.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
+ return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
def vector_norm(
x: array,
@@ -90,30 +92,30 @@ def vector_norm(
ord: Union[int, float, Literal[inf, -inf]] = 2,
**kwargs,
) -> array:
- # paddle.vector_norm incorrectly treats axis=() the same as axis=None
+ # torch.vector_norm incorrectly treats axis=() the same as axis=None
if axis == ():
out = kwargs.get('out')
if out is None:
dtype = None
- if x.dtype == paddle.complex64:
- dtype = paddle.float32
- elif x.dtype == paddle.complex128:
- dtype = paddle.float64
+ if x.dtype == torch.complex64:
+ dtype = torch.float32
+ elif x.dtype == torch.complex128:
+ dtype = torch.float64
- out = paddle.zeros_like(x, dtype=dtype)
+ out = torch.zeros_like(x, dtype=dtype)
# The norm of a single scalar works out to abs(x) in every case except
- # for p=0, which is x != 0.
+ # for ord=0, which is x != 0.
if ord == 0:
out[:] = (x != 0)
else:
- out[:] = paddle.abs(x)
+ out[:] = torch.abs(x)
return out
- return paddle.linalg.vector_norm(x, p=ord, axis=axis, keepdim=keepdims, **kwargs)
+ return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs)
__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',
'cross', 'vecdot', 'solve', 'trace', 'vector_norm']
-_all_ignore = ['paddle_linalg', 'sum']
+_all_ignore = ['torch_linalg', 'sum']
del linalg_all
From 85dc3bafb4f3a9e9e351b8f1037e32360009dd1e Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Tue, 26 Nov 2024 14:05:11 +0800
Subject: [PATCH 03/19] update promotion table and can_cast table
---
array_api_compat/common/_helpers.py | 3 +-
array_api_compat/paddle/_aliases.py | 121 ++++++++++++----------------
tests/_helpers.py | 2 +-
tests/test_all.py | 4 +-
tests/test_common.py | 11 ++-
5 files changed, 64 insertions(+), 77 deletions(-)
diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py
index ff2c213f..ec6b3e0d 100644
--- a/array_api_compat/common/_helpers.py
+++ b/array_api_compat/common/_helpers.py
@@ -144,7 +144,6 @@ def is_paddle_array(x):
import paddle
- # TODO: Should we reject ndarray subclasses?
return paddle.is_tensor(x)
def is_ndonnx_array(x):
@@ -725,7 +724,7 @@ def device(x: Array, /) -> Device:
return "cpu"
elif "gpu" in raw_place_str:
return "gpu"
- raise NotImplementedError(f"Unsupported device {raw_place_str}")
+ raise ValueError(f"Unsupported Paddle device: {x.place}")
return x.device
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index dabe2928..14d3de7f 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -42,37 +42,18 @@
paddle.complex128,
}
+# NOTE: Implicit promotion rules of Paddle is a bit strict than other frameworks,
+# see details: https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/advanced/auto_type_promotion_cn.html
_promotion_table = {
# bool
(paddle.bool, paddle.bool): paddle.bool,
# ints
(paddle.int8, paddle.int8): paddle.int8,
- (paddle.int8, paddle.int16): paddle.int16,
- (paddle.int8, paddle.int32): paddle.int32,
- (paddle.int8, paddle.int64): paddle.int64,
- (paddle.int16, paddle.int8): paddle.int16,
(paddle.int16, paddle.int16): paddle.int16,
- (paddle.int16, paddle.int32): paddle.int32,
- (paddle.int16, paddle.int64): paddle.int64,
- (paddle.int32, paddle.int8): paddle.int32,
- (paddle.int32, paddle.int16): paddle.int32,
(paddle.int32, paddle.int32): paddle.int32,
- (paddle.int32, paddle.int64): paddle.int64,
- (paddle.int64, paddle.int8): paddle.int64,
- (paddle.int64, paddle.int16): paddle.int64,
- (paddle.int64, paddle.int32): paddle.int64,
(paddle.int64, paddle.int64): paddle.int64,
# uints
(paddle.uint8, paddle.uint8): paddle.uint8,
- # ints and uints (mixed sign)
- (paddle.int8, paddle.uint8): paddle.int16,
- (paddle.int16, paddle.uint8): paddle.int16,
- (paddle.int32, paddle.uint8): paddle.int32,
- (paddle.int64, paddle.uint8): paddle.int64,
- (paddle.uint8, paddle.int8): paddle.int16,
- (paddle.uint8, paddle.int16): paddle.int16,
- (paddle.uint8, paddle.int32): paddle.int32,
- (paddle.uint8, paddle.int64): paddle.int64,
# floats
(paddle.float32, paddle.float32): paddle.float32,
(paddle.float32, paddle.float64): paddle.float64,
@@ -158,12 +139,12 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.float64: True,
paddle.complex64: True,
paddle.complex128: True,
- paddle.uint8: False,
- paddle.int8: False,
- paddle.int16: False,
- paddle.int32: False,
- paddle.int64: False,
- paddle.bool: False,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: True,
},
paddle.float16: {
paddle.bfloat16: True,
@@ -172,12 +153,12 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.float64: True,
paddle.complex64: True,
paddle.complex128: True,
- paddle.uint8: False,
- paddle.int8: False,
- paddle.int16: False,
- paddle.int32: False,
- paddle.int64: False,
- paddle.bool: False,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: True,
},
paddle.float32: {
paddle.bfloat16: True,
@@ -186,12 +167,12 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.float64: True,
paddle.complex64: True,
paddle.complex128: True,
- paddle.uint8: False,
- paddle.int8: False,
- paddle.int16: False,
- paddle.int32: False,
- paddle.int64: False,
- paddle.bool: False,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: True,
},
paddle.float64: {
paddle.bfloat16: True,
@@ -200,40 +181,40 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.float64: True,
paddle.complex64: True,
paddle.complex128: True,
- paddle.uint8: False,
- paddle.int8: False,
- paddle.int16: False,
- paddle.int32: False,
- paddle.int64: False,
- paddle.bool: False,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: True,
},
paddle.complex64: {
- paddle.bfloat16: False,
- paddle.float16: False,
- paddle.float32: False,
- paddle.float64: False,
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
paddle.complex64: True,
paddle.complex128: True,
- paddle.uint8: False,
- paddle.int8: False,
- paddle.int16: False,
- paddle.int32: False,
- paddle.int64: False,
- paddle.bool: False,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: True,
},
paddle.complex128: {
- paddle.bfloat16: False,
- paddle.float16: False,
- paddle.float32: False,
- paddle.float64: False,
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
paddle.complex64: True,
paddle.complex128: True,
- paddle.uint8: False,
- paddle.int8: False,
- paddle.int16: False,
- paddle.int32: False,
- paddle.int64: False,
- paddle.bool: False,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: True,
},
paddle.uint8: {
paddle.bfloat16: True,
@@ -247,7 +228,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.int16: True,
paddle.int32: True,
paddle.int64: True,
- paddle.bool: False,
+ paddle.bool: True,
},
paddle.int8: {
paddle.bfloat16: True,
@@ -261,7 +242,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.int16: True,
paddle.int32: True,
paddle.int64: True,
- paddle.bool: False,
+ paddle.bool: True,
},
paddle.int16: {
paddle.bfloat16: True,
@@ -275,7 +256,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.int16: True,
paddle.int32: True,
paddle.int64: True,
- paddle.bool: False,
+ paddle.bool: True,
},
paddle.int32: {
paddle.bfloat16: True,
@@ -289,7 +270,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.int16: True,
paddle.int32: True,
paddle.int64: True,
- paddle.bool: False,
+ paddle.bool: True,
},
paddle.int64: {
paddle.bfloat16: True,
@@ -303,7 +284,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.int16: True,
paddle.int32: True,
paddle.int64: True,
- paddle.bool: False,
+ paddle.bool: True,
},
paddle.bool: {
paddle.bfloat16: True,
diff --git a/tests/_helpers.py b/tests/_helpers.py
index 0321bcb4..07f1859a 100644
--- a/tests/_helpers.py
+++ b/tests/_helpers.py
@@ -3,7 +3,7 @@
import pytest
-wrapped_libraries = ["numpy", "paddle"]
+wrapped_libraries = ["numpy", "paddle", "torch"]
all_libraries = wrapped_libraries + []
# `sparse` added array API support as of Python 3.10.
diff --git a/tests/test_all.py b/tests/test_all.py
index 969d5cfb..7528b22e 100644
--- a/tests/test_all.py
+++ b/tests/test_all.py
@@ -40,5 +40,5 @@ def test_all(library):
all_names = module.__all__
if set(dir_names) != set(all_names):
- assert set(dir_names) - set(all_names) == set(), f"Some dir() names not included in __all__ for {mod_name}"
- assert set(all_names) - set(dir_names) == set(), f"Some __all__ names not in dir() for {mod_name}"
+ assert set(dir_names) - set(all_names) == set(), f"Failed in library '{library}', some dir() names not included in __all__ for {mod_name}"
+ assert set(all_names) - set(dir_names) == set(), f"Failed in library '{library}', some __all__ names not in dir() for {mod_name}"
diff --git a/tests/test_common.py b/tests/test_common.py
index 5c0b5826..a46a2be2 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -17,7 +17,7 @@
is_array_functions = {
'numpy': 'is_numpy_array',
# 'cupy': 'is_cupy_array',
- # 'torch': 'is_torch_array',
+ 'torch': 'is_torch_array',
# 'dask.array': 'is_dask_array',
# 'jax.numpy': 'is_jax_array',
# 'sparse': 'is_pydata_sparse_array',
@@ -27,7 +27,7 @@
is_namespace_functions = {
'numpy': 'is_numpy_namespace',
# 'cupy': 'is_cupy_namespace',
- # 'torch': 'is_torch_namespace',
+ 'torch': 'is_torch_namespace',
# 'dask.array': 'is_dask_namespace',
# 'jax.numpy': 'is_jax_namespace',
# 'sparse': 'is_pydata_sparse_namespace',
@@ -103,6 +103,13 @@ def test_asarray_cross_library(source_library, target_library, request):
if source_library == "cupy" and target_library != "cupy":
# cupy explicitly disallows implicit conversions to CPU
pytest.skip(reason="cupy does not support implicit conversion to CPU")
+ if source_library == "paddle" or target_library == "paddle":
+ pytest.skip(
+ reason=(
+ "paddle does not support implicit conversion from/to other framework "
+ "via 'asarray', dlpack is recommend now."
+ )
+ )
elif source_library == "sparse" and target_library != "sparse":
pytest.skip(reason="`sparse` does not allow implicit densification")
src_lib = import_(source_library, wrapper=True)
From c5b82db6f6429a3ed6f4d08a7bac39a9ff6bfd1c Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Tue, 26 Nov 2024 14:14:17 +0800
Subject: [PATCH 04/19] update doc
---
docs/supported-array-libraries.md | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/docs/supported-array-libraries.md b/docs/supported-array-libraries.md
index fa30ccd2..26a1c1c5 100644
--- a/docs/supported-array-libraries.md
+++ b/docs/supported-array-libraries.md
@@ -158,5 +158,4 @@ Similar to JAX, `sparse` Array API support is contained directly in `sparse`.
- As with NumPy, type annotations and positional-only arguments may not
exactly match the spec for functions that are not wrapped at all.
-The minimum supported PyTorch version is 1.13.
-
+The minimum supported PyTorch version is 3.0.0.
From 7b99449f2634f05ecae2d4f046355e445ec481de Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Tue, 26 Nov 2024 14:18:17 +0800
Subject: [PATCH 05/19] restore code
---
tests/_helpers.py | 8 ++---
tests/test_all.py | 4 +--
tests/test_array_namespace.py | 55 +++++++++++------------------------
tests/test_common.py | 16 +++++-----
tests/test_no_dependencies.py | 4 +--
tests/test_vendoring.py | 20 ++++++-------
6 files changed, 43 insertions(+), 64 deletions(-)
diff --git a/tests/_helpers.py b/tests/_helpers.py
index 07f1859a..801cd32d 100644
--- a/tests/_helpers.py
+++ b/tests/_helpers.py
@@ -3,12 +3,12 @@
import pytest
-wrapped_libraries = ["numpy", "paddle", "torch"]
-all_libraries = wrapped_libraries + []
+wrapped_libraries = ["numpy", "cupy", "torch", "dask.array", "paddle"]
+all_libraries = wrapped_libraries + ["jax.numpy"]
# `sparse` added array API support as of Python 3.10.
-# if sys.version_info >= (3, 10):
-# all_libraries.append('sparse')
+if sys.version_info >= (3, 10):
+ all_libraries.append('sparse')
def import_(library, wrapper=False):
if library == 'cupy':
diff --git a/tests/test_all.py b/tests/test_all.py
index 7528b22e..969d5cfb 100644
--- a/tests/test_all.py
+++ b/tests/test_all.py
@@ -40,5 +40,5 @@ def test_all(library):
all_names = module.__all__
if set(dir_names) != set(all_names):
- assert set(dir_names) - set(all_names) == set(), f"Failed in library '{library}', some dir() names not included in __all__ for {mod_name}"
- assert set(all_names) - set(dir_names) == set(), f"Failed in library '{library}', some __all__ names not in dir() for {mod_name}"
+ assert set(dir_names) - set(all_names) == set(), f"Some dir() names not included in __all__ for {mod_name}"
+ assert set(all_names) - set(dir_names) == set(), f"Some __all__ names not in dir() for {mod_name}"
diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py
index 4b494ec3..cd25a931 100644
--- a/tests/test_array_namespace.py
+++ b/tests/test_array_namespace.py
@@ -5,7 +5,7 @@
# import jax
import numpy as np
import pytest
-# import torch
+import torch
import paddle
import array_api_compat
@@ -73,11 +73,11 @@ def test_array_namespace(library, api_version, use_compat):
"""
subprocess.run([sys.executable, "-c", code], check=True)
-# def test_jax_zero_gradient():
-# jx = jax.numpy.arange(4)
-# jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
-# assert (array_api_compat.get_namespace(jax_zero) is
-# array_api_compat.get_namespace(jx))
+def test_jax_zero_gradient():
+ jx = jax.numpy.arange(4)
+ jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
+ assert (array_api_compat.get_namespace(jax_zero) is
+ array_api_compat.get_namespace(jx))
def test_array_namespace_errors():
pytest.raises(TypeError, lambda: array_namespace([1]))
@@ -87,10 +87,10 @@ def test_array_namespace_errors():
pytest.raises(TypeError, lambda: array_namespace((x, x)))
pytest.raises(TypeError, lambda: array_namespace(x, (x, x)))
-# def test_array_namespace_errors_torch():
-# y = torch.asarray([1, 2])
-# x = np.asarray([1, 2])
-# pytest.raises(TypeError, lambda: array_namespace(x, y))
+def test_array_namespace_errors_torch():
+ y = torch.asarray([1, 2])
+ x = np.asarray([1, 2])
+ pytest.raises(TypeError, lambda: array_namespace(x, y))
def test_array_namespace_errors_paddle():
@@ -98,42 +98,21 @@ def test_array_namespace_errors_paddle():
x = np.asarray([1, 2])
pytest.raises(TypeError, lambda: array_namespace(x, y))
-
-# def test_api_version():
-# x = torch.asarray([1, 2])
-# torch_ = import_("torch", wrapper=True)
-# assert array_namespace(x, api_version="2023.12") == torch_
-# assert array_namespace(x, api_version=None) == torch_
-# assert array_namespace(x) == torch_
-# # Should issue a warning
-# with warnings.catch_warnings(record=True) as w:
-# assert array_namespace(x, api_version="2021.12") == torch_
-# assert len(w) == 1
-# assert "2021.12" in str(w[0].message)
-
-# # Should issue a warning
-# with warnings.catch_warnings(record=True) as w:
-# assert array_namespace(x, api_version="2022.12") == torch_
-# assert len(w) == 1
-# assert "2022.12" in str(w[0].message)
-
-# pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12"))
-
def test_api_version():
- x = paddle.asarray([1, 2])
- paddle_ = import_("paddle", wrapper=True)
- assert array_namespace(x, api_version="2023.12") == paddle_
- assert array_namespace(x, api_version=None) == paddle_
- assert array_namespace(x) == paddle_
+ x = torch.asarray([1, 2])
+ torch_ = import_("torch", wrapper=True)
+ assert array_namespace(x, api_version="2023.12") == torch_
+ assert array_namespace(x, api_version=None) == torch_
+ assert array_namespace(x) == torch_
# Should issue a warning
with warnings.catch_warnings(record=True) as w:
- assert array_namespace(x, api_version="2021.12") == paddle_
+ assert array_namespace(x, api_version="2021.12") == torch_
assert len(w) == 1
assert "2021.12" in str(w[0].message)
# Should issue a warning
with warnings.catch_warnings(record=True) as w:
- assert array_namespace(x, api_version="2022.12") == paddle_
+ assert array_namespace(x, api_version="2022.12") == torch_
assert len(w) == 1
assert "2022.12" in str(w[0].message)
diff --git a/tests/test_common.py b/tests/test_common.py
index a46a2be2..23ac53d1 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -16,21 +16,21 @@
is_array_functions = {
'numpy': 'is_numpy_array',
- # 'cupy': 'is_cupy_array',
+ 'cupy': 'is_cupy_array',
'torch': 'is_torch_array',
- # 'dask.array': 'is_dask_array',
- # 'jax.numpy': 'is_jax_array',
- # 'sparse': 'is_pydata_sparse_array',
+ 'dask.array': 'is_dask_array',
+ 'jax.numpy': 'is_jax_array',
+ 'sparse': 'is_pydata_sparse_array',
'paddle': 'is_paddle_array',
}
is_namespace_functions = {
'numpy': 'is_numpy_namespace',
- # 'cupy': 'is_cupy_namespace',
+ 'cupy': 'is_cupy_namespace',
'torch': 'is_torch_namespace',
- # 'dask.array': 'is_dask_namespace',
- # 'jax.numpy': 'is_jax_namespace',
- # 'sparse': 'is_pydata_sparse_namespace',
+ 'dask.array': 'is_dask_namespace',
+ 'jax.numpy': 'is_jax_namespace',
+ 'sparse': 'is_pydata_sparse_namespace',
'paddle': 'is_paddle_namespace',
}
diff --git a/tests/test_no_dependencies.py b/tests/test_no_dependencies.py
index 201f98ea..11a516ac 100644
--- a/tests/test_no_dependencies.py
+++ b/tests/test_no_dependencies.py
@@ -51,8 +51,8 @@ def _test_dependency(mod):
@pytest.mark.parametrize("library",
[
- "numpy",
- "paddle", "array_api_strict",
+ "numpy", "cupy", "numpy", "torch", "dask.array",
+ "jax.numpy", "sparse", "paddle", "array_api_strict"
]
)
def test_numpy_dependency(library):
diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py
index 91ca1709..3c9b5d92 100644
--- a/tests/test_vendoring.py
+++ b/tests/test_vendoring.py
@@ -7,23 +7,23 @@ def test_vendoring_numpy():
uses_numpy._test_numpy()
-# def test_vendoring_cupy():
-# pytest.importorskip("cupy")
+def test_vendoring_cupy():
+ pytest.importorskip("cupy")
-# from vendor_test import uses_cupy
+ from vendor_test import uses_cupy
-# uses_cupy._test_cupy()
+ uses_cupy._test_cupy()
-# def test_vendoring_torch():
-# from vendor_test import uses_torch
+def test_vendoring_torch():
+ from vendor_test import uses_torch
-# uses_torch._test_torch()
+ uses_torch._test_torch()
-# def test_vendoring_dask():
-# from vendor_test import uses_dask
-# uses_dask._test_dask()
+def test_vendoring_dask():
+ from vendor_test import uses_dask
+ uses_dask._test_dask()
def test_vendoring_paddle():
From bb40851d5b060886b08d32c2122df1186539754b Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Tue, 26 Nov 2024 14:20:23 +0800
Subject: [PATCH 06/19] update docstring
---
array_api_compat/paddle/_info.py | 26 +++++++++++++-------------
1 file changed, 13 insertions(+), 13 deletions(-)
diff --git a/array_api_compat/paddle/_info.py b/array_api_compat/paddle/_info.py
index 1fe48356..97e78960 100644
--- a/array_api_compat/paddle/_info.py
+++ b/array_api_compat/paddle/_info.py
@@ -15,7 +15,7 @@
class __array_namespace_info__:
"""
- Get the array API inspection namespace for PyTorch.
+ Get the array API inspection namespace for Paddle.
The array API inspection namespace defines the following functions:
@@ -32,7 +32,7 @@ class __array_namespace_info__:
Returns
-------
info : ModuleType
- The array API inspection namespace for PyTorch.
+ The array API inspection namespace for Paddle.
Examples
--------
@@ -54,11 +54,11 @@ def capabilities(self):
The resulting dictionary has the following keys:
- **"boolean indexing"**: boolean indicating whether an array library
- supports boolean indexing. Always ``True`` for PyTorch.
+ supports boolean indexing. Always ``True`` for Paddle.
- **"data-dependent shapes"**: boolean indicating whether an array
library supports data-dependent output shapes. Always ``True`` for
- PyTorch.
+ Paddle.
See
https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
@@ -93,7 +93,7 @@ def capabilities(self):
def default_device(self):
"""
- The default device used for new PyTorch arrays.
+ The default device used for new Paddle arrays.
See Also
--------
@@ -105,7 +105,7 @@ def default_device(self):
Returns
-------
device : str
- The default device used for new PyTorch arrays.
+ The default device used for new Paddle arrays.
Examples
--------
@@ -118,18 +118,18 @@ def default_device(self):
def default_dtypes(self, *, device=None):
"""
- The default data types used for new PyTorch arrays.
+ The default data types used for new Paddle arrays.
Parameters
----------
device : str, optional
- The device to get the default data types for. For PyTorch, only
+ The device to get the default data types for. For Paddle, only
``'cpu'`` is allowed.
Returns
-------
dtypes : dict
- A dictionary describing the default data types used for new PyTorch
+ A dictionary describing the default data types used for new Paddle
arrays.
See Also
@@ -244,7 +244,7 @@ def _dtypes(self, kind):
@cache
def dtypes(self, *, device=None, kind=None):
"""
- The array API data types supported by PyTorch.
+ The array API data types supported by Paddle.
Note that this function only returns data types that are defined by
the array API.
@@ -277,7 +277,7 @@ def dtypes(self, *, device=None, kind=None):
-------
dtypes : dict
A dictionary mapping the names of data types to the corresponding
- PyTorch data types.
+ Paddle data types.
See Also
--------
@@ -307,12 +307,12 @@ def dtypes(self, *, device=None, kind=None):
@cache
def devices(self):
"""
- The devices supported by PyTorch.
+ The devices supported by Paddle.
Returns
-------
devices : list of str
- The devices supported by PyTorch.
+ The devices supported by Paddle.
See Also
--------
From a7163f903796684d6f25cb420d72aaadb416433c Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Tue, 26 Nov 2024 14:35:33 +0800
Subject: [PATCH 07/19] refine more code
---
array_api_compat/paddle/_info.py | 3 +--
array_api_compat/paddle/linalg.py | 5 ++---
tests/test_array_namespace.py | 2 +-
3 files changed, 4 insertions(+), 6 deletions(-)
diff --git a/array_api_compat/paddle/_info.py b/array_api_compat/paddle/_info.py
index 97e78960..d8dab7ee 100644
--- a/array_api_compat/paddle/_info.py
+++ b/array_api_compat/paddle/_info.py
@@ -170,8 +170,7 @@ def _dtypes(self, kind):
int32 = paddle.int32
int64 = paddle.int64
uint8 = paddle.uint8
- # uint16, uint32, and uint64 are present in newer versions of pytorch,
- # but they aren't generally supported by the array API functions, so
+ # uint16, uint32, and uint64 are not fully supported in paddle,
# we omit them from this function.
float32 = paddle.float32
float64 = paddle.float64
diff --git a/array_api_compat/paddle/linalg.py b/array_api_compat/paddle/linalg.py
index 6ee57fcf..7ef04a90 100644
--- a/array_api_compat/paddle/linalg.py
+++ b/array_api_compat/paddle/linalg.py
@@ -28,11 +28,10 @@
from ._aliases import matmul, matrix_transpose, tensordot
# Note: paddle.linalg.cross does not default to axis=-1 (it defaults to the
-# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
-
+# first axis with size 3)
# paddle.cross also does not support broadcasting when it would add new
-# dimensions https://github.com/pytorch/pytorch/issues/39656
+# dimensions
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py
index cd25a931..e9e7458f 100644
--- a/tests/test_array_namespace.py
+++ b/tests/test_array_namespace.py
@@ -2,7 +2,7 @@
import sys
import warnings
-# import jax
+import jax
import numpy as np
import pytest
import torch
From ec461786832538cadb1af01bf41fa023621252fb Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Tue, 26 Nov 2024 22:26:24 +0800
Subject: [PATCH 08/19] add suffix for test_python_scalars and add paddle
index-url in rqeuirements
---
requirements-dev.txt | 2 +-
tests/test_array_namespace.py | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/requirements-dev.txt b/requirements-dev.txt
index ae41a25e..7ad022d7 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -4,6 +4,6 @@ jax[cpu]
numpy
pytest
torch
-paddlepaddle
+paddlepaddle -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/
sparse >=0.15.1
ndonnx
diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py
index e9e7458f..4076c74c 100644
--- a/tests/test_array_namespace.py
+++ b/tests/test_array_namespace.py
@@ -122,7 +122,7 @@ def test_get_namespace():
# Backwards compatible wrapper
assert array_api_compat.get_namespace is array_api_compat.array_namespace
-def test_python_scalars():
+def test_python_scalars_torch():
a = torch.asarray([1, 2])
xp = import_("torch", wrapper=True)
@@ -138,7 +138,7 @@ def test_python_scalars():
assert array_namespace(a, True) == xp
assert array_namespace(a, None) == xp
-def test_python_scalars():
+def test_python_scalars_paddle():
a = paddle.to_tensor([1, 2])
xp = import_("paddle", wrapper=True)
From dfd448518b35867ea4ee99a474a51054ac02f2f0 Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Tue, 3 Dec 2024 14:53:05 +0800
Subject: [PATCH 09/19] update paddle code
---
array_api_compat/paddle/__init__.py | 12 +-
array_api_compat/paddle/_aliases.py | 388 ++++++++++++++++++++--------
array_api_compat/paddle/_info.py | 14 +-
array_api_compat/paddle/fft.py | 31 ++-
array_api_compat/paddle/linalg.py | 48 +++-
5 files changed, 349 insertions(+), 144 deletions(-)
diff --git a/array_api_compat/paddle/__init__.py b/array_api_compat/paddle/__init__.py
index 9f96fa9f..1016312d 100644
--- a/array_api_compat/paddle/__init__.py
+++ b/array_api_compat/paddle/__init__.py
@@ -4,16 +4,10 @@
import paddle
for n in dir(paddle):
- if (
- n.startswith("_")
- or n.endswith("_")
- or "gpu" in n
- or "cpu" in n
- or "backward" in n
- ):
+ if n.startswith("_") or n.endswith("_") or "gpu" in n or "cpu" in n or "backward" in n:
continue
- exec(n + " = paddle." + n)
- exec("asarray = paddle.to_tensor")
+ exec(f"{n} = paddle.{n}")
+
# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index 14d3de7f..601afa5f 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -1,14 +1,17 @@
from __future__ import annotations
+from typing import Literal
+import numpy as np
+
from functools import wraps as _wraps
from builtins import all as _builtin_all, any as _builtin_any
from ..common._aliases import (
- matrix_transpose as _aliases_matrix_transpose,
- vecdot as _aliases_vecdot,
- clip as _aliases_clip,
unstack as _aliases_unstack,
- cumulative_sum as _aliases_cumulative_sum,
+)
+from ..common._typing import (
+ SupportsBufferProtocol,
+ NestedSequence,
)
from .._internal import get_xp
@@ -94,7 +97,7 @@ def _fix_promotion(x1, x2, only_scalar=True):
return x1, x2
if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes:
return x1, x2
- # If an argument is 0-D pytorch downcasts the other argument
+ # If an argument is 0-D paddle downcasts the other argument
if not only_scalar or x1.shape == ():
dtype = result_type(x1, x2)
x2 = x2.to(dtype)
@@ -131,6 +134,12 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
+ if paddle.is_tensor(from_):
+ from_ = from_.dtype
+
+ assert isinstance(from_, paddle.dtype), from_.dtype
+ assert isinstance(to, paddle.dtype), to.dtype
+
can_cast_dict = {
paddle.bfloat16: {
paddle.bfloat16: True,
@@ -341,9 +350,6 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
remainder = _two_arg(paddle.remainder)
subtract = _two_arg(paddle.subtract)
-# These wrappers are mostly based on the fact that pytorch uses 'dim' instead
-# of 'axis'.
-
def max(
x: array,
@@ -352,12 +358,21 @@ def max(
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
) -> array:
- # https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return paddle.clone(x)
return paddle.amax(x, axis, keepdim=keepdims)
+def argmax(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+) -> array:
+ return paddle.argmax(x, axis, keepdim=keepdims)
+
+
def min(
x: array,
/,
@@ -365,19 +380,25 @@ def min(
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
) -> array:
- # https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return paddle.clone(x)
return paddle.min(x, axis, keepdim=keepdims)
-clip = get_xp(paddle)(_aliases_clip)
+def argmin(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+) -> array:
+ return paddle.argmin(x, axis, keepdim=keepdims)
+
+
unstack = get_xp(paddle)(_aliases_unstack)
-cumulative_sum = get_xp(paddle)(_aliases_cumulative_sum)
# paddle.sort also returns a tuple
-# https://github.com/pytorch/pytorch/issues/70921
def sort(
x: array,
/,
@@ -387,9 +408,7 @@ def sort(
stable: bool = True,
**kwargs,
) -> array:
- return paddle.sort(
- x, axis=axis, descending=descending, stable=stable, **kwargs
- ).values
+ return paddle.sort(x, axis=axis, descending=descending, stable=stable, **kwargs)
def _normalize_axes(axis, ndim):
@@ -401,9 +420,7 @@ def _normalize_axes(axis, ndim):
for a in axis:
if a < lower or a > upper:
# Match paddle error message (e.g., from sum())
- raise IndexError(
- f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}"
- )
+ raise IndexError(f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}")
if a < 0:
a = a + ndim
if a in axes:
@@ -415,7 +432,6 @@ def _normalize_axes(axis, ndim):
def _axis_none_keepdims(x, ndim, keepdims):
# Apply keepdims when axis=None
- # (https://github.com/pytorch/pytorch/issues/71209)
# Note that this is only valid for the axis=None case.
if keepdims:
for i in range(ndim):
@@ -425,7 +441,6 @@ def _axis_none_keepdims(x, ndim, keepdims):
def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
# Some reductions don't support multiple axes
- # (https://github.com/pytorch/pytorch/issues/56586).
axes = _normalize_axes(axis, x.ndim)
for a in reversed(axes):
x = paddle.movedim(x, a, -1)
@@ -448,10 +463,10 @@ def prod(
keepdims: bool = False,
**kwargs,
) -> array:
- x = paddle.asarray(x)
+ if not paddle.is_tensor(x):
+ x = paddle.to_tensor(x)
ndim = x.ndim
- # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
# below because it still needs to upcast.
if axis == ():
if dtype is None:
@@ -464,14 +479,10 @@ def prod(
return x.to(dtype)
# paddle.prod doesn't support multiple axes
- # (https://github.com/pytorch/pytorch/issues/56586).
if isinstance(axis, tuple):
- return _reduce_multiple_axes(
- paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs
- )
+ return _reduce_multiple_axes(paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs)
if axis is None:
# paddle doesn't support keepdims with axis=None
- # (https://github.com/pytorch/pytorch/issues/71209)
res = paddle.prod(x, dtype=dtype, **kwargs)
res = _axis_none_keepdims(res, ndim, keepdims)
return res
@@ -488,10 +499,10 @@ def sum(
keepdims: bool = False,
**kwargs,
) -> array:
- x = paddle.asarray(x)
+ if not paddle.is_tensor(x):
+ x = paddle.to_tensor(x)
ndim = x.ndim
- # https://github.com/pytorch/pytorch/issues/29137.
# Make sure it upcasts.
if axis == ():
if dtype is None:
@@ -505,7 +516,6 @@ def sum(
if axis is None:
# paddle doesn't support keepdims with axis=None
- # (https://github.com/pytorch/pytorch/issues/71209)
res = paddle.sum(x, dtype=dtype, **kwargs)
res = _axis_none_keepdims(res, ndim, keepdims)
return res
@@ -521,18 +531,17 @@ def any(
keepdims: bool = False,
**kwargs,
) -> array:
- x = paddle.asarray(x)
+ if not paddle.is_tensor(x):
+ x = paddle.to_tensor(x)
ndim = x.ndim
if axis == ():
return x.to(paddle.bool)
# paddle.any doesn't support multiple axes
- # (https://github.com/pytorch/pytorch/issues/56586).
if isinstance(axis, tuple):
res = _reduce_multiple_axes(paddle.any, x, axis, keepdim=keepdims, **kwargs)
return res.to(paddle.bool)
if axis is None:
# paddle doesn't support keepdims with axis=None
- # (https://github.com/pytorch/pytorch/issues/71209)
res = paddle.any(x, **kwargs)
res = _axis_none_keepdims(res, ndim, keepdims)
return res.to(paddle.bool)
@@ -549,18 +558,17 @@ def all(
keepdims: bool = False,
**kwargs,
) -> array:
- x = paddle.asarray(x)
+ if not paddle.is_tensor(x):
+ x = paddle.to_tensor(x)
ndim = x.ndim
if axis == ():
return x.to(paddle.bool)
# paddle.all doesn't support multiple axes
- # (https://github.com/pytorch/pytorch/issues/56586).
if isinstance(axis, tuple):
res = _reduce_multiple_axes(paddle.all, x, axis, keepdim=keepdims, **kwargs)
return res.to(paddle.bool)
if axis is None:
# paddle doesn't support keepdims with axis=None
- # (https://github.com/pytorch/pytorch/issues/71209)
res = paddle.all(x, **kwargs)
res = _axis_none_keepdims(res, ndim, keepdims)
return res.to(paddle.bool)
@@ -577,12 +585,10 @@ def mean(
keepdims: bool = False,
**kwargs,
) -> array:
- # https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return paddle.clone(x)
if axis is None:
# paddle doesn't support keepdims with axis=None
- # (https://github.com/pytorch/pytorch/issues/71209)
res = paddle.mean(x, **kwargs)
res = _axis_none_keepdims(res, x.ndim, keepdims)
return res
@@ -599,15 +605,12 @@ def std(
**kwargs,
) -> array:
# Note, float correction is not supported
- # https://github.com/pytorch/pytorch/issues/61492. We don't try to
# implement it here for now.
if isinstance(correction, float):
_correction = int(correction)
if correction != _correction:
- raise NotImplementedError(
- "float correction in paddle std() is not yet supported"
- )
+ raise NotImplementedError("float correction in paddle std() is not yet supported")
elif isinstance(correction, int):
if correction not in [0, 1]:
raise NotImplementedError("correction only can be 0 or 1")
@@ -616,14 +619,12 @@ def std(
_correction = bool(_correction)
- # https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return paddle.zeros_like(x)
if isinstance(axis, int):
axis = (axis,)
if axis is None:
# paddle doesn't support keepdims with axis=None
- # (https://github.com/pytorch/pytorch/issues/71209)
res = paddle.std(x, tuple(range(x.ndim)), unbiased=_correction, **kwargs)
res = _axis_none_keepdims(res, x.ndim, keepdims)
return res
@@ -640,7 +641,6 @@ def var(
**kwargs,
) -> array:
# Note, float correction is not supported
- # https://github.com/pytorch/pytorch/issues/61492. We don't try to
# implement it here for now.
# if isinstance(correction, float):
@@ -648,9 +648,7 @@ def var(
if isinstance(correction, float):
_correction = int(correction)
if correction != _correction:
- raise NotImplementedError(
- "float correction in paddle std() is not yet supported"
- )
+ raise NotImplementedError("float correction in paddle std() is not yet supported")
elif isinstance(correction, int):
if correction not in [0, 1]:
raise NotImplementedError("correction only can be 0 or 1")
@@ -659,14 +657,12 @@ def var(
_correction = bool(_correction)
- # https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return paddle.zeros_like(x)
if isinstance(axis, int):
axis = (axis,)
if axis is None:
# paddle doesn't support keepdims with axis=None
- # (https://github.com/pytorch/pytorch/issues/71209)
res = paddle.var(x, tuple(range(x.ndim)), unbiased=_correction, **kwargs)
res = _axis_none_keepdims(res, x.ndim, keepdims)
return res
@@ -674,7 +670,6 @@ def var(
# paddle.concat doesn't support dim=None
-# https://github.com/pytorch/pytorch/issues/70925
def concat(
arrays: Union[Tuple[array, ...], List[array]],
/,
@@ -688,9 +683,6 @@ def concat(
return paddle.concat(arrays, axis, **kwargs)
-# paddle.squeeze only accepts int dim and doesn't require it
-# https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was
-# added at https://github.com/pytorch/pytorch/pull/89017.
def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array:
if isinstance(axis, int):
axis = (axis,)
@@ -698,7 +690,7 @@ def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array:
if x.shape[a] != 1:
raise ValueError("squeezed dimensions must be equal to 1")
axes = _normalize_axes(axis, x.ndim)
- # Remove this once pytorch 1.14 is released with the above PR #89017.
+
sequence = [a - i for i, a in enumerate(axes)]
for a in sequence:
x = paddle.squeeze(x, a)
@@ -712,23 +704,15 @@ def broadcast_to(x: array, /, shape: Tuple[int, ...], **kwargs) -> array:
# paddle.permute uses dims instead of axes
def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array:
- if len(axes) == 2:
- perm = list(range(x.ndim))
- perm[axes[0]], perm[axes[1]] = perm[axes[1]], perm[axes[0]]
- axes = perm
return paddle.transpose(x, axes)
# The axis parameter doesn't work for flip() and roll()
-# https://github.com/pytorch/pytorch/issues/71210. Also paddle.flip() doesn't
# accept axis=None
-def flip(
- x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs
-) -> array:
+def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array:
if axis is None:
axis = tuple(range(x.ndim))
# paddle.flip doesn't accept dim as an int but the method does
- # https://github.com/pytorch/pytorch/issues/18095
return x.flip(axis, **kwargs)
@@ -754,19 +738,48 @@ def where(condition: array, x1: array, x2: array, /) -> array:
return paddle.where(condition, x1, x2)
-# paddle.reshape doesn't have the copy keyword
-def reshape(
- x: array, /, shape: Tuple[int, ...], copy: Optional[bool] = None, **kwargs
+def empty_like(x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> array:
+ out = paddle.empty_like(x, dtype=dtype)
+ if device is not None:
+ out = out.to(device)
+ return out
+
+
+def zeros_like(x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> array:
+ out = paddle.zeros_like(x, dtype=dtype)
+ if device is not None:
+ out = out.to(device)
+ return out
+
+
+def ones_like(x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> array:
+ out = paddle.ones_like(x, dtype=dtype)
+ if device is not None:
+ out = out.to(device)
+ return out
+
+
+def full_like(
+ x: array,
+ /,
+ fill_value: bool | int | float | complex,
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
) -> array:
- if copy is not None:
- raise NotImplementedError("paddle.reshape doesn't yet support the copy keyword")
+ out = paddle.full_like(x, fill_value, dtype=dtype)
+ if device is not None:
+ out = out.to(device)
+ return out
+
+
+# paddle.reshape doesn't have the copy keyword
+def reshape(x: array, /, shape: Tuple[int, ...], copy: Optional[bool] = None, **kwargs) -> array:
return paddle.reshape(x, shape, **kwargs)
# paddle.arange doesn't support returning empty arrays
-# (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some
# keyword argument combinations
-# (https://github.com/pytorch/pytorch/issues/70914)
def arange(
start: Union[int, float],
/,
@@ -790,7 +803,6 @@ def arange(
# paddle.eye does not accept None as a default for the second argument and
-# doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910)
def eye(
n_rows: int,
n_cols: Optional[int] = None,
@@ -822,14 +834,11 @@ def linspace(
**kwargs,
) -> array:
if not endpoint:
- return paddle.linspace(start, stop, num + 1, dtype=dtype, **kwargs).to(device)[
- :-1
- ]
+ return paddle.linspace(start, stop, num + 1, dtype=dtype, **kwargs).to(device)[:-1]
return paddle.linspace(start, stop, num, dtype=dtype, **kwargs).to(device)
# paddle.full does not accept an int size
-# https://github.com/pytorch/pytorch/issues/70906
def full(
shape: Union[int, Tuple[int, ...]],
fill_value: Union[bool, int, float, complex],
@@ -886,17 +895,21 @@ def triu(x: array, /, *, k: int = 0) -> array:
return paddle.triu(x, k)
-# Functions that aren't in paddle https://github.com/pytorch/pytorch/issues/58742
def expand_dims(x: array, /, *, axis: int = 0) -> array:
return paddle.unsqueeze(x, axis)
-def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array:
- return x.to(dtype, copy=copy)
+def astype(x: array, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = None) -> array:
+ # if copy is not None:
+ # raise NotImplementedError("paddle.astype doesn't yet support the copy keyword")
+ t = x.to(dtype, device=device)
+ if copy:
+ t = t.detach().clone()
+ return t
def broadcast_arrays(*arrays: array) -> List[array]:
- shape = paddle.broadcast_shapes(*[a.shape for a in arrays])
+ shape = broadcast_shapes(*[a.shape for a in arrays])
return [paddle.broadcast_to(a, shape) for a in arrays]
@@ -905,28 +918,19 @@ def broadcast_arrays(*arrays: array) -> List[array]:
from ..common._aliases import UniqueAllResult, UniqueCountsResult, UniqueInverseResult
-# https://github.com/pytorch/pytorch/issues/70920
def unique_all(x: array) -> UniqueAllResult:
- # paddle.unique doesn't support returning indices.
- # https://github.com/pytorch/pytorch/issues/36748. The workaround
- # suggested in that issue doesn't actually function correctly (it relies
- # on non-deterministic behavior of scatter()).
- raise NotImplementedError(
- "unique_all() not yet implemented for paddle (see https://github.com/pytorch/pytorch/issues/36748)"
+ return paddle.unique(
+ x,
+ return_index=True,
+ return_inverse=True,
+ return_counts=True,
)
- # values, inverse_indices, counts = paddle.unique(x, return_counts=True, return_inverse=True)
- # # paddle.unique incorrectly gives a 0 count for nan values.
- # # https://github.com/pytorch/pytorch/issues/94106
- # counts[paddle.isnan(values)] = 1
- # return UniqueAllResult(values, indices, inverse_indices, counts)
-
def unique_counts(x: array) -> UniqueCountsResult:
values, counts = paddle.unique(x, return_counts=True)
# paddle.unique incorrectly gives a 0 count for nan values.
- # https://github.com/pytorch/pytorch/issues/94106
counts[paddle.isnan(values)] = 1
return UniqueCountsResult(values, counts)
@@ -946,13 +950,19 @@ def matmul(x1: array, x2: array, /, **kwargs) -> array:
return paddle.matmul(x1, x2, **kwargs)
-matrix_transpose = get_xp(paddle)(_aliases_matrix_transpose)
-_vecdot = get_xp(paddle)(_aliases_vecdot)
+def meshgrid(*arrays: array, indexing: str = "xy") -> List[array]:
+ if indexing == "ij":
+ return paddle.meshgrid(*arrays)
+ else:
+ return [i.T for i in paddle.meshgrid(*arrays)]
+
+
+matrix_transpose = paddle.linalg.matrix_transpose
def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
- return _vecdot(x1, x2, axis=axis)
+ return paddle.linalg.vecdot(x1, x2, axis=axis)
# paddle.tensordot uses dims instead of axes
@@ -965,7 +975,6 @@ def tensordot(
**kwargs,
) -> array:
# Note: paddle.tensordot fails with integer dtypes when there is only 1
- # element in the axis (https://github.com/pytorch/pytorch/issues/84530).
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return paddle.tensordot(x1, x2, axes=axes, **kwargs)
@@ -990,16 +999,6 @@ def isdtype(
def is_signed(dtype):
return dtype in [paddle.int8, paddle.int16, paddle.int32, paddle.int64]
- def is_floating_point(dtype):
- return dtype in [
- paddle.float32,
- paddle.float64,
- paddle.float16,
- paddle.bfloat16,
- paddle.float8_e4m3fn,
- paddle.float8_e5m2,
- ]
-
def is_complex(dtype):
return dtype in [paddle.complex64, paddle.complex128]
@@ -1016,7 +1015,7 @@ def is_complex(dtype):
elif kind == "integral":
return dtype in _int_dtypes
elif kind == "real floating":
- return is_floating_point(dtype)
+ return paddle.is_floating_point(dtype)
elif kind == "complex floating":
return is_complex(dtype)
elif kind == "numeric":
@@ -1038,18 +1037,172 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
def sign(x: array, /) -> array:
# paddle sign() does not support complex numbers and does not propagate
# nans. See https://github.com/data-apis/array-api-compat/issues/136
- if x.dtype.is_complex:
+ if paddle.is_complex(x):
out = x / paddle.abs(x)
# sign(0) = 0 but the above formula would give nan
out[x == 0 + 0j] = 0 + 0j
return out
else:
out = paddle.sign(x)
- if x.dtype.is_floating_point:
- out[paddle.isnan(x)] = paddle.nan
+ if paddle.is_floating_point(x):
+ out = paddle.where(paddle.isnan(x), paddle.nan, out)
return out
+def broadcast_shapes(*shapes: List[int]) -> List[int]:
+ out_shape = shapes[0]
+ for i, shape in enumerate(shapes):
+ if i == 0:
+ continue
+ out_shape = paddle.broadcast_shape(out_shape, shape)
+
+ return out_shape
+
+
+# asarray also adds the copy keyword, which is not present in numpy 1.0.
+def asarray(
+ obj: Union[
+ array,
+ bool,
+ int,
+ float,
+ NestedSequence[bool | int | float],
+ SupportsBufferProtocol,
+ ],
+ /,
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+ copy: Optional[bool] = None,
+ **kwargs,
+) -> array:
+ """
+ Array API compatibility wrapper for asarray().
+
+ See the corresponding documentation in the array library and/or the array API
+ specification for more details.
+ """
+ if copy is False:
+ if hasattr(obj, "__dlpack__"):
+ obj = paddle.from_dlpack(obj.__dlpack__())
+ if device is not None:
+ obj = obj.to(device)
+ if dtype is not None:
+ obj = obj.to(dtype)
+ return obj
+ else:
+ raise NotImplementedError(
+ "asarray(obj, ..., copy=False) is not supported " "for obj do not has '__dlpack__()' method"
+ )
+ elif copy is True:
+ obj = np.array(obj, copy=True)
+ return paddle.to_tensor(obj, dtype=dtype, place=device)
+ else:
+ if not paddle.is_tensor(obj) or (dtype is not None and obj.dtype != dtype):
+ obj = np.array(obj, copy=False)
+ obj = paddle.from_dlpack(obj.__dlpack__(), **kwargs).to(dtype)
+ if device is not None:
+ obj = obj.to(device)
+ return obj
+
+ return obj
+
+
+def clip(
+ x: array,
+ /,
+ min: Optional[Union[int, float, array]] = None,
+ max: Optional[Union[int, float, array]] = None,
+) -> array:
+ if min is None and max is None:
+ return x
+
+ def _isscalar(a):
+ return isinstance(a, (int, float, type(None)))
+
+ min_shape = [] if _isscalar(min) else min.shape
+ max_shape = [] if _isscalar(max) else max.shape
+
+ result_shape = broadcast_shapes(x.shape, min_shape, max_shape)
+
+ # np.clip does type promotion but the array API clip requires that the
+ # output have the same dtype as x. We do this instead of just downcasting
+ # the result of xp.clip() to handle some corner cases better (e.g.,
+ # avoiding uint64 -> float64 promotion).
+
+ # Note: cases where min or max overflow (integer) or round (float) in the
+ # wrong direction when downcasting to x.dtype are unspecified. This code
+ # just does whatever NumPy does when it downcasts in the assignment, but
+ # other behavior could be preferred, especially for integers. For example,
+ # this code produces:
+
+ # >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None)
+ # -128
+
+ # but an answer of 0 might be preferred. See
+ # https://github.com/numpy/numpy/issues/24976 for more discussion on this issue.
+
+ # At least handle the case of Python integers correctly (see
+ # https://github.com/numpy/numpy/pull/26892).
+ if type(min) is int and min <= paddle.iinfo(x.dtype).min:
+ min = None
+ if type(max) is int and max >= paddle.iinfo(x.dtype).max:
+ max = None
+
+ if out is None:
+ out = paddle.to_tensor(broadcast_to(x, result_shape), place=x.place)
+ if min is not None:
+ if paddle.is_tensor(x) and x.dtype == paddle.float64 and _isscalar(min):
+ # Avoid loss of precision due to paddle defaulting to float32
+ min = paddle.to_tensor(min, dtype=paddle.float64)
+ a = broadcast_to(paddle.to_tensor(min, place=x.place), result_shape)
+ ia = (out < a) | paddle.isnan(a)
+ # paddle requires an explicit cast here
+ out[ia] = astype(a[ia], out.dtype)
+ if max is not None:
+ if paddle.is_tensor(x) and x.dtype == paddle.float64 and _isscalar(max):
+ max = paddle.to_tensor(max, dtype=paddle.float64)
+ b = broadcast_to(paddle.to_tensor(max, place=x.place), result_shape)
+ ib = (out > b) | paddle.isnan(b)
+ out[ib] = astype(b[ib], out.dtype)
+ # Return a scalar for 0-D
+ return out[()]
+
+
+def cumulative_sum(
+ x: array, /, *, axis: Optional[int] = None, dtype: Optional[Dtype] = None, include_initial: bool = False
+) -> array:
+ if axis is None:
+ if x.ndim > 1:
+ raise ValueError("axis must be specified in cumulative_sum for more than one dimension")
+ axis = 0
+
+ res = paddle.cumsum(x, axis=axis, dtype=dtype)
+
+ # np.cumsum does not support include_initial
+ if include_initial:
+ initial_shape = list(x.shape)
+ initial_shape[axis] = 1
+ res = paddle.concat(
+ [paddle.zeros(shape=initial_shape, dtype=res.dtype).to(res.place), res],
+ axis=axis,
+ )
+ return res
+
+
+def searchsorted(
+ x1: array, x2: array, /, *, side: Literal["left", "right"] = "left", sorter: array | None = None
+) -> array:
+ if sorter is None:
+ return paddle.searchsorted(x1, x2, right=(side == "right"))
+
+ return paddle.searchsorted(
+ x1.take_along_axis(axis=-1, indices=sorter),
+ x2,
+ right=(side == "right"),
+ )
+
+
__all__ = [
"__array_namespace_info__",
"result_type",
@@ -1129,6 +1282,15 @@ def sign(x: array, /) -> array:
"isdtype",
"take",
"sign",
+ "broadcast_shapes",
+ "argmax",
+ "argmin",
+ "searchsorted",
+ "empty_like",
+ "zeros_like",
+ "ones_like",
+ "full_like",
+ "asarray",
]
_all_ignore = ["paddle", "get_xp"]
diff --git a/array_api_compat/paddle/_info.py b/array_api_compat/paddle/_info.py
index d8dab7ee..5d29e270 100644
--- a/array_api_compat/paddle/_info.py
+++ b/array_api_compat/paddle/_info.py
@@ -332,18 +332,12 @@ def devices(self):
# message of paddle.device to get the list of all possible types of
# device:
try:
- paddle.device("notadevice")
- except RuntimeError as e:
+ paddle.set_device("notadevice")
+ except ValueError as e:
# The error message is something like:
# ValueError: The device must be a string which is like 'cpu', 'gpu', 'gpu:x', 'xpu', 'xpu:x', 'npu', 'npu:x
- devices_names = (
- e.args[0]
- .split("ValueError: The device must be a string which is like ")[1]
- .split(", ")
- )
- devices_names = [
- name.strip("'") for name in devices_names if ":" not in name
- ]
+ devices_names = e.args[0].split("The device must be a string which is like ")[1].split(", ")
+ devices_names = [name.strip("'") for name in devices_names if ":" not in name]
# Next we need to check for different indices for different devices.
# device(device_name, index=index) doesn't actually check if the
diff --git a/array_api_compat/paddle/fft.py b/array_api_compat/paddle/fft.py
index 15519b5a..1442aed8 100644
--- a/array_api_compat/paddle/fft.py
+++ b/array_api_compat/paddle/fft.py
@@ -4,9 +4,10 @@
if TYPE_CHECKING:
import paddle
+ from ..common._typing import Device
array = paddle.Tensor
- from typing import Union, Sequence, Literal
+ from typing import Optional, Union, Sequence, Literal
from paddle.fft import * # noqa: F403
import paddle.fft
@@ -80,6 +81,32 @@ def ifftshift(
return paddle.fft.ifftshift(x, axes=axes, **kwargs)
+def fftfreq(
+ n: int,
+ /,
+ *,
+ d: float = 1.0,
+ device: Optional[Device] = None,
+) -> array:
+ out = paddle.fft.fftfreq(n, d)
+ if device is not None:
+ out = out.to(device)
+ return out
+
+
+def rfftfreq(
+ n: int,
+ /,
+ *,
+ d: float = 1.0,
+ device: Optional[Device] = None,
+) -> array:
+ out = paddle.fft.rfftfreq(n, d)
+ if device is not None:
+ out = out.to(device)
+ return out
+
+
__all__ = paddle.fft.__all__ + [
"fftn",
"ifftn",
@@ -87,6 +114,8 @@ def ifftshift(
"irfftn",
"fftshift",
"ifftshift",
+ "fftfreq",
+ "rfftfreq",
]
_all_ignore = ["paddle"]
diff --git a/array_api_compat/paddle/linalg.py b/array_api_compat/paddle/linalg.py
index 7ef04a90..7dd1a266 100644
--- a/array_api_compat/paddle/linalg.py
+++ b/array_api_compat/paddle/linalg.py
@@ -12,7 +12,9 @@
inf = float("inf")
from ._aliases import _fix_promotion, sum
+from collections import namedtuple
+import paddle
from paddle.linalg import * # noqa: F403
# paddle.linalg doesn't define __all__
@@ -23,6 +25,7 @@
# outer is implemented in paddle but aren't in the linalg namespace
from paddle import outer
+import paddle
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot
@@ -30,21 +33,18 @@
# Note: paddle.linalg.cross does not default to axis=-1 (it defaults to the
# first axis with size 3)
+
# paddle.cross also does not support broadcasting when it would add new
# dimensions
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
- raise ValueError(
- f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}"
- )
+ raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}")
if not (x1.shape[axis] == x2.shape[axis] == 3):
- raise ValueError(
- f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}"
- )
+ raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}")
- x1, x2 = paddle.broadcast_tensors(x1, x2)
+ x1, x2 = paddle.broadcast_tensors([x1, x2])
return paddle_linalg.cross(x1, x2, axis=axis)
@@ -64,7 +64,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
x1_ = paddle.moveaxis(x1, axis, -1)
x2_ = paddle.moveaxis(x2, axis, -1)
- x1_, x2_ = paddle.broadcast_tensors(x1_, x2_)
+ x1_, x2_ = paddle.broadcast_tensors([x1_, x2_])
res = x1_[..., None, :] @ x2_[..., None]
return res[..., 0, 0]
@@ -82,9 +82,7 @@ def solve(x1: array, x2: array, /, **kwargs) -> array:
# paddle.trace doesn't support the offset argument and doesn't support stacking
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
# Use our wrapped sum to make sure it does upcasting correctly
- return sum(
- paddle.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype
- )
+ return sum(paddle.diagonal(x, offset=offset, axis1=-2, axis2=-1), axis=-1, dtype=dtype)
def vector_norm(
@@ -118,16 +116,44 @@ def vector_norm(
return paddle.linalg.vector_norm(x, p=ord, axis=axis, keepdim=keepdims, **kwargs)
+def matrix_norm(
+ x: array,
+ /,
+ *,
+ keepdims: bool = False,
+ ord: Optional[Union[int, float, Literal["fro", "nuc"]]] = "fro",
+) -> array:
+ return paddle.linalg.matrix_norm(x, p=ord, axis=(-2, -1), keepdim=keepdims)
+
+
+def pinv(x: array, /, *, rtol: Optional[Union[float, array]] = None) -> array:
+ if rtol is None:
+ return paddle.linalg.pinv(x)
+
+ return paddle.linalg.pinv(x, rcond=rtol)
+
+
+def slogdet(x: array):
+ det = paddle.linalg.det(x)
+ sign = paddle.sign(det)
+ log_det = paddle.log(det)
+
+ slotdet = namedtuple("slotdet", ["sign", "logabsdet"])
+ return slotdet(sign, log_det)
+
+
__all__ = linalg_all + [
"outer",
"matmul",
"matrix_transpose",
+ "matrix_norm",
"tensordot",
"cross",
"vecdot",
"solve",
"trace",
"vector_norm",
+ "slogdet",
]
_all_ignore = ["paddle_linalg", "sum"]
From 5ae8ec8c59106e5e9aa742dc794d9334f1c620f0 Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Tue, 3 Dec 2024 15:33:54 +0800
Subject: [PATCH 10/19] fix
---
array_api_compat/paddle/_aliases.py | 14 ++++----------
1 file changed, 4 insertions(+), 10 deletions(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index 601afa5f..00130e23 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -790,15 +790,6 @@ def arange(
device: Optional[Device] = None,
**kwargs,
) -> array:
- if stop is None:
- start, stop = 0, start
- if step > 0 and stop <= start or step < 0 and stop >= start:
- if dtype is None:
- if _builtin_all(isinstance(i, int) for i in [start, stop, step]):
- dtype = paddle.int64
- else:
- dtype = paddle.float32
- return paddle.empty([0], dtype=dtype, **kwargs).to(device)
return paddle.arange(start, stop, step, dtype=dtype, **kwargs).to(device)
@@ -1100,7 +1091,10 @@ def asarray(
else:
if not paddle.is_tensor(obj) or (dtype is not None and obj.dtype != dtype):
obj = np.array(obj, copy=False)
- obj = paddle.from_dlpack(obj.__dlpack__(), **kwargs).to(dtype)
+ if dtype != paddle.bool and dtype != "bool":
+ obj = paddle.from_dlpack(obj.__dlpack__(), **kwargs).to(dtype)
+ else:
+ obj = paddle.to_tensor(obj, dtype=dtype)
if device is not None:
obj = obj.to(device)
return obj
From b10273b41058945d2969c00426d0bc2edbb015f5 Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Tue, 10 Dec 2024 16:37:13 +0800
Subject: [PATCH 11/19] update code
---
array_api_compat/paddle/_aliases.py | 67 ++++++++++++++++++++++-------
array_api_compat/paddle/_info.py | 22 ++++++++--
2 files changed, 69 insertions(+), 20 deletions(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index 00130e23..989b4d85 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -420,7 +420,9 @@ def _normalize_axes(axis, ndim):
for a in axis:
if a < lower or a > upper:
# Match paddle error message (e.g., from sum())
- raise IndexError(f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}")
+ raise IndexError(
+ f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}"
+ )
if a < 0:
a = a + ndim
if a in axes:
@@ -480,7 +482,9 @@ def prod(
# paddle.prod doesn't support multiple axes
if isinstance(axis, tuple):
- return _reduce_multiple_axes(paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs)
+ return _reduce_multiple_axes(
+ paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs
+ )
if axis is None:
# paddle doesn't support keepdims with axis=None
res = paddle.prod(x, dtype=dtype, **kwargs)
@@ -610,7 +614,9 @@ def std(
if isinstance(correction, float):
_correction = int(correction)
if correction != _correction:
- raise NotImplementedError("float correction in paddle std() is not yet supported")
+ raise NotImplementedError(
+ "float correction in paddle std() is not yet supported"
+ )
elif isinstance(correction, int):
if correction not in [0, 1]:
raise NotImplementedError("correction only can be 0 or 1")
@@ -648,7 +654,9 @@ def var(
if isinstance(correction, float):
_correction = int(correction)
if correction != _correction:
- raise NotImplementedError("float correction in paddle std() is not yet supported")
+ raise NotImplementedError(
+ "float correction in paddle std() is not yet supported"
+ )
elif isinstance(correction, int):
if correction not in [0, 1]:
raise NotImplementedError("correction only can be 0 or 1")
@@ -709,7 +717,9 @@ def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array:
# The axis parameter doesn't work for flip() and roll()
# accept axis=None
-def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array:
+def flip(
+ x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs
+) -> array:
if axis is None:
axis = tuple(range(x.ndim))
# paddle.flip doesn't accept dim as an int but the method does
@@ -738,21 +748,27 @@ def where(condition: array, x1: array, x2: array, /) -> array:
return paddle.where(condition, x1, x2)
-def empty_like(x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> array:
+def empty_like(
+ x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
+) -> array:
out = paddle.empty_like(x, dtype=dtype)
if device is not None:
out = out.to(device)
return out
-def zeros_like(x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> array:
+def zeros_like(
+ x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
+) -> array:
out = paddle.zeros_like(x, dtype=dtype)
if device is not None:
out = out.to(device)
return out
-def ones_like(x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> array:
+def ones_like(
+ x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
+) -> array:
out = paddle.ones_like(x, dtype=dtype)
if device is not None:
out = out.to(device)
@@ -774,7 +790,9 @@ def full_like(
# paddle.reshape doesn't have the copy keyword
-def reshape(x: array, /, shape: Tuple[int, ...], copy: Optional[bool] = None, **kwargs) -> array:
+def reshape(
+ x: array, /, shape: Tuple[int, ...], copy: Optional[bool] = None, **kwargs
+) -> array:
return paddle.reshape(x, shape, **kwargs)
@@ -825,7 +843,9 @@ def linspace(
**kwargs,
) -> array:
if not endpoint:
- return paddle.linspace(start, stop, num + 1, dtype=dtype, **kwargs).to(device)[:-1]
+ return paddle.linspace(start, stop, num + 1, dtype=dtype, **kwargs).to(device)[
+ :-1
+ ]
return paddle.linspace(start, stop, num, dtype=dtype, **kwargs).to(device)
@@ -890,7 +910,9 @@ def expand_dims(x: array, /, *, axis: int = 0) -> array:
return paddle.unsqueeze(x, axis)
-def astype(x: array, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = None) -> array:
+def astype(
+ x: array, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = None
+) -> array:
# if copy is not None:
# raise NotImplementedError("paddle.astype doesn't yet support the copy keyword")
t = x.to(dtype, device=device)
@@ -1036,7 +1058,7 @@ def sign(x: array, /) -> array:
else:
out = paddle.sign(x)
if paddle.is_floating_point(x):
- out = paddle.where(paddle.isnan(x), paddle.nan, out)
+ out = paddle.where(paddle.isnan(x), paddle.full(x.shape, paddle.nan), out)
return out
@@ -1083,7 +1105,8 @@ def asarray(
return obj
else:
raise NotImplementedError(
- "asarray(obj, ..., copy=False) is not supported " "for obj do not has '__dlpack__()' method"
+ "asarray(obj, ..., copy=False) is not supported "
+ "for obj do not has '__dlpack__()' method"
)
elif copy is True:
obj = np.array(obj, copy=True)
@@ -1164,11 +1187,18 @@ def _isscalar(a):
def cumulative_sum(
- x: array, /, *, axis: Optional[int] = None, dtype: Optional[Dtype] = None, include_initial: bool = False
+ x: array,
+ /,
+ *,
+ axis: Optional[int] = None,
+ dtype: Optional[Dtype] = None,
+ include_initial: bool = False,
) -> array:
if axis is None:
if x.ndim > 1:
- raise ValueError("axis must be specified in cumulative_sum for more than one dimension")
+ raise ValueError(
+ "axis must be specified in cumulative_sum for more than one dimension"
+ )
axis = 0
res = paddle.cumsum(x, axis=axis, dtype=dtype)
@@ -1185,7 +1215,12 @@ def cumulative_sum(
def searchsorted(
- x1: array, x2: array, /, *, side: Literal["left", "right"] = "left", sorter: array | None = None
+ x1: array,
+ x2: array,
+ /,
+ *,
+ side: Literal["left", "right"] = "left",
+ sorter: array | None = None,
) -> array:
if sorter is None:
return paddle.searchsorted(x1, x2, right=(side == "right"))
diff --git a/array_api_compat/paddle/_info.py b/array_api_compat/paddle/_info.py
index 5d29e270..6f079020 100644
--- a/array_api_compat/paddle/_info.py
+++ b/array_api_compat/paddle/_info.py
@@ -154,8 +154,16 @@ def default_dtypes(self, *, device=None):
# value here because this error doesn't represent a different default
# per-device.
default_floating = paddle.get_default_dtype()
- default_complex = "complex64" if default_floating == "float32" else "complex128"
- default_integral = "int64"
+ if default_floating in ["float16", "float32", "float64", "bfloat16"]:
+ default_floating = getattr(paddle, default_floating)
+ else:
+ raise ValueError(f"Unsupported default floating: {default_floating}")
+ default_complex = (
+ paddle.complex64
+ if default_floating == paddle.float32
+ else paddle.complex128
+ )
+ default_integral = paddle.int64
return {
"real floating": default_floating,
"complex floating": default_complex,
@@ -336,8 +344,14 @@ def devices(self):
except ValueError as e:
# The error message is something like:
# ValueError: The device must be a string which is like 'cpu', 'gpu', 'gpu:x', 'xpu', 'xpu:x', 'npu', 'npu:x
- devices_names = e.args[0].split("The device must be a string which is like ")[1].split(", ")
- devices_names = [name.strip("'") for name in devices_names if ":" not in name]
+ devices_names = (
+ e.args[0]
+ .split("The device must be a string which is like ")[1]
+ .split(", ")
+ )
+ devices_names = [
+ name.strip("'") for name in devices_names if ":" not in name
+ ]
# Next we need to check for different indices for different devices.
# device(device_name, index=index) doesn't actually check if the
From 8d2425ee538eca51698c35296269f2de114848aa Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Sat, 14 Dec 2024 17:26:40 +0800
Subject: [PATCH 12/19] fix moveaxis
---
array_api_compat/paddle/_aliases.py | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index 989b4d85..31a1193b 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -445,7 +445,7 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
# Some reductions don't support multiple axes
axes = _normalize_axes(axis, x.ndim)
for a in reversed(axes):
- x = paddle.movedim(x, a, -1)
+ x = paddle.moveaxis(x, a, -1)
x = paddle.flatten(x, -len(axes))
out = f(x, -1, **kwargs)
@@ -922,8 +922,7 @@ def astype(
def broadcast_arrays(*arrays: array) -> List[array]:
- shape = broadcast_shapes(*[a.shape for a in arrays])
- return [paddle.broadcast_to(a, shape) for a in arrays]
+ return paddle.broadcast_tensors(arrays)
# Note that these named tuples aren't actually part of the standard namespace,
From 7b8555e8ea57cd644be573e9c613c9d209f2467f Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Wed, 8 Jan 2025 22:13:31 +0800
Subject: [PATCH 13/19] fix default floating dtype of paddle.assaray
---
array_api_compat/paddle/_aliases.py | 15 ++++++++++++++-
1 file changed, 14 insertions(+), 1 deletion(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index 31a1193b..0cccdbc8 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -1027,7 +1027,16 @@ def is_complex(dtype):
elif kind == "integral":
return dtype in _int_dtypes
elif kind == "real floating":
- return paddle.is_floating_point(dtype)
+ return dtype in [
+ paddle.framework.core.VarDesc.VarType.FP32,
+ paddle.framework.core.VarDesc.VarType.FP64,
+ paddle.framework.core.VarDesc.VarType.FP16,
+ paddle.framework.core.VarDesc.VarType.BF16,
+ paddle.framework.core.DataType.FLOAT32,
+ paddle.framework.core.DataType.FLOAT64,
+ paddle.framework.core.DataType.FLOAT16,
+ paddle.framework.core.DataType.BFLOAT16,
+ ]
elif kind == "complex floating":
return is_complex(dtype)
elif kind == "numeric":
@@ -1109,10 +1118,14 @@ def asarray(
)
elif copy is True:
obj = np.array(obj, copy=True)
+ if np.issubdtype(obj.dtype, np.floating):
+ obj = obj.astype(paddle.get_default_dtype())
return paddle.to_tensor(obj, dtype=dtype, place=device)
else:
if not paddle.is_tensor(obj) or (dtype is not None and obj.dtype != dtype):
obj = np.array(obj, copy=False)
+ if np.issubdtype(obj.dtype, np.floating):
+ obj = obj.astype(paddle.get_default_dtype())
if dtype != paddle.bool and dtype != "bool":
obj = paddle.from_dlpack(obj.__dlpack__(), **kwargs).to(dtype)
else:
From 603c8524b20917b6dd4b61d4106c60504458fe0d Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Thu, 9 Jan 2025 11:33:37 +0800
Subject: [PATCH 14/19] use default_dtype only when dtype is None
---
array_api_compat/paddle/_aliases.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index 0cccdbc8..c3e94cf1 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -1118,13 +1118,13 @@ def asarray(
)
elif copy is True:
obj = np.array(obj, copy=True)
- if np.issubdtype(obj.dtype, np.floating):
+ if np.issubdtype(obj.dtype, np.floating) and dtype is None:
obj = obj.astype(paddle.get_default_dtype())
return paddle.to_tensor(obj, dtype=dtype, place=device)
else:
if not paddle.is_tensor(obj) or (dtype is not None and obj.dtype != dtype):
obj = np.array(obj, copy=False)
- if np.issubdtype(obj.dtype, np.floating):
+ if np.issubdtype(obj.dtype, np.floating) and dtype is None:
obj = obj.astype(paddle.get_default_dtype())
if dtype != paddle.bool and dtype != "bool":
obj = paddle.from_dlpack(obj.__dlpack__(), **kwargs).to(dtype)
From 742792f6635689ce9d67270f5cb649db6c357fe4 Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Thu, 9 Jan 2025 16:10:20 +0800
Subject: [PATCH 15/19] add floor and ceil with same return dtype
---
array_api_compat/paddle/_aliases.py | 10 +++++++++-
1 file changed, 9 insertions(+), 1 deletion(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index c3e94cf1..6f23ee20 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -1036,7 +1036,7 @@ def is_complex(dtype):
paddle.framework.core.DataType.FLOAT64,
paddle.framework.core.DataType.FLOAT16,
paddle.framework.core.DataType.BFLOAT16,
- ]
+ ]
elif kind == "complex floating":
return is_complex(dtype)
elif kind == "numeric":
@@ -1137,6 +1137,14 @@ def asarray(
return obj
+def floor(x: array, /) -> array:
+ return paddle.floor(x).to(x.dtype)
+
+
+def ceil(x: array, /) -> array:
+ return paddle.ceil(x).to(x.dtype)
+
+
def clip(
x: array,
/,
From fd6eea032fb9b42ae3c84550a220c404d2ef14a0 Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Thu, 9 Jan 2025 16:22:40 +0800
Subject: [PATCH 16/19] update code
---
array_api_compat/paddle/_aliases.py | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index 6f23ee20..622504d7 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -4,7 +4,7 @@
import numpy as np
from functools import wraps as _wraps
-from builtins import all as _builtin_all, any as _builtin_any
+from builtins import any as _builtin_any
from ..common._aliases import (
unstack as _aliases_unstack,
@@ -1036,7 +1036,7 @@ def is_complex(dtype):
paddle.framework.core.DataType.FLOAT64,
paddle.framework.core.DataType.FLOAT16,
paddle.framework.core.DataType.BFLOAT16,
- ]
+ ]
elif kind == "complex floating":
return is_complex(dtype)
elif kind == "numeric":
@@ -1186,8 +1186,7 @@ def _isscalar(a):
if type(max) is int and max >= paddle.iinfo(x.dtype).max:
max = None
- if out is None:
- out = paddle.to_tensor(broadcast_to(x, result_shape), place=x.place)
+ out = paddle.to_tensor(broadcast_to(x, result_shape), place=x.place)
if min is not None:
if paddle.is_tensor(x) and x.dtype == paddle.float64 and _isscalar(min):
# Avoid loss of precision due to paddle defaulting to float32
@@ -1203,7 +1202,7 @@ def _isscalar(a):
ib = (out > b) | paddle.isnan(b)
out[ib] = astype(b[ib], out.dtype)
# Return a scalar for 0-D
- return out[()]
+ return out
def cumulative_sum(
@@ -1340,6 +1339,8 @@ def searchsorted(
"ones_like",
"full_like",
"asarray",
+ "ceil",
+ "floor",
]
_all_ignore = ["paddle", "get_xp"]
From 37785d442e890a579edf923fb3c174c5bbf64926 Mon Sep 17 00:00:00 2001
From: cangtianhuang
Date: Tue, 1 Apr 2025 15:05:10 +0800
Subject: [PATCH 17/19] Add broadcast_tensors alias, modify result_type
---
array_api_compat/paddle/_aliases.py | 44 ++++++++++++++++++++---------
1 file changed, 30 insertions(+), 14 deletions(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index 622504d7..d19353e0 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+import builtins
from typing import Literal
import numpy as np
@@ -112,25 +113,32 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
raise TypeError("At least one array or dtype must be provided")
if len(arrays_and_dtypes) == 1:
x = arrays_and_dtypes[0]
- if isinstance(x, paddle.dtype):
- return x
- return x.dtype
+ return x if isinstance(x, paddle.dtype) else x.dtype
if len(arrays_and_dtypes) > 2:
return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:]))
x, y = arrays_and_dtypes
- xdt = x.dtype if not isinstance(x, paddle.dtype) else x
- ydt = y.dtype if not isinstance(y, paddle.dtype) else y
+ xdt = x if isinstance(x, paddle.dtype) else x.dtype
+ ydt = y if isinstance(y, paddle.dtype) else y.dtype
if (xdt, ydt) in _promotion_table:
- return _promotion_table[xdt, ydt]
-
- # This doesn't result_type(dtype, dtype) for non-array API dtypes
- # because paddle.result_type only accepts tensors. This does however, allow
- # cross-kind promotion.
- x = paddle.to_tensor([], dtype=x) if isinstance(x, paddle.dtype) else x
- y = paddle.to_tensor([], dtype=y) if isinstance(y, paddle.dtype) else y
- return paddle.result_type(x, y)
+ return _promotion_table[(xdt, ydt)]
+
+ type_order = {
+ paddle.bool: 0,
+ paddle.int8: 1,
+ paddle.uint8: 2,
+ paddle.int16: 3,
+ paddle.int32: 4,
+ paddle.int64: 5,
+ paddle.float16: 6,
+ paddle.float32: 7,
+ paddle.float64: 8,
+ paddle.complex64: 9,
+ paddle.complex128: 10
+ }
+
+ return xdt if type_order.get(xdt, 0) > type_order.get(ydt, 0) else ydt
def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
@@ -922,7 +930,15 @@ def astype(
def broadcast_arrays(*arrays: array) -> List[array]:
- return paddle.broadcast_tensors(arrays)
+ original_dtypes = [arr.dtype for arr in arrays]
+ if len(set(original_dtypes)) == 1:
+ return paddle.broadcast_tensors(arrays)
+ target_dtype = result_type(*arrays)
+ casted_arrays = [arr.astype(target_dtype) if arr.dtype != target_dtype else arr
+ for arr in arrays]
+ broadcasted = paddle.broadcast_tensors(casted_arrays)
+ result = [arr.astype(original_dtype) for arr, original_dtype in zip(broadcasted, original_dtypes)]
+ return result
# Note that these named tuples aren't actually part of the standard namespace,
From 0651731fc5bda2a3362d1665da31ea12635f5963 Mon Sep 17 00:00:00 2001
From: cangtianhuang
Date: Tue, 1 Apr 2025 15:06:56 +0800
Subject: [PATCH 18/19] refine
---
array_api_compat/paddle/_aliases.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index d19353e0..88f71e7d 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -1,6 +1,5 @@
from __future__ import annotations
-import builtins
from typing import Literal
import numpy as np
From 912fe3e56739a4353c95f6111e05270dde6fcb86 Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Mon, 12 May 2025 20:01:10 +0800
Subject: [PATCH 19/19] add paddle skip and xfail files
---
paddle-skips.txt | 6 +++
paddle-xfails.txt | 108 ++++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 114 insertions(+)
create mode 100644 paddle-skips.txt
create mode 100644 paddle-xfails.txt
diff --git a/paddle-skips.txt b/paddle-skips.txt
new file mode 100644
index 00000000..094c553f
--- /dev/null
+++ b/paddle-skips.txt
@@ -0,0 +1,6 @@
+array_api_tests/test_array_object.py::test_getitem_masking
+array_api_tests/test_data_type_functions.py::test_result_type
+array_api_tests/test_data_type_functions.py::test_broadcast_arrays
+array_api_tests/test_manipulation_functions.py::test_roll
+array_api_tests/test_data_type_functions.py::test_broadcast_to
+array_api_tests/test_linalg.py::test_cholesky
diff --git a/paddle-xfails.txt b/paddle-xfails.txt
new file mode 100644
index 00000000..6998f374
--- /dev/null
+++ b/paddle-xfails.txt
@@ -0,0 +1,108 @@
+# Skip 'copy=...'
+array_api_tests/test_array_object.py::test_setitem
+array_api_tests/test_array_object.py::test_setitem_masking
+# array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x, s)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__iand__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__iand__(x, s)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__ilshift__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__ilshift__(x, s)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__ior__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__ior__(x, s)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__irshift__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__irshift__(x, s)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__ixor__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__ixor__(x, s)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__itruediv__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__itruediv__(x, s)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x, s)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x, s)]
+
+# Skip promotion test for 'Scalar op Tensor'
+array_api_tests/test_operators_and_elementwise_functions.py::test_add[add(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_divide[divide(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)]
+
+# torch do not pass
+array_api_tests/test_creation_functions.py::test_asarray_scalars
+array_api_tests/test_creation_functions.py::test_asarray_arrays
+array_api_tests/test_creation_functions.py::test_empty_like
+array_api_tests/test_creation_functions.py::test_eye
+array_api_tests/test_creation_functions.py::test_full
+array_api_tests/test_creation_functions.py::test_full_like
+array_api_tests/test_creation_functions.py::test_linspace
+array_api_tests/test_creation_functions.py::test_ones
+array_api_tests/test_creation_functions.py::test_ones_like
+array_api_tests/test_creation_functions.py::test_zeros
+array_api_tests/test_creation_functions.py::test_zeros_like
+array_api_tests/test_fft.py::test_fft
+array_api_tests/test_fft.py::test_ifft
+array_api_tests/test_fft.py::test_fftn
+array_api_tests/test_fft.py::test_ifftn
+array_api_tests/test_fft.py::test_rfft
+array_api_tests/test_fft.py::test_irfft
+array_api_tests/test_fft.py::test_rfftn
+array_api_tests/test_fft.py::test_hfft
+array_api_tests/test_fft.py::test_ihfft
+array_api_tests/test_has_names.py::test_has_names[manipulation-repeat]
+array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
+array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
+array_api_tests/test_indexing_functions.py::test_take
+array_api_tests/test_linalg.py::test_linalg_matmul
+array_api_tests/test_linalg.py::test_qr
+array_api_tests/test_linalg.py::test_solve
+array_api_tests/test_manipulation_functions.py::test_concat
+array_api_tests/test_manipulation_functions.py::test_repeat
+array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__lshift__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_round
+array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)]
+array_api_tests/test_set_functions.py::test_unique_all
+array_api_tests/test_set_functions.py::test_unique_counts
+array_api_tests/test_set_functions.py::test_unique_inverse
+array_api_tests/test_set_functions.py::test_unique_values
+array_api_tests/test_signatures.py::test_func_signature[astype]
+array_api_tests/test_signatures.py::test_func_signature[repeat]
+array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
+array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__]
+array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
+array_api_tests/test_signatures.py::test_array_method_signature[to_device]
+array_api_tests/test_sorting_functions.py::test_argsort
+array_api_tests/test_sorting_functions.py::test_sort
+array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[__le__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_equal[equal(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[bitwise_and(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[bitwise_left_shift(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[bitwise_or(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__or__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[bitwise_right_shift(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[bitwise_xor(x1, x2)]
+
+# dtype promotion related
+array_api_tests/test_operators_and_elementwise_functions.py::test_floor
+array_api_tests/test_operators_and_elementwise_functions.py::test_ceil
+array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)]
+array_api_tests/test_searching_functions.py::test_where