-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Allow access to different nutpie backends via pip-style syntax #7498
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
I created some global variables to track available backends in |
I'm not sure how the other backends (specially JAX) behave with multiprocessing tbh :O Otherwise the idea sounds cool. Perhaps @aseyboldt can weigh in as he has a better picture of how we handle multiprocessing in pm.sample |
|
Should we go work on a |
The PyMC codebase would still need to know about it and work around differently for JAX |
Codecov Report❌ Patch coverage is
❌ Your patch status has failed because the patch coverage (30.43%) is below the target coverage (50.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #7498 +/- ##
==========================================
- Coverage 90.62% 90.27% -0.36%
==========================================
Files 116 116
Lines 18947 18983 +36
==========================================
- Hits 17170 17136 -34
- Misses 1777 1847 +70
🚀 New features to boost your workflow:
|
|
With numba backend we could actually do it with threads (with nogil)? Maybe worth opening an issue to investigate different backends for pymc samplers. Also this shouldn't have to be nuts specific so for that a |
|
Should probably just be |
Not if we need to change how the samplers/threads are orchestrated |
|
But compile kwargs already works anyway |
What do you mean, |
It doesn't? Maybe I've played with global mode then |
lucianopaz
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jessegrabowski, this looks very nice. I left a few suggestions though
tests/sampling/test_mcmc_external.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| @pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"]) | |
| @pytest.mark.parametrize( | |
| "nuts_sampler", | |
| ["pymc", "nutpie", "nutpie[jax]", "blackjax", "numpyro"], | |
| ) |
tests/sampling/test_mcmc_external.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| with pytest.raises(ValueError, match='Expected one of "numba" or "jax"; found "invalid"'): | |
| with pytest.raises( | |
| ValueError, | |
| match='Could not parse nutpie backend. Expected one of "numba" or "jax"; found "invalid"', | |
| ): |
tests/sampling/test_mcmc_external.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| with pytest.raises(ValueError, match="Could not parse nutpie backend. Found 'nutpie[bad'"): | |
| with pymc_model: | |
| sample(nuts_sampler="nutpie[bad", random_seed=123, chains=2, tune=500, draws=500) |
|
I posted this PR with comments into GPT o1-mini:
---
SummaryThis pull request introduces a pip-style syntax for specifying different Nutpie backends within the Key Changes:
Detailed Review1. Code Changesa.
b.
c.
d. Documentation Updates
2. Testing Enhancements
3. GitHub Discussion & Iterative Improvements
ConclusionThis pull request effectively enhances PyMC's sampling functionality by introducing a flexible and user-friendly way to specify different Nutpie backends. The implementation is thoughtfully designed, with clear type definitions, robust error handling, and comprehensive testing to ensure reliability and maintainability. Recommendations Before Merging:
Overall, this PR represents a significant improvement to PyMC's flexibility and user experience, enabling more tailored and optimized sampling strategies. With the recommended refinements, it is well-positioned for successful integration into the main codebase. Approved with minor recommendations. |
|
Close this as stale now that #7535 is merged? Or is there interest in this syntax still. |
|
This syntax is a bit more ergonomic / discoverable |
763cd87 to
a5b3241
Compare
|
I like the syntax :-) |
|
Thoughts on the best way to ask for the gradient backend in this syntax? Or |
That will break people's existing code if they don't have jax installed |
|
This looks super helpful @jessegrabowski !! |
a5b3241 to
eb6e216
Compare
|
I know there was some back and forth here but I think the That is, the backend is somewhat orthogonal to the sampler, be it pymc or not? Basically a less verbose shortcut for the most common use of |
|
Yeah I think I agree. I've cooled on all the clever string stuff recently. I think I'll close this PR and open a fresh one with the backend argument. |
Description
Adds a pip-style syntax to the
nuts_samplerargument that allows access to alternative compile backends, when relevant. This lets you get the nutpie jax backend by settingnuts_sampler='nutpie[jax]'. For backwards compatibility,nuts_sampler='nutpie'is equivalent tonuts_sampler='nutpie[numba]'.The current PR only deals with nutpie, but we could easily extend this to include the default PyMC sampler, to compile to JAX, numba, or pytorch directly, without going through nutpie. I'm willing to do that extension in this PR if it is deemed worthwhile..
Related Issue
nutpiecompile backends throughpm.sample#7497Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7498.org.readthedocs.build/en/7498/