Skip to content

🚚 port ma.arg{min,max} and MaskedArray.arg{min,max} #468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/numpy-stubs/@test/static/accept/ma.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import Any, TypeAlias, TypeVar
from typing_extensions import assert_type

Expand All @@ -23,3 +24,35 @@ MAR_f4: MaskedNDArray[np.float32]
MAR_i8: MaskedNDArray[np.int64]
MAR_subclass: MaskedNDArraySubclass
MAR_1d: np.ma.MaskedArray[tuple[int], np.dtype[Any]]

assert_type(MAR_b.argmin(), np.intp)
assert_type(MAR_f4.argmin(), np.intp)
assert_type(MAR_f4.argmax(fill_value=math.tau, keepdims=False), np.intp)
assert_type(MAR_b.argmin(axis=0), Any)
assert_type(MAR_f4.argmin(axis=0), Any)
assert_type(MAR_b.argmin(keepdims=True), Any)
assert_type(MAR_f4.argmin(out=MAR_subclass), MaskedNDArraySubclass)
assert_type(MAR_f4.argmin(None, None, out=MAR_subclass), MaskedNDArraySubclass)

assert_type(np.ma.argmin(MAR_b), np.intp)
assert_type(np.ma.argmin(MAR_f4), np.intp)
assert_type(np.ma.argmin(MAR_f4, fill_value=math.tau, keepdims=False), np.intp)
assert_type(np.ma.argmin(MAR_b, axis=0), Any)
assert_type(np.ma.argmin(MAR_f4, axis=0), Any)
assert_type(np.ma.argmin(MAR_b, keepdims=True), Any)

assert_type(MAR_b.argmax(), np.intp)
assert_type(MAR_f4.argmax(), np.intp)
assert_type(MAR_f4.argmax(fill_value=math.tau, keepdims=False), np.intp)
assert_type(MAR_b.argmax(axis=0), Any)
assert_type(MAR_f4.argmax(axis=0), Any)
assert_type(MAR_b.argmax(keepdims=True), Any)
assert_type(MAR_f4.argmax(out=MAR_subclass), MaskedNDArraySubclass)
assert_type(MAR_f4.argmax(None, None, out=MAR_subclass), MaskedNDArraySubclass)

assert_type(np.ma.argmax(MAR_b), np.intp)
assert_type(np.ma.argmax(MAR_f4), np.intp)
assert_type(np.ma.argmax(MAR_f4, fill_value=math.tau, keepdims=False), np.intp)
assert_type(np.ma.argmax(MAR_b, axis=0), Any)
assert_type(np.ma.argmax(MAR_f4, axis=0), Any)
assert_type(np.ma.argmax(MAR_b, keepdims=True), Any)
13 changes: 13 additions & 0 deletions src/numpy-stubs/@test/static/reject/ma.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,16 @@ np.amin(m, axis=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgu
np.amin(m, keepdims=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
np.amin(m, out=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
np.amin(m, fill_value=lambda x: 27) # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportUnknownLambdaType]

m.argmin(axis=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
m.argmin(keepdims=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
m.argmin(out=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
m.argmin(fill_value=lambda x: 27) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue, reportUnknownLambdaType]

np.ma.argmin(m, axis=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
np.ma.argmin(m, axis=(1,)) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
np.ma.argmin(m, keepdims=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
np.ma.argmin(m, out=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
np.ma.argmin(m, fill_value=lambda x: 27) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue, reportUnknownLambdaType]

m.argmax(axis=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
240 changes: 195 additions & 45 deletions src/numpy-stubs/ma/core.pyi
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from _typeshed import Incomplete
from typing import Any, ClassVar, Final, Generic, Literal as L, SupportsIndex as CanIndex, TypeAlias, type_check_only
from typing_extensions import Never, Self, TypeVar, deprecated, overload, override
from typing_extensions import Never, Protocol, Self, TypeVar, deprecated, overload, override

import numpy as np
from _numtype import Array, ToGeneric_0d, ToGeneric_1nd, ToGeneric_nd
from numpy import _OrderACF, _OrderKACF, amax, amin, bool_, expand_dims # noqa: ICN003
from numpy import _OrderACF, _OrderKACF, amax, amin, bool_, expand_dims, intp # noqa: ICN003
from numpy._globals import _NoValueType
from numpy._typing import ArrayLike, _ArrayLike, _BoolCodes, _ScalarLike_co, _ShapeLike

Expand Down Expand Up @@ -189,7 +189,12 @@ __all__ = [
"zeros_like",
]

###

_ArrayT = TypeVar("_ArrayT", bound=np.ndarray[Any, Any])

###

_UFuncT_co = TypeVar("_UFuncT_co", bound=np.ufunc, default=np.ufunc, covariant=True)
_ScalarT = TypeVar("_ScalarT", bound=np.generic)
_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...])
Expand All @@ -198,6 +203,7 @@ _DTypeT = TypeVar("_DTypeT", bound=np.dtype)
_DTypeT_co = TypeVar("_DTypeT_co", bound=np.dtype, default=np.dtype, covariant=True)

_DTypeLikeBool: TypeAlias = type[bool | np.bool] | np.dtype[np.bool] | _BoolCodes
_ScalarT_co = TypeVar("_ScalarT_co", bound=np.generic, covariant=True)

###

Expand Down Expand Up @@ -647,15 +653,44 @@ class MaskedArray(np.ndarray[_ShapeT_co, _DTypeT_co]):
fill_value: Incomplete = ...,
keepdims: Incomplete = ...,
) -> Incomplete: ...
@override
def argmin( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]

# Keep in-sync with np.ma.argmin
@overload # type: ignore[override]
def argmin(
self,
axis: Incomplete = ...,
fill_value: Incomplete = ...,
out: Incomplete = ...,
axis: None = None,
fill_value: _ScalarLike_co | None = None,
out: None = None,
*,
keepdims: Incomplete = ...,
) -> Incomplete: ...
keepdims: L[False] | _NoValueType = ...,
) -> intp: ...
@overload
def argmin(
self,
axis: CanIndex | None = None,
fill_value: _ScalarLike_co | None = None,
out: None = None,
*,
keepdims: bool | _NoValueType = ...,
) -> Any: ...
@overload
def argmin(
self,
axis: CanIndex | None = None,
fill_value: _ScalarLike_co | None = None,
*,
out: _ArrayT,
keepdims: bool | _NoValueType = ...,
) -> _ArrayT: ...
@overload
def argmin( # pyright: ignore[reportIncompatibleMethodOverride]
self,
axis: CanIndex | None,
fill_value: _ScalarLike_co | None,
out: _ArrayT,
*,
keepdims: bool | _NoValueType = ...,
) -> _ArrayT: ...

#
@override
Expand All @@ -666,15 +701,44 @@ class MaskedArray(np.ndarray[_ShapeT_co, _DTypeT_co]):
fill_value: Incomplete = ...,
keepdims: Incomplete = ...,
) -> Incomplete: ...
@override
def argmax( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]

# Keep in-sync with np.ma.argmax
@overload # type: ignore[override]
def argmax(
self,
axis: Incomplete = ...,
fill_value: Incomplete = ...,
out: Incomplete = ...,
axis: None = None,
fill_value: _ScalarLike_co | None = None,
out: None = None,
*,
keepdims: Incomplete = ...,
) -> Incomplete: ...
keepdims: L[False] | _NoValueType = ...,
) -> intp: ...
@overload
def argmax(
self,
axis: CanIndex | None = None,
fill_value: _ScalarLike_co | None = None,
out: None = None,
*,
keepdims: bool | _NoValueType = ...,
) -> Any: ...
@overload
def argmax(
self,
axis: CanIndex | None = None,
fill_value: _ScalarLike_co | None = None,
*,
out: _ArrayT,
keepdims: bool | _NoValueType = ...,
) -> _ArrayT: ...
@overload
def argmax( # pyright: ignore[reportIncompatibleMethodOverride]
self,
axis: CanIndex | None,
fill_value: _ScalarLike_co | None,
out: _ArrayT,
*,
keepdims: bool | _NoValueType = ...,
) -> _ArrayT: ...

#
@override
Expand Down Expand Up @@ -792,12 +856,99 @@ class MaskedConstant(MaskedArray[tuple[()], np.dtype[np.float64]]):
@override
def copy(self, /, *args: object, **kwargs: object) -> Incomplete: ...

class _frommethod:
class _frommethod(Protocol[_ScalarT_co]):
__name__: str
__doc__: str
reversed: Incomplete
def __init__(self, methodname: Incomplete, reversed: Incomplete = ...) -> None: ...
@overload
def __call__(
self,
a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse]
axis: None = None,
fill_value: _ScalarLike_co | None = None,
out: None = None,
*,
keepdims: L[False] | _NoValueType = ...,
) -> _ScalarT_co: ...
@overload
def __call__(
self,
a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse]
axis: CanIndex | None = None,
fill_value: _ScalarLike_co | None = None,
out: None = None,
*,
keepdims: bool | _NoValueType = ...,
) -> Any: ...
@overload
def __call__(
self,
a: _ArrayT,
axis: CanIndex | None = None,
fill_value: _ScalarLike_co | None = None,
*,
out: _ArrayT,
keepdims: bool | _NoValueType = ...,
) -> _ArrayT: ...
@overload
def __call__(
self,
a: _ArrayT,
axis: CanIndex | None,
fill_value: _ScalarLike_co | None,
out: _ArrayT,
*,
keepdims: bool | _NoValueType = ...,
) -> _ArrayT: ...
def getdoc(self) -> Incomplete: ...

@type_check_only
class _ArgMinMaxMethod:
__name__: str
__doc__: str
reversed: Incomplete
def __init__(self, methodname: Incomplete, reversed: Incomplete = ...) -> None: ...
def __call__(self, a: Incomplete, *args: Incomplete, **params: Incomplete) -> Incomplete: ...
@overload
def __call__(
self,
a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse]
axis: None = None,
fill_value: _ScalarLike_co | None = None,
out: None = None,
*,
keepdims: L[False] | _NoValueType = ...,
) -> intp: ...
@overload
def __call__(
self,
a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse]
axis: CanIndex | None = None,
fill_value: _ScalarLike_co | None = None,
out: None = None,
*,
keepdims: bool | _NoValueType = ...,
) -> Any: ...
@overload
def __call__(
self,
a: _ArrayT,
axis: CanIndex | None = None,
fill_value: _ScalarLike_co | None = None,
*,
out: _ArrayT,
keepdims: bool | _NoValueType = ...,
) -> _ArrayT: ...
@overload
def __call__(
self,
a: _ArrayT,
axis: CanIndex | None,
fill_value: _ScalarLike_co | None,
out: _ArrayT,
*,
keepdims: bool | _NoValueType = ...,
) -> _ArrayT: ...
def getdoc(self) -> Incomplete: ...

class _convert2ma:
Expand Down Expand Up @@ -1067,32 +1218,31 @@ squeeze: _convert2ma
zeros: _convert2ma
zeros_like: _convert2ma

all: _frommethod
anomalies: _frommethod
anom: _frommethod
any: _frommethod
compress: _frommethod
cumprod: _frommethod
cumsum: _frommethod
copy: _frommethod
diagonal: _frommethod
harden_mask: _frommethod
ids: _frommethod
mean: _frommethod
nonzero: _frommethod
prod: _frommethod
product: _frommethod
ravel: _frommethod
repeat: _frommethod
soften_mask: _frommethod
std: _frommethod
sum: _frommethod
swapaxes: _frommethod
trace: _frommethod
var: _frommethod
count: _frommethod
argmin: _frommethod
argmax: _frommethod

all: _frommethod[Any]
anomalies: _frommethod[Any]
anom: _frommethod[Any]
any: _frommethod[Any]
compress: _frommethod[Any]
cumprod: _frommethod[Any]
cumsum: _frommethod[Any]
copy: _frommethod[Any]
diagonal: _frommethod[Any]
harden_mask: _frommethod[Any]
ids: _frommethod[Any]
mean: _frommethod[Any]
nonzero: _frommethod[Any]
prod: _frommethod[Any]
product: _frommethod[Any]
ravel: _frommethod[Any]
repeat: _frommethod[Any]
soften_mask: _frommethod[Any]
std: _frommethod[Any]
sum: _frommethod[Any]
swapaxes: _frommethod[Any]
trace: _frommethod[Any]
var: _frommethod[Any]
count: _frommethod[Any]
argmin: _frommethod[intp]
argmax: _frommethod[intp]
minimum: _extrema_operation
maximum: _extrema_operation