|
| 1 | +# ActivitySim |
| 2 | +# See full license in LICENSE.txt. |
| 3 | + |
| 4 | +import pandas as pd |
| 5 | +import pandas.testing as pdt |
| 6 | + |
| 7 | +from activitysim.abm.models.vehicle_type_choice import ( |
| 8 | + get_combinatorial_vehicle_alternatives, |
| 9 | + construct_model_alternatives, |
| 10 | + VehicleTypeChoiceSettings, |
| 11 | +) |
| 12 | +from activitysim.core import workflow |
| 13 | + |
| 14 | + |
| 15 | +def test_vehicle_type_alts(): |
| 16 | + state = workflow.State.make_default(__file__) |
| 17 | + |
| 18 | + alts_cats_dict = { |
| 19 | + "body_type": ["Car", "SUV"], |
| 20 | + "fuel_type": ["Gas", "BEV"], |
| 21 | + "age": [1, 2, 3], |
| 22 | + } |
| 23 | + |
| 24 | + alts_wide, alts_long = get_combinatorial_vehicle_alternatives(alts_cats_dict) |
| 25 | + |
| 26 | + # alts are initially constructed combinatorially |
| 27 | + assert len(alts_long) == 12, "alts_long should have 12 rows" |
| 28 | + assert len(alts_wide) == 12, "alts_wide should have 12 rows" |
| 29 | + |
| 30 | + model_settings = VehicleTypeChoiceSettings.model_construct() |
| 31 | + model_settings.combinatorial_alts = alts_cats_dict |
| 32 | + model_settings.PROBS_SPEC = None |
| 33 | + model_settings.WRITE_OUT_ALTS_FILE = False |
| 34 | + |
| 35 | + # constructing veh type data with missing alts |
| 36 | + vehicle_type_data = pd.DataFrame( |
| 37 | + data={ |
| 38 | + "body_type": ["Car", "Car", "Car", "SUV", "SUV"], |
| 39 | + "fuel_type": ["Gas", "Gas", "BEV", "Gas", "BEV"], |
| 40 | + "age": ["1", "2", "3", "1", "2"], |
| 41 | + "dummy_data": [1, 2, 3, 4, 5], |
| 42 | + }, |
| 43 | + index=[0, 1, 2, 3, 4], |
| 44 | + ) |
| 45 | + |
| 46 | + alts_wide, alts_long = construct_model_alternatives( |
| 47 | + state, model_settings, alts_cats_dict, vehicle_type_data |
| 48 | + ) |
| 49 | + |
| 50 | + # should only have alts left that are in the file |
| 51 | + assert len(alts_long) == 5, "alts_long should have 5 rows" |
| 52 | + |
| 53 | + # indexes need to be the same to choices match alts |
| 54 | + pdt.assert_index_equal(alts_long.index, alts_wide.index) |
| 55 | + |
| 56 | + # columns need to be in correct order for downstream configs |
| 57 | + pdt.assert_index_equal( |
| 58 | + alts_long.columns, pd.Index(["body_type", "age", "fuel_type"]) |
| 59 | + ) |
0 commit comments