From 18b114190820a7a460fe3c106ebaedb8ac1f0bbb Mon Sep 17 00:00:00 2001 From: manuelgloeckler <38903899+manuelgloeckler@users.noreply.github.com> Date: Fri, 25 Oct 2024 15:52:23 +0200 Subject: [PATCH] Fix pickle issues in MCMC posterior + test (#1291) * Fix pickle issues in MCMC posterior + test * fix linting --- sbi/inference/posteriors/mcmc_posterior.py | 13 +++++++++++++ tests/save_and_load_test.py | 7 +++++++ 2 files changed, 20 insertions(+) diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index 65f59b95c..c7d462688 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -1049,6 +1049,19 @@ def get_arviz_inference_data(self) -> InferenceData: return inference_data + def __getstate__(self) -> Dict: + """Get state of MCMCPosterior. + + Removes the posterior sampler from the state, as it may not be picklable. + + Returns: + Dict: State of MCMCPosterior. + """ + state = self.__dict__.copy() + state["_posterior_sampler"] = None + + return state + def _process_thin_default(thin: int) -> int: """ diff --git a/tests/save_and_load_test.py b/tests/save_and_load_test.py index 30956c9af..aa26e6eb1 100644 --- a/tests/save_and_load_test.py +++ b/tests/save_and_load_test.py @@ -8,6 +8,7 @@ from sbi import utils as utils from sbi.inference import NLE, NPE, NRE +from sbi.inference.posteriors.vi_posterior import VIPosterior @pytest.mark.parametrize( @@ -34,6 +35,12 @@ def test_picklability(inference_method, sampling_method: str, tmp_path): x_o ) + # After sample and log_prob, the posterior should still be picklable + if isinstance(posterior, VIPosterior): + posterior.train(max_num_iters=10) + _ = posterior.sample((1,)) + _ = posterior.log_prob(torch.zeros(1, num_dim)) + with open(f"{tmp_path}/saved_posterior.pickle", "wb") as handle: pickle.dump(posterior, handle) with open(f"{tmp_path}/saved_posterior.pickle", "rb") as handle: