Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Estimation Pydantic #2

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 103 additions & 45 deletions activitysim/core/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import os
import shutil
from pathlib import Path
from typing import Literal

import pandas as pd
import yaml
from pydantic import model_validator

from activitysim.core import simulate, workflow
from activitysim.core.configuration import PydanticReadable
from activitysim.core.configuration.base import PydanticBase
from activitysim.core.util import reindex
from activitysim.core.yaml_tools import safe_dump
Expand Down Expand Up @@ -48,22 +50,85 @@ def estimation_enabled(state):
return settings is not None


class SurveyTableConfig(PydanticBase):
file_name: str
index_col: str

# The dataframe is stored in the loaded config dynamically but not given
# directly in the config file, as it's not a simple serializable object that
# can be written in a YAML file.
class Config:
arbitrary_types_allowed = True

df: pd.DataFrame | None = None


class EstimationTableRecipeConfig(PydanticBase):
omnibus_tables: dict[str, list[str]]
omnibus_tables_append_columns: list[str]


class EstimationConfig(PydanticReadable):
SKIP_BUNDLE_WRITE_FOR: list[str] = []
EDB_FILETYPE: Literal["csv", "parquet", "pkl"] = "csv"
EDB_ALTS_FILE_FORMAT: Literal["verbose", "compact"] = "verbose"

enable: bool = False
"""Flag to enable estimation."""

bundles: list[str] = []
"""List of component names to create EDBs for."""

model_estimation_table_types: dict[str, str] = {}
"""Mapping of component names to estimation table types.

The keys of this mapping are the model component names, and the values are the
names of the estimation table recipes that should be used to generate the
estimation tables for the model component. The recipes are generally related
to the generic model types, such as 'simple_simulate', 'interaction_simulate',
'interaction_sample_simulate', etc.
"""

estimation_table_recipes: dict[str, EstimationTableRecipeConfig] = {}
"""Mapping of estimation table recipe names to their configurations.

The keys of this mapping are the names of the estimation table recipes.
The recipes are generally related to the generic model types, such as
'simple_simulate', 'interaction_simulate', 'interaction_sample_simulate',
etc. The values are the configurations for the estimation table recipes.
"""

survey_tables: dict[str, SurveyTableConfig] = {}

# pydantic class validator to ensure that the model_estimation_table_types
# dictionary is a valid dictionary with string keys and string values, and
# that all the values are in the estimation_table_recipes dictionary
@model_validator(mode="after")
def validate_model_estimation_table_types(self):
for key, value in self.model_estimation_table_types.items():
if value not in self.estimation_table_recipes:
raise ValueError(
f"model_estimation_table_types value '{value}' not in estimation_table_recipes"
)
return self


class Estimator:
def __init__(
self,
state: workflow.State,
bundle_name,
model_name,
estimation_table_recipes,
settings,
bundle_name: str,
model_name: str,
estimation_table_recipe: EstimationTableRecipeConfig,
settings: EstimationConfig,
):
logger.info("Initialize Estimator for'%s'" % (model_name,))

self.state = state
self.bundle_name = bundle_name
self.model_name = model_name
self.settings_name = model_name
self.estimation_table_recipes = estimation_table_recipes
self.estimation_table_recipe = estimation_table_recipe
self.estimating = True
self.settings = settings

Expand All @@ -84,10 +149,10 @@ def __init__(
# assert 'override_choices' in self.model_settings, \
# "override_choices not found for %s in %s." % (model_name, ESTIMATION_SETTINGS_FILE_NAME)

self.omnibus_tables = self.estimation_table_recipes["omnibus_tables"]
self.omnibus_tables_append_columns = self.estimation_table_recipes[
"omnibus_tables_append_columns"
]
self.omnibus_tables = self.estimation_table_recipe.omnibus_tables
self.omnibus_tables_append_columns = (
self.estimation_table_recipe.omnibus_tables_append_columns
)
self.tables = {}
self.tables_to_cache = [
table_name
Expand Down Expand Up @@ -345,7 +410,7 @@ def write_omnibus_table(self):
if len(self.omnibus_tables) == 0:
return

edbs_to_skip = self.settings.get("SKIP_BUNDLE_WRITE_FOR", [])
edbs_to_skip = self.settings.SKIP_BUNDLE_WRITE_FOR
if self.bundle_name in edbs_to_skip:
self.debug(f"Skipping write to disk for {self.bundle_name}")
return
Expand Down Expand Up @@ -376,7 +441,7 @@ def write_omnibus_table(self):
self.debug(f"sorting tables: {table_names}")
df.sort_index(ascending=True, inplace=True, kind="mergesort")

filetype = self.settings.get("EDB_FILETYPE", "csv")
filetype = self.settings.EDB_FILETYPE

if filetype == "csv":
file_path = self.output_file_path(omnibus_table, "csv")
Expand Down Expand Up @@ -448,7 +513,7 @@ def write_coefficients_template(self, model_settings):
assert self.estimating

if isinstance(model_settings, PydanticBase):
model_settings = model_settings.dict()
model_settings = model_settings.model_dump()
coefficients_df = simulate.read_model_coefficient_template(
self.state.filesystem, model_settings
)
Expand All @@ -460,7 +525,7 @@ def write_choosers(self, choosers_df):
choosers_df,
"choosers",
append=True,
filetype=self.settings.get("EDB_FILETYPE", "csv"),
filetype=self.settings.EDB_FILETYPE,
)

def write_choices(self, choices):
Expand All @@ -471,7 +536,7 @@ def write_choices(self, choices):
choices,
"choices",
append=True,
filetype=self.settings.get("EDB_FILETYPE", "csv"),
filetype=self.settings.EDB_FILETYPE,
)

def write_override_choices(self, choices):
Expand All @@ -482,7 +547,7 @@ def write_override_choices(self, choices):
choices,
"override_choices",
append=True,
filetype=self.settings.get("EDB_FILETYPE", "csv"),
filetype=self.settings.EDB_FILETYPE,
)

def write_constants(self, constants):
Expand Down Expand Up @@ -521,7 +586,7 @@ def write_model_settings(
)
assert not os.path.isfile(file_path)
with open(file_path, "w") as f:
safe_dump(model_settings.dict(), f)
safe_dump(model_settings.model_dump(), f)
else:
if "include_settings" in model_settings:
file_path = self.output_file_path(
Expand Down Expand Up @@ -582,7 +647,7 @@ def melt_alternatives(self, df):
# 31153,2,util_dist_0_1,1.0
# 31153,3,util_dist_0_1,1.0

output_format = self.settings.get("EDB_ALTS_FILE_FORMAT", "verbose")
output_format = self.settings.EDB_ALTS_FILE_FORMAT
assert output_format in ["verbose", "compact"]

if output_format == "compact":
Expand Down Expand Up @@ -613,15 +678,15 @@ def write_interaction_expression_values(self, df):
df,
"interaction_expression_values",
append=True,
filetype=self.settings.get("EDB_FILETYPE", "csv"),
filetype=self.settings.EDB_FILETYPE,
)

def write_expression_values(self, df):
self.write_table(
df,
"expression_values",
append=True,
filetype=self.settings.get("EDB_FILETYPE", "csv"),
filetype=self.settings.EDB_FILETYPE,
)

def write_alternatives(self, alternatives_df, bundle_directory=False):
Expand All @@ -638,7 +703,7 @@ def write_interaction_sample_alternatives(self, alternatives_df):
alternatives_df,
"interaction_sample_alternatives",
append=True,
filetype=self.settings.get("EDB_FILETYPE", "csv"),
filetype=self.settings.EDB_FILETYPE,
)

def write_interaction_simulate_alternatives(self, interaction_df):
Expand All @@ -647,7 +712,7 @@ def write_interaction_simulate_alternatives(self, interaction_df):
interaction_df,
"interaction_simulate_alternatives",
append=True,
filetype=self.settings.get("EDB_FILETYPE", "csv"),
filetype=self.settings.EDB_FILETYPE,
)

def get_survey_values(self, model_values, table_name, column_names):
Expand Down Expand Up @@ -679,8 +744,8 @@ class EstimationManager(object):
def __init__(self):
self.settings_initialized = False
self.bundles = []
self.estimation_table_recipes = {}
self.model_estimation_table_types = {}
self.estimation_table_recipes: dict[str, EstimationTableRecipeConfig] = {}
self.model_estimation_table_types: dict[str, str] = {}
self.estimating = {}
self.settings = None

Expand All @@ -690,40 +755,33 @@ def initialize_settings(self, state):
return

assert not self.settings_initialized
self.settings = state.filesystem.read_model_settings(
ESTIMATION_SETTINGS_FILE_NAME, mandatory=False
self.settings = EstimationConfig.read_settings_file(
state.filesystem, ESTIMATION_SETTINGS_FILE_NAME, mandatory=False
)
if not self.settings:
# if the model self.settings file is not found, we are not in estimation mode.
self.enabled = False
else:
self.enabled = self.settings.get("enable", "True")
self.bundles = self.settings.get("bundles", [])
self.enabled = self.settings.enable
self.bundles = self.settings.bundles

self.model_estimation_table_types = self.settings.get(
"model_estimation_table_types", {}
)
self.estimation_table_recipes = self.settings.get(
"estimation_table_recipes", {}
)
self.model_estimation_table_types = self.settings.model_estimation_table_types
self.estimation_table_recipes = self.settings.estimation_table_recipes

if self.enabled:
self.survey_tables = self.settings.get("survey_tables", {})
self.survey_tables = self.settings.survey_tables
for table_name, table_info in self.survey_tables.items():
assert (
"file_name" in table_info
), "No file name specified for survey_table '%s' in %s" % (
table_name,
ESTIMATION_SETTINGS_FILE_NAME,
)
table_info.file_name
), f"No file name specified for survey_table '{table_name}' in {ESTIMATION_SETTINGS_FILE_NAME}"
file_path = state.filesystem.get_data_file_path(
table_info["file_name"], mandatory=True
table_info.file_name, mandatory=True
)
assert os.path.exists(
file_path
), "File for survey table '%s' not found: %s" % (table_name, file_path)
df = pd.read_csv(file_path)
index_col = table_info.get("index_col")
index_col = table_info.index_col
if index_col is not None:
assert (
index_col in df.columns
Expand All @@ -743,7 +801,7 @@ def initialize_settings(self, state):
df = df[df.household_id.isin(pipeline_hh_ids)]

# add the table df to survey_tables
table_info["df"] = df
table_info.df = df

self.settings_initialized = True

Expand Down Expand Up @@ -805,7 +863,7 @@ def begin_estimation(
state,
bundle_name,
model_name,
estimation_table_recipes=self.estimation_table_recipes[
estimation_table_recipe=self.estimation_table_recipes[
model_estimation_table_type
],
settings=self.settings,
Expand All @@ -823,7 +881,7 @@ def get_survey_table(self, table_name):
"EstimationManager. get_survey_table: survey table '%s' not in survey_tables"
% table_name
)
df = self.survey_tables[table_name].get("df")
df = self.survey_tables[table_name].df
return df

def get_survey_values(self, model_values, table_name, column_names):
Expand Down
Loading