Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Signatures of get_xp-wrapped functions? #153

Open
mdhaber opened this issue Jun 21, 2024 · 3 comments
Open

Signatures of get_xp-wrapped functions? #153

mdhaber opened this issue Jun 21, 2024 · 3 comments

Comments

@mdhaber
Copy link

mdhaber commented Jun 21, 2024

There seem to be some issues with the signatures of functions wrapped by get_xp.
I haven't narrowed down the exact problem, but here's an MRE:

import cupy as xp
from array_api_compat import cupy as xp_compat

A = xp.eye(3)
A = xp.asarray(A)

xp.linalg.eigh(A)  # fine
xp.linalg.eigh(a=A)  # fine

xp_compat.linalg.eigh(A)  # fine
xp_compat.linalg.eigh(a=A)  # error
# TypeError: eigh() missing 1 required positional argument: 'x'

Also, e.g.

xp.linalg.eigh(A, 'U')  # fine
xp_compat.linalg.eigh(A, 'U')  # error
TypeError: eigh() got multiple values for argument 'xp'
@mdhaber mdhaber changed the title get_xp-wrapped functions must have args? Signatures of get_xp-wrapped functions? Jun 21, 2024
@asmeurer
Copy link
Member

The problem is straightforward. eigh is defined as

def eigh(x: ndarray, /, xp, **kwargs) -> EighResult:
return EighResult(*xp.linalg.eigh(x, **kwargs))

In other words, it passes **kwargs through but doesn't pass *args through. This is easy enough to fix

diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py
index dc2b69d..01db3a0 100644
--- a/array_api_compat/common/_linalg.py
+++ b/array_api_compat/common/_linalg.py
@@ -42,8 +42,8 @@ class SVDResult(NamedTuple):

 # These functions are the same as their NumPy counterparts except they return
 # a namedtuple.
-def eigh(x: ndarray, /, xp, **kwargs) -> EighResult:
-    return EighResult(*xp.linalg.eigh(x, **kwargs))
+def eigh(x: ndarray, /, *args, xp, **kwargs) -> EighResult:
+    return EighResult(*xp.linalg.eigh(x, *args, **kwargs))

 def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced',
        **kwargs) -> QRResult:

We just need to do this for every function for which NumPy supports more positional arguments than the standard.

As for the other issue, eigh(a=A) you are calling eigh using the argument name that is positional-only in the standard.

We could make this work by instead defining all functions like

diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py
index dc2b69d..11b54bb 100644
--- a/array_api_compat/common/_linalg.py
+++ b/array_api_compat/common/_linalg.py
@@ -42,8 +42,8 @@ class SVDResult(NamedTuple):

 # These functions are the same as their NumPy counterparts except they return
 # a namedtuple.
-def eigh(x: ndarray, /, xp, **kwargs) -> EighResult:
-    return EighResult(*xp.linalg.eigh(x, **kwargs))
+def eigh(*args, xp, **kwargs) -> EighResult:
+    return EighResult(*xp.linalg.eigh(*args, **kwargs))

 def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced',
        **kwargs) -> QRResult:

The downside of this is it would completely kill introspectability of functions (right now I have it set up so that help(array_api_compat.numpy.linalg.eigh) shows the actual arguments).

Note that both of these examples are not portable with the standard, which defines the signature as

eigh(x: array, /) 

@mdhaber
Copy link
Author

mdhaber commented Jun 24, 2024

Note that both of these examples are not portable with the standard

Yeah, the thing is that this came up in the context of dispatching calls to SciPy's eigh function to other backends. It's unclear ATM what we want to do when the SciPy function has a much more flexible signature (including many other arguments) than the standard.

That said, my impression from here was that array_api_compat did not intend to limit capabilities of the wrapped libraries to those of the standard, so I went ahead and reported it.

@asmeurer
Copy link
Member

Yes, in principle we should support this. Maybe I can modify the get_xp decorator to keep the standard signature for introspection purposes, but always pass through *args and **kwargs automatically. I'll need to think a bit about it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants