Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reuse jaxified logp when sampling via jax #7681

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

nataziel
Copy link
Contributor

@nataziel nataziel commented Feb 14, 2025

reuse jaxified logp when sampling via jax

Description

#7610 added logic to handle passing a pre-jaxified logp function into the blackjax/numpyro samplers, but missed actually passing the jaxified logp that is computed in sample_jax_nuts

Checklist

  • Checked that the pre-commit linting/style checks pass
  • Included tests that prove the fix is effective or that the new feature works
  • Added necessary documentation (docstrings and/or example notebooks)
  • If you are a pro: each commit corresponds to a [relevant logical change]

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7681.org.readthedocs.build/en/7681/

@nataziel nataziel changed the title reuse jaxified logp times when sampling via jax reuse jaxified logp when sampling via jax Feb 14, 2025
Copy link

codecov bot commented Feb 14, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.64%. Comparing base (358b825) to head (e63a8a2).
Report is 2 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7681      +/-   ##
==========================================
- Coverage   92.70%   92.64%   -0.06%     
==========================================
  Files         107      107              
  Lines       18391    18324      -67     
==========================================
- Hits        17050    16977      -73     
- Misses       1341     1347       +6     
Files with missing lines Coverage Δ
pymc/sampling/jax.py 94.11% <ø> (-0.91%) ⬇️

... and 1 file with indirect coverage changes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant