Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use of dict.pop in pm.sample leads to side effects #7632

Closed
jessegrabowski opened this issue Jan 6, 2025 · 2 comments · Fixed by #7652
Closed

Use of dict.pop in pm.sample leads to side effects #7632

jessegrabowski opened this issue Jan 6, 2025 · 2 comments · Fixed by #7652

Comments

@jessegrabowski
Copy link
Member

jessegrabowski commented Jan 6, 2025

Description

I commonly create a sample_kwargs variable in notebooks and re-use it in multiple models. Such a dictionary ends up being modified in-place by pm.sample when entries are popped, for example here and here, leading to models beyond the first silently using defaults settings.

I guess the easiest fix would just be to deepcopy kwargs, so that there are no side effects on user inputs. Using pop here is a bit overkill (why not just ,get?) but deepcopy is a 1-line change.

@inclinedadarsh
Copy link
Contributor

Hey, this seems to be like a good place to start contributing. I'd like to attempt it.

As much as I have understood, I'll have to deepcopy the kwargs right before the files you have linked.

Can you please help me by redirecting me to sample/example notebooks using which I can reproduce this issue?

@jessegrabowski
Copy link
Member Author

with pm.Model() as m:
    mu = pm.Normal('mu', 0, 1)
    sigma = pm.Exponential('sigma', 1)
    y_hat = pm.Normal('y_hat', mu=mu, sigma=sigma, shape=(1000,))
    
data = pm.draw(y_hat)
m = pm.observe(m, {'y_hat':data})

nuts_sampler_kwargs = {'backend':'jax', 'gradient_backend':'jax'}
sample_kwargs = {'nuts_sampler':'nutpie', 'nuts_sampler_kwargs':nuts_sampler_kwargs}

with m:
    assert sample_kwargs['nuts_sampler_kwargs'] == {'backend':'jax', 'gradient_backend':'jax'}
    idata = pm.sample(**sample_kwargs)
    assert sample_kwargs['nuts_sampler_kwargs'] == {'backend':'jax', 'gradient_backend':'jax'}

The 2nd assert should pass

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants