From a4f78115471bc44a823856ce387aee09c8c7ea65 Mon Sep 17 00:00:00 2001 From: Jan Date: Thu, 5 Sep 2024 13:09:17 +0200 Subject: [PATCH] fix: tutorials test error handling, fix bugs in tutorials (#1264) * fix: tutorials test error handling, fix tutorial 7 * fix remaining bugs in tutorials * refactor training interface tutorial --- tests/tutorials_test.py | 6 +- tutorials/07_sensitivity_analysis.ipynb | 162 ++------------ ...and_permutation_invariant_embeddings.ipynb | 8 +- .../15_importance_sampled_posteriors.ipynb | 6 +- tutorials/18_training_interface.ipynb | 206 +++++++++++++----- 5 files changed, 189 insertions(+), 199 deletions(-) diff --git a/tests/tutorials_test.py b/tests/tutorials_test.py index 3f822845b..bc4103994 100644 --- a/tests/tutorials_test.py +++ b/tests/tutorials_test.py @@ -32,8 +32,10 @@ def test_tutorials(notebook_path): if "Requested MovieWriter" in str(e): print("Skipping error in movie writer.") else: - raise CellExecutionError from e + raise RuntimeError( + f"Error executing the notebook {notebook_path}: {e}" + ) from e except Exception as e: - raise AssertionError( + raise RuntimeError( f"Error executing the notebook {notebook_path}: {e}" ) from e diff --git a/tutorials/07_sensitivity_analysis.ipynb b/tutorials/07_sensitivity_analysis.ipynb index 32da5a5ad..d1d6bf3b7 100644 --- a/tutorials/07_sensitivity_analysis.ipynb +++ b/tutorials/07_sensitivity_analysis.ipynb @@ -40,15 +40,7 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.\n" - ] - } - ], + "outputs": [], "source": [ "import torch\n", "from torch.distributions import MultivariateNormal\n", @@ -72,25 +64,11 @@ "execution_count": 2, "metadata": {}, "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/2000 [00:00" ] @@ -166,22 +129,7 @@ "cell_type": "code", "execution_count": 5, "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Drawing 1000 posterior samples: 0%| | 0/1000 [00:00" ] @@ -356,38 +276,10 @@ "cell_type": "code", "execution_count": 11, "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Drawing 5000 posterior samples: 0%| | 0/5000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "_ = pairplot(samples, limits=[[-3, 3], [-3, 3]], figsize=(3, 3))" + "pairplot(samples, limits=[[-3, 3], [-3, 3]], figsize=(3, 3));" ] }, { @@ -259,7 +280,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "77f5e432-1399-4b26-8942-bb710a009651", "metadata": {}, "outputs": [], @@ -271,10 +292,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "012870c0-7551-4895-9b3d-7cae3d103ae9", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Shape of x_o: torch.Size([1, 2])\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Shape of samples: torch.Size([1000, 2])\n" + ] + } + ], "source": [ "print(f\"Shape of x_o: {x_o.shape}\")\n", "samples = posterior.sample((1000,), x=x_o)\n", @@ -293,10 +329,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "id": "e01d496c-d4f0-43aa-96a3-103afa745416", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQMAAAExCAYAAAB4Y8dtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAa8ElEQVR4nO3de1hU550H8O9hnIERGIgoyE1QiGjijSVCxViJpWpTJW6fqrms2MSNl2o3JIaqGy1x06y7a2x8atOomwiutInZrqnPaqvxBtYEb8zSiCYKlACOIF5guDs48+4fHE9E5TLMDMPB7+d55nnmnHnPOb+Zh/nyzrm8RxJCCBDRQ8/D3QUQUd/AMCAiAAwDIpIxDIgIAMOAiGQMAyICwDAgIhnDgIgAMAyISMYwICIADAMikjEMiAgAw4CIZAwDIgLAMHCq3EvXELl6PwpNZneXQmQ3hoGTmGqbsXDHaeU5kdowDJykptGiPA/117uxEqKeYRgQEQCGARHJGAYuMGvLCe5EJNVhGLgIw4DUhmFARAAYBkQkYxgQEQCGARHJGAZEBIBhQEQyhgERAWAYEJGMYUBEABgGRCRjGBA5SV1LKz4+XY4bDbfcXUqPDHB3AUT9gbmpFRP/9TAst20ovd6INU+PdndJdmPPwMkyZj8GAFj/vxc44tFDwtzUihc+PAnLbRsAoKKmyc0V9QzDwMkmRg7Czpfi0dxqbTf6EfVfv80pRqGpTpn202vdWE3PMQxcIMBb5+4SqBddrWsBAPh6qftXN8OAyAGW2zZ8c6PtZ8EADwkA8MnZy5i5+Tg+/b/L7izNbgwDoh6w3LahyXIbP/2dEQUVtdAN8MCaH4zGUIMXrDaBr6vq8eruvyL7ZJm7S+02dfdriNxg+e+N2P9lpTKtkSRYbTb8/H++vK/t2j8WwiYEUidF9mKFPcOeAZGdjn1drTz3HOABqxCw2jpu/4u957H2j+eUow19FXsGRAD+cecZHP6qGn9+ZQpGBxs6bGe5bYPVJgAAcyaE4I8FVwAAQ3w8oRsgwUurweLvRsFPr8WK3+fjzvc/+2Q5DhRWQa/TKOvSaTwwKtiA2HB/zI0Lh9/ABx+F+PWRIvzq0CV4aT0wxNcTFTcffMh6fLg/9i6f3JO3D4A9AyIAwOGv2v7bv3voUodtLLdt+OnvjLh12watRlKCYEK4P6413IKptgUl1xqxes+XaLh1G5vmTQAAGLwGwNdzAK43WFBxs1l5lFxrxP4vK/HL/V9h9m9O4Ouqugdu9/ilawCAllZbh0EAAH+tqO3BO/8WewZEd7EJ0eFrr+4uwOGvrkLjIaHV2tbu7b8fg1N/u4mCilr8JDESFqsNvz9VjvQ//BXrfth2FqLGQ0LO60n42/VGZV1fVdbhD/mXUX6zCY23rCi/2YQf/fYL7PjJRHxnREC77W6aNx5TN+a0m/dYsAFznwhTprUaD/w4LgyOYBgQdcO1+lvYf65tp6GnRkKTTeDN2Y8h7JGBWH/+AoC2HsIzE0Jw22rDJ2cv40/nquAhATVNrfjPv5Ri5fSRkKS2w49bc0tgLK9tt40mixV/Old5XxhEBHgjNz0Jz24/iUpzC/7hO8Pw1jNjlHU5C8OAqBtu29p+/Os0HoD8JfTTa5Wb7QKAsbwG50xm3GhoO/PUQ5Kw9oeP4V/2XcBvjhVD4yHh1e+PBACU39Xd95AAmwAiAwbip0nRD9x+RIA39v/TFHxdVYdJIwKcHgQA9xkQtTPUz+uB8zXyCUW3bTYI+afEf+e3P6nov/LK8OGJUhyRjzb4eA3AS08OR/LoQADA9uN/w235sMOjgT4AAAnfBsFHi7/T4fYBYJC3DolRg10SBAB7BkQAgB0/eQL//ueL+OcOrjYc4uMJP70W5uZW3Nmt8PJ3h2NsmB9uWwW0mm//rwohYKppRnSQD97ad0HZOfnkowH45Oxl3Gi4hZuNFkgABIBxYX7YvuCJToOgNzAMiABMGxWEaaOCOnxdkiSMC/PDX4quQz6yiJLqRqz5QfvwsNkE/vnTc9h3rhI49+18Dwk4dKEahy5Ut2v/47gw/HLOGHhpNXA3hgFRNy16cjhOFF+HRe7q/3L/V7BYbXgiYpDS5g/5Ffjk7GVIAGKH+aO51YqLVfWwibbzAAJ9PeHjOQCPhxjwdxGPIDbc32XdfnsxDIi6KSkmEK8mj8SvDl1Sdvr9x4GL97W70/2/+2jBnAkh2DRvgrLvoS9iGBDZYcVT0bh4tR77v6yERpIQaPCEbsC3+wuaLVZU17cNexYZMBAeHhK+PzoIP585qk8HAcAwILKLh4eEzfMnwGYT+HNhFSrNLQ9sl5b8KNKSR/ZydY7hoUUiO2k1Hvj1c7GYGxeGe//Ze2k98POZMaoLAoA9A6Ie0Wo8sHHueGycO97dpTgNewZEBIBhQEQyhgERAWAYEJGMYUBEABgGRCRjGBARAIYBEckYBkQEgGHgUjd441VSEYaBCzzirYNeq8HSXfm8LTupBsPABUL99di6II63ZSdVYRi4CG/LTmrDMCAiAAwDIpIxDIgIAMOAiGQMAyICwDAgIhnDgIgAMAyISMYwcBJeh0BqxzBwAlNtM5buyodeq8EjPPOQVIr3TXCCmkYLmlut2PlSPEL99e4uh6hH2DNwIl6PQGrGMCAiAAwDIpIxDIgIAMOAiGQMAyICwDAgIhnDgIgAMAyISMYwICIADAMikjEMiAgAw4CIZAwDIgLAMHA5DnpCasEwcBHefJXUhmHgIrz5KqkNw8CFONgJqQnDgIgAMAx6BXcikhowDFyIOxFJTRgGLsSdiKQmDAMHmWqbUWgyd/g6dyKSWvC+CQ4w1TZj8r8dVaZ5AxVSM/YMHHB31z99RgxvoEKqxjBwEv4cILVjGBARAIaB04wJ9XN3CUQOYRg4wb6fPdllGPDEI+rrGAYuxhOPSC0YBi7GE49ILRgGvYBHGkgNGAYO4H4A6k8YBj1kqm3G0l350Gs13T7zcNaWE52eukzkTgyDHqpptKC51YqtC+LsOvOQYUB9Fa9NsJOpthlXapthqmk7MtCd/QF39xwe8dYpRxVC/fXtnne0vc5eJ3IWSQgherpwdV0Lqutvobi6AWm7C7B5/gREB/o4sz67ubKWG40WLNxxWpnWazU4vHJqt76oB89XYcmu/HbLbvjRWKzZcw4AsHVB3H3BcqPRgqXyMve+3t33yZOhqLscCgMi6j+4z4CIADAMiEjGMCAiAAwDIpL1+NCiEAL19fXOrIVcxNfXF5IkubsM6uN6HAb19fXw8+NhKzUwm80wGAzuLoP6uB4fWuyqZ1BXV4fw8HBUVFSo4g+xP9fLngF1R497BpIkdetLYzAYVPHluoP10sOKOxCJCADDgIhkLgsDT09PZGRkwNPT01WbcCrWSw87XptARAD4M4GIZAwDIgLAMCAiGcOAiAA4GAatra1YtWoVxo4dC29vb4SEhCA1NRVXrlzpdLk333wTkiS1e4waNcqRUhz23nvvITIyEl5eXkhISMDp06e7XsiFNmzYgIkTJ8LX1xeBgYGYM2cOLl682OkyWVlZ932uXl5evVQxqZ1DYdDU1ASj0Yh169bBaDRiz549uHjxIlJSUrpc9vHHH0dlZaXyOHHihCOlOGT37t147bXXkJGRAaPRiPHjx2PGjBmorq52W025ublYvnw5Tp48iUOHDqG1tRXTp09HY2Njp8sZDIZ2n2tZWVkvVUyqJ5zs9OnTAoAoKyvrsE1GRoYYP368szfdY/Hx8WL58uXKtNVqFSEhIWLDhg1urKq96upqAUDk5uZ22CYzM1P4+fn1XlHUrzh9n4HZbIYkSfD39++0XVFREUJCQjBixAi88MILKC8vd3Yp3WKxWJCfn4/k5GRlnoeHB5KTk5GXl+eWmh7EbG4bYn3QoEGdtmtoaEBERATCw8PxzDPP4Pz5871RHvUDTg2DlpYWrFq1Cs8991ynF88kJCQgKysLBw4cwPvvv4/S0lJMmTLFLeMjXL9+HVarFUFBQe3mBwUFoaqqqtfreRCbzYa0tDRMnjwZY8aM6bBdTEwMduzYgb179yI7Oxs2mw2JiYm4fPlyL1ZLqmVPNyI7O1t4e3srj+PHjyuvWSwWMXv2bBEbGyvMZrNd3ZOamhphMBjEBx98YNdyzmAymQQA8cUXX7Sbn56eLuLj43u9ngdZunSpiIiIEBUVFXYtZ7FYRFRUlFi7dq2LKqP+xK5LmFNSUpCQkKBMh4aGAmg7qjBv3jyUlZXh6NGjdl9S6+/vj5EjR6K4uNiu5Zxh8ODB0Gg0uHr1arv5V69exdChQ3u9nnutWLEC+/btw/HjxxEWFmbXslqtFrGxsW75XEl97PqZ4Ovri+joaOWh1+uVICgqKsLhw4cREBBgdxENDQ0oKSlBcHCw3cs6SqfTIS4uDkeOHFHm2Ww2HDlyBJMmTer1eu4QQmDFihX49NNPcfToUQwfPtzudVitVpw7d84tnyupkCPdCovFIlJSUkRYWJgoKCgQlZWVyuPWrVtKu2nTpoktW7Yo0ytXrhQ5OTmitLRUfP755yI5OVkMHjxYVFdXO1JOj3388cfC09NTZGVliQsXLojFixcLf39/UVVV5ZZ6hBBi2bJlws/PT+Tk5LT7XJuampQ2CxYsEKtXr1am169fLw4ePChKSkpEfn6+ePbZZ4WXl5c4f/68O94CqYxDYVBaWioAPPBx7NgxpV1ERITIyMhQpufPny+Cg4OFTqcToaGhYv78+aK4uNiRUhy2ZcsWMWzYMKHT6UR8fLw4efKkW+vp6HPNzMxU2kydOlUsXLhQmU5LS1PeQ1BQkHj66aeF0Wjs/eJJlXgJMxEB4LUJRE5TaDIjcvV+FJrM7i6lRxgGRE5yJwQYBkQPuRuNFneX4BCGAZETmGqbsfFg21Wlj3jr3FxNzzAMiJyg5q5egZdW48ZKeo5hQORkS3flw1Tb7O4y7MYwIHKi9BkxaG61tuspqAXDgMhBhSYzZm1pG5wnQKX7C4B+HAZJSUlIS0tTpiMjI7F582a31UP9l1oPJd6rxzdeVZszZ87A29vb6et9++23sX//fhQUFECn06G2ttbp26C+7e6jB2o9kgD0457BvYYMGYKBAwc6fb0WiwVz587FsmXLnL5uUodQfz0AYNuCOOW5GvWLMGhsbERqaip8fHwQHByMTZs23dfm3p8JkiRh27ZtmDVrFgYOHIjRo0cjLy8PxcXFSEpKgre3NxITE1FSUtLpttevX49XX30VY8eOdfbbIpVRcxAA/SQM0tPTkZubi7179+Kzzz5DTk4OjEZjl8u99dZbSE1NRUFBAUaNGoXnn38eS5YswZo1a3D27FllTAGih4Hq9xk0NDTgww8/RHZ2Nr73ve8BAHbu3NmtUYFefPFFzJs3DwCwatUqTJo0CevWrcOMGTMAAK+88gpefPFF1xVP1IeovmdQUlICi8XSbji2QYMGISYmpstlx40bpzy/MyDq3d39oKAgtLS0oK6uzokVE/VNqg8DR2i1WuW5JEkdzrPZbL1bGJEbqD4MoqKioNVqcerUKWVeTU0NLl265MaqiNRH9fsMfHx8sGjRIqSnpyMgIACBgYF444034OHROzlXXl6Omzdvory8HFarFQUFBQCA6Oho+Pj49EoNRM6g+jAAgI0bN6KhoQGzZ8+Gr68vVq5cqdyByNV+8YtfYOfOncp0bGwsAODYsWNISkrqlRqInIFjIBI56M61Cft+9iQAKM/HhPq5uTL7qH6fARE5B8OAiAAwDIhIxjAgIgAMAyKSMQyICADDgIhkDAMiAsAwICIZw4CIADAMiFxCjfddZBgQOdEj3jrotRpV3lWJYUDkRKH+emxdEKfKuyoxDIicTK13VWIYEBEAhgERyRgGRA5S45GDB2EYEDnAVNuMpbvyoddqVH2fRaCfjIFI5C41jRY0t1qx86V43l6NiNR7BOFuDAMiAsAwICIZw4CIADAMiEjGMCAiAAwDIpIxDIgIAMOAiGQMAyICwDAgIhnDgIgAMAyISMYwICIADAMikjEMiAgAw4CIZAwDIgLAMCAiGcOAyEXUNmoyw4DIydR6v0WGAZGTqfV+iwwDIhdQ42jJDAMiAsAwICIZw4CIADAMiEjGMCAiAAwDIpIxDIgIAMOAiGQMAyICwDAgIhnDgIgAMAyISMYwIHKA2sYs6AzDgKiHTLXNWLorH3qtBo+o8CrFew1wdwFEalXTaEFzqxU7X4pHqL/e3eU4jD0DIgd1NnaBmn5GMAyIXECNQ58xDIhcQI1DnzEMiFxEbUOfMQyICADDgIhkDAMiAsAwICIZw4CIADAMiEjGMCByMbWchcgwIHIRtZ2FyDAg6qGu/uOr7SxEhgFRD3T38mU1nYXIS5iJeqC/Xb4MsGdA5BA1/efvCsOAiAAwDIhIxjAg6gX3Hlo01Tb3ucONDAMiF7rzhV+yKx+FJrMyL3lTLpI35eLMNzf7TCgwDIjslHvpGmZtOdGttnefXzB3ax5Mtc0oNJnR3GpFc6sVc7fmIXlTbp8IBIYBkR0KTWYs3HFame7uEOlPjx2K5lYrCk1mLNmV3+61O/PdjWFAJMu9dA2Rq/d3+MU01TZj7tY86LUabFsQh89XT+vyHIMxoX4AgO8+OgQAUFzd8MB2S3blI3L1fuReutbhtu88Ck3mdtN3v+4ISQghHFoDPXSq61pQXX/L3WU41Y1Gi/IfP31GDKaOHILi6gak7S7A5vkTEB3oo0zvfCkeU0cOsWv9hSZzt39a7Hwpvt35CzcaLVi6Kx/Nrdb72uq1Gmz40Vis2XMOALB1QZyy7J0g6i6GAREB4M8EIpIxDIgIAMOAiGQMAyICwDAgIhnHMyC7CCFQX1/v7jKom3x9fSFJUrfaMgzILvX19fDzs+/4NbmP2WyGwWDoVlueZ0B26U7PoK6uDuHh4aioqOj2H6I79ed62TMgl5EkqdtfGIPBoIov1x0Pe73cgUhEABgGRCRjGJDTeXp6IiMjA56enu4upVtYbxvuQCQiAOwZEJGMYUBEABgGRCRjGBARAIYB9UBraytWrVqFsWPHwtvbGyEhIUhNTcWVK1c6Xe7NN9+EJEntHqNGjeqlqjv23nvvITIyEl5eXkhISMDp06e7XshFNmzYgIkTJ8LX1xeBgYGYM2cOLl682OkyWVlZ932uXl5edm+bYUB2a2pqgtFoxLp162A0GrFnzx5cvHgRKSkpXS77+OOPo7KyUnmcONG9cQFdZffu3XjttdeQkZEBo9GI8ePHY8aMGaiurnZLPbm5uVi+fDlOnjyJQ4cOobW1FdOnT0djY2OnyxkMhnafa1lZmf0bF0ROcPr0aQFAlJWVddgmIyNDjB8/vveK6ob4+HixfPlyZdpqtYqQkBCxYcMGN1b1rerqagFA5ObmdtgmMzNT+Pn5Obwt9gzIKcxmMyRJgr+/f6ftioqKEBISghEjRuCFF15AeXl57xT4ABaLBfn5+UhOTlbmeXh4IDk5GXl5eW6r625mc9uw7YMGDeq0XUNDAyIiIhAeHo5nnnkG58+ft3tbDANyWEtLC1atWoXnnnuu0wtnEhISkJWVhQMHDuD9999HaWkppkyZ4rbxEa5fvw6r1YqgoKB284OCglBVVeWWmu5ms9mQlpaGyZMnY8yYMR22i4mJwY4dO7B3715kZ2fDZrMhMTERly9ftm+DDvctqN/Lzs4W3t7eyuP48ePKaxaLRcyePVvExsYKs9ls13pramqEwWAQH3zwgbNL7haTySQAiC+++KLd/PT0dBEfH++Wmu62dOlSERERISoqKuxazmKxiKioKLF27Vq7luMlzNSllJQUJCQkKNOhoaEA2o4qzJs3D2VlZTh69Kjdl9P6+/tj5MiRKC4udmq93TV48GBoNBpcvXq13fyrV69i6NChbqnpjhUrVmDfvn04fvw4wsLC7FpWq9UiNjbW7s+VPxOoS76+voiOjlYeer1eCYKioiIcPnwYAQEBdq+3oaEBJSUlCA4OdkHVXdPpdIiLi8ORI0eUeTabDUeOHMGkSZPcUpMQAitWrMCnn36Ko0ePYvjw4Xavw2q14ty5c/Z/rnb1I4hEWzc0JSVFhIWFiYKCAlFZWak8bt26pbSbNm2a2LJlizK9cuVKkZOTI0pLS8Xnn38ukpOTxeDBg0V1dbU73oYQQoiPP/5YeHp6iqysLHHhwgWxePFi4e/vL6qqqtxSz7Jly4Sfn5/Iyclp97k2NTUpbRYsWCBWr16tTK9fv14cPHhQlJSUiPz8fPHss88KLy8vcf78ebu2zTAgu5WWlgoAD3wcO3ZMaRcRESEyMjKU6fnz54vg4GCh0+lEaGiomD9/viguLu79N3CPLVu2iGHDhgmdTifi4+PFyZMn3VZLR59rZmam0mbq1Kli4cKFynRaWppSf1BQkHj66aeF0Wi0e9u8hJmIAHCfARHJGAZEBIBhQEQyhgERAWAYEJGMYUBEABgGRCRjGFC/k5SUhLS0NGU6MjISmzdvdls9asEwoH7vzJkzWLx4sVPX+c0332DRokUYPnw49Ho9oqKikJGRAYvF4tTt9CZetUj93pAhQ5y+zq+//ho2mw3btm1DdHQ0CgsL8fLLL6OxsRHvvPOO07fXG9gzIFVrbGxEamoqfHx8EBwcjE2bNt3X5t6fCZIkYdu2bZg1axYGDhyI0aNHIy8vD8XFxUhKSoK3tzcSExNRUlLS4XZnzpyJzMxMTJ8+HSNGjEBKSgpef/117NmzxxVvs1cwDEjV0tPTkZubi7179+Kzzz5DTk4OjEZjl8u99dZbSE1NRUFBAUaNGoXnn38eS5YswZo1a3D27FnlUmJ7mM3mLocn69McvcqKyF3q6+uFTqcTn3zyiTLvxo0bQq/Xi1deeUWZFxERId59911lGkC7UYDy8vIEAPHhhx8q8z766CPh5eXV7VqKioqEwWAQ27dv79mb6QPYMyDVKikpgcViaTcK06BBgxATE9PlsuPGjVOe3xkDcezYse3mtbS0oK6urst1mUwmzJw5E3PnzsXLL79sz1voUxgG9FDSarXKc0mSOpxns9k6Xc+VK1fw1FNPITExEdu3b3dBpb2HYUCqFRUVBa1Wi1OnTinzampqcOnSpV7ZvslkQlJSEuLi4pCZmQkPD3V/nXhokVTLx8cHixYtQnp6OgICAhAYGIg33nijV76Ud4IgIiIC77zzDq5du6a85u7BVHuKYUCqtnHjRjQ0NGD27Nnw9fXFypUrlRuPuNKhQ4dQXFyM4uLi+0YvFiodPIzDnhERAO4zICIZw4CIADAMiEjGMCAiAAwDIpIxDIgIAMOAiGQMAyICwDAgIhnDgIgAMAyISMYwICIAwP8DINNACOSxZ0sAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "_ = pairplot(samples, limits=[[-3, 3], [-3, 3]], figsize=(3, 3), upper=\"contour\")" ] @@ -317,24 +364,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "id": "8dbe6e31-e682-419f-b88f-23b1e833eaf8", "metadata": {}, "outputs": [], "source": [ "class NPEData(torch.utils.data.Dataset):\n", "\n", - " def __init__(self, nsamples:int,\n", - " prior: torch.distributions.Distribution ,\n", + " def __init__(self,\n", + " num_samples: int,\n", + " prior: torch.distributions.Distribution,\n", " simulator: Callable,\n", - " seed:int = 44):\n", + " seed: int = 44):\n", " super().__init__()\n", "\n", " torch.random.manual_seed(seed) #will set the seed device wide\n", " self.prior = prior\n", " self.simulator = simulator\n", "\n", - " self.theta = prior.sample((nsamples,))\n", + " self.theta = prior.sample((num_samples,))\n", " self.x = simulator(self.theta)\n", "\n", " def __len__(self):\n", @@ -356,12 +404,12 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "id": "eb2aa734-dfca-4691-b3b5-f5be889219ba", "metadata": {}, "outputs": [], "source": [ - "train_data = NPEData(2048, prior, simulator)\n", + "train_data = NPEData(num_samples=2048, prior=prior, simulator=simulator)\n", "train_loader = torch.utils.data.DataLoader(train_data, batch_size=128)" ] }, @@ -377,7 +425,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "id": "2e61a11a-b9d6-4211-8456-cb4ff8755d1e", "metadata": {}, "outputs": [], @@ -392,15 +440,32 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "7fe6624f-ce79-43dd-b661-1cced75dca09", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "last loss 4.49238920211792\n", + "last loss -1.2831010818481445\n", + "last loss -1.5764970779418945\n", + "last loss -1.6195335388183594\n", + "last loss -1.6439297199249268\n", + "last loss -1.6492975950241089\n", + "last loss -1.6488871574401855\n", + "last loss -1.6473512649536133\n", + "last loss -1.6515816450119019\n", + "last loss -1.6809775829315186\n" + ] + } + ], "source": [ "optw = AdamW(list(maf_estimator.parameters()), lr=5e-4)\n", - "nepochs = 50\n", + "num_epochs = 100\n", "\n", - "for ep in range(nepochs):\n", + "for ep in range(num_epochs):\n", " for idx, (theta_batch, x_batch) in enumerate(train_loader):\n", " optw.zero_grad()\n", " losses = maf_estimator.loss(theta_batch, condition=x_batch)\n", @@ -413,10 +478,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "id": "cc63ae5f-33e2-422b-9365-f183a53c09b9", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Shape of samples: torch.Size([1000, 1, 2]) # Samples are returned with a batch dimension.\n", + "Shape of samples: torch.Size([1000, 2]) # Removed batch dimension.\n" + ] + } + ], "source": [ "# let's compare the trained estimator to the NSF from above\n", "samples = maf_estimator.sample((1000,), condition=x_o).detach()\n", @@ -428,10 +502,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "id": "cd309901-76e1-4ebc-87e3-f1fa32161185", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQMAAAExCAYAAAB4Y8dtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWa0lEQVR4nO3de2xUdd7H8c9M6bTTu9xqW7DcIhhBQhrbgGvsun3AGKhmE8FLqHGJXAIbu7INGsVKzIZNlKxJYxSjtkSS1d0EQpZNVARagoJomyaAEWlTC5RLvbRDL1OmzPyePzjMWqXQ6UzncOr7lUwyM5wz8+3Evj0z03OOyxhjBOA3z233AABuDsQAgCRiAMBCDABIIgYALMQAgCRiAMBCDABIIgYALGPsHgDO83/uR+weAUOwJ/TviJZnywCAJGIAwEIMAEgiBiOirdOvtk6/3WMAESEGMdbW6VfJljqVbKkjCHAUYhBjHT0B+fuD8vcH1dETsHscYMiIQQwda/NpcdVBu8cAhoUYxBBvC+BkxCCG8rK8do8ADBsxACCJGACwEAMAkogBAAsxACCJGACwEAMAkogBAAsxACCJGACwEAMAkohBTP3ILstwMGIQI22dfq1+v97uMYBhIwYxcvWgJhWLZto9CjAsxCDG7rt9gnb/+Xd2jwFEjBgAkEQMAFiIAQBJxACAhROvArHgcv36PmPiP0cUiAEQLXeCJMnlHhgEEwxaV5wRBWIADJfLJdeYRLlTvVJCglwpKZIk4/dLly8r1OOXudxv85BDRwyA4bgagrRUKWeCQike+XOuxMB7rlfunktyn2tXqDt0ZQvBAVsHxAAYBteYRLm9yXJlpsufm65A1hj5piVIRhprUpTUMUaJHcmSv08KGckE7R75hogBECl3wpW3BhPHq3faWJ0pHqNg7iW9VLhTfaFEbflPqdK/8yq76xa5+/oUCvllQsQAGJ1cbmlMgoJJLl3ODCojw69ZnnPqMR6FPEahMe5rf8NwEyMGQKRCQSkYlC4HJZfkSr0sz5igdvoK1H4pXcntbnl/DMnd5VfQQR8iEgNgGEwwKPfloNz9RqbfrZ4+j4525uonf4oSu6XE3pAU6HfMh4cSMQCGxVy6pNCF75Xidinnk2wFPelq96QrISDd+k2X3J09CnV0XtmKcAhiAAyDuXxZJmSU8GOnMppSZBLcchlJl0NKaO+Q6fXLXLpk95gRIQbAcIWCCvX2yn3me7lcLplg6Mp9Pb0ywdD//gLRIYgBEAUTCCj0U+eV6/3OPgYmMQCiYYz1IWHI7kmiRgyAaDnoQ8Lr4XgGACQRAwAWYgBAEjEAYCEGMcKp1eB0xCAGrp5azZuYoFtSPXaPAwwLMYiBq6dWe2t5gfKyvHaPAwwLMYihcWwVwMGIAQBJxACAhRiMIL5hgJMQgxFwS6pH3sQErX6/Xm2dfrvHAYaEGIyAvCyv3lpeIH9/UB1sHcAhiMEI4ZsFOA0xACCJGACwEAMAkogBAAsxACCJGACwEAMAkogBAAsxACCJGACwEAMAkogBAAsxACCJGACwEAMAkogBAAsxACCJGACwEAMAkogBAAsxACCJGACwEAMAkogBAAsxACCJGMQEJ1jFaEAMotTW6dfq9+vlTUzQLZxSDQ42xu4BnK6jJyB/f1Db/lSovCyv3eMAw8aWQYwMdqJV3kLAKYjBCLkl1SNvYoJWv1+vtk6/3eMAN0QMRkhelldvLS+Qvz+oDrYO4ADEYAQN9tYBuBkRAwCSiAEACzEAIIkYALAQAwCSiAEACzEAIIkYALAQAwCSiEHU2BEJowUxiMJQj2WwuOqgjrX54jgZEDliEIWrxzJ4a3nBNY9l8PNAsOcibnbEIAYG2yEpL8urrcsLwteBmxkxiMJQPi8gAnAKYjBMHPsQo01Ux0Bsv9in9q5LsZrFUZrauyM69mFTe3ccpvq12XmZtjwvnMdljDF2DwHAfrxNACCJGACwEAMAkogBAMuwv00wxqirqyuWs2CEpKeny+Vy2T0GbnLDjkFXV5cyM/naygl8Pp8yMjLsHgM3uWF/tXijLYOLFy9q8uTJOn36tCP+QxzN87JlgKEY9paBy+Ua0i9NRkaGI365rmJe/FbxASIAScQAgGXEYpCUlKTKykolJSWN1FPEFPPit459EwBI4m0CAAsxACCJGACwEAMAkqKMQX9/vzZs2KA5c+YoNTVVubm5Kisr09mzZ6+73ssvvyyXyzXgMmvWrGhGidobb7yhKVOmKDk5WUVFRTpy5Iit82zevFl333230tPTNXHiRD388MM6ceLEddepqan51euanJwcp4nhdFHFoLe3Vw0NDdq4caMaGhq0Y8cOnThxQqWlpTdc984779S5c+fCl4MHD0YzSlQ+/PBDPfvss6qsrFRDQ4Pmzp2rRYsWqb293baZ6urqtHbtWh0+fFh79uxRf3+/Fi5cqJ6enuuul5GRMeB1bW1tjdPEcDwTY0eOHDGSTGtr66DLVFZWmrlz58b6qYetsLDQrF27Nnw7GAya3Nxcs3nzZhunGqi9vd1IMnV1dYMuU11dbTIzM+M3FEaVmH9m4PP55HK5lJWVdd3lTp48qdzcXE2bNk1PPPGETp06FetRhiQQCKi+vl4lJSXh+9xut0pKSnTo0CFbZroWn+/KGZnGjh173eW6u7uVn5+vyZMn66GHHtLx48fjMR5GgZjGoK+vTxs2bNBjjz123Z1nioqKVFNTo48++khvvvmmWlpadO+999pyfIQffvhBwWBQ2dnZA+7Pzs7W+fPn4z7PtYRCIZWXl+uee+7R7NmzB11u5syZeu+997Rr1y5t375doVBICxYs0JkzZ+I4LRwrks2I7du3m9TU1PDlwIED4X8LBAJmyZIlZt68ecbn80W0edLR0WEyMjLMO++8E9F6sdDW1mYkmc8//3zA/RUVFaawsDDu81zL6tWrTX5+vjl9+nRE6wUCATN9+nTz4osvjtBkGE0i2oW5tLRURUVF4dt5eXmSrnyrsHTpUrW2tmrfvn0R71KblZWl22+/XU1NTRGtFwvjx49XQkKCLly4MOD+Cxcu6NZbb437PL+0bt067d69WwcOHNCkSZMiWjcxMVHz5s2z5XWF80T0NiE9PV0zZswIX7xebzgEJ0+e1Keffqpx48ZFPER3d7eam5uVk5MT8brR8ng8Kigo0N69e8P3hUIh7d27V/Pnz4/7PFcZY7Ru3Trt3LlT+/bt09SpUyN+jGAwqKNHj9ryusKBotmsCAQCprS01EyaNMk0Njaac+fOhS+XLl0KL3f//febqqqq8O3169eb2tpa09LSYj777DNTUlJixo8fb9rb26MZZ9g++OADk5SUZGpqaszXX39tVq5cabKyssz58+dtmccYY9asWWMyMzNNbW3tgNe1t7c3vMzy5cvNc889F769adMm8/HHH5vm5mZTX19vHn30UZOcnGyOHz9ux48Ah4kqBi0tLUbSNS/79+8PL5efn28qKyvDt5ctW2ZycnKMx+MxeXl5ZtmyZaapqSmaUaJWVVVlbrvtNuPxeExhYaE5fPiwrfMM9rpWV1eHl7nvvvvMk08+Gb5dXl4e/hmys7PNgw8+aBoaGuI/PByJXZiBGGrr9Ety5tm32TcBiJG2Tr9KttSpZEtdOApOQgyAGDnW5pO/Pyh/f1AdPQG7x4kYMQBioK3Tr1Xv19s9RlSIARADTtwS+CViAEASMQBgIQYAJBEDAJZRG4Pi4mKVl5eHb0+ZMkWvv/66bfMAN7thn3jVab788kulpqbG/HH/9re/6b///a8aGxvl8XjU2dkZ8+cA4mHUbhn80oQJE5SSkhLzxw0EAnrkkUe0Zs2amD82EE+jIgY9PT0qKytTWlqacnJytGXLll8t88u3CS6XS1u3btXixYuVkpKiO+64Q4cOHVJTU5OKi4uVmpqqBQsWqLm5+brPvWnTJv3lL3/RnDlzYv1jAXE1KmJQUVGhuro67dq1S5988olqa2vV0NBww/VeeeUVlZWVqbGxUbNmzdLjjz+uVatW6fnnn9dXX30VPqYAMFR//6Nz/6fg+M8Muru79e6772r79u36wx/+IEnatm3bkI4K9NRTT2np0qWSpA0bNmj+/PnauHGjFi1aJEl65pln9NRTT43c8MBNxPFbBs3NzQoEAgMOxzZ27FjNnDnzhuvedddd4etXD4j688397Oxs9fX16eLFizGcGLg5OT4G0UhMTAxfd7lcg94XCoXiOxhgA8fHYPr06UpMTNQXX3wRvq+jo0PffvutjVMBzuP4zwzS0tK0YsUKVVRUaNy4cZo4caJeeOEFud3x6dypU6f0008/6dSpUwoGg2psbJQkzZgxQ2lpaXGZAYgFx8dAkl599VV1d3dryZIlSk9P1/r168NnIBppL730krZt2xa+PW/ePEnS/v37VVxcHJcZgFjgGIhADBxr82lx1UH9/Y9z9NyOo9r9599pdl6m3WNFxPGfGQCIDWIAQBIxAGAhBgAkEQMAFmIAQBIxAGAhBgAkEQMAFmIAQBIxAKJ29U+RnY4YAFE61hafneJGGjEAIIkYALAQAyBKt6R6rnndaYgBEKW8LK8kaevygvB1JyIGQIw4OQQSMQBgIQYAJBEDABZiAEASMQBgIQYAJBEDABZiAEASMQBgIQYAJBEDABZiAEASMQBgIQbACPixJ2D3CBEjBkAM3ZLqkTcxQavfr1dbp9/ucSJCDIAYysvy6q3lBfL3B9XhsK0DYgDE2DiHHvqMGACQRAwAWIgBAEnEAICFGACQRAwAWIgBAEnEAICFGABRcuJ+CNdCDIAotHX6tfr9enkTExx90lVJGmP3AICTdfQE5O8PatufCjnXIgDn7o/wc8QAGCHswgz8xl2NwCqHHdOAGAAx9vPjGDjpmAbEAIAkYgDAQgwASCIGACzEAIix2XmZdo8wLMQAiLHZeZna/eff2T1GxIgBAEnEAICFGACQRAwAWIgBAEnEAICFGACQRAwAWIgBAEnEAICFGACQRAyAqIyWcyZIxAAYttF0zgSJ8yYAwzaazpkgsWUARG00nDNBIgYALMQAgCRiAMBCDIARtLjqoI61+eweY0iIATBMQ/0bA2IAjGKj7W8MJP7OABiW0fY3BhJbBkBUBvsbAyduLRADYATkZXn179XzJUmb/vO1I07NTgyAEXL3lLHa9qdC+fuDjjg1O58ZIGLtF/vU3nXJ7jFs1dTePaTlrr6NGOrysRTpad5cxhgzQrMAcBDeJgCQRAwAWIgBAEnEAICFGACQxFeLiJAxRl1dXXaPgSFKT0+Xy+Ua0rLEABHp6upSZmZk31/DPj6fTxkZGUNalr8zQESGsmVw8eJFTZ48WadPnx7yf4h2Gs3zsmWAEeNyuYb8C5ORkeGIX66rfuvz8gEiAEnEAICFGCDmkpKSVFlZqaSkJLtHGRLmvYIPEAFIYssAgIUYAJBEDABYiAEAScQAw9Df368NGzZozpw5Sk1NVW5ursrKynT27Nnrrvfyyy/L5XINuMyaNStOUw/ujTfe0JQpU5ScnKyioiIdOXLEtlk2b96su+++W+np6Zo4caIefvhhnThx4rrr1NTU/Op1TU5Ojvi5iQEi1tvbq4aGBm3cuFENDQ3asWOHTpw4odLS0huue+edd+rcuXPhy8GDB+Mw8eA+/PBDPfvss6qsrFRDQ4Pmzp2rRYsWqb293ZZ56urqtHbtWh0+fFh79uxRf3+/Fi5cqJ6enuuul5GRMeB1bW1tjfzJDRADR44cMZJMa2vroMtUVlaauXPnxm+oISgsLDRr164N3w4GgyY3N9ds3rzZxqn+p7293UgydXV1gy5TXV1tMjMzo34utgwQEz6fTy6XS1lZWddd7uTJk8rNzdW0adP0xBNP6NSpU/EZ8BoCgYDq6+tVUlISvs/tdqukpESHDh2yba6f8/munKdx7Nix112uu7tb+fn5mjx5sh566CEdP3484uciBohaX1+fNmzYoMcee+y6O84UFRWppqZGH330kd588021tLTo3nvvte34CD/88IOCwaCys7MH3J+dna3z58/bMtPPhUIhlZeX65577tHs2bMHXW7mzJl67733tGvXLm3fvl2hUEgLFizQmTNnInvCqLctMOpt377dpKamhi8HDhwI/1sgEDBLliwx8+bNMz6fL6LH7ejoMBkZGeadd96J9chD0tbWZiSZzz//fMD9FRUVprCw0JaZfm716tUmPz/fnD59OqL1AoGAmT59unnxxRcjWo9dmHFDpaWlKioqCt/Oy8uTdOVbhaVLl6q1tVX79u2LeHfarKws3X777WpqaorpvEM1fvx4JSQk6MKFCwPuv3Dhgm699VZbZrpq3bp12r17tw4cOKBJkyZFtG5iYqLmzZsX8evK2wTcUHp6umbMmBG+eL3ecAhOnjypTz/9VOPGjYv4cbu7u9Xc3KycnJwRmPrGPB6PCgoKtHfv3vB9oVBIe/fu1fz5822ZyRijdevWaefOndq3b5+mTp0a8WMEg0EdPXo08tc1ou0IwFzZDC0tLTWTJk0yjY2N5ty5c+HLpUuXwsvdf//9pqqqKnx7/fr1pra21rS0tJjPPvvMlJSUmPHjx5v29nY7fgxjjDEffPCBSUpKMjU1Nebrr782K1euNFlZWeb8+fO2zLNmzRqTmZlpamtrB7yuvb294WWWL19unnvuufDtTZs2mY8//tg0Nzeb+vp68+ijj5rk5GRz/PjxiJ6bGCBiLS0tRtI1L/v37w8vl5+fbyorK8O3ly1bZnJycozH4zF5eXlm2bJlpqmpKf4/wC9UVVWZ2267zXg8HlNYWGgOHz5s2yyDva7V1dXhZe677z7z5JNPhm+Xl5eH58/OzjYPPvigaWhoiPi52YUZgCQ+MwBgIQYAJBEDABZiAEASMQBgIQYAJBEDABZigFGnuLhY5eXl4dtTpkzR66+/bts8TkEMMOp9+eWXWrlyZUwf87vvvtOKFSs0depUeb1eTZ8+XZWVlQoEAjF9nnhir0WMehMmTIj5Y37zzTcKhULaunWrZsyYoWPHjunpp59WT0+PXnvttZg/XzywZQBH6+npUVlZmdLS0pSTk6MtW7b8aplfvk1wuVzaunWrFi9erJSUFN1xxx06dOiQmpqaVFxcrNTUVC1YsEDNzc2DPu8DDzyg6upqLVy4UNOmTVNpaan++te/aseOHSPxY8YFMYCjVVRUqK6uTrt27dInn3yi2tpaNTQ03HC9V155RWVlZWpsbNSsWbP0+OOPa9WqVXr++ef11VdfhXcljoTP57vh4cluatHuZQXYpaury3g8HvOvf/0rfN+PP/5ovF6veeaZZ8L35efnm3/84x/h25IGHAXo0KFDRpJ59913w/f985//NMnJyUOe5eTJkyYjI8O8/fbbw/thbgJsGcCxmpubFQgEBhyFaezYsZo5c+YN173rrrvC168eA3HOnDkD7uvr69PFixdv+FhtbW164IEH9Mgjj+jpp5+O5Ee4qRAD/CYlJiaGr7tcrkHvC4VC132cs2fP6ve//70WLFigt99+ewQmjR9iAMeaPn26EhMT9cUXX4Tv6+jo0LfffhuX529ra1NxcbEKCgpUXV0tt9vZv058tQjHSktL04oVK1RRUaFx48Zp4sSJeuGFF+LyS3k1BPn5+Xrttdf0/fffh//N7oOpDhcxgKO9+uqr6u7u1pIlS5Senq7169eHTzwykvbs2aOmpiY1NTX96ujFxqEHD+OwZwAk8ZkBAAsxACCJGACwEAMAkogBAAsxACCJGACwEAMAkogBAAsxACCJGACwEAMAkqT/ByvNTY5hz0AGAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "_ = pairplot(samples, limits=[[-3, 3], [-3, 3]], figsize=(3, 3))\n" ] @@ -458,28 +543,28 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "id": "794d401c-9410-4f39-9fd8-787a90a391cc", "metadata": {}, "outputs": [], "source": [ - "from sbi.neural_nets.net_builders import build_nsf\n", "from sbi.inference.posteriors import MCMCPosterior\n", "from sbi.inference.potentials import likelihood_estimator_based_potential" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "id": "ccb162e2-c4ed-4268-82f7-81ef79ab8ddb", "metadata": {}, "outputs": [], "source": [ - "density_estimator = build_nsf(x, theta) # Note that the order of x and theta are reversed in comparison to NPE.\n", + "# Note that the order of x and theta are reversed in comparison to NPE.\n", + "density_estimator = build_nsf(x, theta)\n", "\n", "# Training loop.\n", "opt = Adam(list(density_estimator.parameters()), lr=5e-4)\n", - "for _ in range(100):\n", + "for _ in range(200):\n", " opt.zero_grad()\n", " losses = density_estimator.loss(x, condition=theta)\n", " loss = torch.mean(losses)\n", @@ -500,13 +585,24 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "id": "cb208afd-5ec0-4995-82ee-8611ade423e3", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQMAAAExCAYAAAB4Y8dtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAX9ElEQVR4nO3dfVAUZ54H8G+DMzDCAIqKgAQUVkwpuhQrnGRTIYbTVE5JbqvUvJR4xoovq7shMRSmEpdY2T3vTq1YxeYSc0nAkquN+cOUdVoX4xt4KsbIFLdqEoRZgjoiqIHhRXBw5rk/bPokEZhmeqZp/H6qpopuuqd/M8V8eeZ5up+WhBACRPTQC9K7ACIaGRgGRASAYUBEMoYBEQFgGBCRjGFARAAYBkQkYxgQEQCGARHJGAZEBIBhQEQyhgERAWAYEJGMYaAxR1s3HG3depdBpBrDQEOOtm7k7qhE7o5KBgIZDsNAQ61dLnT3utHd60Zrl0vvcohUYRgQEQCGARHJGAZEBIBhQEQyhgERAWAYEJGMYUBEABgGRCRjGBARAIYBEckYBkQEgGFARDKGAREBYBj4zaKSk7jgcOpdBpHXGAZ+xDAgI2EYEBEAhgERyRgGRASAYUBEMoaBH40LM+tdAgWAEAK3Ou/oXYbPGAZ+8Pv5KQCA+CiLzpVQIPx7hR0ZfzyCrf/9HYQQepczbGP0LmA0imMIPFTO/O0WAGBX5d9gCgrC75/6BSQJMAUb638tw4DIBz9tCfz5eD3+fLwekgQsyZiCf/7HNIwxSCgYo0qiEUgIgT8e/A7/U3cTAPAPs2MRMiZI/h3w+bmreHVvDe66PXqW6TW2DIhUcN31YMt/XcSJuhu46xZocvYAACZaQ/DXq22YaDVDCKDXLXCz8w4O/rUJY4Ik7Fz2S0iSpHP1g2MYEHnJddeD3/6nDUe+a1bWSQAEgBsdDx5NCJKA/TXXMGNyBNblJAem0GHi1wQiL722twZHvmvGmCAJj8ZaMWdKJIYaO/DIG/zboe9x2n7T7zX6gmFA5IUbHXdw8HwTAMAcLOG7pg7871XvLkQLNQVBCGCfzeHPEn3GMCDywl3PvU5Ac3AQIH/3t4aMwZOpE4fcN0jefqR3JDIMiLxgMQUjSAJcbo/y4e64cxfHa28g0vLgrrcJ4ffOQPXI3xXGhozsLrqRXR3RCBE11ox1Ocl4/7gdnXfu4pcJUQgPCcbXDT/C2X0XjyVH4+9nxijbf9PwIw6evw4A6LnrwVhzMP4pO0mn6r3DMCDy0hsLUuH2AB9W2lFzpU1ZHxwEnLLfwin7rX7b9400AMD2JXMwPcYasFqHg2FA5CVJklD0dCrix1lw4tIN3HV7UHHpBtweICl6LFImhSvnEly+1YXa5k4AwBsLpuOZtFg9S/cK+wyIVJAkCcv/LhH/kf8rlK7MxL/+ZjYkCfjh1m3ER1nw0fIMZCSOU4KgIPcX2DD/FzpX7R22DDRyweHEopKTepdBAbZ0bgIAoGjfX7G7qhEXrrWjurEVwL0gKMidrmd5qjAMNMLJTx9e9weCUYMAYBgQaWLp3ARIErD9q1qsyE7Cb3NS9C5JNYaBRu6f1YgzHD2clvwqAUt+laB3GcPGDkSN9M1qtGt5Bmc4IkNiGGiMQUBGxTAgIgAMAyKSMQyICADDgIhkDAM/utXl0rsEIq8xDPxgXJgZFlMw1u6phqOtW+9yiLzCMPCD+CgLPlyege5eN1rZOiCDYBj4STTPQiSDYRgQEQCGARHJGAZEBIBhQEQyhgERAWAYEJGMYUBEABgGRCRjGBARAIYBEckYBkQEgGFARDKGgUY4dwEZHcNAA462bqzdUw2LKZj3TCDD4k1UNNDa5UJ3rxu7X87kVOlkWGwZaIhzGJCRMQyICADDgIhkDAMiAsAwICIZw4CIADAM/I4nI5FRMAz8hDdSIaNhGPgJb6RCRsMw8COehERGwjAgIgAMAyKSMQyICADDgIhkDAMiAsAwICIZw4CIADAMiEjGMCAiAAwDIpIxDIgIAMMgIHgZMxkBw0ADA33YeRkzGQnDwEeD3UCFlzGTkfAmKj4a6gYqvIyZjIItA43wQ09GxzAgIgAMAyKSMQyICADDwGc8h4BGC4aBDwYbViQyGg4t+mCoYUUiI2HLQAMcVqTRgGEwTI62blWnGC8qOYkLDqcfKyLyzaj5mtD3wfR3c93R1o1rbd3I/+QsunvdqvatvHQDs+IjvToG4P/XQnQ/n8Kgpb0HLR13tKpl2G51ubB2TzUA4MPlGX5rtvcd5/4QGKrz8P7fbTtUi22HarFz2S+RMil80GMA2rwWb8KHCAAkIYTQuwgi0h/7DIgIAMOAiGQMAyICwDAgItmwRxOEEOjo6NCyFvITq9UKSZL0LoNGuGGHQUdHByIjOWxlBE6nExEREXqXQSPcsIcWh2oZtLe3IyEhAVeuXDHEH+JorpctA/LGsFsGkiR59aGJiIgwxIerD+ulhxU7EIkIAMOAiGR+C4OQkBAUFxcjJCTEX4fQFOulhx2vTSAiAPyaQEQyhgERAWAYEJGMYUBEAHwMg97eXhQVFSEtLQ1hYWGIi4tDfn4+rl27Nuh+77zzDiRJ6veYMWOGL6X47P3330dSUhJCQ0ORlZWFs2fP6lrP1q1bMXfuXFitVkyaNAnPPfccamtrB92nrKzsZ+9raGhogComo/MpDG7fvg2bzYbNmzfDZrNh3759qK2tRV5e3pD7zpw5E01NTcrj5MmTvpTik7179+L1119HcXExbDYb5syZg4ULF6KlpUW3miorK7F+/XqcOXMGhw8fRm9vLxYsWICurq5B94uIiOj3vjY2NgaoYjI8obGzZ88KAKKxsXHAbYqLi8WcOXO0PvSwZWZmivXr1yvLbrdbxMXFia1bt+pYVX8tLS0CgKisrBxwm9LSUhEZGRm4omhU0bzPwOl0QpIkREVFDbpdXV0d4uLiMG3aNLz00ku4fPmy1qV4xeVyobq6Grm5ucq6oKAg5ObmoqqqSpeaHsTpvDfN+vjx4wfdrrOzE4mJiUhISMCzzz6LixcvBqI8GgU0DYOenh4UFRXhhRdeGPTimaysLJSVleHLL7/EBx98gIaGBjz++OO6zI9w8+ZNuN1uxMTE9FsfExOD69evB7yeB/F4PCgoKMBjjz2GWbNmDbhdamoqPv30U+zfvx/l5eXweDzIzs7G1atXA1gtGZaaZkR5ebkICwtTHidOnFB+53K5xOLFi0V6erpwOp2qmietra0iIiJCfPzxx6r204LD4RAAxOnTp/utLywsFJmZmQGv50HWrl0rEhMTxZUrV1Tt53K5RHJysnj77bf9VBmNJqouYc7Ly0NWVpayHB8fD+DeqMLSpUvR2NiIY8eOqb6kNioqCtOnT0d9fb2q/bQwYcIEBAcHo7m5ud/65uZmTJ48OeD1/NSGDRtw4MABnDhxAlOmTFG1r8lkQnp6ui7vKxmPqq8JVqsVKSkpysNisShBUFdXhyNHjiA6Olp1EZ2dnbDb7YiNjVW9r6/MZjMyMjJw9OhRZZ3H48HRo0cxb968gNfTRwiBDRs24IsvvsCxY8cwdepU1c/hdrtx/vx5Xd5XMiBfmhUul0vk5eWJKVOmiJqaGtHU1KQ87ty5o2w3f/58UVJSoixv3LhRVFRUiIaGBnHq1CmRm5srJkyYIFpaWnwpZ9g+++wzERISIsrKysS3334rVq9eLaKiosT169d1qUcIIdatWyciIyNFRUVFv/f19u3byjbLly8XmzZtUpa3bNkiDh06JOx2u6iurhbPP/+8CA0NFRcvXtTjJZDB+BQGDQ0NAsADH8ePH1e2S0xMFMXFxcrysmXLRGxsrDCbzSI+Pl4sW7ZM1NfX+1KKz0pKSsQjjzwizGazyMzMFGfOnNG1noHe19LSUmWbJ554QqxYsUJZLigoUF5DTEyMeOaZZ4TNZgt88WRIvISZiADw2gQiTV1wOJG06SAuOJx6l6Iaw4BIQ30hwDAgIsNiGBARAIYBEckYBkQaueBwYtO+83qXMWwMAyKNGLHT8H4MAyICMIrDICcnBwUFBcpyUlISdu7cqVs9NPrd6nLpXYJPhn3jVaP55ptvEBYWpvnz/ulPf8LBgwdRU1MDs9mMtrY2zY9BI5+jrRvbDg0+R+VIN2pbBj81ceJEjB07VvPndblcWLJkCdatW6f5c5NxtBq8VQCMkjDo6upCfn4+wsPDERsbix07dvxsm59+TZAkCbt27cKiRYswduxYPProo6iqqkJ9fT1ycnIQFhaG7Oxs2O32QY+9ZcsWvPbaa0hLS9P6ZREF1KgIg8LCQlRWVmL//v346quvUFFRAZvNNuR+7777LvLz81FTU4MZM2bgxRdfxJo1a/Dmm2/i3LlzypwCRA8Dw/cZdHZ24pNPPkF5eTmeeuopAMDu3bu9mhVo5cqVWLp0KQCgqKgI8+bNw+bNm7Fw4UIAwKuvvoqVK1f6r3iiEcTwLQO73Q6Xy9VvOrbx48cjNTV1yH1nz56t/Nw3Ier9zf2YmBj09PSgvb1dw4qJRibDh4EvTCaT8rMkSQOu83g8gS2MDG9cmFnvElQzfBgkJyfDZDLh66+/Vta1trbi0qVLOlZFD6vfz08BAMRHWXSuRD3D9xmEh4dj1apVKCwsRHR0NCZNmoS33noLQUGBybnLly/jxx9/xOXLl+F2u1FTUwMASElJQXh4eEBqoJEjzoAh0MfwYQAA27ZtQ2dnJxYvXgyr1YqNGzcqdyDytz/84Q/YvXu3spyeng4AOH78OHJycgJSA5EWOAcikQYuOJxYVHIS//KbNGzadx4HfvdrzIqP1LssVQzfZ0BE2mAYEGnA6BcpAQwDIp852rqxdk81LKZgQw4p9hkVHYhEemrtcqG7143dL2ci2sBhwJYBkUaMHAQAw4CIZAwDIgLAMCDyCyOOLjAMiDQ0LswMiykYa/dUw9HWrXc5qjAMiDQUH2XBh8sz0N3rNtxUaAwDIo0ZdVSBYUBEABgGRCRjGBARAIYBEckYBkQEgGFARDKGAREBYBgQkYxhQEQAGAZEJGMYEBEAhgERyRgGRASAYUBEMoYBEQFgGBCRjGFARAAYBkQkYxgQEQCGARHJGAZEBIBhQEQyhgERAWAYEJGMYUBEABgGRCRjGBD5yIh3XH4QhgGRDxxt3Vi7pxoWUzDGGfQei33G6F0AkZG1drnQ3evG7pczER9l0bscn7BlQKQBo955+X4MAyICwDAgIhnDgIgAMAyISMYwIPITo51/wDAg0ti4MDMspmCs3VMNR1u33uV4jWFApLH4KAs+XJ6B7l43Wg3UOmAYEPmBEc87YBgQEQCGARHJGAZEBIBhQORXRhpeZBgQ+YERhxcZBkR+YMThRYYBkZ8YbXiRYUBEABgGRCRjGBARAIYBEckYBkQEgGFARDKGAREBYBgQ+cRIpxsPhWFANEyj6W5KAO+oRDRso+luSgBbBkQ+M9ppxwNhGBAN02jqLwAYBkTDMtr6CwD2GRANy2jrLwDYMiAalr4JS0ZLfwHAMCBSzdHWjTV7qpWfh2KUvgWGAZFK989cNNhXBKNNfcYwIFKp7z/9gd/9GrPiIwfczmhTnzEMiFRQO4pgpD4FjiYQqTAaRxH6sGVANAxG+o/vLYYBUQAYYUSBYUCkgtoPtZFGFBgGRPdxtHU/8EPraOvGoYvXVZ+CfP+IwgWHU+tyNcUORFKtpb0HLR139C5DU/UtnSjYWwMAsJiC8eHyDKVf4FaXCys+Patse+B3v1bVedjT6wYArNlTjd0vZwasv2GwYc8HkYQQwk+1EJGB8GsCEQFgGBCRjGFARAAYBkQkYxgQEQAOLZJKQgh0dHToXQZ5yWq1QpIkr7ZlGJAqHR0diIxUN35N+nE6nYiIiPBqW55nQKp40zJob29HQkICrly54vUfop5Gc71sGZDfSJLk9QcmIiLCEB+uPg97vexAJCIADAMikjEMSHMhISEoLi5GSEiI3qV4hfXeww5EIgLAlgERyRgGRASAYUBEMoYBEQFgGNAw9Pb2oqioCGlpaQgLC0NcXBzy8/Nx7dq1Qfd75513IElSv8eMGTMCVPXA3n//fSQlJSE0NBRZWVk4e/bs0Dv5ydatWzF37lxYrVZMmjQJzz33HGprawfdp6ys7Gfva2hoqOpjMwxItdu3b8Nms2Hz5s2w2WzYt28famtrkZeXN+S+M2fORFNTk/I4efJkACoe2N69e/H666+juLgYNpsNc+bMwcKFC9HS0qJLPZWVlVi/fj3OnDmDw4cPo7e3FwsWLEBXV9eg+0VERPR7XxsbG9UfXBBp4OzZswKAaGxsHHCb4uJiMWfOnMAV5YXMzEyxfv16Zdntdou4uDixdetWHav6fy0tLQKAqKysHHCb0tJSERkZ6fOx2DIgTTidTkiShKioqEG3q6urQ1xcHKZNm4aXXnoJly9fDkyBD+ByuVBdXY3c3FxlXVBQEHJzc1FVVaVbXfdzOu9Nrz5+/PhBt+vs7ERiYiISEhLw7LPP4uLFi6qPxTAgn/X09KCoqAgvvPDCoBfOZGVloaysDF9++SU++OADNDQ04PHHH9dtfoSbN2/C7XYjJiam3/qYmBhcv35dl5ru5/F4UFBQgMceewyzZs0acLvU1FR8+umn2L9/P8rLy+HxeJCdnY2rV6+qO6DPbQsa9crLy0VYWJjyOHHihPI7l8slFi9eLNLT04XT6VT1vK2trSIiIkJ8/PHHWpfsFYfDIQCI06dP91tfWFgoMjMzdanpfmvXrhWJiYniypUrqvZzuVwiOTlZvP3226r24yXMNKS8vDxkZWUpy/Hx8QDujSosXboUjY2NOHbsmOrLaaOiojB9+nTU19drWq+3JkyYgODgYDQ3N/db39zcjMmTJ+tSU58NGzbgwIEDOHHiBKZMmaJqX5PJhPT0dNXvK78m0JCsVitSUlKUh8ViUYKgrq4OR44cQXR0tOrn7ezshN1uR2xsrB+qHprZbEZGRgaOHj2qrPN4PDh69CjmzZunS01CCGzYsAFffPEFjh07hqlTp6p+DrfbjfPnz6t/X1W1I4jEvWZoXl6emDJliqipqRFNTU3K486dO8p28+fPFyUlJcryxo0bRUVFhWhoaBCnTp0Subm5YsKECaKlpUWPlyGEEOKzzz4TISEhoqysTHz77bdi9erVIioqSly/fl2XetatWyciIyNFRUVFv/f19u3byjbLly8XmzZtUpa3bNkiDh06JOx2u6iurhbPP/+8CA0NFRcvXlR1bIYBqdbQ0CAAPPBx/PhxZbvExERRXFysLC9btkzExsYKs9ks4uPjxbJly0R9fX3gX8BPlJSUiEceeUSYzWaRmZkpzpw5o1stA72vpaWlyjZPPPGEWLFihbJcUFCg1B8TEyOeeeYZYbPZVB+blzATEQD2GRCRjGFARAAYBkQkYxgQEQCGARHJGAZEBIBhQEQyhgGNOjk5OSgoKFCWk5KSsHPnTt3qMQqGAY1633zzDVavXq3pc/7www9YtWoVpk6dCovFguTkZBQXF8Plcml6nEDiVYs06k2cOFHz5/z+++/h8Xiwa9cupKSk4MKFC3jllVfQ1dWF7du3a368QGDLgAytq6sL+fn5CA8PR2xsLHbs2PGzbX76NUGSJOzatQuLFi3C2LFj8eijj6Kqqgr19fXIyclBWFgYsrOzYbfbBzzu008/jdLSUixYsADTpk1DXl4e3njjDezbt88fLzMgGAZkaIWFhaisrMT+/fvx1VdfoaKiAjabbcj93n33XeTn56OmpgYzZszAiy++iDVr1uDNN9/EuXPnlEuJ1XA6nUNOTzai+XqVFZFeOjo6hNlsFp9//rmy7tatW8JisYhXX31VWZeYmCjee+89ZRlAv1mAqqqqBADxySefKOv+8pe/iNDQUK9rqaurExEREeKjjz4a3osZAdgyIMOy2+1wuVz9ZmEaP348UlNTh9x39uzZys99cyCmpaX1W9fT04P29vYhn8vhcODpp5/GkiVL8Morr6h5CSMKw4AeSiaTSflZkqQB13k8nkGf59q1a3jyySeRnZ2Njz76yA+VBg7DgAwrOTkZJpMJX3/9tbKutbUVly5dCsjxHQ4HcnJykJGRgdLSUgQFGfvjxKFFMqzw8HCsWrUKhYWFiI6OxqRJk/DWW28F5EPZFwSJiYnYvn07bty4ofxO78lUh4thQIa2bds2dHZ2YvHixbBardi4caNy4xF/Onz4MOrr61FfX/+z2YuFQScP47RnRASAfQZEJGMYEBEAhgERyRgGRASAYUBEMoYBEQFgGBCRjGFARAAYBkQkYxgQEQCGARHJGAZEBAD4P+7v34HRIFrSAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "samples = posterior.sample((1000,), x=x_o)\n", - "_ = pairplot(samples, limits=[[-3, 3], [-3, 3]], figsize=(3, 3), upper=\"contour\")" + "pairplot(samples, limits=[[-3, 3], [-3, 3]], figsize=(3, 3), upper=\"contour\");" ] }, { @@ -523,20 +619,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "id": "6078a8bb-c03f-4649-8035-632581f0b8c2", "metadata": {}, "outputs": [], "source": [ "from sbi.neural_nets.net_builders import build_resnet_classifier\n", - "from sbi.inference.posteriors import MCMCPosterior\n", "from sbi.inference.potentials import ratio_estimator_based_potential\n", "from sbi import utils as utils" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "id": "1644654f-12e3-4655-a8fc-00f80c4dde39", "metadata": {}, "outputs": [], @@ -546,7 +641,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "id": "337b13c5-e6f8-460f-80c4-bdec4b30e93f", "metadata": {}, "outputs": [], @@ -580,7 +675,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "id": "f70d4cb3-b283-4b36-b259-5c8cc4b540c8", "metadata": {}, "outputs": [], @@ -597,13 +692,24 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "id": "8a0b3de2-b630-4502-ace1-2a4b827451be", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQMAAAExCAYAAAB4Y8dtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZw0lEQVR4nO3dfUwU574H8O/ssssuyy6rWCigxbcUm2q9hhSutI2cHq42jdL+o/blQNPrrWL0pJx6uLZpPdQ0J/7RmtOENNWmFYwkpz0nsddcb9LWNzC2WFu4JNXmWOBStOsLVWF5W9hl5rl/7OwcUED2fQe/n2RSZpzZ/bHpfnnmmWeekYQQAkR0zzPEuwAiSgwMAyICwDAgIhXDgIgAMAyISMUwICIADAMiUjEMiAgAw4CIVEnxLoD0598M6+NdAk3DMeXvQe3PlgERAWAYEJGKYUBEABgGRKRiGETYeZcb81//H5x3ueNdClFQGAYRFggBhgHpDcOAiAAwDKLm5qA33iUQBYVhEGGzbGYAwLtfXoSr1xPnaoimj2EQYTlOq/ZzD1sHpCMMgyjiqQLpCcMgiioONfNUgXSDYRAlVWvy4PHJPFUg3WAYREm62pFIpBcMAyICwDAgIhXDgIgAMAyISMUwICIADIOI40Aj0iuGQQS5ej2oONQMq8mo3aNApBecHTmCega98PhkHPz3Ao4zIN1hyyAKGAT3CEnyLzMEw4AoVNLM+vrwNIFoMpIECDF+HfCHgFD8i2SAlGQEAIhR3/jjxx6rAwwDoolIkvqll8etSwb/f4UMQJEBCZDMJv8XXygQypgACByrEwwDookIAUC5Y10oBkgGxR8KhiTAaIRkNEII4f9Z8oeC/xBJV60DhgHRZG7/IgsBCBkC/gCQzCbAYPCHQGAfRYEYHQUUAUgCkOBvQegAw4BoOiQJktHoP1UwGiCZzZBsKUCyGcJmBYSAYdADjMoQIyPA6CjE8AjE6KhuWggMA6JpkJJMkMwmfyAYjZBsKVDSHZBTkzF0fzIMsoCl2wrDyCgM/cOQfKNATy8wDH8/gg76DxgGRNMhFP9fd4MEKdkMkWLBqD0ZPocJw7P8lxhHLVYYvQLWa0kwDvlg8AxDyAqk0dFAN0JCYxgQTYNQBOAbhZRihbDbIM+2YTDHgpE0Cf0LAMUkoJgA44gBsy6kwHpLhm3EB0kIQJb9x4/tOzAY/xkwCYJhQDRdBv+5vzQqqx2EgGKW4EsbBSwKrI5heL1JGLmSAoNsgMVugdE3Csnrg+T1QsCoXo6UkIhNBYYB0VTUgUaSKQlSUhKE1wf09MJoMcPoTYFsBrIX3sASZzdeyWhEt2zHq32/g++XJEiKDZZbFlgBSCMjgCwDsnRnKyFBzKzxlETRMHbYsaJA+EaBURkGr4BhFPCOJmFESYJJkmEzjAAmBYoJEAZASPAHitEIKTCQKUGxZUA0Fcngv4KgCIjRUf8X2mCA5BlB8q0R2M0Sbnw/B9/MnoVby1IAAMbeJBiHActNGdarg/4rCxYLhKL4WxYTnSIEhjrHsQ+BYUA0XYqAMACSogCjozB4RmEakGG5YYQkG/F/v6bDYBAweiQYvUDSsAzJ4wVkxd/fkOAYBkRTUWSIsX/JJf+IQ6V/AAZZgdVtg3FkFnypSXDftAMGwHJLQZJHgenXIUgDQ9ogJMhT9BMkwFUFhgHR3Yz7oioAjP4hx/39gCLDZEqCMSUZQAqUJAlJgzKMPgWGoWF/EHh9ELLsv38hgTEMiCYyxTm8kGVIQoIAgMEhSACMJhNShh3+47w+SLIC0dMLMTyivoyYumWQABgGETTRZKicIFXHxt7CHCD84wsCtyoLrw+SPOC/X8Hn84fB6CiEIqB4hv3BYUjsqwgBiV+hTtw+GeosmxlWk5FPYtYrMcVYAEW+44qAkBWIwSGIgUGI4REoIyMQsuzvc5BliFGftp4I/QMTYRhESHv3ADw+GfvK8pHjtCLHacW+snw+iXmmuuP2ZgXK8LB/GRnxj0UIhIkQ/1wSGE8TImBsq2BxRqq2nROjznCB1oGQbpvh6LaJUXSCYRABY6dIz3Fa410OxYr6l95/i7Iy4b/pCU8TIogtgXtUAt50FAq2DIjCpcNWwETYMiAiAAwDIlIxDIgIAMOAiFQMAyICwDAgIhXDgIgAMAyISMUwICIADAMiUjEMiAgAw4CIVAwDIgLAMCAiFcOAiAAwDIhIxTAgIgAMAyJSMQxiYG3NGZx3ueNdBtGUGAYxwjCgRMcwICIADAMiUjEMiAgAw4CIVAyDCOBj12kmYBiE6fZHsRPpFR+vFqapHro6NhwYFJTo2DKIkIkeuprjtGJ/Wb72M1EiYxhEGUOA9IJhQEQAGAZEpGIYEBEAhgERqRgGRASAYRAzHKVIiY5hEGWzbGZYTUZUHGqGq9cT73KIJsUwiLIcpxX7yvLh8cnoYeuAEhjDIAYmGp1IlGgYBkQEgGFARCqGQZh4lYBmCoZBGDiXAc0knM8gDFPNZUCkN2wZRACvFtBMwDAgIgAMAyJSMQxiiFceKJExDGKA9yeQHjAMwjDdv/S8P4H0gGEQomDHGASuOPBUgRIVwyBEgTEG+8rypzXGgKcKlOgYBmGa7hgDnipQomMYhCiU5j4HJ1EiYxiEINx7Ely9Hp4qUMJhGIQg2P6C22051IySvY0MBEooDIMwBNvsD3QiAoDHJ6O9eyAaZRGFhGEQBFevB+dd7pC/xDlOK47vWIW/V6zUriycd7kjXCVRaMK6hbm7bxjd/SORqiWh3Rz0ouJQMzw+GQBC7i/IcVqR47Ti7xUrsX5fE9bva8K+svyodS4uzUmLyuvSzCMJIUS8iyCi+ONpAhEBYBgQkYphQEQAGAZEpAr5aoIQAv39/ZGshaLEbrdDkqR4l0EJLuQw6O/vR1oaL1vpgdvthsPhiHcZlOBCvrR4t5ZBX18f5s2bh8uXL+vif8SZXC9bBjQdIbcMJEma1pfG4XDo4ssVwHrpXsUORCICwDAgIlXUwiA5ORnV1dVITk6O1ltEFOulex3vTSAiADxNICIVw4CIADAMiEjFMCAiAGGGgc/nw86dO7Fs2TLYbDZkZ2ejvLwcV65cmfK4t99+G5IkjVuWLFkSTilh++CDDzB//nxYLBYUFhbi3Llzca1nz549ePTRR2G325GRkYFnn30WFy9enPKYurq6Oz5Xi8USo4pJ78IKg6GhIbS0tGDXrl1oaWnB4cOHcfHiRZSWlt712IcffhhXr17VljNnzoRTSlg+++wzvPbaa6iurkZLSwuWL1+ONWvWoLu7O241NTY2Ytu2bTh79iyOHTsGn8+H1atXY3BwcMrjHA7HuM+1q6srRhWT7okIO3funAAgurq6Jt2nurpaLF++PNJvHbKCggKxbds2bV2WZZGdnS327NkTx6rG6+7uFgBEY2PjpPvU1taKtLS02BVFM0rE+wzcbjckSYLT6Zxyv7a2NmRnZ2PhwoV48cUXcenSpUiXMi1erxfNzc0oKSnRthkMBpSUlKCpqSkuNU3E7fbPojx79uwp9xsYGEBubi7mzZuHZ555BhcuXIhFeTQDRDQMhoeHsXPnTjz//PNT3jxTWFiIuro6fPHFF/jwww/R2dmJJ554Ii7zI9y4cQOyLCMzM3Pc9szMTFy7di3m9UxEURRUVlbisccew9KlSyfdLy8vDwcOHMCRI0dQX18PRVFQVFSEX375JYbVkm4F04yor68XNptNW06fPq39m9frFevWrRMrVqwQbrc7qOZJT0+PcDgc4uOPPw7quEhwuVwCgPjmm2/Gba+qqhIFBQUxr2ciFRUVIjc3V1y+fDmo47xer1i0aJF46623olQZzSRB3cJcWlqKwsJCbT0nJweA/6rChg0b0NXVhZMnTwZ9S63T6cSDDz6I9vb2oI6LhDlz5sBoNOL69evjtl+/fh33339/zOu53fbt23H06FGcPn0ac+fODepYk8mEFStWxOVzJf0J6jTBbrdj8eLF2mK1WrUgaGtrw/Hjx5Genh50EQMDA+jo6EBWVlbQx4bLbDYjPz8fJ06c0LYpioITJ05g5cqVMa8nQAiB7du34/PPP8fJkyexYMGCoF9DlmX88MMPcflcSYfCaVZ4vV5RWloq5s6dK1pbW8XVq1e1ZWRkRNvvySefFDU1Ndr6jh07RENDg+js7BRff/21KCkpEXPmzBHd3d3hlBOyTz/9VCQnJ4u6ujrx448/is2bNwun0ymuXbsWl3qEEGLr1q0iLS1NNDQ0jPtch4aGtH3KysrE66+/rq3v3r1bfPnll6Kjo0M0NzeL5557TlgsFnHhwoV4/AqkM2GFQWdnpwAw4XLq1Cltv9zcXFFdXa2tb9y4UWRlZQmz2SxycnLExo0bRXt7ezilhK2mpkY88MADwmw2i4KCAnH27Nm41jPZ51pbW6vts2rVKvHSSy9p65WVldrvkJmZKZ5++mnR0tIS++JJl3gLM1GEuHo9APzP09Qj3ptAFAGuXg9K9jaiZG+jFgp6wzAgioCeQS88Phken4yeQW+8ywkJw4CIADAMiEjFMCAiAAwDIlIxDIgIAMOAiFQzNgyKi4tRWVmprc+fPx/vv/9+3Oqhe8dNXlpMbN999x02b94c8df985//jKKiIqSkpNx1Qhe6N1QcatblwKN7Jgzuu+8+pKSkRPx1vV4v1q9fj61bt0b8tUl/qtbk6Xbg0YwIg8HBQZSXlyM1NRVZWVnYu3fvHfvcfpogSRL279+PtWvXIiUlBQ899BCamprQ3t6O4uJi2Gw2FBUVoaOjY8r33r17N/7whz9g2bJlkf61SIfSbeZ4lxCyGREGVVVVaGxsxJEjR/DVV1+hoaEBLS0tdz3unXfeQXl5OVpbW7FkyRK88MIL2LJlC9544w18//332pwCRPeCoGY6SkQDAwP45JNPUF9fj9/+9rcAgIMHD05rVqCXX34ZGzZsAADs3LkTK1euxK5du7BmzRoAwKuvvoqXX345esUTJRDdtww6Ojrg9XrHTcc2e/Zs5OXl3fXYRx55RPs5MCHq2OZ+ZmYmhoeH0dfXF8GKiRKT7sMgHCaTSftZkqRJtymKEtvCiOJA92GwaNEimEwmfPvtt9q2np4e/PTTT3Gsiu41eryUeDvd9xmkpqZi06ZNqKqqQnp6OjIyMvDmm2/CYIhNzl26dAm3bt3CpUuXIMsyWltbAQCLFy9GampqTGqg+HL1erDlUDMA4IqOQ0H3YQAA7777LgYGBrBu3TrY7Xbs2LFDewJRtP3pT3/CwYMHtfUVK1YAAE6dOoXi4uKY1EDxNXZMwa0h/Y0vCOAciERhOu9yY22N/8HBv/vXB1B/9hKO/v5xLM1Ji3NlwdF9nwERRQbDgCiClmbrqzUw1ozoMyBKBEd//3i8SwgLWwZEBIBhQEQqhgERAWAYEJGKYUBEABgGRKRiGBARAIYBEakYBkQEgGFARCqGAREBYBgQkYphQEQAGAZEpGIYEBEAhgERqRgGRASAYUBEKoYBEQFgGBCRimFAFKabg/p9cMpYDAOiMLh6Pag41AyryYhZNvO47Xp7/iLDgCgMPYNeeHwy9pXlI8dp1bZvOdSMkr2NugoEhgFRBKSPaRUEeHzyuOcwJjqGAREBYBgQkYphQBRBs2xmWE3GeJcREoYBUQTlOK04vmMV9pflx7uUoPHBq0QRluO06qrjMIAtAyICwDAgIhXDgIgAMAyISMUwICIADAMiUjEMiAgAw4CIVAwDIgLAMCAiFcOAiAAwDIhIxTAgIgAMAyJSMQyICADDgCgsM+WZCQDDgChkkz0zQa840xFRiALPTDj47wXjnpmgV2wZEIVpomcm6BHDgIgAMAyISMUwIIoiPV1tYBgQRUHgYSoVh5p18/BVhgFRFOQ4rdhXlq+rh68yDIiiRG9XGRgGRASAYUAUMj11Dk4Hw4AoBDNtKDLA4chEIZlpQ5EBtgyIwqK3TsKpMAyICADDgCgkM63zEGAYEAUt2M5DvQQHw4AoSIHOw31l+VN2HuptSDLDgChEd+s81NuQZIYBURTp6WoDw4CIADAMiEjFMCAiAAwDopjQw+VFhgFRkIL5Yuvp8iLDgCgIwQ440tPlRd61SBSEUO5W1MvlRbYMiEIQyhc80fsNGAZE0+Dq9eC//teFtTVngj52bL/Bdz/fSti+A0kIIeJdBFEic/V6ULK3ER6frG07+vvHsTQnbdqv0fjTr3jpwDkAgNVkxPEdqxJuUhS2DIgm4er1wNXr0foJ/uPxBSG/1tjTikTtTGQHItEEAq0BAPjPp/IAAIszUqP6fgDi2lpgGFDQuvuG0d0/Eu8yoqq9e0A7Ldj93z8C8P9Ft5qMABD0JKiBfoPAa66tOYP3N/4LFmek4uagFxWHmgEA+8ryI3b1IZjTGIB9BkSkYp8BEQFgGBCRimFARAAYBkSkYhgQEQBeWqQgCSHQ398f7zJomux2OyRJmta+DAMKSn9/P9LSgrt+TfHjdrvhcDimtS/HGVBQptMy6Ovrw7x583D58uVp/48YTzO5XrYMKGokSZr2F8bhcOjiyxVwr9fLDkQiAsAwICIVw4AiLjk5GdXV1UhOTo53KdPCev3YgUhEANgyICIVw4CIADAMiEjFMCAiAAwDCoHP58POnTuxbNky2Gw2ZGdno7y8HFeuXJnyuLfffhuSJI1blixZEqOqJ/fBBx9g/vz5sFgsKCwsxLlz5+JWy549e/Doo4/CbrcjIyMDzz77LC5evDjlMXV1dXd8rhaLJej3ZhhQ0IaGhtDS0oJdu3ahpaUFhw8fxsWLF1FaWnrXYx9++GFcvXpVW86cCf45BJH02Wef4bXXXkN1dTVaWlqwfPlyrFmzBt3d3XGpp7GxEdu2bcPZs2dx7Ngx+Hw+rF69GoODg1Me53A4xn2uXV1dwb+5IIqAc+fOCQCiq6tr0n2qq6vF8uXLY1fUNBQUFIht27Zp67Isi+zsbLFnz544VvVP3d3dAoBobGycdJ/a2lqRlpYW9nuxZUAR4Xa7IUkSnE7nlPu1tbUhOzsbCxcuxIsvvohLly7FpsAJeL1eNDc3o6SkRNtmMBhQUlKCpqamuNU1ltvtBgDMnj17yv0GBgaQm5uLefPm4ZlnnsGFCxeCfi+GAYVteHgYO3fuxPPPPz/ljTOFhYWoq6vDF198gQ8//BCdnZ144okn4jY/wo0bNyDLMjIzM8dtz8zMxLVr1+JS01iKoqCyshKPPfYYli5dOul+eXl5OHDgAI4cOYL6+nooioKioiL88ssvwb1h2G0LmvHq6+uFzWbTltOnT2v/5vV6xbp168SKFSuE2+0O6nV7enqEw+EQH3/8caRLnhaXyyUAiG+++Wbc9qqqKlFQUBCXmsaqqKgQubm54vLly0Ed5/V6xaJFi8Rbb70V1HG8hZnuqrS0FIWFhdp6Tk4OAP9VhQ0bNqCrqwsnT54M+nZap9OJBx98EO3t7RGtd7rmzJkDo9GI69evj9t+/fp13H///XGpKWD79u04evQoTp8+jblz5wZ1rMlkwooVK4L+XHmaQHdlt9uxePFibbFarVoQtLW14fjx40hPTw/6dQcGBtDR0YGsrKwoVH13ZrMZ+fn5OHHihLZNURScOHECK1eujEtNQghs374dn3/+OU6ePIkFC4J/vqMsy/jhhx+C/1yDakcQCX8ztLS0VMydO1e0traKq1evasvIyIi235NPPilqamq09R07doiGhgbR2dkpvv76a1FSUiLmzJkjuru74/FrCCGE+PTTT0VycrKoq6sTP/74o9i8ebNwOp3i2rVrcaln69atIi0tTTQ0NIz7XIeGhrR9ysrKxOuvv66t7969W3z55Zeio6NDNDc3i+eee05YLBZx4cKFoN6bYUBB6+zsFAAmXE6dOqXtl5ubK6qrq7X1jRs3iqysLGE2m0VOTo7YuHGjaG9vj/0vcJuamhrxwAMPCLPZLAoKCsTZs2fjVstkn2ttba22z6pVq8RLL72krVdWVmr1Z2Zmiqefflq0tLQE/d68hZmIALDPgIhUDAMiAsAwICIVw4CIADAMiEjFMCAiAAwDIlIxDGjGKS4uRmVlpbY+f/58vP/++3GrRy8YBjTjfffdd9i8eXNEX/Pnn3/Gpk2bsGDBAlitVixatAjV1dXwer0RfZ9Y4l2LNOPdd999EX/Nf/zjH1AUBfv378fixYtx/vx5vPLKKxgcHMR7770X8feLBbYMSNcGBwdRXl6O1NRUZGVlYe/evXfsc/tpgiRJ2L9/P9auXYuUlBQ89NBDaGpqQnt7O4qLi2Gz2VBUVISOjo5J3/epp55CbW0tVq9ejYULF6K0tBR//OMfcfjw4Wj8mjHBMCBdq6qqQmNjI44cOYKvvvoKDQ0NaGlpuetx77zzDsrLy9Ha2oolS5bghRdewJYtW/DGG2/g+++/124lDobb7b7r9GQJLdy7rIjipb+/X5jNZvG3v/1N23bz5k1htVrFq6++qm3Lzc0Vf/nLX7R1AONmAWpqahIAxCeffKJt++tf/yosFsu0a2lraxMOh0N89NFHof0yCYAtA9Ktjo4OeL3ecbMwzZ49G3l5eXc99pFHHtF+DsyBuGzZsnHbhoeH0dfXd9fXcrlceOqpp7B+/Xq88sorwfwKCYVhQPckk8mk/SxJ0qTbFEWZ8nWuXLmC3/zmNygqKsJHH30UhUpjh2FAurVo0SKYTCZ8++232raenh789NNPMXl/l8uF4uJi5Ofno7a2FgaDvr9OvLRIupWamopNmzahqqoK6enpyMjIwJtvvhmTL2UgCHJzc/Hee+/h119/1f4t3pOphophQLr27rvvYmBgAOvWrYPdbseOHTu0B49E07Fjx9De3o729vY7Zi8WOp08jNOeEREA9hkQkYphQEQAGAZEpGIYEBEAhgERqRgGRASAYUBEKoYBEQFgGBCRimFARAAYBkSkYhgQEQDg/wGZbFdWjwVPmgAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "samples = posterior.sample((1000,), x=x_o)\n", - "_ = pairplot(samples, limits=[[-3, 3], [-3, 3]], figsize=(3, 3))" + "pairplot(samples, limits=[[-3, 3], [-3, 3]], figsize=(3, 3));" ] } ],