-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Signatures of get_xp
-wrapped functions?
#153
Comments
get_xp
-wrapped functions must have args
?get_xp
-wrapped functions?
The problem is straightforward. array-api-compat/array_api_compat/common/_linalg.py Lines 45 to 46 in ac15c52
In other words, it passes diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py
index dc2b69d..01db3a0 100644
--- a/array_api_compat/common/_linalg.py
+++ b/array_api_compat/common/_linalg.py
@@ -42,8 +42,8 @@ class SVDResult(NamedTuple):
# These functions are the same as their NumPy counterparts except they return
# a namedtuple.
-def eigh(x: ndarray, /, xp, **kwargs) -> EighResult:
- return EighResult(*xp.linalg.eigh(x, **kwargs))
+def eigh(x: ndarray, /, *args, xp, **kwargs) -> EighResult:
+ return EighResult(*xp.linalg.eigh(x, *args, **kwargs))
def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced',
**kwargs) -> QRResult: We just need to do this for every function for which NumPy supports more positional arguments than the standard. As for the other issue, We could make this work by instead defining all functions like diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py
index dc2b69d..11b54bb 100644
--- a/array_api_compat/common/_linalg.py
+++ b/array_api_compat/common/_linalg.py
@@ -42,8 +42,8 @@ class SVDResult(NamedTuple):
# These functions are the same as their NumPy counterparts except they return
# a namedtuple.
-def eigh(x: ndarray, /, xp, **kwargs) -> EighResult:
- return EighResult(*xp.linalg.eigh(x, **kwargs))
+def eigh(*args, xp, **kwargs) -> EighResult:
+ return EighResult(*xp.linalg.eigh(*args, **kwargs))
def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced',
**kwargs) -> QRResult: The downside of this is it would completely kill introspectability of functions (right now I have it set up so that Note that both of these examples are not portable with the standard, which defines the signature as
|
Yeah, the thing is that this came up in the context of dispatching calls to SciPy's That said, my impression from here was that |
Yes, in principle we should support this. Maybe I can modify the |
There seem to be some issues with the signatures of functions wrapped by
get_xp
.I haven't narrowed down the exact problem, but here's an MRE:
Also, e.g.
The text was updated successfully, but these errors were encountered: