-
-
Notifications
You must be signed in to change notification settings - Fork 132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix for issue # 813 modified the way kwargs are handled #863
base: main
Are you sure you want to change the base?
Conversation
Hi @tanishy7777 thanks for working on this, and sorry for the delay in the response. I think we should do one of the following
I'm not sure if the proposed changes are what we want. If I understand correctly, it will set up What if we have the following example: kwargs = {
"num_draws": 500,
}
blackjax_nuts_idata = model.fit(draws=250, inference_method="blackjax_nuts", **kwargs) The user is passing through This example makes me think we should just emit a warning in that case, and respect the name required by the underlying sampler. A separate comment, it's always good to add a test when one implements a change in behavior like this one. Just let me/us know if you would need help with that. |
Thanks @tomicapretto for catching this. bayeux uses the argument names (often Below is a comment and code snippet from Colin The only samplers are from blackjax, numpyro, nutpie, tfp, and flowMC. bayeux uses the argument names from underlying libraries for method in model.mcmc.methods:
sampler = getattr(model.mcmc, method)
print(method)
for v in ('chains', 'num_chains', 'num_draws', 'num_samples', 'num_results'):
if any(k == v for step in sampler.get_kwargs().values() for k in step):
print('\t', v)
As an aside, we should make sure changes here are compatible with #855. I do not think there is any major conflict, but good to double check. |
Hey @tanishy7777 any updates regarding my comment above? Thanks! |
Sorry for the late response, I think I might have missed this. Will look into it! |
If I understand correctly, you are saying that we should handle Fix 1:So, if some argument is passed in Fix 2a:Also that we should map the extra arguments like Fix 2b:And give a warning to the user when both As show in #863 (comment) kwargs = {
"num_draws": 500,
}
blackjax_nuts_idata = model.fit(draws=250, inference_method="blackjax_nuts", **kwargs) should produce warning because we use |
Solves issue #813
black
.pylint
.