From 87b414f099e8ccc669287324d66b902c7406dc6f Mon Sep 17 00:00:00 2001 From: David Hensle Date: Thu, 5 Dec 2024 11:59:43 -0800 Subject: [PATCH] updating estimation checks to allow for non-zero household_sample_size --- activitysim/abm/models/joint_tour_frequency.py | 5 ++++- activitysim/abm/models/non_mandatory_tour_frequency.py | 4 +++- activitysim/abm/models/stop_frequency.py | 6 +++++- activitysim/abm/models/trip_destination.py | 2 ++ activitysim/core/estimation.py | 2 +- 5 files changed, 15 insertions(+), 4 deletions(-) diff --git a/activitysim/abm/models/joint_tour_frequency.py b/activitysim/abm/models/joint_tour_frequency.py index 1700c143b..9d918c955 100644 --- a/activitysim/abm/models/joint_tour_frequency.py +++ b/activitysim/abm/models/joint_tour_frequency.py @@ -192,8 +192,11 @@ def joint_tour_frequency( print(f"len(joint_tours) {len(joint_tours)}") different = False + # need to check households as well because the full survey sample may not be used + # (e.g. if we set household_sample_size in settings.yaml) survey_tours_not_in_tours = survey_tours[ ~survey_tours.index.isin(joint_tours.index) + & survey_tours.household_id.isin(households.index) ] if len(survey_tours_not_in_tours) > 0: print(f"survey_tours_not_in_tours\n{survey_tours_not_in_tours}") @@ -201,7 +204,7 @@ def joint_tour_frequency( tours_not_in_survey_tours = joint_tours[ ~joint_tours.index.isin(survey_tours.index) ] - if len(survey_tours_not_in_tours) > 0: + if len(tours_not_in_survey_tours) > 0: print(f"tours_not_in_survey_tours\n{tours_not_in_survey_tours}") different = True assert not different diff --git a/activitysim/abm/models/non_mandatory_tour_frequency.py b/activitysim/abm/models/non_mandatory_tour_frequency.py index 4d80c26ba..f656840e1 100644 --- a/activitysim/abm/models/non_mandatory_tour_frequency.py +++ b/activitysim/abm/models/non_mandatory_tour_frequency.py @@ -442,8 +442,10 @@ def non_mandatory_tour_frequency( if estimator: # make sure they created the right tours survey_tours = estimation.manager.get_survey_table("tours").sort_index() + # need the household_id check below incase household_sample_size != 0 non_mandatory_survey_tours = survey_tours[ - survey_tours.tour_category == "non_mandatory" + (survey_tours.tour_category == "non_mandatory") + & survey_tours.household_id.isin(persons.household_id) ] # need to remove the pure-escort tours from the survey tours table for comparison below if state.is_table("school_escort_tours"): diff --git a/activitysim/abm/models/stop_frequency.py b/activitysim/abm/models/stop_frequency.py index 70755ff86..6d24e3e46 100644 --- a/activitysim/abm/models/stop_frequency.py +++ b/activitysim/abm/models/stop_frequency.py @@ -277,7 +277,11 @@ def stop_frequency( survey_trips = estimation.manager.get_survey_table(table_name="trips") different = False - survey_trips_not_in_trips = survey_trips[~survey_trips.index.isin(trips.index)] + # need the check below on household_id incase household_sample_size != 0 + survey_trips_not_in_trips = survey_trips[ + ~survey_trips.index.isin(trips.index) + & survey_trips.household_id.isin(trips.household_id) + ] if len(survey_trips_not_in_trips) > 0: print(f"survey_trips_not_in_trips\n{survey_trips_not_in_trips}") different = True diff --git a/activitysim/abm/models/trip_destination.py b/activitysim/abm/models/trip_destination.py index 6e49e5c79..c831faac1 100644 --- a/activitysim/abm/models/trip_destination.py +++ b/activitysim/abm/models/trip_destination.py @@ -1362,6 +1362,8 @@ def run_trip_destination( # expect all the same trips survey_trips = estimator.get_survey_table("trips").sort_index() + # need to check household_id incase household_sample_size != 0 + survey_trips = survey_trips[survey_trips.household_id.isin(trips.household_id)] assert survey_trips.index.equals(trips.index) first = survey_trips.trip_num == 1 diff --git a/activitysim/core/estimation.py b/activitysim/core/estimation.py index 8f2443bed..143f36f62 100644 --- a/activitysim/core/estimation.py +++ b/activitysim/core/estimation.py @@ -831,7 +831,7 @@ def initialize_settings(self, state): if state.settings.multiprocess: pipeline_hh_ids = state.get_table("households").index if table_name == "households": - df = df[df.index.isin(pipeline_hh_ids)] + df = df.reindex(pipeline_hh_ids) assert pipeline_hh_ids.equals( df.index ), "household_ids not equal between survey and pipeline"