Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: support array type for datetime_ref in plot_age_pyramid() #43

1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Allow saving DB locally in client or cluster mode
- Add data cleaning function to handle incorrect datetime in spark
- Filter biology config on care site
- Adding person-dependent `datetime_ref` to `plot_age_pyramid`

### Fixed

Expand Down
83 changes: 57 additions & 26 deletions eds_scikit/plot/data_quality.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
from copy import copy
from datetime import datetime
from typing import Tuple

import altair as alt
import numpy as np
import pandas as pd
from pandas.core.frame import DataFrame
from loguru import logger
from pandas.core.series import Series

from ..utils.checks import check_columns
from ..utils.framework import bd
from ..utils.typing import DataFrame


def plot_age_pyramid(
person: DataFrame,
datetime_ref: datetime = None,
filename: str = None,
savefig: bool = False,
return_vector: bool = False,
return_array: bool = False,
) -> Tuple[alt.Chart, Series]:
"""Plot an age pyramid from a 'person' pandas DataFrame.

Expand All @@ -28,16 +30,21 @@ def plot_age_pyramid(
- `person_id`, dtype : any
- `gender_source_value`, dtype : str, {'m', 'f'}

datetime_ref : datetime,
datetime_ref : Union[datetime, str], default None
The reference date to compute population age from.
If a string, it searches for a column with the same name in the person table: each patient has his own datetime reference.
If a datetime, the reference datetime is the same for all patients.
If set to None, datetime.today() will be used instead.

savefig : bool,
filename : str, default None
The path to save figure at.

savefig : bool, default False
If set to True, filename must be set.
The plot will be saved at the specified filename.

filename : Optional[str],
The path to save figure at.
return_array : bool, default False
If set to True, return chart and its pd.Dataframe representation.

Returns
-------
Expand All @@ -55,38 +62,62 @@ def plot_age_pyramid(
if not isinstance(filename, str):
raise ValueError(f"'filename' type must be str, got {type(filename)}")

person_ = person.copy()
datetime_ref_raw = copy(datetime_ref)

if datetime_ref is None:
today = datetime.today()
datetime_ref = datetime.today()
elif isinstance(datetime_ref, datetime):
datetime_ref = pd.to_datetime(datetime_ref)
elif isinstance(datetime_ref, str):
# A string type for datetime_ref could be either
# a column name or a datetime in string format.
if datetime_ref in person.columns:
datetime_ref = person[datetime_ref]
else:
datetime_ref = pd.to_datetime(
datetime_ref, errors="coerce"
) # In case of error, will return NaT
if pd.isnull(datetime_ref):
raise ValueError(
f"`datetime_ref` must either be a column name or parseable date, "
f"got string '{datetime_ref_raw}'"
)
else:
today = pd.to_datetime(datetime_ref)
raise TypeError(
f"`datetime_ref` must be either None, a parseable string date"
f", a column name or a datetime. Got type: {type(datetime_ref)}, {datetime_ref}"
)

cols_to_keep = ["person_id", "birth_datetime", "gender_source_value"]
person_ = bd.to_pandas(person[cols_to_keep])

# TODO: replace with from ..utils.datetime_helpers.substract_datetime
deltas = today - person_["birth_datetime"]
if bd.is_pandas(person_):
deltas = deltas.dt.total_seconds()
person_["age"] = (datetime_ref - person_["birth_datetime"]).dt.total_seconds()
person_["age"] /= 365 * 24 * 3600

person_["age"] = deltas / (365 * 24 * 3600)
person_ = person_.query("age > 0.0")
Vincent-Maladiere marked this conversation as resolved.
Show resolved Hide resolved
# Remove outliers
mask_age_inliners = (person_["age"] > 0) & (person_["age"] < 125)
n_outliers = (~mask_age_inliners).sum()
if n_outliers > 0:
perc_outliers = 100 * n_outliers / person_.shape[0]
logger.warning(
f"{n_outliers} ({perc_outliers:.4f}%) individuals' "
"age is out of the (0, 125) interval, we skip them."
)
person_ = person_.loc[mask_age_inliners]

# Aggregate rare age categories
mask_rare_age_agg = person_["age"] > 90
person_.loc[mask_rare_age_agg, "age"] = 99

bins = np.arange(0, 100, 10)
labels = [f"{left}-{right}" for left, right in zip(bins[:-1], bins[1:])]
person_["age_bins"] = bd.cut(person_["age"], bins=bins, labels=labels)

person_["age_bins"] = (
person_["age_bins"].astype(str).str.lower().str.replace("nan", "90+")
)
person_["age_bins"] = pd.cut(person_["age"], bins=bins, labels=labels)

person_ = person_.loc[person_["gender_source_value"].isin(["m", "f"])]
group_gender_age = person_.groupby(["gender_source_value", "age_bins"])[
"person_id"
].count()

# Convert to pandas to ease plotting.
# Since we have aggregated the data, this operation won't crash.
group_gender_age = bd.to_pandas(group_gender_age)

male = group_gender_age["m"].reset_index()
female = group_gender_age["f"].reset_index()

Expand Down Expand Up @@ -123,10 +154,10 @@ def plot_age_pyramid(

if savefig:
chart.save(filename)
if return_vector:
if return_array:
return group_gender_age

if return_vector:
if return_array:
return chart, group_gender_age

return chart
54 changes: 36 additions & 18 deletions tests/test_age_pyramid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import altair as alt
import numpy as np
import pandas as pd
import pytest
from pandas.core.series import Series
from pandas.testing import assert_frame_equal
Expand All @@ -12,25 +13,25 @@

data = load_person()

person_with_inclusion_date = data.person.copy()
N = len(person_with_inclusion_date)
delta_days = pd.to_timedelta(np.random.randint(0, 1000, N), unit="d")

@pytest.mark.parametrize(
"datetime_ref",
[
None,
datetime(2020, 1, 1),
np.full(data.person.shape[0], datetime(2020, 1, 1)),
],
person_with_inclusion_date["inclusion_datetime"] = (
person_with_inclusion_date["birth_datetime"] + delta_days
)
def test_age_pyramid_datetime_ref_format(datetime_ref):
original_person = data.person.copy()

chart = plot_age_pyramid(
data.person, datetime_ref, savefig=False, return_vector=False
)

@pytest.mark.parametrize(
"datetime_ref", [datetime(2020, 1, 1), "inclusion_datetime", "2020-01-01"]
)
def test_plot_age_pyramid(datetime_ref):
original_person = person_with_inclusion_date.copy()
chart = plot_age_pyramid(person_with_inclusion_date, datetime_ref, savefig=False)
assert isinstance(chart, alt.vegalite.v4.api.ConcatChart)

# Check that the data is unchanged
assert_frame_equal(original_person, data.person)
assert_frame_equal(original_person, person_with_inclusion_date)


def test_age_pyramid_output():
Expand All @@ -42,23 +43,40 @@ def test_age_pyramid_output():
path.unlink()

group_gender_age = plot_age_pyramid(
data.person, savefig=True, return_vector=True, filename=filename
data.person, savefig=True, return_array=True, filename=filename
)
assert isinstance(group_gender_age, Series)

chart, group_gender_age = plot_age_pyramid(
data.person, savefig=False, return_vector=True
data.person, savefig=False, return_array=True
)
assert isinstance(chart, alt.vegalite.v4.api.ConcatChart)
assert isinstance(group_gender_age, Series)

chart = plot_age_pyramid(data.person, savefig=False, return_vector=False)
chart = plot_age_pyramid(data.person, savefig=False, return_array=False)
assert isinstance(chart, alt.vegalite.v4.api.ConcatChart)

with pytest.raises(ValueError, match="You have to set a filename"):
_ = plot_age_pyramid(data.person, savefig=True, filename=None)
plot_age_pyramid(person_with_inclusion_date, savefig=True, filename=None)

with pytest.raises(
ValueError, match="'filename' type must be str, got <class 'list'>"
):
_ = plot_age_pyramid(data.person, savefig=True, filename=[1])
plot_age_pyramid(person_with_inclusion_date, savefig=True, filename=[1])


def test_plot_age_pyramid_datetime_ref_error():
with pytest.raises(
ValueError,
match="`datetime_ref` must either be a column name or parseable date, got string '20x2-01-01'",
):
_ = plot_age_pyramid(
person_with_inclusion_date, datetime_ref="20x2-01-01", savefig=False
)
with pytest.raises(
TypeError,
match="`datetime_ref` must be either None, a parseable string date, a column name or a datetime. Got type: <class 'int'>, 2022",
):
_ = plot_age_pyramid(
person_with_inclusion_date, datetime_ref=2022, savefig=False
)