Skip to content

Commit

Permalink
fix: tutorials test error handling, fix bugs in tutorials (#1264)
Browse files Browse the repository at this point in the history
* fix: tutorials test error handling, fix tutorial 7

* fix remaining bugs in tutorials

* refactor training interface tutorial
  • Loading branch information
janfb authored Sep 5, 2024
1 parent 57e1b83 commit a4f7811
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 199 deletions.
6 changes: 4 additions & 2 deletions tests/tutorials_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ def test_tutorials(notebook_path):
if "Requested MovieWriter" in str(e):
print("Skipping error in movie writer.")
else:
raise CellExecutionError from e
raise RuntimeError(
f"Error executing the notebook {notebook_path}: {e}"
) from e
except Exception as e:
raise AssertionError(
raise RuntimeError(
f"Error executing the notebook {notebook_path}: {e}"
) from e
162 changes: 20 additions & 142 deletions tutorials/07_sensitivity_analysis.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@
}
],
"source": [
"# Train SNLE.\n",
"inferer = SNLE(prior, show_progress_bars=True, density_estimator=\"mdn\")\n",
"# Train NLE.\n",
"inferer = NLE(prior, show_progress_bars=True, density_estimator=\"mdn\")\n",
"theta, x = simulate_for_sbi(simulator, prior, 10000, simulation_batch_size=1000)\n",
"inferer.append_simulations(theta, x).train(training_batch_size=1000);"
]
Expand Down Expand Up @@ -310,9 +310,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The pairplot above already indicates that (S)NLE is well able to obtain accurate posterior samples also for increasing number of trials (note that we trained the single-round version of SNLE so that we did not have to re-train it for new $x_o$).\n",
"The pairplot above already indicates that (S)NLE is well able to obtain accurate posterior samples also for increasing number of trials (note that we trained the single-round version of NLE so that we did not have to re-train it for new $x_o$).\n",
"\n",
"Quantitatively we can measure the accuracy of SNLE by calculating the `c2st` score between SNLE and the true posterior samples, where the best accuracy is perfect for `0.5`:\n"
"Quantitatively we can measure the accuracy of NLE by calculating the `c2st` score between NLE and the true posterior samples, where the best accuracy is perfect for `0.5`:\n"
]
},
{
Expand Down
6 changes: 5 additions & 1 deletion tutorials/15_importance_sampled_posteriors.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,11 @@
"print(\"observations.shape\", observation.shape)\n",
"\n",
"# sample from posterior\n",
"theta_inferred = posterior.sample((10_000,))"
"theta_inferred = posterior.sample((10_000,))\n",
"\n",
"# get samples from ground-truth posterior\n",
"gt_samples = MultivariateNormal(observation, eye(2)).sample((len(theta_inferred) * 5,))\n",
"gt_samples = gt_samples[prior.support.check(gt_samples)][:len(theta_inferred)]"
]
},
{
Expand Down
Loading

0 comments on commit a4f7811

Please sign in to comment.