From a9dcd0cc1a652d72d7ece06c714dbe05d99b8a33 Mon Sep 17 00:00:00 2001 From: Jeff Newman Date: Mon, 7 Oct 2024 08:47:29 -0500 Subject: [PATCH 1/5] pydantic for estimation settings --- activitysim/core/estimation.py | 90 ++++++++++++++++++++++++---------- 1 file changed, 64 insertions(+), 26 deletions(-) diff --git a/activitysim/core/estimation.py b/activitysim/core/estimation.py index b79618509..585b8d7de 100644 --- a/activitysim/core/estimation.py +++ b/activitysim/core/estimation.py @@ -6,11 +6,14 @@ import os import shutil from pathlib import Path +from typing import Any, Literal import pandas as pd import yaml +from pydantic import BaseModel, class_validators, model_validator from activitysim.core import simulate, workflow +from activitysim.core.configuration import PydanticReadable from activitysim.core.configuration.base import PydanticBase from activitysim.core.util import reindex from activitysim.core.yaml_tools import safe_dump @@ -48,14 +51,53 @@ def estimation_enabled(state): return settings is not None +class SurveyTableConfig(PydanticBase): + file_name: str + index_col: str + + # The dataframe is stored in the loaded config dynamically but not given + # directly in the config file, as it's not a simple serializable object that + # can be written in a YAML file. + df: pd.DataFrame | None = None + + +class EstimationTableRecipeConfig(PydanticBase): + omnibus_tables: dict[str, list[str]] + omnibus_tables_append_columns: list[str] + + +class EstimationConfig(PydanticReadable): + SKIP_BUNDLE_WRITE_FOR: list[str] = [] + EDB_FILETYPE: Literal["csv", "parquet", "pkl"] = "csv" + EDB_ALTS_FILE_FORMAT: Literal["verbose", "compact"] = "verbose" + + enable: bool = False + bundles: list[str] = [] + model_estimation_table_types: dict[str, str] = {} + estimation_table_recipes: dict[str, EstimationTableRecipeConfig] = {} + survey_tables: dict[str, SurveyTableConfig] = {} + + # pydantic class validator to ensure that the model_estimation_table_types + # dictionary is a valid dictionary with string keys and string values, and + # that all the values are in the estimation_table_recipes dictionary + @model_validator(mode="after") + def validate_model_estimation_table_types(self): + for key, value in self.model_estimation_table_types.items(): + if value not in self.estimation_table_recipes: + raise ValueError( + f"model_estimation_table_types value '{value}' not in estimation_table_recipes" + ) + return self + + class Estimator: def __init__( self, state: workflow.State, - bundle_name, - model_name, - estimation_table_recipes, - settings, + bundle_name: str, + model_name: str, + estimation_table_recipes: dict[str, Any], + settings: EstimationConfig, ): logger.info("Initialize Estimator for'%s'" % (model_name,)) @@ -345,7 +387,7 @@ def write_omnibus_table(self): if len(self.omnibus_tables) == 0: return - edbs_to_skip = self.settings.get("SKIP_BUNDLE_WRITE_FOR", []) + edbs_to_skip = self.settings.SKIP_BUNDLE_WRITE_FOR if self.bundle_name in edbs_to_skip: self.debug(f"Skipping write to disk for {self.bundle_name}") return @@ -376,7 +418,7 @@ def write_omnibus_table(self): self.debug(f"sorting tables: {table_names}") df.sort_index(ascending=True, inplace=True, kind="mergesort") - filetype = self.settings.get("EDB_FILETYPE", "csv") + filetype = self.settings.EDB_FILETYPE if filetype == "csv": file_path = self.output_file_path(omnibus_table, "csv") @@ -460,7 +502,7 @@ def write_choosers(self, choosers_df): choosers_df, "choosers", append=True, - filetype=self.settings.get("EDB_FILETYPE", "csv"), + filetype=self.settings.EDB_FILETYPE, ) def write_choices(self, choices): @@ -471,7 +513,7 @@ def write_choices(self, choices): choices, "choices", append=True, - filetype=self.settings.get("EDB_FILETYPE", "csv"), + filetype=self.settings.EDB_FILETYPE, ) def write_override_choices(self, choices): @@ -482,7 +524,7 @@ def write_override_choices(self, choices): choices, "override_choices", append=True, - filetype=self.settings.get("EDB_FILETYPE", "csv"), + filetype=self.settings.EDB_FILETYPE, ) def write_constants(self, constants): @@ -582,7 +624,7 @@ def melt_alternatives(self, df): # 31153,2,util_dist_0_1,1.0 # 31153,3,util_dist_0_1,1.0 - output_format = self.settings.get("EDB_ALTS_FILE_FORMAT", "verbose") + output_format = self.settings.EDB_ALTS_FILE_FORMAT assert output_format in ["verbose", "compact"] if output_format == "compact": @@ -613,7 +655,7 @@ def write_interaction_expression_values(self, df): df, "interaction_expression_values", append=True, - filetype=self.settings.get("EDB_FILETYPE", "csv"), + filetype=self.settings.EDB_FILETYPE, ) def write_expression_values(self, df): @@ -621,7 +663,7 @@ def write_expression_values(self, df): df, "expression_values", append=True, - filetype=self.settings.get("EDB_FILETYPE", "csv"), + filetype=self.settings.EDB_FILETYPE, ) def write_alternatives(self, alternatives_df, bundle_directory=False): @@ -638,7 +680,7 @@ def write_interaction_sample_alternatives(self, alternatives_df): alternatives_df, "interaction_sample_alternatives", append=True, - filetype=self.settings.get("EDB_FILETYPE", "csv"), + filetype=self.settings.EDB_FILETYPE, ) def write_interaction_simulate_alternatives(self, interaction_df): @@ -647,7 +689,7 @@ def write_interaction_simulate_alternatives(self, interaction_df): interaction_df, "interaction_simulate_alternatives", append=True, - filetype=self.settings.get("EDB_FILETYPE", "csv"), + filetype=self.settings.EDB_FILETYPE, ) def get_survey_values(self, model_values, table_name, column_names): @@ -690,25 +732,21 @@ def initialize_settings(self, state): return assert not self.settings_initialized - self.settings = state.filesystem.read_model_settings( - ESTIMATION_SETTINGS_FILE_NAME, mandatory=False + self.settings = EstimationConfig.read_settings_file( + state.filesystem, ESTIMATION_SETTINGS_FILE_NAME, mandatory=False ) if not self.settings: # if the model self.settings file is not found, we are not in estimation mode. self.enabled = False else: - self.enabled = self.settings.get("enable", "True") - self.bundles = self.settings.get("bundles", []) + self.enabled = self.settings.enable + self.bundles = self.settings.bundles - self.model_estimation_table_types = self.settings.get( - "model_estimation_table_types", {} - ) - self.estimation_table_recipes = self.settings.get( - "estimation_table_recipes", {} - ) + self.model_estimation_table_types = self.settings.model_estimation_table_types + self.estimation_table_recipes = self.settings.estimation_table_recipes if self.enabled: - self.survey_tables = self.settings.get("survey_tables", {}) + self.survey_tables = self.settings.survey_tables for table_name, table_info in self.survey_tables.items(): assert ( "file_name" in table_info @@ -723,7 +761,7 @@ def initialize_settings(self, state): file_path ), "File for survey table '%s' not found: %s" % (table_name, file_path) df = pd.read_csv(file_path) - index_col = table_info.get("index_col") + index_col = table_info.index_col if index_col is not None: assert ( index_col in df.columns From 29424a708ed3691d9e1238ac88111e5385f44fb7 Mon Sep 17 00:00:00 2001 From: Jeff Newman Date: Mon, 7 Oct 2024 08:56:00 -0500 Subject: [PATCH 2/5] allow df as type in config --- activitysim/core/estimation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/activitysim/core/estimation.py b/activitysim/core/estimation.py index 585b8d7de..8e1391c68 100644 --- a/activitysim/core/estimation.py +++ b/activitysim/core/estimation.py @@ -58,6 +58,9 @@ class SurveyTableConfig(PydanticBase): # The dataframe is stored in the loaded config dynamically but not given # directly in the config file, as it's not a simple serializable object that # can be written in a YAML file. + class Config: + arbitrary_types_allowed = True + df: pd.DataFrame | None = None From 1fac9b5194275cb8947783c6bb25ab4f3c00df2f Mon Sep 17 00:00:00 2001 From: Jeff Newman Date: Wed, 9 Oct 2024 11:26:04 -0500 Subject: [PATCH 3/5] fix table_info --- activitysim/core/estimation.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/activitysim/core/estimation.py b/activitysim/core/estimation.py index 8e1391c68..26d2a4e6f 100644 --- a/activitysim/core/estimation.py +++ b/activitysim/core/estimation.py @@ -752,13 +752,10 @@ def initialize_settings(self, state): self.survey_tables = self.settings.survey_tables for table_name, table_info in self.survey_tables.items(): assert ( - "file_name" in table_info - ), "No file name specified for survey_table '%s' in %s" % ( - table_name, - ESTIMATION_SETTINGS_FILE_NAME, - ) + table_info.file_name + ), f"No file name specified for survey_table '{table_name}' in {ESTIMATION_SETTINGS_FILE_NAME}" file_path = state.filesystem.get_data_file_path( - table_info["file_name"], mandatory=True + table_info.file_name, mandatory=True ) assert os.path.exists( file_path @@ -784,7 +781,7 @@ def initialize_settings(self, state): df = df[df.household_id.isin(pipeline_hh_ids)] # add the table df to survey_tables - table_info["df"] = df + table_info.df = df self.settings_initialized = True From e61133331370ca289b022faa241ce75052e1042a Mon Sep 17 00:00:00 2001 From: Jeff Newman Date: Thu, 10 Oct 2024 12:18:52 -0500 Subject: [PATCH 4/5] repair for Pydantic --- activitysim/core/estimation.py | 39 ++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/activitysim/core/estimation.py b/activitysim/core/estimation.py index 26d2a4e6f..1c18d9779 100644 --- a/activitysim/core/estimation.py +++ b/activitysim/core/estimation.py @@ -75,9 +75,30 @@ class EstimationConfig(PydanticReadable): EDB_ALTS_FILE_FORMAT: Literal["verbose", "compact"] = "verbose" enable: bool = False + """Flag to enable estimation.""" + bundles: list[str] = [] + """List of component names to create EDBs for.""" + model_estimation_table_types: dict[str, str] = {} + """Mapping of component names to estimation table types. + + The keys of this mapping are the model component names, and the values are the + names of the estimation table recipes that should be used to generate the + estimation tables for the model component. The recipes are generally related + to the generic model types, such as 'simple_simulate', 'interaction_simulate', + 'interaction_sample_simulate', etc. + """ + estimation_table_recipes: dict[str, EstimationTableRecipeConfig] = {} + """Mapping of estimation table recipe names to their configurations. + + The keys of this mapping are the names of the estimation table recipes. + The recipes are generally related to the generic model types, such as + 'simple_simulate', 'interaction_simulate', 'interaction_sample_simulate', + etc. The values are the configurations for the estimation table recipes. + """ + survey_tables: dict[str, SurveyTableConfig] = {} # pydantic class validator to ensure that the model_estimation_table_types @@ -99,7 +120,7 @@ def __init__( state: workflow.State, bundle_name: str, model_name: str, - estimation_table_recipes: dict[str, Any], + estimation_table_recipe: EstimationTableRecipeConfig, settings: EstimationConfig, ): logger.info("Initialize Estimator for'%s'" % (model_name,)) @@ -108,7 +129,7 @@ def __init__( self.bundle_name = bundle_name self.model_name = model_name self.settings_name = model_name - self.estimation_table_recipes = estimation_table_recipes + self.estimation_table_recipe = estimation_table_recipe self.estimating = True self.settings = settings @@ -129,10 +150,10 @@ def __init__( # assert 'override_choices' in self.model_settings, \ # "override_choices not found for %s in %s." % (model_name, ESTIMATION_SETTINGS_FILE_NAME) - self.omnibus_tables = self.estimation_table_recipes["omnibus_tables"] - self.omnibus_tables_append_columns = self.estimation_table_recipes[ - "omnibus_tables_append_columns" - ] + self.omnibus_tables = self.estimation_table_recipe.omnibus_tables + self.omnibus_tables_append_columns = ( + self.estimation_table_recipe.omnibus_tables_append_columns + ) self.tables = {} self.tables_to_cache = [ table_name @@ -724,8 +745,8 @@ class EstimationManager(object): def __init__(self): self.settings_initialized = False self.bundles = [] - self.estimation_table_recipes = {} - self.model_estimation_table_types = {} + self.estimation_table_recipes: dict[str, EstimationTableRecipeConfig] = {} + self.model_estimation_table_types: dict[str, str] = {} self.estimating = {} self.settings = None @@ -843,7 +864,7 @@ def begin_estimation( state, bundle_name, model_name, - estimation_table_recipes=self.estimation_table_recipes[ + estimation_table_recipe=self.estimation_table_recipes[ model_estimation_table_type ], settings=self.settings, From f5a5e426335531c25e851bbc68fd3328399c1a29 Mon Sep 17 00:00:00 2001 From: Jeff Newman Date: Sun, 13 Oct 2024 08:32:05 -0500 Subject: [PATCH 5/5] df is attribute --- activitysim/core/estimation.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/activitysim/core/estimation.py b/activitysim/core/estimation.py index 1c18d9779..8109be7e9 100644 --- a/activitysim/core/estimation.py +++ b/activitysim/core/estimation.py @@ -6,11 +6,10 @@ import os import shutil from pathlib import Path -from typing import Any, Literal +from typing import Literal import pandas as pd -import yaml -from pydantic import BaseModel, class_validators, model_validator +from pydantic import model_validator from activitysim.core import simulate, workflow from activitysim.core.configuration import PydanticReadable @@ -514,7 +513,7 @@ def write_coefficients_template(self, model_settings): assert self.estimating if isinstance(model_settings, PydanticBase): - model_settings = model_settings.dict() + model_settings = model_settings.model_dump() coefficients_df = simulate.read_model_coefficient_template( self.state.filesystem, model_settings ) @@ -587,7 +586,7 @@ def write_model_settings( ) assert not os.path.isfile(file_path) with open(file_path, "w") as f: - safe_dump(model_settings.dict(), f) + safe_dump(model_settings.model_dump(), f) else: if "include_settings" in model_settings: file_path = self.output_file_path( @@ -882,7 +881,7 @@ def get_survey_table(self, table_name): "EstimationManager. get_survey_table: survey table '%s' not in survey_tables" % table_name ) - df = self.survey_tables[table_name].get("df") + df = self.survey_tables[table_name].df return df def get_survey_values(self, model_values, table_name, column_names):