Skip to content

Commit 0bd5801

Browse files
committed
🚚 port ma.arg{min,max} and MaskedArray.arg{min,max}
1 parent 79b5462 commit 0bd5801

File tree

4 files changed

+224
-31
lines changed

4 files changed

+224
-31
lines changed

‎src/numpy-stubs/@test/static/accept/ma.pyi

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,45 @@ assert_type(m.dtype, np.dtype[np.float64])
1010

1111
assert_type(int(m), int)
1212
assert_type(float(m), float)
13+
ScalarType_co = TypeVar("_ScalarType_co", bound=np.generic, covariant=True)
14+
MaskedNDArray: TypeAlias = np.ma.MaskedArray[_Shape, np.dtype[_ScalarType_co]]
15+
16+
class MaskedNDArraySubclass(MaskedNDArray[np.complex128]): ...
17+
18+
MAR_b: MaskedNDArray[np.bool]
19+
MAR_f4: MaskedNDArray[np.float32]
20+
MAR_i8: MaskedNDArray[np.int64]
21+
MAR_subclass: MaskedNDArraySubclass
22+
MAR_1d: np.ma.MaskedArray[tuple[int], np.dtype[Any]]
23+
24+
assert_type(MAR_b.argmin(), np.intp)
25+
assert_type(MAR_f4.argmin(), np.intp)
26+
assert_type(MAR_f4.argmax(fill_value=math.tau, keepdims=False), np.intp)
27+
assert_type(MAR_b.argmin(axis=0), Any)
28+
assert_type(MAR_f4.argmin(axis=0), Any)
29+
assert_type(MAR_b.argmin(keepdims=True), Any)
30+
assert_type(MAR_f4.argmin(out=MAR_subclass), MaskedNDArraySubclass)
31+
assert_type(MAR_f4.argmin(None, None, out=MAR_subclass), MaskedNDArraySubclass)
32+
33+
assert_type(np.ma.argmin(MAR_b), np.intp)
34+
assert_type(np.ma.argmin(MAR_f4), np.intp)
35+
assert_type(np.ma.argmin(MAR_f4, fill_value=math.tau, keepdims=False), np.intp)
36+
assert_type(np.ma.argmin(MAR_b, axis=0), Any)
37+
assert_type(np.ma.argmin(MAR_f4, axis=0), Any)
38+
assert_type(np.ma.argmin(MAR_b, keepdims=True), Any)
39+
40+
assert_type(MAR_b.argmax(), np.intp)
41+
assert_type(MAR_f4.argmax(), np.intp)
42+
assert_type(MAR_f4.argmax(fill_value=math.tau, keepdims=False), np.intp)
43+
assert_type(MAR_b.argmax(axis=0), Any)
44+
assert_type(MAR_f4.argmax(axis=0), Any)
45+
assert_type(MAR_b.argmax(keepdims=True), Any)
46+
assert_type(MAR_f4.argmax(out=MAR_subclass), MaskedNDArraySubclass)
47+
assert_type(MAR_f4.argmax(None, None, out=MAR_subclass), MaskedNDArraySubclass)
48+
49+
assert_type(np.ma.argmax(MAR_b), np.intp)
50+
assert_type(np.ma.argmax(MAR_f4), np.intp)
51+
assert_type(np.ma.argmax(MAR_f4, fill_value=math.tau, keepdims=False), np.intp)
52+
assert_type(np.ma.argmax(MAR_b, axis=0), Any)
53+
assert_type(np.ma.argmax(MAR_f4, axis=0), Any)
54+
assert_type(np.ma.argmax(MAR_b, keepdims=True), Any)

‎src/numpy-stubs/@test/static/reject/ma.pyi

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,16 @@ m: np.ma.MaskedArray[tuple[int], np.dtype[np.float64]]
44

55
m.shape = (3, 1) # type: ignore[assignment]
66
m.dtype = np.bool # type: ignore[assignment] # pyright: ignore[reportAttributeAccessIssue]
7+
8+
m.argmin(axis=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
9+
m.argmin(keepdims=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
10+
m.argmin(out=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
11+
m.argmin(fill_value=lambda x: 27) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue, reportUnknownLambdaType]
12+
13+
np.ma.argmin(m, axis=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
14+
np.ma.argmin(m, axis=(1,)) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
15+
np.ma.argmin(m, keepdims=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
16+
np.ma.argmin(m, out=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
17+
np.ma.argmin(m, fill_value=lambda x: 27) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue, reportUnknownLambdaType]
18+
19+
m.argmax(axis=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]

‎src/numpy-stubs/ma/core.pyi

Lines changed: 157 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ from typing_extensions import Never, Self, TypeVar, deprecated, overload, overri
44

55
import numpy as np
66
from _numtype import Array, ToGeneric_0d, ToGeneric_1nd, ToGeneric_nd
7-
from numpy import _OrderACF, _OrderKACF, amax, amin, bool_, expand_dims # noqa: ICN003
8-
from numpy._typing import _BoolCodes
7+
from numpy import _OrderACF, _OrderKACF, amax, amin, bool_, expand_dims, intp # noqa: ICN003
8+
from numpy._globals import _NoValueType
9+
from numpy._typing import _BoolCodes, _ScalarLike_co
910

1011
__all__ = [
1112
"MAError",
@@ -188,6 +189,12 @@ __all__ = [
188189
"zeros_like",
189190
]
190191

192+
###
193+
194+
_ArrayT = TypeVar("_ArrayT", bound=np.ndarray[Any, Any])
195+
196+
###
197+
191198
_UFuncT_co = TypeVar("_UFuncT_co", bound=np.ufunc, default=np.ufunc, covariant=True)
192199
_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...])
193200
_ShapeT_co = TypeVar("_ShapeT_co", bound=tuple[int, ...], default=tuple[int, ...], covariant=True)
@@ -644,15 +651,44 @@ class MaskedArray(np.ndarray[_ShapeT_co, _DTypeT_co]):
644651
fill_value: Incomplete = ...,
645652
keepdims: Incomplete = ...,
646653
) -> Incomplete: ...
647-
@override
648-
def argmin( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
654+
655+
# Keep in-sync with np.ma.argmin
656+
@overload # type: ignore[override]
657+
def argmin(
649658
self,
650-
axis: Incomplete = ...,
651-
fill_value: Incomplete = ...,
652-
out: Incomplete = ...,
659+
axis: None = None,
660+
fill_value: _ScalarLike_co | None = None,
661+
out: None = None,
653662
*,
654-
keepdims: Incomplete = ...,
655-
) -> Incomplete: ...
663+
keepdims: L[False] | _NoValueType = ...,
664+
) -> intp: ...
665+
@overload
666+
def argmin(
667+
self,
668+
axis: CanIndex | None = None,
669+
fill_value: _ScalarLike_co | None = None,
670+
out: None = None,
671+
*,
672+
keepdims: bool | _NoValueType = ...,
673+
) -> Any: ...
674+
@overload
675+
def argmin(
676+
self,
677+
axis: CanIndex | None = None,
678+
fill_value: _ScalarLike_co | None = None,
679+
*,
680+
out: _ArrayT,
681+
keepdims: bool | _NoValueType = ...,
682+
) -> _ArrayT: ...
683+
@overload
684+
def argmin( # pyright: ignore[reportIncompatibleMethodOverride]
685+
self,
686+
axis: CanIndex | None,
687+
fill_value: _ScalarLike_co | None,
688+
out: _ArrayT,
689+
*,
690+
keepdims: bool | _NoValueType = ...,
691+
) -> _ArrayT: ...
656692

657693
#
658694
@override
@@ -663,15 +699,44 @@ class MaskedArray(np.ndarray[_ShapeT_co, _DTypeT_co]):
663699
fill_value: Incomplete = ...,
664700
keepdims: Incomplete = ...,
665701
) -> Incomplete: ...
666-
@override
667-
def argmax( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
702+
703+
# Keep in-sync with np.ma.argmax
704+
@overload # type: ignore[override]
705+
def argmax(
668706
self,
669-
axis: Incomplete = ...,
670-
fill_value: Incomplete = ...,
671-
out: Incomplete = ...,
707+
axis: None = None,
708+
fill_value: _ScalarLike_co | None = None,
709+
out: None = None,
672710
*,
673-
keepdims: Incomplete = ...,
674-
) -> Incomplete: ...
711+
keepdims: L[False] | _NoValueType = ...,
712+
) -> intp: ...
713+
@overload
714+
def argmax(
715+
self,
716+
axis: CanIndex | None = None,
717+
fill_value: _ScalarLike_co | None = None,
718+
out: None = None,
719+
*,
720+
keepdims: bool | _NoValueType = ...,
721+
) -> Any: ...
722+
@overload
723+
def argmax(
724+
self,
725+
axis: CanIndex | None = None,
726+
fill_value: _ScalarLike_co | None = None,
727+
*,
728+
out: _ArrayT,
729+
keepdims: bool | _NoValueType = ...,
730+
) -> _ArrayT: ...
731+
@overload
732+
def argmax( # pyright: ignore[reportIncompatibleMethodOverride]
733+
self,
734+
axis: CanIndex | None,
735+
fill_value: _ScalarLike_co | None,
736+
out: _ArrayT,
737+
*,
738+
keepdims: bool | _NoValueType = ...,
739+
) -> _ArrayT: ...
675740

676741
#
677742
@override
@@ -1060,8 +1125,81 @@ swapaxes: _frommethod
10601125
trace: _frommethod
10611126
var: _frommethod
10621127
count: _frommethod
1063-
argmin: _frommethod
1064-
argmax: _frommethod
1065-
10661128
minimum: _extrema_operation
10671129
maximum: _extrema_operation
1130+
1131+
#
1132+
@overload
1133+
def argmin(
1134+
a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse]
1135+
axis: None = None,
1136+
fill_value: _ScalarLike_co | None = None,
1137+
out: None = None,
1138+
*,
1139+
keepdims: L[False] | _NoValueType = ...,
1140+
) -> intp: ...
1141+
@overload
1142+
def argmin(
1143+
a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse]
1144+
axis: CanIndex | None = None,
1145+
fill_value: _ScalarLike_co | None = None,
1146+
out: None = None,
1147+
*,
1148+
keepdims: bool | _NoValueType = ...,
1149+
) -> Any: ...
1150+
@overload
1151+
def argmin(
1152+
a: _ArrayT,
1153+
axis: CanIndex | None = None,
1154+
fill_value: _ScalarLike_co | None = None,
1155+
*,
1156+
out: _ArrayT,
1157+
keepdims: bool | _NoValueType = ...,
1158+
) -> _ArrayT: ...
1159+
@overload
1160+
def argmin(
1161+
a: _ArrayT,
1162+
axis: CanIndex | None,
1163+
fill_value: _ScalarLike_co | None,
1164+
out: _ArrayT,
1165+
*,
1166+
keepdims: bool | _NoValueType = ...,
1167+
) -> _ArrayT: ...
1168+
1169+
#
1170+
@overload
1171+
def argmax(
1172+
a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse]
1173+
axis: None = None,
1174+
fill_value: _ScalarLike_co | None = None,
1175+
out: None = None,
1176+
*,
1177+
keepdims: L[False] | _NoValueType = ...,
1178+
) -> intp: ...
1179+
@overload
1180+
def argmax(
1181+
a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse]
1182+
axis: CanIndex | None = None,
1183+
fill_value: _ScalarLike_co | None = None,
1184+
out: None = None,
1185+
*,
1186+
keepdims: bool | _NoValueType = ...,
1187+
) -> Any: ...
1188+
@overload
1189+
def argmax(
1190+
a: _ArrayT,
1191+
axis: CanIndex | None = None,
1192+
fill_value: _ScalarLike_co | None = None,
1193+
*,
1194+
out: _ArrayT,
1195+
keepdims: bool | _NoValueType = ...,
1196+
) -> _ArrayT: ...
1197+
@overload
1198+
def argmax(
1199+
a: _ArrayT,
1200+
axis: CanIndex | None,
1201+
fill_value: _ScalarLike_co | None,
1202+
out: _ArrayT,
1203+
*,
1204+
keepdims: bool | _NoValueType = ...,
1205+
) -> _ArrayT: ...

‎uv.lock

Lines changed: 12 additions & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)