Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: support array type for datetime_ref in plot_age_pyramid() #43

1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Allow saving DB locally in client or cluster mode
- Add data cleaning function to handle incorrect datetime in spark
- Filter biology config on care site
- Adding person-dependent `datetime_ref` to `plot_age_pyramid`

### Fixed

Expand Down
61 changes: 38 additions & 23 deletions eds_scikit/plot/data_quality.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
from copy import copy
from datetime import datetime
from typing import Tuple

import altair as alt
import numpy as np
import pandas as pd
from pandas.core.frame import DataFrame
from pandas.core.series import Series

from ..utils.checks import check_columns
from ..utils.framework import bd
from ..utils.typing import DataFrame


def plot_age_pyramid(
person: DataFrame,
datetime_ref: datetime = None,
filename: str = None,
savefig: bool = False,
return_vector: bool = False,
return_array: bool = False,
) -> Tuple[alt.Chart, Series]:
"""Plot an age pyramid from a 'person' pandas DataFrame.

Expand All @@ -28,8 +29,10 @@ def plot_age_pyramid(
- `person_id`, dtype : any
- `gender_source_value`, dtype : str, {'m', 'f'}

datetime_ref : datetime,
datetime_ref : Union[datetime, str],
The reference date to compute population age from.
If a string, it searches for a column with the same name in the person table: each patient has his own datetime reference.
If a datetime, the reference datetime is the same for all patients.
If set to None, datetime.today() will be used instead.

savefig : bool,
Expand All @@ -55,38 +58,50 @@ def plot_age_pyramid(
if not isinstance(filename, str):
raise ValueError(f"'filename' type must be str, got {type(filename)}")

person_ = person.copy()
datetime_ref_raw = copy(datetime_ref)

if datetime_ref is None:
today = datetime.today()
datetime_ref = datetime.today()
elif isinstance(datetime_ref, datetime):
datetime_ref = pd.to_datetime(datetime_ref)
elif isinstance(datetime_ref, str):
# A string type for datetime_ref could be either
# a column name or a datetime in string format.
if datetime_ref in person.columns:
datetime_ref = person[datetime_ref]
else:
datetime_ref = pd.to_datetime(
datetime_ref, errors="coerce"
) # In case of error, will return NaT
if pd.isnull(datetime_ref):
raise ValueError(
f"`datetime_ref` must either be a column name or parseable date, "
f"got string '{datetime_ref_raw}'"
)
else:
today = pd.to_datetime(datetime_ref)
raise TypeError(
f"`datetime_ref` must be either None, a parseable string date"
f", a column name or a datetime. Got type: {type(datetime_ref)}, {datetime_ref}"
)

cols_to_keep = ["person_id", "birth_datetime", "gender_source_value"]
person_ = bd.to_pandas(person[cols_to_keep])

# TODO: replace with from ..utils.datetime_helpers.substract_datetime
deltas = today - person_["birth_datetime"]
if bd.is_pandas(person_):
deltas = deltas.dt.total_seconds()
person_["age"] = (datetime_ref - person_["birth_datetime"]).dt.total_seconds()
person_["age"] /= 365 * 24 * 3600

person_["age"] = deltas / (365 * 24 * 3600)
person_ = person_.query("age > 0.0")
Vincent-Maladiere marked this conversation as resolved.
Show resolved Hide resolved
mask = person_["age"] > 90
Copy link
Contributor

@strayMat strayMat Mar 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this mask ?
Is it for privacy reason due to low samples with aged more than 90 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was mainly to gather rare categories (90, 100, 110) and avoid making too many bins with very few individuals. Does it seem relevant to you?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is relevant for me as I understand that this mask basically merge the bins: (90, 100], (100, 110], and above, which is the classic intention of a user.

person_.loc[mask, "age"] = 99

bins = np.arange(0, 100, 10)
labels = [f"{left}-{right}" for left, right in zip(bins[:-1], bins[1:])]
person_["age_bins"] = bd.cut(person_["age"], bins=bins, labels=labels)

person_["age_bins"] = (
person_["age_bins"].astype(str).str.lower().str.replace("nan", "90+")
)
person_["age_bins"] = pd.cut(person_["age"], bins=bins, labels=labels)

person_ = person_.loc[person_["gender_source_value"].isin(["m", "f"])]
group_gender_age = person_.groupby(["gender_source_value", "age_bins"])[
"person_id"
].count()

# Convert to pandas to ease plotting.
# Since we have aggregated the data, this operation won't crash.
group_gender_age = bd.to_pandas(group_gender_age)

male = group_gender_age["m"].reset_index()
female = group_gender_age["f"].reset_index()

Expand Down Expand Up @@ -123,10 +138,10 @@ def plot_age_pyramid(

if savefig:
chart.save(filename)
if return_vector:
if return_array:
return group_gender_age

if return_vector:
if return_array:
return chart, group_gender_age

return chart
54 changes: 36 additions & 18 deletions tests/test_age_pyramid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import altair as alt
import numpy as np
import pandas as pd
import pytest
from pandas.core.series import Series
from pandas.testing import assert_frame_equal
Expand All @@ -12,25 +13,25 @@

data = load_person()

person_with_inclusion_date = data.person.copy()
N = len(person_with_inclusion_date)
delta_days = pd.to_timedelta(np.random.randint(0, 1000, N), unit="d")

@pytest.mark.parametrize(
"datetime_ref",
[
None,
datetime(2020, 1, 1),
np.full(data.person.shape[0], datetime(2020, 1, 1)),
],
person_with_inclusion_date["inclusion_datetime"] = (
person_with_inclusion_date["birth_datetime"] + delta_days
)
def test_age_pyramid_datetime_ref_format(datetime_ref):
original_person = data.person.copy()

chart = plot_age_pyramid(
data.person, datetime_ref, savefig=False, return_vector=False
)

@pytest.mark.parametrize(
"datetime_ref", [datetime(2020, 1, 1), "inclusion_datetime", "2020-01-01"]
)
def test_plot_age_pyramid(datetime_ref):
original_person = person_with_inclusion_date.copy()
chart = plot_age_pyramid(person_with_inclusion_date, datetime_ref, savefig=False)
assert isinstance(chart, alt.vegalite.v4.api.ConcatChart)

# Check that the data is unchanged
assert_frame_equal(original_person, data.person)
assert_frame_equal(original_person, person_with_inclusion_date)


def test_age_pyramid_output():
Expand All @@ -42,23 +43,40 @@ def test_age_pyramid_output():
path.unlink()

group_gender_age = plot_age_pyramid(
data.person, savefig=True, return_vector=True, filename=filename
data.person, savefig=True, return_array=True, filename=filename
)
assert isinstance(group_gender_age, Series)

chart, group_gender_age = plot_age_pyramid(
data.person, savefig=False, return_vector=True
data.person, savefig=False, return_array=True
)
assert isinstance(chart, alt.vegalite.v4.api.ConcatChart)
assert isinstance(group_gender_age, Series)

chart = plot_age_pyramid(data.person, savefig=False, return_vector=False)
chart = plot_age_pyramid(data.person, savefig=False, return_array=False)
assert isinstance(chart, alt.vegalite.v4.api.ConcatChart)

with pytest.raises(ValueError, match="You have to set a filename"):
_ = plot_age_pyramid(data.person, savefig=True, filename=None)
plot_age_pyramid(person_with_inclusion_date, savefig=True, filename=None)

with pytest.raises(
ValueError, match="'filename' type must be str, got <class 'list'>"
):
_ = plot_age_pyramid(data.person, savefig=True, filename=[1])
plot_age_pyramid(person_with_inclusion_date, savefig=True, filename=[1])


def test_plot_age_pyramid_datetime_ref_error():
with pytest.raises(
ValueError,
match="`datetime_ref` must either be a column name or parseable date, got string '20x2-01-01'",
):
_ = plot_age_pyramid(
person_with_inclusion_date, datetime_ref="20x2-01-01", savefig=False
)
with pytest.raises(
TypeError,
match="`datetime_ref` must be either None, a parseable string date, a column name or a datetime. Got type: <class 'int'>, 2022",
):
_ = plot_age_pyramid(
person_with_inclusion_date, datetime_ref=2022, savefig=False
)