-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: decorator to validate and annotate dependencies #2967
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really nice!
But what is meant by "extras" in this case? Is it pip install awkward[extras]
? That would likely be obscure to most users.
I think it's useful for both the documentation and the exception message to include sample pip-install and conda-install text that can be copy-pasted into a terminal to actually install the necessary packages. There are situations in which I'll deliberately do something wrong if I know that the error message has what I need to do in a copy-pasteable form—it beats typing it manually or remembering its spelling.
I don't think any of these libraries will need a mapping between PyPI name and conda name, but if such a thing becomes necessary in the future, it can be a dict. (Uproot optionally requires xxhash
, which is python-xxhash
in conda-forge for disambiguation.)
I made a new environment without the optional dependencies to test it out.
>>> ak.to_arrow(ak.Array([[1.1, 2.2, 3.3], [], [4.4, 5.5]]))
ImportError: to use to_arrow, you must install pyarrow:
pip install pyarrow
or
conda install -c conda-forge pyarrow
This error occurred while calling
ak.to_arrow(
<Array [[1.1, 2.2, 3.3], [], [4.4, 5.5]] type='3 * var * float64'>
)
>>> ak.to_jax(ak.Array([[1.1, 2.2, 3.3], [], [4.4, 5.5]]))
ImportError: This function has the following dependency requirements that are not met by your current environment:
* jax — you do not have this package
* jaxlib — you do not have this package
This error occurred while calling
ak.to_jax(
<Array [[1.1, 2.2, 3.3], [], [4.4, 5.5]] type='3 * var * float64'>
)
>>> ak.to_parquet("/tmp/whatever.parquet", ak.Array([[1.1, 2.2, 3.3], [], [4.4, 5.5]]))
ImportError: This function has the following dependency requirements that are not met by your current environment:
* pyarrow>=7.0.0 — you do not have this package
* fsspec — you do not have this package
You can fix this error by installing these packages directly, or install awkward with all of the following extras:
* arrow
This error occurred while calling
ak.to_parquet(
'/tmp/whatever.parquet'
<Array [[1.1, 2.2, 3.3], [], [4.4, 5.5]] type='3 * var * float64'>
)
>>> ak.to_backend(ak.Array([[1.1, 2.2, 3.3], [], [4.4, 5.5]]), "cuda")
ModuleNotFoundError: to use Awkward Arrays with CUDA, you must install cupy:
pip install cupy
or
conda install -c conda-forge cupy
This error occurred while calling
ak.to_backend(
<Array [[1.1, 2.2, 3.3], [], [4.4, 5.5]] type='3 * var * float64'>
'cuda'
)
It might be better to exclude the "This error occurred while calling" for failures due to missing libraries, so that the message about the missing libraries is last. Otherwise, someone might spend a lot of time looking at the arguments they passed to the function, looking for errors, but the more relevant message is right above.
I also tried installing the wrong version of Arrow, to see what the message would look like:
>>> ak.to_arrow(ak.Array([[1.1, 2.2, 3.3], [], [4.4, 5.5]]))
ImportError: pyarrow 7.0.0 or later required for to_arrow
This error occurred while calling
ak.to_arrow(
<Array [[1.1, 2.2, 3.3], [], [4.4, 5.5]] type='3 * var * float64'>
)
It doesn't seem to be reaching
f" * {req} — you have {ver} installed"
for that case.
@jpivarski I did another pass! I can't reproduce your arrow bug, perhaps you could interrogate what Let me know what you think of the docs deployment. |
I've added I've made the docs |
During debugging!
We need to unblock this, and update to the latest sphinx stack. This will just get things working *for now* (TODO)
I noticed that the docs environment isn't solving properly. I suspect that the dependencies are not pinned in Sphinx. The proper fix is to update our sphinx usage, but for now let's try just pinning everything hard. Also, I need to update the decorators of more functions as-per your recent comment in the issue tracker. |
The new errors are caused by pathways that don't technically require e.g. fsspec hitting this new decorator. As I see it, we have two options:
@high_level_function(name="ak.to_json", dependencies=["fsspec>1.0.0"])
def to_json(...):
if FOO:
# Ensures that we get fsspec>1.0.0
fsspec = require("fsspec")
else:
# no fsspec
pass @jpivarski any preferences? |
Option 2. The documentation can say, without qualification, that |
Will the packaging tests pass if we update the branch? (Not including pypy; that one is fixed by just running it multiple times, but it's not required.) |
@jpivarski probably not. I imagine it's because the dependencies are now enforced at the function scope, not at the branch level. The changes we've agreed on (2) will fix this! |
In an environment that has all of these things installed (i.e. my JAX works, I can make a JAX array), I get >>> ak.to_jax(ak.Array([[1.1, 2.2, 3.3], [], [4.4, 5.5]]))
In ak_to_jax.py, there's no version requirement on JAX. I have noticed that the CUDA part of my JAX installation has broken: >>> import jax
>>> a = jax.numpy.arange(10)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
>>> That looks like a print statement, not a warning, so I don't think Awkward is detecting it. Here's my JAX version and source: % mamba list jax
# packages in environment at /home/jpivarski/mambaforge:
#
# Name Version Build Channel
jax 0.4.23 pyhd8ed1ab_0 conda-forge
jaxlib 0.4.23 cpu_py310h9ed8a0c_0 conda-forge (That is, my JAX is not 0.4.23.dev20231222 but 0.4.23.) All of the other optional library checks seem to be working fine. (I'm going to test an environment that doesn't have them installed now.) |
In an environment without dependencies, this is what I'm seeing: >>> ak.to_arrow(ak.Array([[1.1, 2.2, 3.3], [], [4.4, 5.5]]))
I.e. no message about how this function requires pyarrow and how to install it. Same for JAX: no message about the fact that But the one for CuPy works: >>> ak.to_backend(ak.Array([[1.1, 2.2, 3.3], [], [4.4, 5.5]]), "cuda")
Also, I noticed that this is following Option 1: calling >>> ak.to_json(ak.Array([[1.1, 2.2, 3.3], [], [4.4, 5.5]]), "/tmp/whatever.json")
It doesn't even get to the point of checking the arguments; I can use no arguments and get this error message; it's checking before anything else. |
d6d2e52
to
0986f75
Compare
I've just pushed a WIP that tries this new approach — functions establish a dependency context, which is evaluated when the function implementation calls I've not yet finished (or fixed the docs component to reflect the new declaration style). Some points:
|
I'm confused about why the interface needs all of this structure. A high-level function (possibly) needs libraries If there are dependencies, it tries to run the high-level function, catches any ImportError (including ModuleNotFoundError), and re-raises it with a note about how to install the set of dependencies. In principle, the PyPI or conda install name of the package might be different from the import name of the package, but we haven't encountered any like that. If there is such a distinction, then it would be global: the install-import mapping does not depend on which high-level function we're running, so it can be a private dict somewhere. That's my understanding of the problem; the solution so far is a lot more complicated than that. |
The short answer is versions matter. We also want to error if the user has e.g.
We need to embed this information somewhere: already, We could define the "distribution-level" metadata somewhere such that most definitions can use the short-hand, i.e. a table mapping distribution to conda package and python import name (we only need distribution → python name for the case that it's not installed, otherwise @requires_global("arrow", "fsspec", group="arrow")
@high_level_function(requires) There's more to think about here — I'll come back to it :) I feel that there's an opportunity here to clean up / centralise our "extras" handling (including how |
Let's see if:
The |
Let's close this - I don't have bandwidth for it right now, and it's adding to the noise :) |
Closes #2953
This PR adds a new
dependencies
argument tohigh_level_function
. This let's us fail at runtime if dependency specifications are not met, and also to annotate our documentation to inform the user ahead-of-time.We will still need to keep these values in sync with our new
pyproject.toml
groups, but we can make a helper for that if needs be.Examples: