From 438ac56ae390760d50ac7926269256d49d9bcaa9 Mon Sep 17 00:00:00 2001 From: GUAN MING Date: Tue, 8 Apr 2025 12:21:16 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=F0=9F=9A=9A=20port=20ma.arg{min,max}=20and?= =?UTF-8?q?=20MaskedArray.arg{min,max}?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/numpy-stubs/@test/static/accept/ma.pyi | 33 ++++ src/numpy-stubs/@test/static/reject/ma.pyi | 14 ++ src/numpy-stubs/ma/core.pyi | 172 ++++++++++++++++++--- 3 files changed, 201 insertions(+), 18 deletions(-) diff --git a/src/numpy-stubs/@test/static/accept/ma.pyi b/src/numpy-stubs/@test/static/accept/ma.pyi index c794a281..6544d9a6 100644 --- a/src/numpy-stubs/@test/static/accept/ma.pyi +++ b/src/numpy-stubs/@test/static/accept/ma.pyi @@ -1,3 +1,4 @@ +import math from typing import Any, TypeAlias, TypeVar from typing_extensions import assert_type @@ -23,3 +24,35 @@ MAR_f4: MaskedNDArray[np.float32] MAR_i8: MaskedNDArray[np.int64] MAR_subclass: MaskedNDArraySubclass MAR_1d: np.ma.MaskedArray[tuple[int], np.dtype[Any]] + +assert_type(MAR_b.argmin(), np.intp) +assert_type(MAR_f4.argmin(), np.intp) +assert_type(MAR_f4.argmax(fill_value=math.tau, keepdims=False), np.intp) +assert_type(MAR_b.argmin(axis=0), Any) +assert_type(MAR_f4.argmin(axis=0), Any) +assert_type(MAR_b.argmin(keepdims=True), Any) +assert_type(MAR_f4.argmin(out=MAR_subclass), MaskedNDArraySubclass) +assert_type(MAR_f4.argmin(None, None, out=MAR_subclass), MaskedNDArraySubclass) + +assert_type(np.ma.argmin(MAR_b), np.intp) +assert_type(np.ma.argmin(MAR_f4), np.intp) +assert_type(np.ma.argmin(MAR_f4, fill_value=math.tau, keepdims=False), np.intp) +assert_type(np.ma.argmin(MAR_b, axis=0), Any) +assert_type(np.ma.argmin(MAR_f4, axis=0), Any) +assert_type(np.ma.argmin(MAR_b, keepdims=True), Any) + +assert_type(MAR_b.argmax(), np.intp) +assert_type(MAR_f4.argmax(), np.intp) +assert_type(MAR_f4.argmax(fill_value=math.tau, keepdims=False), np.intp) +assert_type(MAR_b.argmax(axis=0), Any) +assert_type(MAR_f4.argmax(axis=0), Any) +assert_type(MAR_b.argmax(keepdims=True), Any) +assert_type(MAR_f4.argmax(out=MAR_subclass), MaskedNDArraySubclass) +assert_type(MAR_f4.argmax(None, None, out=MAR_subclass), MaskedNDArraySubclass) + +assert_type(np.ma.argmax(MAR_b), np.intp) +assert_type(np.ma.argmax(MAR_f4), np.intp) +assert_type(np.ma.argmax(MAR_f4, fill_value=math.tau, keepdims=False), np.intp) +assert_type(np.ma.argmax(MAR_b, axis=0), Any) +assert_type(np.ma.argmax(MAR_f4, axis=0), Any) +assert_type(np.ma.argmax(MAR_b, keepdims=True), Any) diff --git a/src/numpy-stubs/@test/static/reject/ma.pyi b/src/numpy-stubs/@test/static/reject/ma.pyi index 8463f8d5..df68eaed 100644 --- a/src/numpy-stubs/@test/static/reject/ma.pyi +++ b/src/numpy-stubs/@test/static/reject/ma.pyi @@ -9,3 +9,17 @@ np.amin(m, axis=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgu np.amin(m, keepdims=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue] np.amin(m, out=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue] np.amin(m, fill_value=lambda x: 27) # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportUnknownLambdaType] + +m.argmin(axis=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue] +m.argmin(keepdims=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue] +m.argmin(out=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue] +m.argmin(fill_value=lambda x: 27) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue, reportUnknownLambdaType] + +np.ma.argmin(m, axis=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue] +np.ma.argmin(m, axis=(1,)) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue] +np.ma.argmin(m, keepdims=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue] +np.ma.argmin(m, out=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue] +np.ma.argmin(m, fill_value=lambda x: 27) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue, reportUnknownLambdaType] + +m.argmax(axis=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue] + diff --git a/src/numpy-stubs/ma/core.pyi b/src/numpy-stubs/ma/core.pyi index 2037feaa..988be4b3 100644 --- a/src/numpy-stubs/ma/core.pyi +++ b/src/numpy-stubs/ma/core.pyi @@ -4,7 +4,7 @@ from typing_extensions import Never, Self, TypeVar, deprecated, overload, overri import numpy as np from _numtype import Array, ToGeneric_0d, ToGeneric_1nd, ToGeneric_nd -from numpy import _OrderACF, _OrderKACF, amax, amin, bool_, expand_dims # noqa: ICN003 +from numpy import _OrderACF, _OrderKACF, amax, amin, bool_, expand_dims, intp # noqa: ICN003 from numpy._globals import _NoValueType from numpy._typing import ArrayLike, _ArrayLike, _BoolCodes, _ScalarLike_co, _ShapeLike @@ -189,7 +189,12 @@ __all__ = [ "zeros_like", ] +### + _ArrayT = TypeVar("_ArrayT", bound=np.ndarray[Any, Any]) + +### + _UFuncT_co = TypeVar("_UFuncT_co", bound=np.ufunc, default=np.ufunc, covariant=True) _ScalarT = TypeVar("_ScalarT", bound=np.generic) _ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...]) @@ -647,15 +652,44 @@ class MaskedArray(np.ndarray[_ShapeT_co, _DTypeT_co]): fill_value: Incomplete = ..., keepdims: Incomplete = ..., ) -> Incomplete: ... - @override - def argmin( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] + + # Keep in-sync with np.ma.argmin + @overload # type: ignore[override] + def argmin( self, - axis: Incomplete = ..., - fill_value: Incomplete = ..., - out: Incomplete = ..., + axis: None = None, + fill_value: _ScalarLike_co | None = None, + out: None = None, *, - keepdims: Incomplete = ..., - ) -> Incomplete: ... + keepdims: L[False] | _NoValueType = ..., + ) -> intp: ... + @overload + def argmin( + self, + axis: CanIndex | None = None, + fill_value: _ScalarLike_co | None = None, + out: None = None, + *, + keepdims: bool | _NoValueType = ..., + ) -> Any: ... + @overload + def argmin( + self, + axis: CanIndex | None = None, + fill_value: _ScalarLike_co | None = None, + *, + out: _ArrayT, + keepdims: bool | _NoValueType = ..., + ) -> _ArrayT: ... + @overload + def argmin( # pyright: ignore[reportIncompatibleMethodOverride] + self, + axis: CanIndex | None, + fill_value: _ScalarLike_co | None, + out: _ArrayT, + *, + keepdims: bool | _NoValueType = ..., + ) -> _ArrayT: ... # @override @@ -666,15 +700,44 @@ class MaskedArray(np.ndarray[_ShapeT_co, _DTypeT_co]): fill_value: Incomplete = ..., keepdims: Incomplete = ..., ) -> Incomplete: ... - @override - def argmax( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] + + # Keep in-sync with np.ma.argmax + @overload # type: ignore[override] + def argmax( self, - axis: Incomplete = ..., - fill_value: Incomplete = ..., - out: Incomplete = ..., + axis: None = None, + fill_value: _ScalarLike_co | None = None, + out: None = None, *, - keepdims: Incomplete = ..., - ) -> Incomplete: ... + keepdims: L[False] | _NoValueType = ..., + ) -> intp: ... + @overload + def argmax( + self, + axis: CanIndex | None = None, + fill_value: _ScalarLike_co | None = None, + out: None = None, + *, + keepdims: bool | _NoValueType = ..., + ) -> Any: ... + @overload + def argmax( + self, + axis: CanIndex | None = None, + fill_value: _ScalarLike_co | None = None, + *, + out: _ArrayT, + keepdims: bool | _NoValueType = ..., + ) -> _ArrayT: ... + @overload + def argmax( # pyright: ignore[reportIncompatibleMethodOverride] + self, + axis: CanIndex | None, + fill_value: _ScalarLike_co | None, + out: _ArrayT, + *, + keepdims: bool | _NoValueType = ..., + ) -> _ArrayT: ... # @override @@ -1091,8 +1154,81 @@ swapaxes: _frommethod trace: _frommethod var: _frommethod count: _frommethod -argmin: _frommethod -argmax: _frommethod - minimum: _extrema_operation maximum: _extrema_operation + +# +@overload +def argmin( + a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse] + axis: None = None, + fill_value: _ScalarLike_co | None = None, + out: None = None, + *, + keepdims: L[False] | _NoValueType = ..., +) -> intp: ... +@overload +def argmin( + a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse] + axis: CanIndex | None = None, + fill_value: _ScalarLike_co | None = None, + out: None = None, + *, + keepdims: bool | _NoValueType = ..., +) -> Any: ... +@overload +def argmin( + a: _ArrayT, + axis: CanIndex | None = None, + fill_value: _ScalarLike_co | None = None, + *, + out: _ArrayT, + keepdims: bool | _NoValueType = ..., +) -> _ArrayT: ... +@overload +def argmin( + a: _ArrayT, + axis: CanIndex | None, + fill_value: _ScalarLike_co | None, + out: _ArrayT, + *, + keepdims: bool | _NoValueType = ..., +) -> _ArrayT: ... + +# +@overload +def argmax( + a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse] + axis: None = None, + fill_value: _ScalarLike_co | None = None, + out: None = None, + *, + keepdims: L[False] | _NoValueType = ..., +) -> intp: ... +@overload +def argmax( + a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse] + axis: CanIndex | None = None, + fill_value: _ScalarLike_co | None = None, + out: None = None, + *, + keepdims: bool | _NoValueType = ..., +) -> Any: ... +@overload +def argmax( + a: _ArrayT, + axis: CanIndex | None = None, + fill_value: _ScalarLike_co | None = None, + *, + out: _ArrayT, + keepdims: bool | _NoValueType = ..., +) -> _ArrayT: ... +@overload +def argmax( + a: _ArrayT, + axis: CanIndex | None, + fill_value: _ScalarLike_co | None, + out: _ArrayT, + *, + keepdims: bool | _NoValueType = ..., +) -> _ArrayT: ... From 4909fe61e143adf883d03cf8f1e330073cd24a44 Mon Sep 17 00:00:00 2001 From: GUAN MING Date: Tue, 8 Apr 2025 13:19:28 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E2=9C=A8=20overloaded=20`call`=20signature?= =?UTF-8?q?s?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/numpy-stubs/@test/static/reject/ma.pyi | 1 - src/numpy-stubs/ma/core.pyi | 126 ++++++++------------- 2 files changed, 50 insertions(+), 77 deletions(-) diff --git a/src/numpy-stubs/@test/static/reject/ma.pyi b/src/numpy-stubs/@test/static/reject/ma.pyi index df68eaed..a90857c6 100644 --- a/src/numpy-stubs/@test/static/reject/ma.pyi +++ b/src/numpy-stubs/@test/static/reject/ma.pyi @@ -22,4 +22,3 @@ np.ma.argmin(m, out=1.0) # type: ignore[call-overload] # pyright: ignore[report np.ma.argmin(m, fill_value=lambda x: 27) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue, reportUnknownLambdaType] m.argmax(axis=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue] - diff --git a/src/numpy-stubs/ma/core.pyi b/src/numpy-stubs/ma/core.pyi index 988be4b3..43d03ee3 100644 --- a/src/numpy-stubs/ma/core.pyi +++ b/src/numpy-stubs/ma/core.pyi @@ -863,6 +863,54 @@ class _frommethod: def __call__(self, a: Incomplete, *args: Incomplete, **params: Incomplete) -> Incomplete: ... def getdoc(self) -> Incomplete: ... +@type_check_only +class _ArgMinMaxMethod: + __name__: str + __doc__: str + reversed: Incomplete + def __init__(self, methodname: Incomplete, reversed: Incomplete = ...) -> None: ... + @overload + def __call__( + self, + a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse] + axis: None = None, + fill_value: _ScalarLike_co | None = None, + out: None = None, + *, + keepdims: L[False] | _NoValueType = ..., + ) -> intp: ... + @overload + def __call__( + self, + a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse] + axis: CanIndex | None = None, + fill_value: _ScalarLike_co | None = None, + out: None = None, + *, + keepdims: bool | _NoValueType = ..., + ) -> Any: ... + @overload + def __call__( + self, + a: _ArrayT, + axis: CanIndex | None = None, + fill_value: _ScalarLike_co | None = None, + *, + out: _ArrayT, + keepdims: bool | _NoValueType = ..., + ) -> _ArrayT: ... + @overload + def __call__( + self, + a: _ArrayT, + axis: CanIndex | None, + fill_value: _ScalarLike_co | None, + out: _ArrayT, + *, + keepdims: bool | _NoValueType = ..., + ) -> _ArrayT: ... + def getdoc(self) -> Incomplete: ... + class _convert2ma: def __init__(self, /, funcname: str, np_ret: str, np_ma_ret: str, params: dict[str, Any] | None = None) -> None: ... def __call__(self, /, *args: object, **params: object) -> Any: ... @@ -1154,81 +1202,7 @@ swapaxes: _frommethod trace: _frommethod var: _frommethod count: _frommethod +argmin: _ArgMinMaxMethod +argmax: _ArgMinMaxMethod minimum: _extrema_operation maximum: _extrema_operation - -# -@overload -def argmin( - a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse] - axis: None = None, - fill_value: _ScalarLike_co | None = None, - out: None = None, - *, - keepdims: L[False] | _NoValueType = ..., -) -> intp: ... -@overload -def argmin( - a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse] - axis: CanIndex | None = None, - fill_value: _ScalarLike_co | None = None, - out: None = None, - *, - keepdims: bool | _NoValueType = ..., -) -> Any: ... -@overload -def argmin( - a: _ArrayT, - axis: CanIndex | None = None, - fill_value: _ScalarLike_co | None = None, - *, - out: _ArrayT, - keepdims: bool | _NoValueType = ..., -) -> _ArrayT: ... -@overload -def argmin( - a: _ArrayT, - axis: CanIndex | None, - fill_value: _ScalarLike_co | None, - out: _ArrayT, - *, - keepdims: bool | _NoValueType = ..., -) -> _ArrayT: ... - -# -@overload -def argmax( - a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse] - axis: None = None, - fill_value: _ScalarLike_co | None = None, - out: None = None, - *, - keepdims: L[False] | _NoValueType = ..., -) -> intp: ... -@overload -def argmax( - a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse] - axis: CanIndex | None = None, - fill_value: _ScalarLike_co | None = None, - out: None = None, - *, - keepdims: bool | _NoValueType = ..., -) -> Any: ... -@overload -def argmax( - a: _ArrayT, - axis: CanIndex | None = None, - fill_value: _ScalarLike_co | None = None, - *, - out: _ArrayT, - keepdims: bool | _NoValueType = ..., -) -> _ArrayT: ... -@overload -def argmax( - a: _ArrayT, - axis: CanIndex | None, - fill_value: _ScalarLike_co | None, - out: _ArrayT, - *, - keepdims: bool | _NoValueType = ..., -) -> _ArrayT: ... From 0e7d46513346d93f55572240f7d0ffc75d56e1f2 Mon Sep 17 00:00:00 2001 From: GUAN MING Date: Fri, 11 Apr 2025 01:26:54 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=F0=9F=8F=B7=EF=B8=8F=20add=20generic=20typ?= =?UTF-8?q?e?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/numpy-stubs/ma/core.pyi | 98 ++++++++++++++++++++++++++----------- 1 file changed, 69 insertions(+), 29 deletions(-) diff --git a/src/numpy-stubs/ma/core.pyi b/src/numpy-stubs/ma/core.pyi index 43d03ee3..6b257d91 100644 --- a/src/numpy-stubs/ma/core.pyi +++ b/src/numpy-stubs/ma/core.pyi @@ -1,6 +1,6 @@ from _typeshed import Incomplete from typing import Any, ClassVar, Final, Generic, Literal as L, SupportsIndex as CanIndex, TypeAlias, type_check_only -from typing_extensions import Never, Self, TypeVar, deprecated, overload, override +from typing_extensions import Never, Protocol, Self, TypeVar, deprecated, overload, override import numpy as np from _numtype import Array, ToGeneric_0d, ToGeneric_1nd, ToGeneric_nd @@ -203,6 +203,7 @@ _DTypeT = TypeVar("_DTypeT", bound=np.dtype) _DTypeT_co = TypeVar("_DTypeT_co", bound=np.dtype, default=np.dtype, covariant=True) _DTypeLikeBool: TypeAlias = type[bool | np.bool] | np.dtype[np.bool] | _BoolCodes +_ScalarT_co = TypeVar("_ScalarT_co", bound=np.generic, covariant=True) ### @@ -855,12 +856,51 @@ class MaskedConstant(MaskedArray[tuple[()], np.dtype[np.float64]]): @override def copy(self, /, *args: object, **kwargs: object) -> Incomplete: ... -class _frommethod: +class _frommethod(Protocol[_ScalarT_co]): __name__: str __doc__: str reversed: Incomplete def __init__(self, methodname: Incomplete, reversed: Incomplete = ...) -> None: ... - def __call__(self, a: Incomplete, *args: Incomplete, **params: Incomplete) -> Incomplete: ... + @overload + def __call__( + self, + a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse] + axis: None = None, + fill_value: _ScalarLike_co | None = None, + out: None = None, + *, + keepdims: L[False] | _NoValueType = ..., + ) -> _ScalarT_co: ... + @overload + def __call__( + self, + a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse] + axis: CanIndex | None = None, + fill_value: _ScalarLike_co | None = None, + out: None = None, + *, + keepdims: bool | _NoValueType = ..., + ) -> Any: ... + @overload + def __call__( + self, + a: _ArrayT, + axis: CanIndex | None = None, + fill_value: _ScalarLike_co | None = None, + *, + out: _ArrayT, + keepdims: bool | _NoValueType = ..., + ) -> _ArrayT: ... + @overload + def __call__( + self, + a: _ArrayT, + axis: CanIndex | None, + fill_value: _ScalarLike_co | None, + out: _ArrayT, + *, + keepdims: bool | _NoValueType = ..., + ) -> _ArrayT: ... def getdoc(self) -> Incomplete: ... @type_check_only @@ -1178,31 +1218,31 @@ squeeze: _convert2ma zeros: _convert2ma zeros_like: _convert2ma -all: _frommethod -anomalies: _frommethod -anom: _frommethod -any: _frommethod -compress: _frommethod -cumprod: _frommethod -cumsum: _frommethod -copy: _frommethod -diagonal: _frommethod -harden_mask: _frommethod -ids: _frommethod -mean: _frommethod -nonzero: _frommethod -prod: _frommethod -product: _frommethod -ravel: _frommethod -repeat: _frommethod -soften_mask: _frommethod -std: _frommethod -sum: _frommethod -swapaxes: _frommethod -trace: _frommethod -var: _frommethod -count: _frommethod -argmin: _ArgMinMaxMethod -argmax: _ArgMinMaxMethod +all: _frommethod[Any] +anomalies: _frommethod[Any] +anom: _frommethod[Any] +any: _frommethod[Any] +compress: _frommethod[Any] +cumprod: _frommethod[Any] +cumsum: _frommethod[Any] +copy: _frommethod[Any] +diagonal: _frommethod[Any] +harden_mask: _frommethod[Any] +ids: _frommethod[Any] +mean: _frommethod[Any] +nonzero: _frommethod[Any] +prod: _frommethod[Any] +product: _frommethod[Any] +ravel: _frommethod[Any] +repeat: _frommethod[Any] +soften_mask: _frommethod[Any] +std: _frommethod[Any] +sum: _frommethod[Any] +swapaxes: _frommethod[Any] +trace: _frommethod[Any] +var: _frommethod[Any] +count: _frommethod[Any] +argmin: _frommethod[intp] +argmax: _frommethod[intp] minimum: _extrema_operation maximum: _extrema_operation