-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
286 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,250 @@ | ||
""" | ||
Read in the output files generated by analysis_scenarios | ||
generate life tables to estimate life expectancy for each run/draw | ||
produce summary statistics | ||
""" | ||
|
||
import datetime | ||
from pathlib import Path | ||
from typing import Dict, Tuple | ||
|
||
import pandas as pd | ||
|
||
from tlo.analysis.utils import ( | ||
extract_results, | ||
get_scenario_info, | ||
load_pickled_dataframes, | ||
summarize, | ||
) | ||
|
||
|
||
def _map_age_to_age_group(age: pd.Series) -> pd.Series: | ||
""" | ||
Returns age-groups used in the calculation of life-expectancy. | ||
Args: | ||
- age (pd.Series): The pd.Series containing ages | ||
Returns: | ||
- pd.Series: Series of the 'age-group', corresponding the `age` argument. | ||
""" | ||
# Define age groups in 5-year intervals | ||
age_groups = ['0'] + ['1-4'] + [f'{start}-{start + 4}' for start in range(5, 90, 5)] + ['90'] | ||
|
||
return pd.cut( | ||
age, | ||
bins=[0] + [1] + list(range(5, 95, 5)) + [float('inf')], | ||
labels=age_groups, right=False | ||
) | ||
|
||
|
||
def _extract_person_years(results_folder, _draw, _run) -> pd.Series: | ||
"""Returns the person-years that are logged.""" | ||
return load_pickled_dataframes( | ||
results_folder, _draw, _run, 'tlo.methods.demography' | ||
)['tlo.methods.demography']['person_years'] | ||
|
||
|
||
def _num_deaths_by_age_group(results_folder, target_period) -> pd.DataFrame: | ||
"""Returns dataframe with number of deaths by sex/age-group within the target period for each draw/run | ||
(dataframe returned: index=sex/age-grp, columns=draw/run) | ||
""" | ||
|
||
def extract_deaths_by_age_group(df: pd.DataFrame) -> pd.Series: | ||
age_group = _map_age_to_age_group(df['age']) | ||
return df.loc[ | ||
pd.to_datetime(df.date).dt.date.between(*target_period, inclusive='both') | ||
].groupby([age_group, df["sex"]]).size() | ||
|
||
return extract_results( | ||
results_folder, | ||
module="tlo.methods.demography", | ||
key="death", | ||
custom_generate_series=extract_deaths_by_age_group, | ||
do_scaling=False | ||
) | ||
|
||
|
||
def _aggregate_person_years_by_age(results_folder, target_period) -> pd.DataFrame: | ||
""" Returns person-years in each sex/age-group for each draw/run (as pd.DataFrame with index=sex/age-groups and | ||
columns=draw/run) | ||
""" | ||
info = get_scenario_info(results_folder) | ||
py_by_sex_and_agegroup = dict() | ||
for draw in range(info["number_of_draws"]): | ||
for run in range(info["runs_per_draw"]): | ||
_df = _extract_person_years(results_folder, _draw=draw, _run=run) | ||
|
||
# mask for entries with dates within the target period | ||
mask = _df.date.dt.date.between(*target_period, inclusive="both") | ||
|
||
# Compute PY within time-period and summing within age-group, for each sex | ||
py_by_sex_and_agegroup[(draw, run)] = pd.concat({ | ||
sex: _df.loc[mask, sex] | ||
.apply(pd.Series) | ||
.sum(axis=0) | ||
.pipe(lambda x: x.groupby(_map_age_to_age_group(x.index.astype(float))).sum()) | ||
for sex in ["M", "F"]} | ||
) | ||
|
||
# Format as pd.DataFrame with multiindex in index (sex/age-group) and columns (draw/run) | ||
py_by_sex_and_agegroup = pd.DataFrame.from_dict(py_by_sex_and_agegroup) | ||
py_by_sex_and_agegroup.index = py_by_sex_and_agegroup.index.set_names( | ||
level=[0, 1], names=["sex", "age_group"] | ||
) | ||
py_by_sex_and_agegroup.columns = py_by_sex_and_agegroup.columns.set_names( | ||
level=[0, 1], names=["draw", "run"] | ||
) | ||
|
||
return py_by_sex_and_agegroup | ||
|
||
|
||
|
||
def _estimate_life_expectancy( | ||
_person_years_at_risk: pd.Series, | ||
_number_of_deaths_in_interval: pd.Series | ||
) -> Dict[str, float]: | ||
""" | ||
For a single run, estimate life expectancy for males and females | ||
returns: Dict (keys by "M" and "F" for the sex, values the estimated life-expectancy at birth). | ||
""" | ||
|
||
estimated_life_expectancy_at_birth = dict() | ||
|
||
# first age-group is 0, then 1-4, 5-9, 10-14 etc. 22 categories in total | ||
age_group_labels = _person_years_at_risk.index.get_level_values('age_group').unique() | ||
|
||
# Extract interval width | ||
interval_width = [ | ||
5 if '90' in interval else int(interval.split('-')[1]) - int(interval.split('-')[0]) + 1 | ||
if '-' in interval else 1 for interval in age_group_labels.categories | ||
] | ||
number_age_groups = len(interval_width) | ||
fraction_of_last_age_survived = pd.Series([0.5] * number_age_groups, index=age_group_labels) | ||
|
||
# separate male and female data | ||
for sex in ['M', 'F']: | ||
person_years_by_sex = _person_years_at_risk.xs(key=sex, level='sex') | ||
number_of_deaths_by_sex = _number_of_deaths_in_interval.xs(key=sex, level='sex') | ||
|
||
death_rate_in_interval = number_of_deaths_by_sex / person_years_by_sex | ||
# if no deaths or person-years, produces nan | ||
death_rate_in_interval = death_rate_in_interval.fillna(0) | ||
# if no deaths in age 90+, set death rate equal to value in age 85-89 | ||
if death_rate_in_interval.loc['90'] == 0: | ||
death_rate_in_interval.loc['90'] = death_rate_in_interval.loc['85-89'] | ||
|
||
# Calculate the probability of dying in the interval | ||
# condition checks whether the observed number deaths is significantly higher than would be expected | ||
# based on population years at risk and survival fraction | ||
# if true, suggests very high mortality rates and returns value 1 | ||
condition = number_of_deaths_by_sex > ( | ||
person_years_by_sex / interval_width / fraction_of_last_age_survived) | ||
probability_of_dying_in_interval = pd.Series(index=number_of_deaths_by_sex.index, dtype=float) | ||
probability_of_dying_in_interval[condition] = 1 | ||
probability_of_dying_in_interval[~condition] = interval_width * death_rate_in_interval / ( | ||
1 + interval_width * (1 - fraction_of_last_age_survived) * death_rate_in_interval) | ||
# all those surviving to final interval die during this interval | ||
probability_of_dying_in_interval.at['90'] = 1 | ||
|
||
# number_alive_at_start_of_interval | ||
# keep dtype as float in case using aggregated outputs | ||
# note range stops BEFORE the specified number | ||
number_alive_at_start_of_interval = pd.Series(index=range(number_age_groups), dtype=float) | ||
number_alive_at_start_of_interval[0] = 100_000 # hypothetical cohort | ||
for i in range(1, number_age_groups): | ||
number_alive_at_start_of_interval[i] = (1 - probability_of_dying_in_interval[i - 1]) * \ | ||
number_alive_at_start_of_interval[i - 1] | ||
|
||
# number_dying_in_interval | ||
number_dying_in_interval = pd.Series(index=range(number_age_groups), dtype=float) | ||
for i in range(0, number_age_groups - 1): | ||
number_dying_in_interval[i] = number_alive_at_start_of_interval[i] - number_alive_at_start_of_interval[ | ||
i + 1] | ||
number_dying_in_interval[number_age_groups - 1] = number_alive_at_start_of_interval[number_age_groups - 1] | ||
|
||
# person-years lived in interval | ||
py_lived_in_interval = pd.Series(index=range(number_age_groups), dtype=float) | ||
for i in range(0, number_age_groups - 1): | ||
py_lived_in_interval[i] = interval_width[i] * ( | ||
number_alive_at_start_of_interval[i + 1] + fraction_of_last_age_survived[i] * number_dying_in_interval[ | ||
i]) | ||
py_lived_in_interval[number_age_groups - 1] = number_alive_at_start_of_interval[number_age_groups - 1] / \ | ||
death_rate_in_interval[number_age_groups - 1] | ||
|
||
# person-years lived beyond start of interval | ||
# have to iterate backwards for this | ||
py_lived_beyond_start_of_interval = pd.Series(index=range(number_age_groups), dtype=float) | ||
py_lived_beyond_start_of_interval[number_age_groups - 1] = py_lived_in_interval[number_age_groups - 1] | ||
for i in range((number_age_groups - 2), -1, -1): | ||
py_lived_beyond_start_of_interval[i] = py_lived_beyond_start_of_interval[i + 1] + py_lived_in_interval[i] | ||
|
||
# calculate observed life expectancy at start of interval | ||
# if number of people alive at start of interval=0, condition returns true and observed life expectancy=0 | ||
condition = number_alive_at_start_of_interval == 0 | ||
observed_life_expectancy = pd.Series(index=range(number_age_groups), dtype=float) | ||
observed_life_expectancy[condition] = 0 | ||
observed_life_expectancy[~condition] = py_lived_beyond_start_of_interval / number_alive_at_start_of_interval | ||
|
||
# estimated life expectancy from birth | ||
estimated_life_expectancy_at_birth[sex] = observed_life_expectancy[0] | ||
|
||
return estimated_life_expectancy_at_birth | ||
|
||
|
||
def get_life_expectancy_estimates( | ||
results_folder: Path, | ||
target_period: Tuple[datetime.date, datetime.date], | ||
summary: bool = True | ||
) -> pd.DataFrame: | ||
""" | ||
produces sets of life expectancy estimates for each draw/run | ||
calls: | ||
*1 _num_deaths_by_age_group | ||
*2 _aggregate_person_years_by_age | ||
Args: | ||
- results_folder (PosixPath): The path to the results folder containing log, `tlo.methods.demography` | ||
- target period (tuple of dates): declare the date range (inclusively) in which life expectancy is to be estimated | ||
- summary (bool): declare whether to return a summarized value (mean with 95% uncertainty intervals) | ||
or return the estimate for each draw/run | ||
Returns: | ||
- pd.DataFrame: The DataFrame with the life expectancy estimates (in years) | ||
for every draw/run in the results folder; or, with option `summary=True` summarized (central, lower, | ||
upper estimates) for each draw. | ||
example use: | ||
test = produce_life_expectancy_estimates(results_folder, median=True, | ||
target_period=(Date(2019, 1, 1), Date(2020, 1, 1))) | ||
""" | ||
|
||
# get number of draws and numbers of runs | ||
info = get_scenario_info(results_folder) | ||
|
||
# extract numbers of deaths (by age-group, within the target_period) | ||
deaths = _num_deaths_by_age_group(results_folder, target_period) | ||
|
||
# extract person-years (by age-group, within the target_period) | ||
person_years = _aggregate_person_years_by_age(results_folder, target_period) | ||
|
||
# Initialize an empty list to collect life expectancies | ||
le_for_each_draw_and_run = dict() | ||
|
||
for draw in range(info['number_of_draws']): | ||
for run in range(info['runs_per_draw']): | ||
le_for_each_draw_and_run[(draw, run)] = _estimate_life_expectancy( | ||
_number_of_deaths_in_interval=deaths[(draw, run)], | ||
_person_years_at_risk=person_years[(draw, run)] | ||
) | ||
|
||
output = pd.DataFrame.from_dict(le_for_each_draw_and_run) | ||
output.index.name = "sex" | ||
output.columns = output.columns.set_names(level=[0, 1], names=['draw', 'run']) | ||
|
||
if not summary: | ||
return output | ||
|
||
else: | ||
return summarize(results=output, only_mean=False, collapse_columns=False) |
Binary file added
BIN
+33.2 KB
tests/resources/dummy_simulation_run/0/0/tlo.methods.demography.pickle
Binary file not shown.
Binary file added
BIN
+34.4 KB
tests/resources/dummy_simulation_run/0/1/tlo.methods.demography.pickle
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import datetime | ||
import os | ||
from pathlib import Path | ||
|
||
import pandas as pd | ||
|
||
from tlo.analysis.life_expectancy import get_life_expectancy_estimates | ||
|
||
|
||
def test_get_life_expectancy(): | ||
"""Use `get_life_expectancy_estimates` to generate estimate of life-expectancy from the dummy simulation data.""" | ||
|
||
results_folder_dummy_results = Path(os.path.dirname(__file__)) / 'resources' / 'dummy_simulation_run' | ||
|
||
# Summary measure: Should have row ('M', 'F') and columns ('mean', 'lower', 'upper') | ||
rtn_summary = get_life_expectancy_estimates( | ||
results_folder=results_folder_dummy_results, | ||
target_period=(datetime.date(2010, 1, 1), datetime.date(2010, 12, 31)), | ||
summary=True, | ||
) | ||
assert isinstance(rtn_summary, pd.DataFrame) | ||
assert sorted(rtn_summary.index.to_list()) == ["F", "M"] | ||
assert list(rtn_summary.columns.names) == ['draw', 'stat'] | ||
assert rtn_summary.columns.levels[1].to_list() == ["lower", "mean", "upper"] | ||
|
||
# Non-summary measure: Estimate should be for each run/draw | ||
rtn_full = get_life_expectancy_estimates( | ||
results_folder=results_folder_dummy_results, | ||
target_period=(datetime.date(2010, 1, 1), datetime.date(2010, 12, 31)), | ||
summary=False, | ||
) | ||
assert isinstance(rtn_full, pd.DataFrame) | ||
assert sorted(rtn_full.index.to_list()) == ["F", "M"] | ||
assert list(rtn_full.columns.names) == ['draw', 'run'] | ||
assert rtn_full.columns.levels[1].to_list() == [0, 1] |