Skip to content

Commit

Permalink
Fix pickle issues in MCMC posterior + test (#1291)
Browse files Browse the repository at this point in the history
* Fix pickle issues in MCMC posterior + test

* fix linting
  • Loading branch information
manuelgloeckler authored Oct 25, 2024
1 parent 3d5cb24 commit 18b1141
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
13 changes: 13 additions & 0 deletions sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,19 @@ def get_arviz_inference_data(self) -> InferenceData:

return inference_data

def __getstate__(self) -> Dict:
"""Get state of MCMCPosterior.
Removes the posterior sampler from the state, as it may not be picklable.
Returns:
Dict: State of MCMCPosterior.
"""
state = self.__dict__.copy()
state["_posterior_sampler"] = None

return state


def _process_thin_default(thin: int) -> int:
"""
Expand Down
7 changes: 7 additions & 0 deletions tests/save_and_load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from sbi import utils as utils
from sbi.inference import NLE, NPE, NRE
from sbi.inference.posteriors.vi_posterior import VIPosterior


@pytest.mark.parametrize(
Expand All @@ -34,6 +35,12 @@ def test_picklability(inference_method, sampling_method: str, tmp_path):
x_o
)

# After sample and log_prob, the posterior should still be picklable
if isinstance(posterior, VIPosterior):
posterior.train(max_num_iters=10)
_ = posterior.sample((1,))
_ = posterior.log_prob(torch.zeros(1, num_dim))

with open(f"{tmp_path}/saved_posterior.pickle", "wb") as handle:
pickle.dump(posterior, handle)
with open(f"{tmp_path}/saved_posterior.pickle", "rb") as handle:
Expand Down

0 comments on commit 18b1141

Please sign in to comment.