-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Suggestion] : Remove strict requirement on backends. #229
Comments
Thanks for the suggestion @ASKabalan! At present we only have PyTorch support for the precompute approach but we do plan to add PyTorch support for the on-the-fly transforms as well. So it would be good to have a nice setup like you're suggesting. @matt-graham do you have any thoughts on this or know of best practices? |
Hi @jasonmcewen. This is related to what I suggested in #224 - we can easily make all of PyTorch, I would say dropping hard requirements on all external packages which are not 'core' to the package would be good practice for a library like |
Just noticed suggestion to also apply this to JAX - while we could do this for JAX too, as a lot of the modules (perhaps the majority?) make multiple imports from import array_api_compat
...
def spectral_periodic_extension(fm, L: int):
xp = array_api_compat.array_namespace(fm)
nphi = fm.shape[0]
return xnp.concatenate(
(
fm[-xnp.arange(L - nphi // 2, 0, -1) % nphi],
fm,
fm[xnp.arange(L - (nphi + 1) // 2) % nphi],
)
) This would make the function compatible with all of NumPy, JAX and PyTorch arrays without requiring an explicit import from any of them and also avoid having multiple |
Ok, fair enough @matt-graham ! Does this mean users would need to install from source rather than PyPi tho? Perhaps that is not a big issue anyway. |
Maybe. This seems like quite a big revision tho. Let's discuss further... |
No, the extras / optional dependencies syntax works fine for both packages installed from a local directory and a packaging index like PyPI. It's the same syntax that JAX uses to install from PyPI with CUDA support:
|
Ok, sounds like it makes a lot of sense then. |
Currently, we have to install pytorch to use the JAX backend of s2fft.
I think it would be a nice to be able to conditionally activate backends depending on the availability of the package
For example I user can use s2FFT using only numpy if he does not have JAX or pytorch
Otherwise he gets a runtime error instead of an import time error
Same for pyssht.
Don't know the best practice to do this, but a try catch around the import + global boolean should do the trick
The text was updated successfully, but these errors were encountered: