From dea827aa23f1aa2fb67a8e0def9818991959b90c Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Fri, 6 Dec 2024 14:43:27 +0100 Subject: [PATCH] Clarify last round behavior of SNPE-A --- tutorials/16_implemented_methods.ipynb | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tutorials/16_implemented_methods.ipynb b/tutorials/16_implemented_methods.ipynb index ffb128551..110af2228 100644 --- a/tutorials/16_implemented_methods.ipynb +++ b/tutorials/16_implemented_methods.ipynb @@ -66,10 +66,13 @@ "\n", "inference = NPE_A(prior)\n", "proposal = prior\n", - "for _ in range(num_rounds):\n", + "for r in range(num_rounds):\n", " theta = proposal.sample((num_sims,))\n", " x = simulator(theta)\n", - " _ = inference.append_simulations(theta, x, proposal=proposal).train()\n", + " # NPE trains a Gaussian density estimator in all but the last round. In the last round,\n", + " # it trains a mixture of Gaussians, which is why we have to pass the `final_round` flag.\n", + " final_round = True if r == num_rounds - 1 else False\n", + " _ = inference.append_simulations(theta, x, proposal=proposal).train(final_round=final_round)\n", " posterior = inference.build_posterior().set_default_x(x_o)\n", " proposal = posterior" ] @@ -598,7 +601,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.12.4" }, "toc": { "base_numbering": 1,