From 33a2eb058ee164cd693ac66ea22fc3c5756fb254 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff <35577657+nikhilwoodruff@users.noreply.github.com> Date: Fri, 2 Aug 2024 10:09:35 +0100 Subject: [PATCH] Small updates (#231) * Remove unneeded dependencies * Update docs * Add CSV loading function * Add `relative_change` function * Fix Enum bugs * Remove print statement * Fix enum bugs * Remove microdf imports --- changelog_entry.yaml | 5 +++ docs/{python_api => usage}/reforms.ipynb | 0 policyengine_core/charts/bar.py | 2 +- policyengine_core/data/dataset.py | 7 ++- policyengine_core/enums/enum.py | 5 ++- policyengine_core/parameters/parameter.py | 9 ++++ .../populations/group_population.py | 16 +++++++ policyengine_core/populations/population.py | 3 ++ .../simulations/microsimulation.py | 2 +- policyengine_core/simulations/simulation.py | 43 +++++++++++++------ setup.py | 2 - 11 files changed, 70 insertions(+), 24 deletions(-) rename docs/{python_api => usage}/reforms.ipynb (100%) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29bb..45848b400 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,5 @@ +- bump: minor + changes: + added: + - Simulation loading from dataframes. + - Simulation `start_instant` attribute. diff --git a/docs/python_api/reforms.ipynb b/docs/usage/reforms.ipynb similarity index 100% rename from docs/python_api/reforms.ipynb rename to docs/usage/reforms.ipynb diff --git a/policyengine_core/charts/bar.py b/policyengine_core/charts/bar.py index 8f54ed169..b41a4548b 100644 --- a/policyengine_core/charts/bar.py +++ b/policyengine_core/charts/bar.py @@ -1,7 +1,7 @@ import pandas as pd from .formatting import * import plotly.express as px -from microdf import MicroSeries +from policyengine_core.weighting import MicroSeries from typing import Callable import numpy as np diff --git a/policyengine_core/data/dataset.py b/policyengine_core/data/dataset.py index 3d3bc17d5..71a8072b3 100644 --- a/policyengine_core/data/dataset.py +++ b/policyengine_core/data/dataset.py @@ -371,17 +371,16 @@ def from_dataframe(dataframe: pd.DataFrame, time_period: str = None): Returns: Dataset: The dataset. """ - file_path = Path(file_path) dataset = type( "Dataset", (Dataset,), { - "name": file_path.stem, - "label": file_path.stem, + "name": "dataframe", + "label": "DataFrame", "data_format": Dataset.FLAT_FILE, "file_path": "dataframe", "time_period": time_period, - "load": lambda: dataframe, + "load": lambda self: dataframe, }, )() diff --git a/policyengine_core/enums/enum.py b/policyengine_core/enums/enum.py index b55ec3470..b66d7f237 100644 --- a/policyengine_core/enums/enum.py +++ b/policyengine_core/enums/enum.py @@ -53,10 +53,11 @@ def encode(cls, array: Union[EnumArray, np.ndarray]) -> EnumArray: # Confusingly, Numpy uses "S" to refer to byte-string arrays # and "U" to refer to Unicode-string arrays, which are also # referred to as the "str" type - if array.dtype.kind == "S": + if isinstance(array[0], Enum): + array = np.array([item.name for item in array]) + if array.dtype.kind == "S" or array.dtype == object: # Convert boolean array to string array array = array.astype(str) - if isinstance(array, np.ndarray) and array.dtype.kind in {"U", "S"}: # String array indices = np.select( diff --git a/policyengine_core/parameters/parameter.py b/policyengine_core/parameters/parameter.py index 77317314c..fd792030d 100644 --- a/policyengine_core/parameters/parameter.py +++ b/policyengine_core/parameters/parameter.py @@ -224,3 +224,12 @@ def _get_at_instant(self, instant): if value_at_instant.instant_str <= instant: return value_at_instant.value return None + + def relative_change(self, start_instant, end_instant): + start_instant = str(start_instant) + end_instant = str(end_instant) + end_value = self._get_at_instant(end_instant) + start_value = self._get_at_instant(start_instant) + if end_value is None or start_value is None: + return None + return end_value / start_value - 1 diff --git a/policyengine_core/populations/group_population.py b/policyengine_core/populations/group_population.py index bd85b79c0..e201d0f50 100644 --- a/policyengine_core/populations/group_population.py +++ b/policyengine_core/populations/group_population.py @@ -7,6 +7,8 @@ from policyengine_core.entities import Entity, Role from policyengine_core.enums import EnumArray from policyengine_core.populations.population import Population +from policyengine_core.periods.period_ import Period +from typing import Optional, Container if TYPE_CHECKING: from policyengine_core.simulations import Simulation @@ -21,6 +23,20 @@ def __init__(self, entity: Entity, members: Population): self._members_position: ArrayLike = None self._ordered_members_map = None + def __call__( + self, + variable_name: str, + period: Period = None, + options: Optional[Container[str]] = None, + ): + variable = self.simulation.tax_benefit_system.variables.get( + variable_name + ) + if variable.entity.is_person: + return self.sum(self.members(variable_name, period, options)) + else: + return super().__call__(variable_name, period, options) + def clone( self, simulation: "Simulation", members: Population ) -> "GroupPopulation": diff --git a/policyengine_core/populations/population.py b/policyengine_core/populations/population.py index b40fb0bd3..a04b8de63 100644 --- a/policyengine_core/populations/population.py +++ b/policyengine_core/populations/population.py @@ -34,6 +34,9 @@ def clone(self, simulation: "Simulation") -> "Population": result.ids = self.ids return result + def has_any_input(self, variable_name: str) -> bool: + return len(self.get_holder(variable_name).get_known_periods()) > 0 + def empty_array(self) -> numpy.ndarray: return numpy.zeros(self.count) diff --git a/policyengine_core/simulations/microsimulation.py b/policyengine_core/simulations/microsimulation.py index 2ef2334cf..905aa6ad3 100644 --- a/policyengine_core/simulations/microsimulation.py +++ b/policyengine_core/simulations/microsimulation.py @@ -1,6 +1,6 @@ from typing import Dict, Type -from microdf import MicroDataFrame, MicroSeries +from policyengine_core.weighting import MicroDataFrame, MicroSeries import numpy as np from policyengine_core.data.dataset import Dataset from policyengine_core.periods import Period diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index cc8d960e0..4911e5483 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -76,6 +76,9 @@ class Simulation: macro_cache_write: bool = True """Whether to write to the macro cache.""" + start_instant: str = None + """The earliest data input instant of the simulation.""" + def __init__( self, tax_benefit_system: "TaxBenefitSystem" = None, @@ -155,6 +158,10 @@ def __init__( ) if isinstance(dataset, type): self.dataset: Dataset = dataset(require=True) + elif isinstance(dataset, pd.DataFrame): + self.dataset = Dataset.from_dataframe( + dataset, self.default_input_period + ) else: self.dataset = dataset self.build_from_dataset() @@ -242,6 +249,9 @@ def build_from_dataset(self) -> None: + "Make sure you have downloaded or built it using the `policyengine-core data` command." ) from e + if self.dataset.data_format == Dataset.FLAT_FILE: + data = {col: data[col].values for col in data.columns} + person_entity = self.tax_benefit_system.person_entity entity_id_field = f"{person_entity.key}_id" if self.dataset.data_format != Dataset.FLAT_FILE: @@ -250,14 +260,11 @@ def build_from_dataset(self) -> None: ), f"Missing {entity_id_field} column in the dataset. Each person entity must have an ID array defined for ETERNITY." elif entity_id_field not in data: data[entity_id_field] = np.arange(len(data)) - if self.dataset.data_format != Dataset.FLAT_FILE: - get_eternity_array = lambda ds: ( - ds[list(ds.keys())[0]] - if self.dataset.data_format == Dataset.TIME_PERIOD_ARRAYS - else ds - ) - else: - get_eternity_array = lambda ds: ds + get_eternity_array = lambda ds: ( + ds[list(ds.keys())[0]] + if self.dataset.data_format == Dataset.TIME_PERIOD_ARRAYS + else ds + ) entity_ids = get_eternity_array(data[entity_id_field]) builder.declare_person_entity(person_entity.key, entity_ids) @@ -268,7 +275,12 @@ def build_from_dataset(self) -> None: entity_id_field in data ), f"Missing {entity_id_field} column in the dataset. Each group entity must have an ID array defined for ETERNITY." elif entity_id_field not in data: - data[entity_id_field] = np.arange(len(data)) + if f"person_{group_entity.key}_id" in data: + data[entity_id_field] = np.arange( + len(np.unique(data[f"person_{group_entity.key}_id"])) + ) + else: + data[entity_id_field] = np.arange(len(data)) entity_ids = get_eternity_array(data[entity_id_field]) builder.declare_entity(group_entity.key, entity_ids) @@ -333,9 +345,6 @@ def build_from_dataset(self) -> None: ) if variable_name not in self.tax_benefit_system.variables: - logging.warn( - f"Variable {variable_name} not found. Skipping." - ) continue variable_meta = self.tax_benefit_system.get_variable( @@ -355,7 +364,9 @@ def build_from_dataset(self) -> None: self.set_input(variable, time_period, entity_level_data) - self.default_calculation_period = self.dataset.time_period + self.default_calculation_period = ( + self.dataset.time_period or self.default_calculation_period + ) self.tax_benefit_system.data_modified = False @@ -684,6 +695,8 @@ def _calculate( ): # Variables with a calculate-output property specify last_known_period = sorted(known_periods)[-1] + if last_known_period.start > period.start: + return holder.default_array() array = holder.get_array(last_known_period) else: array = holder.default_array() @@ -1139,10 +1152,12 @@ def set_input( If a ``set_input`` property has been set for the variable, this method may accept inputs for periods not matching the ``definition_period`` of the variable. To read more about this, check the `documentation `_. """ + period = periods.period(period) + if self.start_instant is None or self.start_instant > period.start: + self.start_instant = period.start variable = self.tax_benefit_system.get_variable( variable_name, check_existence=True ) - period = periods.period(period) if (variable.end is not None) and (period.start.date > variable.end): return self.get_holder(variable_name).set_input( diff --git a/setup.py b/setup.py index 0bc1fa55f..1ca417f29 100644 --- a/setup.py +++ b/setup.py @@ -23,8 +23,6 @@ "psutil<6", "wheel<1", "h5py>=3,<4", - "microdf_python>=0.3.0,<1", - "tqdm>=4.46.0,<5", "requests>=2.27.1,<3", "pandas>=1", "plotly>=5.6.0,<6",