-
Notifications
You must be signed in to change notification settings - Fork 6
Feat: support array type for datetime_ref
in plot_age_pyramid()
#43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
16dd70c
605cde7
38439e1
1b024f2
322c3ff
5520249
3ae36e6
4f20bc5
7035cea
d32a72c
099b337
7810eaf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,23 @@ | ||
from copy import copy | ||
from datetime import datetime | ||
from typing import Tuple | ||
|
||
import altair as alt | ||
import numpy as np | ||
import pandas as pd | ||
from pandas.core.frame import DataFrame | ||
from pandas.core.series import Series | ||
|
||
from ..utils.checks import check_columns | ||
from ..utils.framework import bd | ||
from ..utils.typing import DataFrame | ||
|
||
|
||
def plot_age_pyramid( | ||
person: DataFrame, | ||
datetime_ref: datetime = None, | ||
filename: str = None, | ||
savefig: bool = False, | ||
return_vector: bool = False, | ||
return_array: bool = False, | ||
) -> Tuple[alt.Chart, Series]: | ||
"""Plot an age pyramid from a 'person' pandas DataFrame. | ||
|
||
|
@@ -28,8 +29,10 @@ def plot_age_pyramid( | |
- `person_id`, dtype : any | ||
- `gender_source_value`, dtype : str, {'m', 'f'} | ||
|
||
datetime_ref : datetime, | ||
datetime_ref : Union[datetime, str], | ||
The reference date to compute population age from. | ||
If a string, it searches for a column with the same name in the person table: each patient has his own datetime reference. | ||
If a datetime, the reference datetime is the same for all patients. | ||
If set to None, datetime.today() will be used instead. | ||
|
||
savefig : bool, | ||
|
@@ -55,38 +58,50 @@ def plot_age_pyramid( | |
if not isinstance(filename, str): | ||
raise ValueError(f"'filename' type must be str, got {type(filename)}") | ||
|
||
person_ = person.copy() | ||
datetime_ref_raw = copy(datetime_ref) | ||
|
||
if datetime_ref is None: | ||
today = datetime.today() | ||
datetime_ref = datetime.today() | ||
elif isinstance(datetime_ref, datetime): | ||
datetime_ref = pd.to_datetime(datetime_ref) | ||
elif isinstance(datetime_ref, str): | ||
# A string type for datetime_ref could be either | ||
# a column name or a datetime in string format. | ||
if datetime_ref in person.columns: | ||
datetime_ref = person[datetime_ref] | ||
else: | ||
datetime_ref = pd.to_datetime( | ||
datetime_ref, errors="coerce" | ||
) # In case of error, will return NaT | ||
if pd.isnull(datetime_ref): | ||
raise ValueError( | ||
f"`datetime_ref` must either be a column name or parseable date, " | ||
f"got string '{datetime_ref_raw}'" | ||
) | ||
else: | ||
today = pd.to_datetime(datetime_ref) | ||
raise TypeError( | ||
f"`datetime_ref` must be either None, a parseable string date" | ||
f", a column name or a datetime. Got type: {type(datetime_ref)}, {datetime_ref}" | ||
) | ||
|
||
cols_to_keep = ["person_id", "birth_datetime", "gender_source_value"] | ||
person_ = bd.to_pandas(person[cols_to_keep]) | ||
|
||
# TODO: replace with from ..utils.datetime_helpers.substract_datetime | ||
deltas = today - person_["birth_datetime"] | ||
if bd.is_pandas(person_): | ||
deltas = deltas.dt.total_seconds() | ||
person_["age"] = (datetime_ref - person_["birth_datetime"]).dt.total_seconds() | ||
person_["age"] /= 365 * 24 * 3600 | ||
|
||
person_["age"] = deltas / (365 * 24 * 3600) | ||
person_ = person_.query("age > 0.0") | ||
mask = person_["age"] > 90 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this mask ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was mainly to gather rare categories (90, 100, 110) and avoid making too many bins with very few individuals. Does it seem relevant to you? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is relevant for me as I understand that this mask basically merge the bins: (90, 100], (100, 110], and above, which is the classic intention of a user. |
||
person_.loc[mask, "age"] = 99 | ||
|
||
bins = np.arange(0, 100, 10) | ||
labels = [f"{left}-{right}" for left, right in zip(bins[:-1], bins[1:])] | ||
person_["age_bins"] = bd.cut(person_["age"], bins=bins, labels=labels) | ||
|
||
person_["age_bins"] = ( | ||
person_["age_bins"].astype(str).str.lower().str.replace("nan", "90+") | ||
) | ||
person_["age_bins"] = pd.cut(person_["age"], bins=bins, labels=labels) | ||
|
||
person_ = person_.loc[person_["gender_source_value"].isin(["m", "f"])] | ||
group_gender_age = person_.groupby(["gender_source_value", "age_bins"])[ | ||
"person_id" | ||
].count() | ||
|
||
# Convert to pandas to ease plotting. | ||
# Since we have aggregated the data, this operation won't crash. | ||
group_gender_age = bd.to_pandas(group_gender_age) | ||
|
||
male = group_gender_age["m"].reset_index() | ||
female = group_gender_age["f"].reset_index() | ||
|
||
|
@@ -123,10 +138,10 @@ def plot_age_pyramid( | |
|
||
if savefig: | ||
chart.save(filename) | ||
if return_vector: | ||
if return_array: | ||
return group_gender_age | ||
|
||
if return_vector: | ||
if return_array: | ||
return chart, group_gender_age | ||
|
||
return chart |
Uh oh!
There was an error while loading. Please reload this page.