Skip to content

Commit

Permalink
Feat: support array type for datetime_ref in plot_age_pyramid() (#43
Browse files Browse the repository at this point in the history
)

Co-authored-by: Thomas Petit-Jean <[email protected]>
Co-authored-by: Matthieu Doutreligne <[email protected]>
  • Loading branch information
3 people authored Jun 1, 2023
1 parent 24bf047 commit 001fe9b
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 74 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
- Allow saving DB locally in client or cluster mode
- Add data cleaning function to handle incorrect datetime in spark
- Filter biology config on care site
- Adding person-dependent `datetime_ref` to `plot_age_pyramid`

### Fixed

Expand Down
60 changes: 60 additions & 0 deletions docs/functionalities/plotting/age_pyramid.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Visualizing age pyramid

The age pyramid is helpful to quickly visualize the age and gender distributions in a cohort.

## Load a synthetic dataset

`plot_age_pyramid` uses the "person" table:

```python
from eds_scikit.datasets.synthetic.person import load_person

df_person = load_person()
df_person.head()
```

```python
# Out: person_id gender_source_value birth_datetime
0 0 m 2010-01-01
1 1 m 1938-01-01
2 2 f 1994-01-01
3 3 m 1994-01-01
4 4 m 2004-01-01
```

## Visualize age pyramid

### Basic usage

By default, `plot_age_pyramid` will compute age as the difference between today and the date of birth:

```python
from eds_scikit.plot.age_pyramid import plot_age_pyramid

plot_age_pyramid(df_person)
```

![age_pyramid_default](age_pyramid_default.png)

### Advanced parameters

Further configuration can be provided, including :

- `datetime_ref` : Choose the reference to compute the age from.
It can be either a single datetime (string or datetime type), an array of datetime
(one reference for each patient) or a string representing a column of the input dataframe
- `return_array`: If set to True, return a dataframe instead of a chart.

```python
import pandas as pd
from datetime import datetime
from eds_scikit.plot.age_pyramid import plot_age_pyramid

dates_of_first_visit = pd.Series([datetime(2020, 1, 1)] * df_person.shape[0])
plot_age_pyramid(df_person, datetime_ref=dates_of_first_visit)
```

![age_pyramid_single_ref.png](age_pyramid_single_ref.png)


Please check [the documentation][eds_scikit.plot.age_pyramid.plot_age_pyramid] for further details on the function's parameters.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
96 changes: 57 additions & 39 deletions eds_scikit/plot/data_quality.py → eds_scikit/plot/age_pyramid.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
from copy import copy
from datetime import datetime
from typing import Tuple

import altair as alt
import numpy as np
import pandas as pd
from pandas.core.frame import DataFrame
from loguru import logger
from pandas.core.series import Series

from ..utils.checks import check_columns
from ..utils.framework import bd
from ..utils.typing import DataFrame


def plot_age_pyramid(
person: DataFrame,
datetime_ref: datetime = None,
filename: str = None,
savefig: bool = False,
return_vector: bool = False,
return_array: bool = False,
) -> Tuple[alt.Chart, Series]:
"""Plot an age pyramid from a 'person' pandas DataFrame.
Expand All @@ -28,16 +28,21 @@ def plot_age_pyramid(
- `person_id`, dtype : any
- `gender_source_value`, dtype : str, {'m', 'f'}
datetime_ref : datetime,
datetime_ref : Union[datetime, str], default None
The reference date to compute population age from.
If a string, it searches for a column with the same name in the person table: each patient has his own datetime reference.
If a datetime, the reference datetime is the same for all patients.
If set to None, datetime.today() will be used instead.
savefig : bool,
filename : str, default None
The path to save figure at.
savefig : bool, default False
If set to True, filename must be set.
The plot will be saved at the specified filename.
filename : Optional[str],
The path to save figure at.
return_array : bool, default False
If set to True, return chart and its pd.Dataframe representation.
Returns
-------
Expand All @@ -49,44 +54,62 @@ def plot_age_pyramid(
"""
check_columns(person, ["person_id", "birth_datetime", "gender_source_value"])

if savefig:
if filename is None:
raise ValueError("You have to set a filename")
if not isinstance(filename, str):
raise ValueError(f"'filename' type must be str, got {type(filename)}")

person_ = person.copy()
datetime_ref_raw = copy(datetime_ref)

if datetime_ref is None:
today = datetime.today()
datetime_ref = datetime.today()
elif isinstance(datetime_ref, datetime):
datetime_ref = pd.to_datetime(datetime_ref)
elif isinstance(datetime_ref, str):
# A string type for datetime_ref could be either
# a column name or a datetime in string format.
if datetime_ref in person.columns:
datetime_ref = person[datetime_ref]
else:
datetime_ref = pd.to_datetime(
datetime_ref, errors="coerce"
) # In case of error, will return NaT
if pd.isnull(datetime_ref):
raise ValueError(
f"`datetime_ref` must either be a column name or parseable date, "
f"got string '{datetime_ref_raw}'"
)
else:
today = pd.to_datetime(datetime_ref)
raise TypeError(
f"`datetime_ref` must be either None, a parseable string date"
f", a column name or a datetime. Got type: {type(datetime_ref)}, {datetime_ref}"
)

cols_to_keep = ["person_id", "birth_datetime", "gender_source_value"]
person_ = bd.to_pandas(person[cols_to_keep])

# TODO: replace with from ..utils.datetime_helpers.substract_datetime
deltas = today - person_["birth_datetime"]
if bd.is_pandas(person_):
deltas = deltas.dt.total_seconds()
person_["age"] = (datetime_ref - person_["birth_datetime"]).dt.total_seconds()
person_["age"] /= 365 * 24 * 3600

person_["age"] = deltas / (365 * 24 * 3600)
person_ = person_.query("age > 0.0")
# Remove outliers
mask_age_inliners = (person_["age"] > 0) & (person_["age"] < 125)
n_outliers = (~mask_age_inliners).sum()
if n_outliers > 0:
perc_outliers = 100 * n_outliers / person_.shape[0]
logger.warning(
f"{n_outliers} ({perc_outliers:.4f}%) individuals' "
"age is out of the (0, 125) interval, we skip them."
)
person_ = person_.loc[mask_age_inliners]

# Aggregate rare age categories
mask_rare_age_agg = person_["age"] > 90
person_.loc[mask_rare_age_agg, "age"] = 99

bins = np.arange(0, 100, 10)
labels = [f"{left}-{right}" for left, right in zip(bins[:-1], bins[1:])]
person_["age_bins"] = bd.cut(person_["age"], bins=bins, labels=labels)

person_["age_bins"] = (
person_["age_bins"].astype(str).str.lower().str.replace("nan", "90+")
)
person_["age_bins"] = pd.cut(person_["age"], bins=bins, labels=labels)

person_ = person_.loc[person_["gender_source_value"].isin(["m", "f"])]
group_gender_age = person_.groupby(["gender_source_value", "age_bins"])[
"person_id"
].count()

# Convert to pandas to ease plotting.
# Since we have aggregated the data, this operation won't crash.
group_gender_age = bd.to_pandas(group_gender_age)

male = group_gender_age["m"].reset_index()
female = group_gender_age["f"].reset_index()

Expand Down Expand Up @@ -121,12 +144,7 @@ def plot_age_pyramid(

chart = alt.concat(left, middle, right, spacing=5)

if savefig:
chart.save(filename)
if return_vector:
return group_gender_age

if return_vector:
return chart, group_gender_age
if return_array:
return group_gender_age

return chart
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ nav:
- Plotting:
- Event sequence: functionalities/plotting/event_sequences.md
- Generating inclusion/exclusion flowchart: functionalities/plotting/flowchart.ipynb
- Age pyramid: functionalities/plotting/age_pyramid.md
- Recipes:
- Generating inclusion/exclusion flowchart: recipes/flowchart.ipynb
- Saving small cohorts locally: recipes/small-cohorts.ipynb
Expand Down
68 changes: 33 additions & 35 deletions tests/test_age_pyramid.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,62 @@
from datetime import datetime
from pathlib import Path

import altair as alt
import numpy as np
import pandas as pd
import pytest
from pandas.core.series import Series
from pandas.testing import assert_frame_equal

from eds_scikit.datasets.synthetic.person import load_person
from eds_scikit.plot.data_quality import plot_age_pyramid
from eds_scikit.plot.age_pyramid import plot_age_pyramid

data = load_person()

person_with_inclusion_date = data.person.copy()
N = len(person_with_inclusion_date)
delta_days = pd.to_timedelta(np.random.randint(0, 1000, N), unit="d")

@pytest.mark.parametrize(
"datetime_ref",
[
None,
datetime(2020, 1, 1),
np.full(data.person.shape[0], datetime(2020, 1, 1)),
],
person_with_inclusion_date["inclusion_datetime"] = (
person_with_inclusion_date["birth_datetime"] + delta_days
)
def test_age_pyramid_datetime_ref_format(datetime_ref):
original_person = data.person.copy()

chart = plot_age_pyramid(
data.person, datetime_ref, savefig=False, return_vector=False
)

@pytest.mark.parametrize(
"datetime_ref", [datetime(2020, 1, 1), "inclusion_datetime", "2020-01-01"]
)
def test_plot_age_pyramid(datetime_ref):
original_person = person_with_inclusion_date.copy()
chart = plot_age_pyramid(person_with_inclusion_date, datetime_ref)
assert isinstance(chart, alt.vegalite.v4.api.ConcatChart)

# Check that the data is unchanged
assert_frame_equal(original_person, data.person)
assert_frame_equal(original_person, person_with_inclusion_date)


def test_age_pyramid_output():

filename = "test.html"
plot_age_pyramid(data.person, savefig=True, filename=filename)
path = Path(filename)
assert path.exists()
path.unlink()

group_gender_age = plot_age_pyramid(
data.person, savefig=True, return_vector=True, filename=filename
)
assert isinstance(group_gender_age, Series)

chart, group_gender_age = plot_age_pyramid(
data.person, savefig=False, return_vector=True
)
chart = plot_age_pyramid(data.person)
assert isinstance(chart, alt.vegalite.v4.api.ConcatChart)

group_gender_age = plot_age_pyramid(data.person, return_array=True)
assert isinstance(group_gender_age, Series)

chart = plot_age_pyramid(data.person, savefig=False, return_vector=False)
assert isinstance(chart, alt.vegalite.v4.api.ConcatChart)

with pytest.raises(ValueError, match="You have to set a filename"):
_ = plot_age_pyramid(data.person, savefig=True, filename=None)
def test_plot_age_pyramid_datetime_ref_error():
with pytest.raises(
ValueError,
match="`datetime_ref` must either be a column name or parseable date, got string '20x2-01-01'",
):
_ = plot_age_pyramid(
person_with_inclusion_date,
datetime_ref="20x2-01-01",
)

with pytest.raises(
ValueError, match="'filename' type must be str, got <class 'list'>"
TypeError,
match="`datetime_ref` must be either None, a parseable string date, a column name or a datetime. Got type: <class 'int'>, 2022",
):
_ = plot_age_pyramid(data.person, savefig=True, filename=[1])
_ = plot_age_pyramid(
person_with_inclusion_date,
datetime_ref=2022,
)

0 comments on commit 001fe9b

Please sign in to comment.