Skip to content

Commit

Permalink
Merge pull request #759 from camsys/explicit-chunk-with-config
Browse files Browse the repository at this point in the history
Explicit chunking
  • Loading branch information
jpn-- authored Feb 6, 2024
2 parents 25e4f6c + 3c047e0 commit 69465d9
Show file tree
Hide file tree
Showing 9 changed files with 231 additions and 40 deletions.
49 changes: 38 additions & 11 deletions activitysim/abm/models/accessibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,54 @@ class AccessibilitySettings(PydanticReadable):
SPEC: str = "accessibility.csv"
"""Filename for the accessibility specification (csv) file."""

explicit_chunk: int = 0
"""If > 0, use this chunk size instead of adaptive chunking."""


@nb.njit
def _accumulate_accessibility(arr, orig_zone_count, dest_zone_count):
assert arr.size == orig_zone_count * dest_zone_count
arr2 = arr.reshape((orig_zone_count, dest_zone_count))
assert arr.ndim == 1
i = 0
result = np.empty((orig_zone_count,), dtype=arr.dtype)
for o in range(orig_zone_count):
x = 0
for d in range(dest_zone_count):
x += arr2[o, d]
x += arr[i]
i += 1
result[o] = np.log1p(x)
return result


def compute_accessibilities_for_zones(
state,
accessibility_df,
land_use_df,
assignment_spec,
constants,
network_los,
trace_label,
chunk_sizer,
state: workflow.State,
accessibility_df: pd.DataFrame,
land_use_df: pd.DataFrame,
assignment_spec: dict,
constants: dict,
network_los: los.Network_LOS,
trace_label: str,
chunk_sizer: chunk.ChunkSizer,
):
"""
Compute accessibility for each zone in land use file using expressions from accessibility_spec.
Parameters
----------
state : workflow.State
accessibility_df : pd.DataFrame
land_use_df : pd.DataFrame
assignment_spec : dict
constants : dict
network_los : los.Network_LOS
trace_label : str
chunk_sizer : chunk.ChunkSizer
Returns
-------
accessibility_df : pd.DataFrame
The accessibility_df is updated in place.
"""
orig_zones = accessibility_df.index.values
dest_zones = land_use_df.index.values

Expand Down Expand Up @@ -215,13 +239,16 @@ def compute_accessibility(
)

accessibilities_list = []
explicit_chunk_size = model_settings.explicit_chunk

for (
_i,
chooser_chunk,
_chunk_trace_label,
chunk_sizer,
) in chunk.adaptive_chunked_choosers(state, accessibility_df, trace_label):
) in chunk.adaptive_chunked_choosers(
state, accessibility_df, trace_label, explicit_chunk_size=explicit_chunk_size
):
accessibilities = compute_accessibilities_for_zones(
state,
chooser_chunk,
Expand Down
63 changes: 63 additions & 0 deletions activitysim/abm/test/test_agg_accessibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

import logging

import pytest

from activitysim.abm import models # noqa: F401
from activitysim.abm.models.accessibility import (
AccessibilitySettings,
compute_accessibility,
)
from activitysim.core import workflow

logger = logging.getLogger(__name__)


@pytest.fixture
def state() -> workflow.State:
state = workflow.create_example("prototype_mtc", temp=True)

state.settings.models = [
"initialize_landuse",
"initialize_households",
"compute_accessibility",
]
state.settings.chunk_size = 0
state.settings.sharrow = False

state.run.by_name("initialize_landuse")
state.run.by_name("initialize_households")
return state


def test_simple_agg_accessibility(state, dataframe_regression):
state.run.by_name("compute_accessibility")
df = state.get_dataframe("accessibility")
dataframe_regression.check(df, basename="simple_agg_accessibility")


def test_agg_accessibility_explicit_chunking(state, dataframe_regression):
# set top level settings
state.settings.chunk_size = 0
state.settings.sharrow = False
state.settings.chunk_training_mode = "explicit"

# read the accessibility settings and override the explicit chunk size to 5
model_settings = AccessibilitySettings.read_settings_file(
state.filesystem, "accessibility.yaml"
)
model_settings.explicit_chunk = 5

compute_accessibility(
state,
state.get_dataframe("land_use"),
state.get_dataframe("accessibility"),
state.get("network_los"),
model_settings,
model_settings_file_name="accessibility.yaml",
trace_label="compute_accessibility",
output_table_name="accessibility",
)
df = state.get_dataframe("accessibility")
dataframe_regression.check(df, basename="simple_agg_accessibility")
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
zone_id,auPkRetail,auPkTotal,auOpRetail,auOpTotal,trPkRetail,trPkTotal,trOpRetail,trOpTotal,nmRetail,nmTotal
0,9.3164942696213568,12.615175743409841,9.3074367804092777,12.607849383502469,7.7642635203141968,11.145248204314596,7.6930860038975712,11.037285967769643,8.1373609284815895,11.726242204251774
1,9.3168979433052908,12.613460949773618,9.3046270180951538,12.604209004116514,7.5113009238919934,10.950045517942753,7.4270600345597773,10.763101915020352,8.1427168965425558,11.724186002096861
2,9.2932169855583542,12.580014484201365,9.2862416670099286,12.574901916285086,7.3409752547469438,10.787608449410779,7.2526778560064189,10.574953629615715,8.0503691347033364,11.478912540319461
3,9.3573494887919679,12.630894217760538,9.3482485710113998,12.623585758033322,7.8733268188651611,11.224171200372194,7.8143652246183066,11.135415940164263,8.3711974981452286,11.775230687719375
4,9.3435505366989382,12.585069456547828,9.3332621841590289,12.574553613617503,7.5893556698506153,11.082549781423353,7.5495574089217605,11.027965011367975,8.3180592569770333,11.431764418643981
5,9.2713502507871883,12.523449294093886,9.2657623133569711,12.519697725868093,7.3138724828278905,10.504310979303222,7.0683412975590123,10.251789948959422,7.8382412773559516,11.023737623179843
6,9.2931944176067329,12.528401489853936,9.2863728392168525,12.52041616561077,7.6419100510389839,10.805002739363189,7.6078784435600868,10.752509743759392,8.0169148075612071,11.108804747288168
7,9.2678442418060758,12.497146015767587,9.2621330948390046,12.489886065239302,7.5469338237309387,10.834136335989049,7.5014237921006828,10.779320100783975,7.9819505240836239,11.052152868115712
8,9.1895029665940431,12.42603649432956,9.184035053974922,12.415459802362889,7.188751493522151,10.303186212218705,7.149056566233341,10.260609523175923,7.4156298027129628,10.75866342061707
9,9.1860041373516861,12.40389009503904,9.1807619960868472,12.396344385917818,7.3793358243378222,10.548674769786773,7.3065218950712447,10.495921674032259,7.5678261562359008,10.694411485785926
10,9.3200926649609794,12.519242143782318,9.3150950606590026,12.511758212122075,7.4557019917326173,10.875601348026509,7.3483680514593566,10.762777861942512,8.2282865778153074,11.171156639341758
11,9.3515905766957719,12.600777214457102,9.3402871188008589,12.590072565945791,7.9459651463751646,11.204374985337394,7.8463256139030397,11.074533943328571,8.420517825501955,11.618972560237365
12,9.3475957875421258,12.610940370257774,9.3373286290969943,12.601590484236365,7.7678939430595539,11.12100647463696,7.691841626412466,11.012476514764131,8.4227464698159356,11.742390115774514
13,9.3272875917488811,12.61272185509814,9.3195224070699076,12.605842856331305,7.9829144769225122,11.205704184915069,7.9147382904661239,11.09630520845371,8.2936062476422165,11.736593006351654
14,9.2849351600384047,12.581337475036822,9.2777982690366851,12.575463251817387,7.6566141377182202,10.99707556494131,7.5743974507217873,10.914272271565647,8.0004874415095593,11.541814468777813
15,9.3121586675513246,12.554715357067975,9.3098515403136357,12.554250369775952,7.0161536446213777,10.534220863424366,6.9452061747018057,10.442447038350544,8.2473032556208885,11.373742456242004
16,9.2525132367431926,12.480891480108502,9.2515129069576805,12.479315380270007,6.6611995964557575,9.84475304878708,6.5626838669177907,9.7353179337959279,7.6671416984161134,10.785216324011325
17,9.2493602579025467,12.438990589690656,9.2489616305878055,12.440034826308484,7.0859296965612435,10.25268796871535,6.997755995841743,10.137302158694951,7.5966361055753779,10.414585321652353
18,9.1690294854761021,12.357455449640511,9.169914338241016,12.359583887732027,6.0886234793709892,9.29599202765759,5.9074690455933911,9.1012171904995682,7.0886934502985248,9.8650009904318754
19,9.2217425939140494,12.420066322549278,9.2188633511870108,12.41597699421243,6.636652875558787,9.8019044704452742,6.5908204292849781,9.7532014746502398,7.4941970400287934,10.367677845680188
20,9.3219157803287924,12.515866541333265,9.3251764782856572,12.518960683082542,7.428996554006094,10.580037613397844,7.3427336309884295,10.454321473639835,8.1084359799671955,11.011608267238865
21,9.2296522071417968,12.543187229493718,9.2196001636905649,12.535205063776825,7.1509239235248376,10.713897532287987,6.9410839531896329,10.463984552764911,7.6379082836081578,11.319586635314565
22,9.1161497032019767,12.433017912353973,9.1078484333015215,12.426054879223617,6.2116445474246689,9.7707806110979902,6.1182229314057297,9.6877986968503578,6.8876395653603648,10.656699616249735
23,9.2437967585247165,12.55097406396871,9.2300864314796112,12.541349601021635,7.3227417893504727,10.850763529501046,7.1151212727695663,10.577146937229784,7.7136276931827608,11.346711473571119
24,9.1982619817999431,12.494596324310164,9.1914370661962863,12.490871942939226,7.2966461173380575,10.729604500822386,7.1803656319823519,10.549488525934116,7.6165179958947329,11.016222756734685
6 changes: 5 additions & 1 deletion activitysim/cli/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations

import logging
import os
import sys

Expand Down Expand Up @@ -67,10 +70,11 @@ def main():
sys.exit(workflows.main(sys.argv[2:]))
else:
sys.exit(asim.execute())
except Exception:
except Exception as err:
# if we are in the debugger, re-raise the error instead of exiting
if sys.gettrace() is not None:
raise
logging.exception(err)
sys.exit(99)


Expand Down
Loading

0 comments on commit 69465d9

Please sign in to comment.