@@ -27,6 +27,8 @@ from typing import (
2727)
2828from typing_extensions import Buffer , CapsuleType , LiteralString , Never , Protocol , Self , TypeVar , Unpack , deprecated , override
2929
30+ import numpy as np
31+
3032from . import (
3133 __config__ as __config__ ,
3234 _array_api_info as _array_api_info ,
@@ -611,6 +613,8 @@ _DT64ItemT = TypeVar("_DT64ItemT", bound=dt.date | int | None)
611613_DT64ItemT_co = TypeVar ("_DT64ItemT_co" , bound = dt .date | int | None , default = dt .date | int | None , covariant = True )
612614_TD64UnitT = TypeVar ("_TD64UnitT" , bound = _TD64Unit , default = _TD64Unit )
613615
616+ _Array1D : TypeAlias = np .ndarray [tuple [int ], np .dtype [_ScalarT ]]
617+
614618###
615619# Type Aliases (for internal use only)
616620
@@ -2531,8 +2535,8 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
25312535 @overload
25322536 def __imul__ (self : NDArray [object_ ], rhs : object , / ) -> ndarray [_ShapeT_co , _DTypeT_co ]: ...
25332537
2534- # TODO(jorenham): Support the "1d @ 1d -> scalar" case
2535- # https://github.com/numpy/numtype/issues/197
2538+ @ overload
2539+ def __matmul__ ( self : _Array1D [ _ScalarT ], rhs : _Array1D [ _ScalarT ], / ) -> _ScalarT : ...
25362540 @overload
25372541 def __matmul__ (self : NDArray [_NumberT ], rhs : _ArrayLikeBool_co , / ) -> NDArray [_NumberT ]: ...
25382542 @overload
@@ -2566,12 +2570,14 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
25662570 @overload
25672571 def __matmul__ (self : NDArray [bool_ | number ], rhs : _ArrayLikeNumber_co , / ) -> NDArray [Incomplete ]: ...
25682572 @overload
2569- def __matmul__ (self : NDArray [object_ ], rhs : object , / ) -> NDArray [object_ ]: ...
2573+ def __matmul__ (self : NDArray [object_ ], rhs : _ArrayLikeObject_co , / ) -> NDArray [object_ ]: ...
25702574 @overload
25712575 def __matmul__ (self , rhs : _ArrayLikeObject_co , / ) -> NDArray [object_ ]: ...
25722576
25732577 # keep in sync with __matmul__
25742578 @overload
2579+ def __rmatmul__ (self : _Array1D [_ScalarT ], rhs : _Array1D [_ScalarT ], / ) -> _ScalarT : ...
2580+ @overload
25752581 def __rmatmul__ (self : NDArray [_NumberT ], lhs : _ArrayLikeBool_co , / ) -> NDArray [_NumberT ]: ...
25762582 @overload
25772583 def __rmatmul__ (self : NDArray [bool_ ], lhs : _ArrayLike [_NumberT ], / ) -> NDArray [_NumberT ]: ...
@@ -2604,7 +2610,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
26042610 @overload
26052611 def __rmatmul__ (self : NDArray [bool_ | number ], lhs : _ArrayLikeNumber_co , / ) -> NDArray [Incomplete ]: ...
26062612 @overload
2607- def __rmatmul__ (self : NDArray [object_ ], lhs : object , / ) -> NDArray [object_ ]: ...
2613+ def __rmatmul__ (self : NDArray [object_ ], lhs : _ArrayLikeObject_co , / ) -> NDArray [object_ ]: ...
26082614 @overload
26092615 def __rmatmul__ (self , lhs : _ArrayLikeObject_co , / ) -> NDArray [object_ ]: ...
26102616
0 commit comments