diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 91161d24..23dafde9 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -241,6 +241,20 @@ def sort( ) -> Array: return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values + +# Wrap torch.argsort to set stable=True by default +def argsort( + x: Array, + /, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, + **kwargs: object, +) -> Array: + return torch.argsort(x, dim=axis, descending=descending, stable=stable, **kwargs) + + def _normalize_axes(axis, ndim): axes = [] if ndim == 0 and axis: @@ -837,9 +851,9 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array 'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot', 'less', 'less_equal', 'logaddexp', 'maximum', 'minimum', 'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max', - 'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod', 'sort', 'prod', 'sum', - 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', - 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', + 'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod', 'sort', + 'argsort', 'prod', 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', + 'squeeze', 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index 76342980..f11b3eb5 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from typing import Literal -import torch +import torch # noqa: F401 import torch.fft from ._typing import Array