Skip to content

Commit

Permalink
Merge pull request #764 from camsys/overflow-protection-2
Browse files Browse the repository at this point in the history
Overflow protection
  • Loading branch information
jpn-- authored Dec 15, 2023
2 parents d8e836f + ff17220 commit a793094
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 17 deletions.
1 change: 1 addition & 0 deletions activitysim/core/interaction_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def _interaction_sample(
allow_zero_probs=allow_zero_probs,
trace_label=trace_label,
trace_choosers=choosers,
overflow_protection=not allow_zero_probs,
)
chunk_sizer.log_df(trace_label, "probs", probs)

Expand Down
29 changes: 18 additions & 11 deletions activitysim/core/interaction_sample_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,20 +248,27 @@ def _interaction_sample_simulate(

# convert to probabilities (utilities exponentiated and normalized to probs)
# probs is same shape as utilities, one row per chooser and one column for alternative
probs = logit.utils_to_probs(
state,
utilities_df,
allow_zero_probs=allow_zero_probs,
trace_label=trace_label,
trace_choosers=choosers,
)
chunk_sizer.log_df(trace_label, "probs", probs)

if want_logsums:
logsums = logit.utils_to_logsums(
utilities_df, allow_zero_probs=allow_zero_probs
probs, logsums = logit.utils_to_probs(
state,
utilities_df,
allow_zero_probs=allow_zero_probs,
trace_label=trace_label,
trace_choosers=choosers,
overflow_protection=not allow_zero_probs,
return_logsums=True,
)
chunk_sizer.log_df(trace_label, "logsums", logsums)
else:
probs = logit.utils_to_probs(
state,
utilities_df,
allow_zero_probs=allow_zero_probs,
trace_label=trace_label,
trace_choosers=choosers,
overflow_protection=not allow_zero_probs,
)
chunk_sizer.log_df(trace_label, "probs", probs)

del utilities_df
chunk_sizer.log_df(trace_label, "utilities_df", None)
Expand Down
50 changes: 48 additions & 2 deletions activitysim/core/logit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import logging
import warnings
from builtins import object

import numpy as np
Expand Down Expand Up @@ -130,6 +131,8 @@ def utils_to_probs(
exponentiated=False,
allow_zero_probs=False,
trace_choosers=None,
overflow_protection: bool = True,
return_logsums: bool = False,
):
"""
Convert a table of utilities to probabilities.
Expand All @@ -155,6 +158,20 @@ def utils_to_probs(
by report_bad_choices because it can't deduce hh_id from the interaction_dataset
which is indexed on index values from alternatives df
overflow_protection : bool, default True
Always shift utility values such that the maximum utility in each row is
zero. This constant per-row shift should not fundamentally alter the
computed probabilities, but will ensure that an overflow does not occur
that will create infinite or NaN values. This will also provide effective
protection against underflow; extremely rare probabilities will round to
zero, but by definition they are extremely rare and losing them entirely
should not impact the simulation in a measureable fashion, and at least one
(and sometimes only one) alternative is guaranteed to have non-zero
probability, as long as at least one alternative has a finite utility value.
If utility values are certain to be well-behaved and non-extreme, enabling
overflow_protection will have no benefit but impose a modest computational
overhead cost.
Returns
-------
probs : pandas.DataFrame
Expand All @@ -167,9 +184,27 @@ def utils_to_probs(
# utils_arr = utils.values.astype('float')
utils_arr = utils.values

if utils_arr.dtype == np.float32 and utils_arr.max() > 85:
if allow_zero_probs:
if overflow_protection:
warnings.warn(
"cannot set overflow_protection with allow_zero_probs", stacklevel=2
)
overflow_protection = utils_arr.dtype == np.float32 and utils_arr.max() > 85
if overflow_protection:
raise ValueError(
"cannot prevent expected overflow with allow_zero_probs"
)
else:
overflow_protection = overflow_protection or (
utils_arr.dtype == np.float32 and utils_arr.max() > 85
)

if overflow_protection:
# exponentiated utils will overflow, downshift them
utils_arr -= utils_arr.max(1, keepdims=True)
shifts = utils_arr.max(1, keepdims=True)
utils_arr -= shifts
else:
shifts = None

if not exponentiated:
# TODO: reduce memory usage by exponentiating in-place.
Expand All @@ -185,6 +220,15 @@ def utils_to_probs(

arr_sum = utils_arr.sum(axis=1)

if return_logsums:
with np.errstate(divide="ignore" if allow_zero_probs else "warn"):
logsums = np.log(arr_sum)
if shifts is not None:
logsums += np.squeeze(shifts, 1)
logsums = pd.Series(logsums, index=utils.index)
else:
logsums = None

if not allow_zero_probs:
zero_probs = arr_sum == 0.0
if zero_probs.any():
Expand Down Expand Up @@ -222,6 +266,8 @@ def utils_to_probs(

probs = pd.DataFrame(utils_arr, columns=utils.columns, index=utils.index)

if return_logsums:
return probs, logsums
return probs


Expand Down
1 change: 1 addition & 0 deletions activitysim/core/pathbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,7 @@ def build_virtual_path(
utilities_df,
allow_zero_probs=True,
trace_label=trace_label,
overflow_protection=False,
)
chunk_sizer.log_df(trace_label, "probs", probs)

Expand Down
1 change: 1 addition & 0 deletions activitysim/core/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,7 @@ def compute_nested_probabilities(
trace_label=trace_label,
exponentiated=True,
allow_zero_probs=True,
overflow_protection=False,
)

nested_probabilities = pd.concat([nested_probabilities, probs], axis=1)
Expand Down
28 changes: 26 additions & 2 deletions activitysim/core/test/test_logit.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,40 @@ def test_utils_to_probs_raises():
idx = pd.Index(name="household_id", data=[1])
with pytest.raises(RuntimeError) as excinfo:
logit.utils_to_probs(
state, pd.DataFrame([[1, 2, np.inf, 3]], index=idx), trace_label=None
state,
pd.DataFrame([[1, 2, np.inf, 3]], index=idx),
trace_label=None,
overflow_protection=False,
)
assert "infinite exponentiated utilities" in str(excinfo.value)

with pytest.raises(RuntimeError) as excinfo:
logit.utils_to_probs(
state, pd.DataFrame([[-999, -999, -999, -999]], index=idx), trace_label=None
state,
pd.DataFrame([[1, 2, 9999, 3]], index=idx),
trace_label=None,
overflow_protection=False,
)
assert "infinite exponentiated utilities" in str(excinfo.value)

with pytest.raises(RuntimeError) as excinfo:
logit.utils_to_probs(
state,
pd.DataFrame([[-999, -999, -999, -999]], index=idx),
trace_label=None,
overflow_protection=False,
)
assert "all probabilities are zero" in str(excinfo.value)

# test that overflow protection works
z = logit.utils_to_probs(
state,
pd.DataFrame([[1, 2, 9999, 3]], index=idx),
trace_label=None,
overflow_protection=True,
)
assert np.asarray(z).ravel() == pytest.approx(np.asarray([0.0, 0.0, 1.0, 0.0]))


def test_make_choices_only_one():
state = workflow.State().default_settings()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ tour_id,person_id,tour_type,tour_type_count,tour_type_num,tour_num,tour_count,to
2373898,57899,work,1,1,1,1,mandatory,1,3402.0,3746.0,20552,47.0,7.0,17.0,10.0,,,WALK,1.0388895039783694,no_subtours,,0out_0in,work
2373980,57901,work,2,1,1,2,mandatory,1,3115.0,3746.0,20552,25.0,6.0,12.0,6.0,,,SHARED3FREE,0.6022315390131013,no_subtours,,0out_0in,work
2373981,57901,work,2,2,2,2,mandatory,1,3115.0,3746.0,20552,150.0,15.0,20.0,5.0,,,SHARED2FREE,0.6232767878249469,no_subtours,,1out_0in,work
2563802,62531,school,1,1,1,1,mandatory,1,3460.0,3316.0,21869,180.0,20.0,20.0,0.0,,,SHARED3FREE,-0.7094603590463964,,,0out_0in,school
2563802,62531,school,1,1,1,1,mandatory,1,3460.0,3316.0,21869,181.0,20.0,21.0,1.0,,,SHARED3FREE,-0.7094603590463964,,,0out_0in,school
2563821,62532,escort,1,1,1,1,non_mandatory,1,3398.0,3316.0,21869,20.0,6.0,7.0,1.0,,12.499268454965652,SHARED2FREE,-1.4604154628072699,,,0out_0in,escort
2563862,62533,escort,3,1,1,4,non_mandatory,1,3402.0,3316.0,21869,1.0,5.0,6.0,1.0,,12.534424209198946,SHARED3FREE,-1.2940574569954848,,,0out_3in,escort
2563863,62533,escort,3,2,2,4,non_mandatory,1,3519.0,3316.0,21869,99.0,11.0,11.0,0.0,,12.466623656700463,SHARED2FREE,-0.9326373013150777,,,0out_0in,escort
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ trip_id,person_id,household_id,primary_purpose,trip_num,outbound,trip_count,dest
18991850,57901,20552,work,2,True,2,3115,3460,2373981,work,,16,DRIVEALONEFREE,0.10597046751418379
18991853,57901,20552,work,1,False,1,3746,3115,2373981,home,,20,SHARED2FREE,0.23660752783217825
20510417,62531,21869,school,1,True,1,3460,3316,2563802,school,,20,SHARED3FREE,-1.4448137456466916
20510421,62531,21869,school,1,False,1,3316,3460,2563802,home,,20,WALK,-1.5207459403958272
20510421,62531,21869,school,1,False,1,3316,3460,2563802,home,,21,WALK,-1.5207459403958272
20510569,62532,21869,escort,1,True,1,3398,3316,2563821,escort,,6,SHARED2FREE,0.17869598454022895
20510573,62532,21869,escort,1,False,1,3316,3398,2563821,home,,7,DRIVEALONEFREE,0.20045149458253975
20510897,62533,21869,escort,1,True,1,3402,3316,2563862,escort,,5,SHARED3FREE,0.7112775892674524
Expand Down

0 comments on commit a793094

Please sign in to comment.