Skip to content

Commit

Permalink
life expectancy estimates (#1219)
Browse files Browse the repository at this point in the history
  • Loading branch information
tdm32 authored Feb 8, 2024
1 parent 7c20a2e commit 0109011
Show file tree
Hide file tree
Showing 5 changed files with 286 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
*.shp filter=lfs diff=lfs merge=lfs -text
*.shx filter=lfs diff=lfs merge=lfs -text
*.xlsx filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
resources/** filter=lfs diff=lfs merge=lfs -text
250 changes: 250 additions & 0 deletions src/tlo/analysis/life_expectancy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
"""
Read in the output files generated by analysis_scenarios
generate life tables to estimate life expectancy for each run/draw
produce summary statistics
"""

import datetime
from pathlib import Path
from typing import Dict, Tuple

import pandas as pd

from tlo.analysis.utils import (
extract_results,
get_scenario_info,
load_pickled_dataframes,
summarize,
)


def _map_age_to_age_group(age: pd.Series) -> pd.Series:
"""
Returns age-groups used in the calculation of life-expectancy.
Args:
- age (pd.Series): The pd.Series containing ages
Returns:
- pd.Series: Series of the 'age-group', corresponding the `age` argument.
"""
# Define age groups in 5-year intervals
age_groups = ['0'] + ['1-4'] + [f'{start}-{start + 4}' for start in range(5, 90, 5)] + ['90']

return pd.cut(
age,
bins=[0] + [1] + list(range(5, 95, 5)) + [float('inf')],
labels=age_groups, right=False
)


def _extract_person_years(results_folder, _draw, _run) -> pd.Series:
"""Returns the person-years that are logged."""
return load_pickled_dataframes(
results_folder, _draw, _run, 'tlo.methods.demography'
)['tlo.methods.demography']['person_years']


def _num_deaths_by_age_group(results_folder, target_period) -> pd.DataFrame:
"""Returns dataframe with number of deaths by sex/age-group within the target period for each draw/run
(dataframe returned: index=sex/age-grp, columns=draw/run)
"""

def extract_deaths_by_age_group(df: pd.DataFrame) -> pd.Series:
age_group = _map_age_to_age_group(df['age'])
return df.loc[
pd.to_datetime(df.date).dt.date.between(*target_period, inclusive='both')
].groupby([age_group, df["sex"]]).size()

return extract_results(
results_folder,
module="tlo.methods.demography",
key="death",
custom_generate_series=extract_deaths_by_age_group,
do_scaling=False
)


def _aggregate_person_years_by_age(results_folder, target_period) -> pd.DataFrame:
""" Returns person-years in each sex/age-group for each draw/run (as pd.DataFrame with index=sex/age-groups and
columns=draw/run)
"""
info = get_scenario_info(results_folder)
py_by_sex_and_agegroup = dict()
for draw in range(info["number_of_draws"]):
for run in range(info["runs_per_draw"]):
_df = _extract_person_years(results_folder, _draw=draw, _run=run)

# mask for entries with dates within the target period
mask = _df.date.dt.date.between(*target_period, inclusive="both")

# Compute PY within time-period and summing within age-group, for each sex
py_by_sex_and_agegroup[(draw, run)] = pd.concat({
sex: _df.loc[mask, sex]
.apply(pd.Series)
.sum(axis=0)
.pipe(lambda x: x.groupby(_map_age_to_age_group(x.index.astype(float))).sum())
for sex in ["M", "F"]}
)

# Format as pd.DataFrame with multiindex in index (sex/age-group) and columns (draw/run)
py_by_sex_and_agegroup = pd.DataFrame.from_dict(py_by_sex_and_agegroup)
py_by_sex_and_agegroup.index = py_by_sex_and_agegroup.index.set_names(
level=[0, 1], names=["sex", "age_group"]
)
py_by_sex_and_agegroup.columns = py_by_sex_and_agegroup.columns.set_names(
level=[0, 1], names=["draw", "run"]
)

return py_by_sex_and_agegroup



def _estimate_life_expectancy(
_person_years_at_risk: pd.Series,
_number_of_deaths_in_interval: pd.Series
) -> Dict[str, float]:
"""
For a single run, estimate life expectancy for males and females
returns: Dict (keys by "M" and "F" for the sex, values the estimated life-expectancy at birth).
"""

estimated_life_expectancy_at_birth = dict()

# first age-group is 0, then 1-4, 5-9, 10-14 etc. 22 categories in total
age_group_labels = _person_years_at_risk.index.get_level_values('age_group').unique()

# Extract interval width
interval_width = [
5 if '90' in interval else int(interval.split('-')[1]) - int(interval.split('-')[0]) + 1
if '-' in interval else 1 for interval in age_group_labels.categories
]
number_age_groups = len(interval_width)
fraction_of_last_age_survived = pd.Series([0.5] * number_age_groups, index=age_group_labels)

# separate male and female data
for sex in ['M', 'F']:
person_years_by_sex = _person_years_at_risk.xs(key=sex, level='sex')
number_of_deaths_by_sex = _number_of_deaths_in_interval.xs(key=sex, level='sex')

death_rate_in_interval = number_of_deaths_by_sex / person_years_by_sex
# if no deaths or person-years, produces nan
death_rate_in_interval = death_rate_in_interval.fillna(0)
# if no deaths in age 90+, set death rate equal to value in age 85-89
if death_rate_in_interval.loc['90'] == 0:
death_rate_in_interval.loc['90'] = death_rate_in_interval.loc['85-89']

# Calculate the probability of dying in the interval
# condition checks whether the observed number deaths is significantly higher than would be expected
# based on population years at risk and survival fraction
# if true, suggests very high mortality rates and returns value 1
condition = number_of_deaths_by_sex > (
person_years_by_sex / interval_width / fraction_of_last_age_survived)
probability_of_dying_in_interval = pd.Series(index=number_of_deaths_by_sex.index, dtype=float)
probability_of_dying_in_interval[condition] = 1
probability_of_dying_in_interval[~condition] = interval_width * death_rate_in_interval / (
1 + interval_width * (1 - fraction_of_last_age_survived) * death_rate_in_interval)
# all those surviving to final interval die during this interval
probability_of_dying_in_interval.at['90'] = 1

# number_alive_at_start_of_interval
# keep dtype as float in case using aggregated outputs
# note range stops BEFORE the specified number
number_alive_at_start_of_interval = pd.Series(index=range(number_age_groups), dtype=float)
number_alive_at_start_of_interval[0] = 100_000 # hypothetical cohort
for i in range(1, number_age_groups):
number_alive_at_start_of_interval[i] = (1 - probability_of_dying_in_interval[i - 1]) * \
number_alive_at_start_of_interval[i - 1]

# number_dying_in_interval
number_dying_in_interval = pd.Series(index=range(number_age_groups), dtype=float)
for i in range(0, number_age_groups - 1):
number_dying_in_interval[i] = number_alive_at_start_of_interval[i] - number_alive_at_start_of_interval[
i + 1]
number_dying_in_interval[number_age_groups - 1] = number_alive_at_start_of_interval[number_age_groups - 1]

# person-years lived in interval
py_lived_in_interval = pd.Series(index=range(number_age_groups), dtype=float)
for i in range(0, number_age_groups - 1):
py_lived_in_interval[i] = interval_width[i] * (
number_alive_at_start_of_interval[i + 1] + fraction_of_last_age_survived[i] * number_dying_in_interval[
i])
py_lived_in_interval[number_age_groups - 1] = number_alive_at_start_of_interval[number_age_groups - 1] / \
death_rate_in_interval[number_age_groups - 1]

# person-years lived beyond start of interval
# have to iterate backwards for this
py_lived_beyond_start_of_interval = pd.Series(index=range(number_age_groups), dtype=float)
py_lived_beyond_start_of_interval[number_age_groups - 1] = py_lived_in_interval[number_age_groups - 1]
for i in range((number_age_groups - 2), -1, -1):
py_lived_beyond_start_of_interval[i] = py_lived_beyond_start_of_interval[i + 1] + py_lived_in_interval[i]

# calculate observed life expectancy at start of interval
# if number of people alive at start of interval=0, condition returns true and observed life expectancy=0
condition = number_alive_at_start_of_interval == 0
observed_life_expectancy = pd.Series(index=range(number_age_groups), dtype=float)
observed_life_expectancy[condition] = 0
observed_life_expectancy[~condition] = py_lived_beyond_start_of_interval / number_alive_at_start_of_interval

# estimated life expectancy from birth
estimated_life_expectancy_at_birth[sex] = observed_life_expectancy[0]

return estimated_life_expectancy_at_birth


def get_life_expectancy_estimates(
results_folder: Path,
target_period: Tuple[datetime.date, datetime.date],
summary: bool = True
) -> pd.DataFrame:
"""
produces sets of life expectancy estimates for each draw/run
calls:
*1 _num_deaths_by_age_group
*2 _aggregate_person_years_by_age
Args:
- results_folder (PosixPath): The path to the results folder containing log, `tlo.methods.demography`
- target period (tuple of dates): declare the date range (inclusively) in which life expectancy is to be estimated
- summary (bool): declare whether to return a summarized value (mean with 95% uncertainty intervals)
or return the estimate for each draw/run
Returns:
- pd.DataFrame: The DataFrame with the life expectancy estimates (in years)
for every draw/run in the results folder; or, with option `summary=True` summarized (central, lower,
upper estimates) for each draw.
example use:
test = produce_life_expectancy_estimates(results_folder, median=True,
target_period=(Date(2019, 1, 1), Date(2020, 1, 1)))
"""

# get number of draws and numbers of runs
info = get_scenario_info(results_folder)

# extract numbers of deaths (by age-group, within the target_period)
deaths = _num_deaths_by_age_group(results_folder, target_period)

# extract person-years (by age-group, within the target_period)
person_years = _aggregate_person_years_by_age(results_folder, target_period)

# Initialize an empty list to collect life expectancies
le_for_each_draw_and_run = dict()

for draw in range(info['number_of_draws']):
for run in range(info['runs_per_draw']):
le_for_each_draw_and_run[(draw, run)] = _estimate_life_expectancy(
_number_of_deaths_in_interval=deaths[(draw, run)],
_person_years_at_risk=person_years[(draw, run)]
)

output = pd.DataFrame.from_dict(le_for_each_draw_and_run)
output.index.name = "sex"
output.columns = output.columns.set_names(level=[0, 1], names=['draw', 'run'])

if not summary:
return output

else:
return summarize(results=output, only_mean=False, collapse_columns=False)
Binary file not shown.
Binary file not shown.
35 changes: 35 additions & 0 deletions tests/test_life_expectancy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import datetime
import os
from pathlib import Path

import pandas as pd

from tlo.analysis.life_expectancy import get_life_expectancy_estimates


def test_get_life_expectancy():
"""Use `get_life_expectancy_estimates` to generate estimate of life-expectancy from the dummy simulation data."""

results_folder_dummy_results = Path(os.path.dirname(__file__)) / 'resources' / 'dummy_simulation_run'

# Summary measure: Should have row ('M', 'F') and columns ('mean', 'lower', 'upper')
rtn_summary = get_life_expectancy_estimates(
results_folder=results_folder_dummy_results,
target_period=(datetime.date(2010, 1, 1), datetime.date(2010, 12, 31)),
summary=True,
)
assert isinstance(rtn_summary, pd.DataFrame)
assert sorted(rtn_summary.index.to_list()) == ["F", "M"]
assert list(rtn_summary.columns.names) == ['draw', 'stat']
assert rtn_summary.columns.levels[1].to_list() == ["lower", "mean", "upper"]

# Non-summary measure: Estimate should be for each run/draw
rtn_full = get_life_expectancy_estimates(
results_folder=results_folder_dummy_results,
target_period=(datetime.date(2010, 1, 1), datetime.date(2010, 12, 31)),
summary=False,
)
assert isinstance(rtn_full, pd.DataFrame)
assert sorted(rtn_full.index.to_list()) == ["F", "M"]
assert list(rtn_full.columns.names) == ['draw', 'run']
assert rtn_full.columns.levels[1].to_list() == [0, 1]

0 comments on commit 0109011

Please sign in to comment.