Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Feature: Forestplot for CoxPH model #838

Merged
merged 34 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
74bab7c
coxph_forestplot
aGuyLearning Dec 6, 2024
aca8220
Merge branch 'main' into enhancement/issue-743
aGuyLearning Dec 11, 2024
39e4d42
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2024
225b606
changed-notebook
aGuyLearning Dec 11, 2024
47a5b12
Update ehrapy/plot/_survival_analysis.py
aGuyLearning Dec 11, 2024
a32c121
Update ehrapy/plot/_survival_analysis.py
aGuyLearning Dec 11, 2024
0c7cd45
Remove useless empty line
Zethson Dec 11, 2024
785a2cf
Remove useless comment
Zethson Dec 11, 2024
75ee4ba
undo again; check rtd build
eroell Dec 11, 2024
541e505
renamed function and updated documentation to mention, that it is a l…
aGuyLearning Dec 13, 2024
ca4530a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
f77f468
removed fitted check and updated name in notebook
aGuyLearning Dec 13, 2024
66815c4
Update ehrapy/plot/_survival_analysis.py
aGuyLearning Dec 13, 2024
76f0f0c
updated variable names and moved test to better file
aGuyLearning Dec 13, 2024
2108681
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
88e18a7
made anything after coxphfitter keyword only
aGuyLearning Dec 13, 2024
613be89
changed type to iterable
aGuyLearning Dec 13, 2024
8c3d170
added title and show args
aGuyLearning Dec 13, 2024
b3a7c21
less ambiguous loop index
aGuyLearning Dec 13, 2024
cfdb1cc
fixed test. had to return figure and axis, so the test can save the i…
aGuyLearning Dec 13, 2024
f1b648b
Merge branch 'main' into enhancement/issue-743
eroell Dec 18, 2024
de4572d
Merge branch 'main' into enhancement/issue-743
eroell Jan 8, 2025
924cd0a
Merge branch 'main' into enhancement/issue-743
aGuyLearning Jan 15, 2025
0d5d6e8
get summary form adata
aGuyLearning Jan 15, 2025
503a97d
updated docs
aGuyLearning Jan 15, 2025
bb83a5a
updated exampel
aGuyLearning Jan 15, 2025
b560748
removed fitter object from docu
aGuyLearning Jan 15, 2025
ad6e8e7
Update ehrapy/plot/_survival_analysis.py
aGuyLearning Jan 15, 2025
3386d87
docu updates, for understandability
aGuyLearning Jan 15, 2025
1364d97
link between functions
eroell Jan 15, 2025
c0e9b49
Enhance documentation for cox_ph_forestplot function to clarify usage…
aGuyLearning Jan 15, 2025
6724d74
Update documentation for cox_ph_forestplot function to clarify uns_ke…
aGuyLearning Jan 15, 2025
39aca12
Refactor subplot variable names for clarity in cox_ph_forestplot func…
aGuyLearning Jan 15, 2025
f79ef81
Merge branch 'main' into enhancement/issue-743
eroell Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
1 change: 1 addition & 0 deletions docs/usage/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ Methods that extract and visualize tool-specific annotation in an AnnData object

plot.ols
plot.kaplan_meier
plot.cox_ph_forestplot
```

### Causal Inference
Expand Down
2 changes: 1 addition & 1 deletion ehrapy/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from ehrapy.plot._colormaps import * # noqa: F403
from ehrapy.plot._missingno_pl_api import * # noqa: F403
from ehrapy.plot._scanpy_pl_api import * # noqa: F403
from ehrapy.plot._survival_analysis import kaplan_meier, kmf, ols
from ehrapy.plot._survival_analysis import cox_ph_forestplot, kaplan_meier, ols
from ehrapy.plot.causal_inference._dowhy import causal_effect
from ehrapy.plot.feature_ranking._feature_importances import rank_features_supervised
174 changes: 173 additions & 1 deletion ehrapy/plot/_survival_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
from matplotlib import gridspec
from numpy import ndarray

from ehrapy.plot import scatter

if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Iterable, Sequence
from xmlrpc.client import Boolean

from anndata import AnnData
Expand Down Expand Up @@ -293,5 +296,174 @@ def kaplan_meier(

if not show:
return ax

else:
return None


def cox_ph_forestplot(
adata: AnnData,
*,
uns_key: str = "cox_ph",
labels: Iterable[str] | None = None,
fig_size: tuple = (10, 10),
t_adjuster: float = 0.1,
ecolor: str = "dimgray",
size: int = 3,
marker: str = "o",
decimal: int = 2,
text_size: int = 12,
color: str = "k",
show: bool = None,
title: str | None = None,
):
"""Generates a forest plot to visualize the coefficients and confidence intervals of a Cox Proportional Hazards model.
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
The `adata` object must first be populated using the :func:`~ehrapy.tools.cox_ph` function. This function stores the summary table of the `CoxPHFitter` in the `.uns` attribute of `adata`.
The summary table is created when the model is fitted using the :func:`ehrapy.tl.cox_ph` function.
For more information on the `CoxPHFitter`, see the `Lifelines documentation <https://lifelines.readthedocs.io/en/latest/fitters/regression/CoxPHFitter.html>`_.
Inspired by `zepid.graphics.EffectMeasurePlot <https://readthedocs.org>`_ (zEpid Package, https://pypi.org/project/zepid/).
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
Args:
adata: :class:`~anndata.AnnData` object containing the summary table from the CoxPHFitter. This is stored in the `.uns` attribute, after fitting the model using :func:`~ehrapy.tl.cox_ph`.
uns_key: Key in `.uns` where :func:`~ehrapy.tools.cox_ph` function stored the summary table. See argument `uns_key` in :func:`~ehrapy.tools.cox_ph`.
labels: List of labels for each coefficient, default uses the index of the summary ta
fig_size: Width, height in inches.
t_adjuster: Adjust the table to the right.
ecolor: Color of the error bars.
size: Size of the markers.
marker: Marker style.
decimal: Number of decimal places to display.
text_size: Font size of the text.
color: Color of the markers.
show: Show the plot, do not return figure and axis.
title: Set the title of the plot.
eroell marked this conversation as resolved.
Show resolved Hide resolved
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=False)
>>> adata_subset = adata[:, ["mort_day_censored", "censor_flg", "gender_num", "afib_flg", "day_icu_intime_num"]]
>>> coxph = ep.tl.cox_ph(adata_subset, event_col="censor_flg", duration_col="mort_day_censored")
>>> ep.pl.cox_ph_forestplot(adata_subset)
.. image:: /_static/docstring_previews/coxph_forestplot.png
"""
if uns_key not in adata.uns:
raise ValueError(f"Key {uns_key} not found in adata.uns. Please provide a valid key.")

Zethson marked this conversation as resolved.
Show resolved Hide resolved
coxph_fitting_summary = adata.uns[
uns_key
] # pd.Dataframe with columns: coef, exp(coef), se(coef), z, p, lower 0.95, upper 0.95
auc_col = "coef"

if labels is None:
labels = coxph_fitting_summary.index
tval = []
ytick = []
for row_index in range(len(coxph_fitting_summary)):
if not np.isnan(coxph_fitting_summary[auc_col][row_index]):
if (
(isinstance(coxph_fitting_summary[auc_col][row_index], float))
& (isinstance(coxph_fitting_summary["coef lower 95%"][row_index], float))
& (isinstance(coxph_fitting_summary["coef upper 95%"][row_index], float))
):
tval.append(
[
round(coxph_fitting_summary[auc_col][row_index], decimal),
(
"("
+ str(round(coxph_fitting_summary["coef lower 95%"][row_index], decimal))
+ ", "
+ str(round(coxph_fitting_summary["coef upper 95%"][row_index], decimal))
+ ")"
),
]
)
else:
tval.append(
[
coxph_fitting_summary[auc_col][row_index],
(
"("
+ str(coxph_fitting_summary["coef lower 95%"][row_index])
+ ", "
+ str(coxph_fitting_summary["coef upper 95%"][row_index])
+ ")"
),
]
)
ytick.append(row_index)
else:
tval.append([" ", " "])
ytick.append(row_index)

x_axis_upper_bound = round(((pd.to_numeric(coxph_fitting_summary["coef upper 95%"])).max() + 0.1), 2)

x_axis_lower_bound = round(((pd.to_numeric(coxph_fitting_summary["coef lower 95%"])).min() - 0.1), 1)

fig = plt.figure(figsize=fig_size)
gspec = gridspec.GridSpec(1, 6)
plot = plt.subplot(gspec[0, 0:4])
table = plt.subplot(gspec[0, 4:])
plot.set_ylim(-1, (len(coxph_fitting_summary))) # spacing out y-axis properly

plot.axvline(1, color="gray", zorder=1)
lower_diff = coxph_fitting_summary[auc_col] - coxph_fitting_summary["coef lower 95%"]
upper_diff = coxph_fitting_summary["coef upper 95%"] - coxph_fitting_summary[auc_col]
plot.errorbar(
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
coxph_fitting_summary[auc_col],
coxph_fitting_summary.index,
xerr=[lower_diff, upper_diff],
marker="None",
zorder=2,
ecolor=ecolor,
linewidth=0,
elinewidth=1,
)
# plot markers
plot.scatter(
coxph_fitting_summary[auc_col],
coxph_fitting_summary.index,
c=color,
s=(size * 25),
marker=marker,
zorder=3,
edgecolors="None",
)
# plot settings
plot.xaxis.set_ticks_position("bottom")
plot.yaxis.set_ticks_position("left")
plot.get_xaxis().set_major_formatter(ticker.ScalarFormatter())
plot.get_xaxis().set_minor_formatter(ticker.NullFormatter())
plot.set_yticks(ytick)
plot.set_xlim([x_axis_lower_bound, x_axis_upper_bound])
plot.set_xticks([x_axis_lower_bound, 1, x_axis_upper_bound])
plot.set_xticklabels([x_axis_lower_bound, 1, x_axis_upper_bound])
plot.set_yticklabels(labels)
plot.tick_params(axis="y", labelsize=text_size)
plot.yaxis.set_ticks_position("none")
plot.invert_yaxis() # invert y-axis to align values properly with table
tb = table.table(
cellText=tval, cellLoc="center", loc="right", colLabels=[auc_col, "95% CI"], bbox=[0, t_adjuster, 1, 1]
)
table.axis("off")
tb.auto_set_font_size(False)
tb.set_fontsize(text_size)
for _, cell in tb.get_celld().items():
cell.set_linewidth(0)

# remove spines
plot.spines["top"].set_visible(False)
plot.spines["right"].set_visible(False)
plot.spines["left"].set_visible(False)

if title:
plt.title(title)

if not show:
return fig, plot

else:
return None
86 changes: 86 additions & 0 deletions tests/_scripts/coxph_forestplot_create_expected.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"import ehrapy as ep"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"current_notebook_dir = %pwd\n",
"_TEST_IMAGE_PATH = f\"{current_notebook_dir}/../plot/_images\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"adata = ep.dt.mimic_2(encoded=False)\n",
"adata_subset = adata[:, [\"mort_day_censored\", \"censor_flg\", \"gender_num\", \"afib_flg\", \"day_icu_intime_num\"]]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"genderafib_coxph = ep.tl.cox_ph(adata_subset, duration_col=\"mort_day_censored\", event_col=\"censor_flg\")\n",
"\n",
"fig, ax = ep.pl.cox_ph_forestplot(genderafib_coxph, fig_size=(12, 3), t_adjuster=0.15, marker=\"o\", size=2, text_size=14)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig.savefig(f\"{_TEST_IMAGE_PATH}/coxph_forestplot_expected.png\", dpi=80)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ def rng():
return np.random.default_rng(seed=42)


@pytest.fixture
def mimic_2():
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
adata = ep.dt.mimic_2()
return adata


@pytest.fixture
def mimic_2_encoded():
adata = ep.dt.mimic_2(encoded=True)
Expand Down
Binary file added tests/plot/_images/coxph_forestplot_expected.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions tests/plot/test_catplot.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from pathlib import Path

from ehrapy.plot import catplot
import ehrapy as ep

CURRENT_DIR = Path(__file__).parent
_TEST_IMAGE_PATH = f"{CURRENT_DIR}/_images"


def test_catplot_vanilla(adata_mini, check_same_image):
fig = catplot(adata_mini, jitter=False)
fig = ep.pl.catplot(adata_mini, jitter=False)

check_same_image(
fig=fig,
Expand Down
18 changes: 18 additions & 0 deletions tests/plot/test_survival_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from pathlib import Path

import ehrapy as ep

CURRENT_DIR = Path(__file__).parent
_TEST_IMAGE_PATH = f"{CURRENT_DIR}/_images"


def test_coxph_forestplot(mimic_2, check_same_image):
adata_subset = mimic_2[:, ["mort_day_censored", "censor_flg", "gender_num", "afib_flg", "day_icu_intime_num"]]
ep.tl.cox_ph(adata_subset, duration_col="mort_day_censored", event_col="censor_flg")
fig, ax = ep.pl.cox_ph_forestplot(adata_subset, fig_size=(12, 3), t_adjuster=0.15, marker="o", size=2, text_size=14)

check_same_image(
fig=fig,
base_path=f"{_TEST_IMAGE_PATH}/coxph_forestplot",
tol=2e-1,
)
Loading