Skip to content

Commit

Permalink
Merge branch 'main' into sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
eroell committed Apr 12, 2024
2 parents 8a94d38 + eb27f49 commit 26e150e
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repos:
args: [--fix, --exit-non-zero-on-fix, --unsafe-fixes]
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: detect-private-key
- id: check-ast
Expand Down
2 changes: 0 additions & 2 deletions ehrapy/anndata/anndata_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,6 @@ def move_to_obs(adata: AnnData, to_obs: list[str] | str, copy_obs: bool = False)
adata.obs[var_num] = adata.obs[var_num].apply(pd.to_numeric, errors="ignore", downcast="float")
adata.obs = _cast_obs_columns(adata.obs)

logg.info(f"Added `{to_obs}` to `obs`.")

return adata


Expand Down
1 change: 0 additions & 1 deletion ehrapy/preprocessing/_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def encode(
progress.update(task, description="Updating layer originals ...")

# update layer content with the latest categorical encoding and the old other values
logg.info("Encoding strings in X to save to .h5ad. Loading the file will reverse the encoding.")
updated_layer = _update_layer_after_encoding(
adata.layers["original"],
encoded_x,
Expand Down
91 changes: 85 additions & 6 deletions ehrapy/tools/cohort_tracking/_cohort_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import pandas as pd
import seaborn as sns
from matplotlib.axes import Axes
from matplotlib.font_manager import FontProperties
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from scanpy import AnnData
from tableone import TableOne
Expand Down Expand Up @@ -41,6 +43,20 @@ def _detect_categorical_columns(data) -> list:
return list(categorical_cols)


import matplotlib.text as mtext


class LegendTitle:
def __init__(self, text_props=None):
self.text_props = text_props or {}

def legend_artist(self, legend, orig_handle, fontsize, handlebox):
x0, y0 = handlebox.xdescent, handlebox.ydescent
title = mtext.Text(x0, y0, orig_handle, **self.text_props)
handlebox.add_artist(title)
return title


class CohortTracker:
"""Track cohort changes over multiple filtering or processing steps.
Expand Down Expand Up @@ -148,6 +164,18 @@ def _check_yticks_labels(self, yticks_labels: dict) -> None:
if missing_keys:
raise ValueError(f"legend_handels key(s) {missing_keys} not found as categories or numerical column names.")

def _check_legend_subtitle_names(self, legend_subtitles_names: dict) -> None:
if not isinstance(legend_subtitles_names, dict):
raise ValueError("legend_subtitles_names must be a dictionary.")

# Find keys in legend_handels that are not in values or self.columns
missing_keys = [key for key in legend_subtitles_names if key not in self.columns]

if missing_keys:
raise ValueError(
f"legend_subtitles_names key(s) {missing_keys} not found as categories or numerical column names."
)

@property
def tracked_steps(self):
"""Number of tracked steps."""
Expand All @@ -164,8 +192,11 @@ def plot_cohort_barplot(
color_palette: str = "colorblind",
yticks_labels: dict = None,
legend_labels: dict = None,
legend_subtitles: bool = False,
legend_subtitles_names: dict = None,
show: bool = True,
ax: Axes | Sequence[Axes] = None,
fontsize: int = 10,
subplots_kwargs: dict = None,
legend_kwargs: dict = None,
) -> None | list[Axes] | tuple[Figure, list[Axes]]:
Expand All @@ -178,8 +209,11 @@ def plot_cohort_barplot(
color_palette: The color palette to use for the plot. Default is "colorblind".
yticks_labels: Dictionary to rename the axis labels. If `None`, the original labels will be used. The keys should be the column names.
legend_labels: Dictionary to rename the legend labels. If `None`, the original labels will be used. For categoricals, the keys should be the categories. For numericals, the key should be the column name.
legend_subtitles: If `True`, subtitles will be added to the legend. Default is `False`.
legend_subtitles_names: Dictionary to rename the legend subtitles. If `None`, the original labels will be used. The keys should be the column names.
show: If `True`, the plot will be shown. If `False`, plotting handels are returned.
ax: If `None`, a new figure and axes will be created. If an axes object is provided, the plot will be added to it.
fontsize: Fontsize for the text in the plot. Default is 10.
subplots_kwargs: Additional keyword arguments for the subplots.
legend_kwargs: Additional keyword arguments for the legend.
Expand Down Expand Up @@ -212,11 +246,15 @@ def plot_cohort_barplot(
.. image:: /_static/docstring_previews/cohort_tracking.png
"""
subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs

legend_labels = {} if legend_labels is None else legend_labels
self._check_legend_labels(legend_labels)

subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs

legend_subtitles_names = {} if legend_subtitles_names is None else legend_subtitles_names
self._check_legend_subtitle_names(legend_subtitles_names)

yticks_labels = {} if yticks_labels is None else yticks_labels
self._check_yticks_labels(yticks_labels)

Expand Down Expand Up @@ -247,7 +285,7 @@ def plot_cohort_barplot(
single_ax.grid(False)

if subfigure_title:
single_ax.set_title(self._tracked_text[idx])
single_ax.set_title(self._tracked_text[idx], size=fontsize)

color_count = 0
# iterate over the tracked columns in the dataframe
Expand Down Expand Up @@ -285,6 +323,7 @@ def plot_cohort_barplot(
va="center",
color="white",
fontweight="bold",
size=fontsize,
)

single_ax.set_yticks([])
Expand Down Expand Up @@ -320,6 +359,7 @@ def plot_cohort_barplot(
va="center",
color="white",
fontweight="bold",
size=fontsize,
)
if idx == 0:
name = legend_labels[col] if col in legend_labels.keys() else col
Expand All @@ -331,18 +371,57 @@ def plot_cohort_barplot(
yticks_labels[col] if yticks_labels is not None and col in yticks_labels.keys() else col
for col in self.columns
]
single_ax.set_yticklabels(names)
single_ax.set_yticklabels(names, fontsize=fontsize)

# These list of lists is needed to reverse the order of the legend labels,
# making the plot much more readable
legend_handles.reverse()
legend_handels = [item for sublist in legend_handles for item in sublist]

tot_legend_kwargs = {"loc": "best", "bbox_to_anchor": (1, 1)}
tot_legend_kwargs = {"loc": "best", "bbox_to_anchor": (1, 1), "fontsize": fontsize}
if legend_kwargs is not None:
tot_legend_kwargs.update(legend_kwargs)

plt.legend(handles=legend_handels, **tot_legend_kwargs)
def create_legend_with_subtitles(patches_list, subtitles_list, tot_legend_kwargs):
"""Create a legend with subtitles."""
size = {"size": tot_legend_kwargs["fontsize"]}
subtitle_font = FontProperties(weight="bold", **size)
handles = []
labels = []

# there can be empty lists which distort the logic of matching patches to subtitles
patches_list = [patch for patch in patches_list if patch]

for patches, subtitle in zip(patches_list, subtitles_list):
handles.append(Line2D([], [], linestyle="none", marker="", alpha=0)) # Placeholder for title
labels.append(subtitle)

for patch in patches:
handles.append(patch)
labels.append(patch.get_label())

# empty space after block
handles.append(Line2D([], [], linestyle="none", marker="", alpha=0))
labels.append("")

legend = axes[0].legend(handles, labels, **tot_legend_kwargs)

for text in legend.get_texts():
if text.get_text() in subtitles_list:
text.set_font_properties(subtitle_font)

if legend_subtitles:
subtitles = [
legend_subtitles_names[col] if col in legend_subtitles_names.keys() else col
for col in self.columns[::-1]
]
create_legend_with_subtitles(
legend_handles,
subtitles,
tot_legend_kwargs,
)
else:
legend_handles = [item for sublist in legend_handles for item in sublist]
plt.legend(handles=legend_handles, **tot_legend_kwargs)

if show:
plt.tight_layout()
Expand Down
45 changes: 29 additions & 16 deletions tests/_scripts/cohort_tracker_test_create_expected_plots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,17 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/eljasroellin/Documents/ehrapy_clean/ehrapy_venv_march_II/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import ehrapy as ep"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -38,12 +29,13 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"_TEST_DATA_PATH = \"/Users/eljasroellin/Documents/ehrapy_clean/ehrapy/tests/tools/ehrapy_data/dataset1.csv\"\n",
"_TEST_IMAGE_PATH = \"/Users/eljasroellin/Documents/ehrapy_clean/ehrapy/tests/tools/_images\""
"_TEST_IMAGE_PATH = \"/Users/eljasroellin/Documents/ehrapy_clean/ehrapy/tests/tools/_images\"\n",
"adata_mini = ep.io.read_csv(_TEST_DATA_PATH, columns_obs_only=[\"glucose\", \"weight\", \"disease\", \"station\"])"
]
},
{
Expand All @@ -52,8 +44,6 @@
"metadata": {},
"outputs": [],
"source": [
"adata_mini = ep.io.read_csv(_TEST_DATA_PATH, columns_obs_only=[\"glucose\", \"weight\", \"disease\", \"station\"])\n",
"\n",
"ct = ep.tl.CohortTracker(adata_mini)\n",
"\n",
"ct(adata_mini, label=\"First step\", operations_done=\"Some operations\")\n",
Expand Down Expand Up @@ -95,6 +85,29 @@
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ct = ep.tl.CohortTracker(adata_mini)\n",
"ct(adata_mini, label=\"First step\", operations_done=\"Some operations\")\n",
"fig1_use_settings_big, _ = ct.plot_cohort_barplot(\n",
" show=False,\n",
" yticks_labels={\"weight\": \"wgt\"},\n",
" legend_labels={\"A\": \"Dis. A\", \"weight\": \"(kg)\"},\n",
" legend_subtitles=True,\n",
" legend_subtitles_names={\"station\": \"\", \"disease\": \"dis\", \"weight\": \"wgt\", \"glucose\": \"glc\"},\n",
")\n",
"\n",
"fig1_use_settings_big.tight_layout()\n",
"fig1_use_settings_big.savefig(\n",
" f\"{_TEST_IMAGE_PATH}/cohorttracker_adata_mini_step1_use_settings_big_expected.png\",\n",
" dpi=80,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 19 additions & 0 deletions tests/tools/cohort_tracking/test_cohort_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,25 @@ def test_CohortTracker_plot_cohort_barplot_use_settings(adata_mini, check_same_i
)


def test_CohortTracker_plot_cohort_barplot_use_settings_big(adata_mini, check_same_image):
ct = ep.tl.CohortTracker(adata_mini)

ct(adata_mini, label="First step", operations_done="Some operations")
fig, _ = ct.plot_cohort_barplot(
show=False,
yticks_labels={"weight": "wgt"},
legend_labels={"A": "Dis. A", "weight": "(kg)"},
legend_subtitles=True,
legend_subtitles_names={"station": "", "disease": "dis", "weight": "wgt", "glucose": "glc"},
)

check_same_image(
fig=fig,
base_path=f"{_TEST_IMAGE_PATH}/cohorttracker_adata_mini_step1_use_settings_big",
tol=1e-1,
)


def test_CohortTracker_plot_cohort_barplot_loosing_category(adata_mini, check_same_image):
ct = ep.tl.CohortTracker(adata_mini)

Expand Down

0 comments on commit 26e150e

Please sign in to comment.