From 60d589a15d6fe75b849fc58b242ab2cbb79b265c Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff <35577657+nikhilwoodruff@users.noreply.github.com> Date: Tue, 24 Sep 2024 19:10:09 +0100 Subject: [PATCH] Add `Simulation.subsample` (#279) * Update documentation * Fix bug in reform handling * Test with API * Versioning * Add seedability * Fix syntax bug --- Makefile | 1 + changelog_entry.yaml | 4 + docs/_static/style.css | 7 - docs/add_plotly_to_book.py | 27 + docs/usage/charts.ipynb | 181 ---- docs/usage/cli.md | 14 - docs/usage/datasets.ipynb | 44 +- docs/usage/reforms.ipynb | 18 +- docs/usage/simulation.ipynb | 996 +++++++++++++++++++- policyengine_core/simulations/simulation.py | 78 ++ 10 files changed, 1127 insertions(+), 243 deletions(-) create mode 100644 docs/add_plotly_to_book.py delete mode 100644 docs/usage/charts.ipynb delete mode 100644 docs/usage/cli.md diff --git a/Makefile b/Makefile index c04fd21e5..bf7fbafa1 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,7 @@ all: install format test build changelog documentation: jb clean docs jb build docs + python docs/add_plotly_to_book.py docs/_build format: black . -l 79 diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29bb..e451532b4 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: minor + changes: + added: + - Simulation subsampling. diff --git a/docs/_static/style.css b/docs/_static/style.css index 2a7a0ae4c..e511f94b3 100644 --- a/docs/_static/style.css +++ b/docs/_static/style.css @@ -1,9 +1,2 @@ @import url('https://fonts.googleapis.com/css2?family=Roboto+Serif:opsz@8..144&family=Roboto:wght@300&display=swap'); -h1, h2, h3, h4, h5, h6 { - font-family: "Roboto"; -} - -body { - font-family: "Roboto Serif"; -} \ No newline at end of file diff --git a/docs/add_plotly_to_book.py b/docs/add_plotly_to_book.py new file mode 100644 index 000000000..822e77abc --- /dev/null +++ b/docs/add_plotly_to_book.py @@ -0,0 +1,27 @@ +import argparse +from pathlib import Path + +# This command-line tools enables Plotly charts to show in the HTML files for the Jupyter Book documentation. + +parser = argparse.ArgumentParser() +parser.add_argument("book_path", help="Path to the Jupyter Book.") + +args = parser.parse_args() + +# Find every HTML file in the Jupyter Book. Then, add a script tag to the start of the tag in each file, with the contents: +# + +book_folder = Path(args.book_path) + +for html_file in book_folder.glob("**/*.html"): + with open(html_file, "r") as f: + html = f.read() + + # Add the script tag to the start of the tag. + html = html.replace( + "", + '', + ) + + with open(html_file, "w") as f: + f.write(html) diff --git a/docs/usage/charts.ipynb b/docs/usage/charts.ipynb deleted file mode 100644 index 3577bb4b0..000000000 --- a/docs/usage/charts.ipynb +++ /dev/null @@ -1,181 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Charts\n", - "\n", - "PolicyEngine Core provides a set of chart utils to speed up data visualisation for PolicyEngine model-powered analyses. These use the PolicyEngine styling by default. The examples below use the PolicyEngine UK microsimulation model.\n", - "\n", - "## Bar chart\n", - "\n", - "The `bar_chart` function creates a bar chart from a dataframe." - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "
" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Reform code generated from the PolicyEngine export function.\n", - "\n", - "from policyengine_uk import Microsimulation\n", - "from policyengine_core.reforms import Reform\n", - "from policyengine_core.periods import instant\n", - "\n", - "\n", - "def modify_parameters(parameters):\n", - " parameters.gov.hmrc.income_tax.rates.uk[0].rate.update(\n", - " start=instant(\"2023-01-01\"), stop=instant(\"2028-12-31\"), value=0.25\n", - " )\n", - " return parameters\n", - "\n", - "\n", - "class reform(Reform):\n", - " def apply(self):\n", - " self.modify_parameters(modify_parameters)\n", - "\n", - "\n", - "baseline = Microsimulation()\n", - "reformed = Microsimulation(reform=reform)\n", - "\n", - "baseline_income = baseline.calculate(\"household_net_income\", 2023)\n", - "reformed_income = reformed.calculate(\"household_net_income\", 2023)\n", - "gain = reformed_income - baseline_income\n", - "decile = baseline.calculate(\"household_income_decile\", 2023)\n", - "decile_impacts = (\n", - " gain.groupby(decile).sum() / baseline_income.groupby(decile).sum()\n", - ")\n", - "decile_impacts = decile_impacts[decile_impacts.index != 0]\n", - "\n", - "from policyengine_core.charts import *\n", - "\n", - "display_fig(\n", - " bar_chart(\n", - " decile_impacts,\n", - " title=\"Change in net income by decile\",\n", - " xaxis_title=\"Decile\",\n", - " yaxis_title=\"Change in net income\",\n", - " xaxis_tickvals=list(range(1, 11)),\n", - " yaxis_tickformat=\".0%\",\n", - " text_format=\".1%\",\n", - " hover_text_function=lambda x, y: f\"The {cardinal(x)} decile sees a {y:+.1%} in net income\",\n", - " )\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Cross-section bar chart\n", - "\n", - "The cross-section bar chart is useful for showing the distribution of outcomes along different breakdowns." - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "
" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "lower_age_group = baseline.calculate(\"age\", 2023) // 10\n", - "personal_gain = reformed.calculate(\n", - " \"household_net_income\", 2023, map_to=\"person\"\n", - ") - baseline.calculate(\"household_net_income\", 2023, map_to=\"person\")\n", - "personal_gain = personal_gain[lower_age_group < 8]\n", - "lower_age_group = lower_age_group[lower_age_group < 8] + 1\n", - "\n", - "display_fig(\n", - " cross_section_bar_chart(\n", - " personal_gain,\n", - " lower_age_group,\n", - " slices=[-0.1, -0.01, 0.01, 0.1],\n", - " xaxis_tickformat=\".0%\",\n", - " category_names=[\n", - " \"Lose more than 10%\",\n", - " \"Lose between 1% and 10%\",\n", - " \"Experience less than 1% change\",\n", - " \"Gain between 1% and 10%\",\n", - " \"Gain more than 10%\",\n", - " ],\n", - " yaxis_ticktext=[\n", - " \"Under 10\",\n", - " \"10 to 19\",\n", - " \"20 to 29\",\n", - " \"30 to 39\",\n", - " \"40 to 49\",\n", - " \"50 to 59\",\n", - " \"60 to 69\",\n", - " \"70 to 79\",\n", - " ],\n", - " color_discrete_map={\n", - " \"Lose more than 10%\": DARK_GRAY,\n", - " \"Lose between 1% and 10%\": MEDIUM_DARK_GRAY,\n", - " \"Experience less than 1% change\": GRAY,\n", - " \"Gain between 1% and 10%\": LIGHT_GRAY,\n", - " \"Gain more than 10%\": BLUE,\n", - " },\n", - " legend_orientation=\"h\",\n", - " legend_y=-0.2,\n", - " title=\"Gain by age\",\n", - " hover_text_function=lambda age, outcome, percent: f\"{percent:.1%} of {age * 10:.0f} to {(age + 1) * 10:.0f} year olds {outcome.lower()} of their income\",\n", - " )\n", - ")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.12" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/docs/usage/cli.md b/docs/usage/cli.md deleted file mode 100644 index ad4549eed..000000000 --- a/docs/usage/cli.md +++ /dev/null @@ -1,14 +0,0 @@ -# Using the command-line interface - -Use the `policyengine-core` command-line tool to run tests or manage data without writing Python code. - -```{eval-rst} -.. argparse:: - :module: policyengine_core.scripts.policyengine_command - :func: get_parser - :prog: policyengine-core -``` - -```{eval-rst} -.. hint:: To list all the datasets for a country package, use `policyengine-core data datasets list`, passing in a country package as needed. -``` \ No newline at end of file diff --git a/docs/usage/datasets.ipynb b/docs/usage/datasets.ipynb index 493f2200d..0efe93fbc 100644 --- a/docs/usage/datasets.ipynb +++ b/docs/usage/datasets.ipynb @@ -15,9 +15,20 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "array([100., 0., 200.], dtype=float32)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from policyengine_core.country_template.constants import COUNTRY_DIR\n", "from policyengine_core.data import Dataset\n", @@ -28,12 +39,14 @@ " # Specify metadata used to describe and store the dataset.\n", " name = \"country_template_dataset\"\n", " label = \"Country template dataset\"\n", - " folder_path = COUNTRY_DIR / \"data\" / \"storage\"\n", + " file_path = (\n", + " COUNTRY_DIR / \"data\" / \"storage\" / \"country_template_dataset.h5\"\n", + " )\n", " data_format = Dataset.TIME_PERIOD_ARRAYS\n", "\n", " # The generation function is the most important part: it defines\n", - " # how the dataset is generated from the raw data for a given year.\n", - " def generate(self, year: int) -> None:\n", + " # how the dataset is generated from the raw data.\n", + " def generate(self) -> None:\n", " person_id = [0, 1, 2]\n", " household_id = [0, 1]\n", " person_household_id = [0, 0, 1]\n", @@ -50,25 +63,16 @@ " \"salary\": {salary_time_period: salary},\n", " \"household_weight\": {weight_time_period: weight},\n", " }\n", - " self.save_variable_values(year, data)\n", + " self.save_dataset(data)\n", "\n", "\n", - "# Important: we must instantiate datasets. This tests their validity and adds dynamic logic.\n", - "CountryTemplateDataset = CountryTemplateDataset()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Dataset API\n", + "from policyengine_core.country_template import Simulation\n", "\n", - "PolicyEngine Core also includes two subclasses of `Dataset`:\n", + "CountryTemplateDataset().generate()\n", "\n", - "* `PublicDataset` - a dataset that is publicly available, and can be downloaded from a URL. Includes a `download` method to download the dataset.\n", - "* `PrivateDataset` - a dataset that is not publicly available, and must be downloaded from a private URL (specifically, Google Cloud buckets). Includes a `download` method to download the dataset, and a `upload` method to upload the dataset.\n", + "simulation = Simulation(dataset=CountryTemplateDataset)\n", "\n", - "See {doc}`/python_api/data` for the API reference." + "simulation.calculate(\"salary\", \"2022-01\")" ] } ], @@ -88,7 +92,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.10.14" }, "orig_nbformat": 4, "vscode": { diff --git a/docs/usage/reforms.ipynb b/docs/usage/reforms.ipynb index 06b97f14a..7827ced97 100644 --- a/docs/usage/reforms.ipynb +++ b/docs/usage/reforms.ipynb @@ -74,21 +74,21 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "( value weight\n", - " 0 4000.0 1000000.0\n", - " 1 6000.0 1200000.0,\n", - " value weight\n", - " 0 2000.0 1000000.0\n", - " 1 3000.0 1200000.0)" + "( value weight\n", + " 0 200.0 1000000.0\n", + " 1 200.0 1200000.0,\n", + " value weight\n", + " 0 200.0 1000000.0\n", + " 1 200.0 1200000.0)" ] }, - "execution_count": 4, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -116,7 +116,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.19" + "version": "3.10.14" } }, "nbformat": 4, diff --git a/docs/usage/simulation.ipynb b/docs/usage/simulation.ipynb index bbaea18ed..66979a965 100644 --- a/docs/usage/simulation.ipynb +++ b/docs/usage/simulation.ipynb @@ -109,19 +109,937 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 3, "metadata": {}, "outputs": [ { - "ename": "ImportError", - "evalue": "cannot import name 'display_fig' from 'policyengine_core.charts' (/Users/nikhil/policyengine/policyengine-core/policyengine_core/charts/__init__.py)", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m/Users/nikhil/policyengine/policyengine-core/docs/usage/simulation.ipynb Cell 5\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mplotly\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mexpress\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mpx\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mpolicyengine_core\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mcharts\u001b[39;00m \u001b[39mimport\u001b[39;00m format_fig, display_fig\n\u001b[1;32m 4\u001b[0m fig \u001b[39m=\u001b[39m px\u001b[39m.\u001b[39mline(\n\u001b[1;32m 5\u001b[0m x\u001b[39m=\u001b[39msimulation\u001b[39m.\u001b[39mcalculate(\u001b[39m\"\u001b[39m\u001b[39msalary\u001b[39m\u001b[39m\"\u001b[39m),\n\u001b[1;32m 6\u001b[0m y\u001b[39m=\u001b[39msimulation\u001b[39m.\u001b[39mcalculate(\u001b[39m\"\u001b[39m\u001b[39mincome_tax\u001b[39m\u001b[39m\"\u001b[39m),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 16\u001b[0m hovertemplate\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mSalary: £\u001b[39m\u001b[39m%\u001b[39m\u001b[39m{x:,.0f}\u001b[39;00m\u001b[39m
Income tax: £\u001b[39m\u001b[39m%\u001b[39m\u001b[39m{y:,.0f}\u001b[39;00m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 17\u001b[0m )\n\u001b[1;32m 19\u001b[0m display_fig(format_fig(fig))\n", - "\u001b[0;31mImportError\u001b[0m: cannot import name 'display_fig' from 'policyengine_core.charts' (/Users/nikhil/policyengine/policyengine-core/policyengine_core/charts/__init__.py)" - ] + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "hovertemplate": "Salary: £%{x:,.0f}
Income tax: £%{y:,.0f}", + "legendgroup": "", + "line": { + "color": "#2C6496", + "dash": "solid" + }, + "marker": { + "symbol": "circle" + }, + "mode": "lines", + "name": "", + "orientation": "v", + "showlegend": false, + "type": "scatter", + "x": [ + 0, + 11111.111328125, + 22222.22265625, + 33333.33203125, + 44444.4453125, + 55555.5546875, + 66666.6640625, + 77777.78125, + 88888.890625, + 100000 + ], + "xaxis": "x", + "y": [ + 0, + 1666.666748046875, + 3333.33349609375, + 5000, + 6666.6669921875, + 8333.333984375, + 10000, + 11666.66796875, + 13333.333984375, + 15000.0009765625 + ], + "yaxis": "y" + } + ], + "layout": { + "font": { + "color": "black", + "family": "Roboto Serif" + }, + "height": 600, + "images": [ + { + "sizex": 0.15, + "sizey": 0.15, + "source": "https://raw.githubusercontent.com/PolicyEngine/policyengine-app/master/src/images/logos/policyengine/blue.png", + "x": 1.1, + "xanchor": "right", + "xref": "paper", + "y": -0.15, + "yanchor": "bottom", + "yref": "paper" + } + ], + "legend": { + "tracegroupgap": 0 + }, + "margin": { + "t": 60 + }, + "modebar": { + "bgcolor": "rgba(0,0,0,0)", + "color": "rgba(0,0,0,0)" + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "white", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "white", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "#C8D4E3", + "linecolor": "#C8D4E3", + "minorgridcolor": "#C8D4E3", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "#C8D4E3", + "linecolor": "#C8D4E3", + "minorgridcolor": "#C8D4E3", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "white", + "showlakes": true, + "showland": true, + "subunitcolor": "#C8D4E3" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "white", + "polar": { + "angularaxis": { + "gridcolor": "#EBF0F8", + "linecolor": "#EBF0F8", + "ticks": "" + }, + "bgcolor": "white", + "radialaxis": { + "gridcolor": "#EBF0F8", + "linecolor": "#EBF0F8", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "white", + "gridcolor": "#DFE8F3", + "gridwidth": 2, + "linecolor": "#EBF0F8", + "showbackground": true, + "ticks": "", + "zerolinecolor": "#EBF0F8" + }, + "yaxis": { + "backgroundcolor": "white", + "gridcolor": "#DFE8F3", + "gridwidth": 2, + "linecolor": "#EBF0F8", + "showbackground": true, + "ticks": "", + "zerolinecolor": "#EBF0F8" + }, + "zaxis": { + "backgroundcolor": "white", + "gridcolor": "#DFE8F3", + "gridwidth": 2, + "linecolor": "#EBF0F8", + "showbackground": true, + "ticks": "", + "zerolinecolor": "#EBF0F8" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "#DFE8F3", + "linecolor": "#A2B1C6", + "ticks": "" + }, + "baxis": { + "gridcolor": "#DFE8F3", + "linecolor": "#A2B1C6", + "ticks": "" + }, + "bgcolor": "white", + "caxis": { + "gridcolor": "#DFE8F3", + "linecolor": "#A2B1C6", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "#EBF0F8", + "linecolor": "#EBF0F8", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "#EBF0F8", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "#EBF0F8", + "linecolor": "#EBF0F8", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "#EBF0F8", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Income tax by salary" + }, + "width": 800, + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "tickformat": ",.0f", + "tickprefix": "£", + "title": { + "text": "Salary" + } + }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "tickformat": ",.0f", + "tickprefix": "£", + "title": { + "text": "Income tax" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -148,7 +1066,7 @@ " )\n", ")\n", "\n", - "display_fig(format_fig(fig))" + "format_fig(fig)" ] }, { @@ -190,6 +1108,60 @@ "source": [ "If you inspect the result of the `sim.calculate` call, you'll find it actually returns a `MicroSeries` (defined by the `microdf` Python package). This is a class inheriting from `pandas.Series`, with a few extra methods for handling survey weights. The general intuition is that you can treat this weighted array as if it were an array of the full population it's representative of, using it it as you would any other `pandas.Series`." ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Subsampling simulations\n", + "\n", + "Often, we're running simulations over very large (100,000+) datasets. This can be slow, so we might want to subsample the dataset to speed up the simulation. This can be done by using `Simulation.subsample` or `Microsimulation.subsample`, which will return a new `Simulation` or `Microsimulation` object with a smaller dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "13996.034939691408" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from policyengine_us import Microsimulation\n", + "\n", + "sim = Microsimulation()\n", + "\n", + "sim.calculate(\"adjusted_gross_income\", 2024).sum() / 1e9" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "13891.76442888221" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sim = Microsimulation().subsample(frac=0.1)\n", + "sim.calculate(\"adjusted_gross_income\", 2024).sum() / 1e9" + ] } ], "metadata": { @@ -208,7 +1180,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.10.14" }, "orig_nbformat": 4, "vscode": { diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index 08dfdb3df..71090cbff 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -22,6 +22,7 @@ SimpleTracer, TracingParameterNodeAtInstant, ) +import random import json @@ -1471,6 +1472,83 @@ def to_input_dataframe( return df + def subsample( + self, n=None, frac=None, seed=None, time_period=None + ) -> "Simulation": + """Quantize the simulation to a smaller size by sampling households. + + Args: + n (int, optional): The number of households to sample. Defaults to 10_000. + frac (float, optional): The fraction of households to sample. Defaults to None. + seed (int, optional): The key used to seed the random number generator. Defaults to the dataset name. + time_period (str, optional): Sample households based on their weight in this time period. Defaults to the default calculation period. + + Returns: + Simulation: The quantized simulation. + """ + # Set default key if not provided + if seed is None: + seed = self.dataset.name + + # Set default time period if not provided + if time_period is None: + time_period = self.default_calculation_period + + # Convert simulation inputs to DataFrame + df = self.to_input_dataframe() + + # Extract time period from DataFrame columns + df_time_period = df.columns.values[0].split("__")[1] + df_household_id_column = f"household_id__{df_time_period}" + + # Determine the appropriate household weight column + if f"household_weight__{time_period}" in df.columns: + household_weight_column = f"household_weight__{time_period}" + else: + household_weight_column = f"household_weight__{df_time_period}" + + # Group by household ID and get the first entry for each group + h_df = df.groupby(df_household_id_column).first() + h_ids = pd.Series(h_df.index) + if n is None and frac is None: + raise ValueError("Either n or frac must be provided.") + if n is None: + n = int(len(h_ids) * frac) + h_weights = pd.Series(h_df[household_weight_column].values) + + if n > len(h_weights): + # Don't need to subsample! + return self + + # Seed the random number generators for reproducibility + random.seed(str(seed)) + state = random.randint(0, 2**32 - 1) + np.random.seed(state) + + # Sample household IDs based on their weights + chosen_household_ids = np.random.choice( + h_ids, + n, + p=h_weights.values / h_weights.values.sum(), + replace=False, + ) + + # Filter DataFrame to include only the chosen households + df = df[df[df_household_id_column].isin(chosen_household_ids)] + + # Adjust household weights to maintain the total weight + df[household_weight_column] *= ( + h_weights.sum() + / df.groupby(df_household_id_column) + .first()[household_weight_column] + .sum() + ) + + # Update the dataset and rebuild the simulation + self.dataset = Dataset.from_dataframe(df) + self.build_from_dataset() + return self + class NpEncoder(json.JSONEncoder): def default(self, obj):