Skip to content

Commit

Permalink
New Feature: Forestplot for CoxPH model (#838)
Browse files Browse the repository at this point in the history
* coxph_forestplot

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* changed-notebook

* Update ehrapy/plot/_survival_analysis.py

Co-authored-by: Eljas Roellin <[email protected]>

* Update ehrapy/plot/_survival_analysis.py

Co-authored-by: Eljas Roellin <[email protected]>

* Remove useless empty line

* Remove useless comment

* undo again; check rtd build

* renamed function and updated documentation to mention, that it is a lifelines object

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* removed fitted check and updated name in notebook

* Update ehrapy/plot/_survival_analysis.py

Co-authored-by: Lukas Heumos <[email protected]>

* updated variable names and moved test to better file

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* made anything after coxphfitter keyword only

* changed type to iterable

* added title and show args

* less ambiguous loop index

* fixed test. had to return figure and axis, so the test can save the image

* get summary form adata

* updated docs

* updated exampel

* removed fitter object from docu

* Update ehrapy/plot/_survival_analysis.py

Co-authored-by: Eljas Roellin <[email protected]>

* docu updates, for understandability

* link between functions

* Enhance documentation for cox_ph_forestplot function to clarify usage of adata and summary table

* Update documentation for cox_ph_forestplot function to clarify uns_key argument usage

* Refactor subplot variable names for clarity in cox_ph_forestplot function

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eljas Roellin <[email protected]>
Co-authored-by: Lukas Heumos <[email protected]>
Co-authored-by: eroell <[email protected]>
  • Loading branch information
5 people authored Jan 15, 2025
1 parent 728f1bd commit ac2c531
Show file tree
Hide file tree
Showing 10 changed files with 288 additions and 5 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
1 change: 1 addition & 0 deletions docs/usage/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ Methods that extract and visualize tool-specific annotation in an AnnData object
plot.ols
plot.kaplan_meier
plot.cox_ph_forestplot
```

### Causal Inference
Expand Down
2 changes: 1 addition & 1 deletion ehrapy/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from ehrapy.plot._colormaps import * # noqa: F403
from ehrapy.plot._missingno_pl_api import * # noqa: F403
from ehrapy.plot._scanpy_pl_api import * # noqa: F403
from ehrapy.plot._survival_analysis import kaplan_meier, kmf, ols
from ehrapy.plot._survival_analysis import cox_ph_forestplot, kaplan_meier, ols
from ehrapy.plot.causal_inference._dowhy import causal_effect
from ehrapy.plot.feature_ranking._feature_importances import rank_features_supervised
174 changes: 173 additions & 1 deletion ehrapy/plot/_survival_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
from matplotlib import gridspec
from numpy import ndarray

from ehrapy.plot import scatter

if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Iterable, Sequence
from xmlrpc.client import Boolean

from anndata import AnnData
Expand Down Expand Up @@ -293,5 +296,174 @@ def kaplan_meier(

if not show:
return ax

else:
return None


def cox_ph_forestplot(
adata: AnnData,
*,
uns_key: str = "cox_ph",
labels: Iterable[str] | None = None,
fig_size: tuple = (10, 10),
t_adjuster: float = 0.1,
ecolor: str = "dimgray",
size: int = 3,
marker: str = "o",
decimal: int = 2,
text_size: int = 12,
color: str = "k",
show: bool = None,
title: str | None = None,
):
"""Generates a forest plot to visualize the coefficients and confidence intervals of a Cox Proportional Hazards model.
The `adata` object must first be populated using the :func:`~ehrapy.tools.cox_ph` function. This function stores the summary table of the `CoxPHFitter` in the `.uns` attribute of `adata`.
The summary table is created when the model is fitted using the :func:`ehrapy.tl.cox_ph` function.
For more information on the `CoxPHFitter`, see the `Lifelines documentation <https://lifelines.readthedocs.io/en/latest/fitters/regression/CoxPHFitter.html>`_.
Inspired by `zepid.graphics.EffectMeasurePlot <https://readthedocs.org>`_ (zEpid Package, https://pypi.org/project/zepid/).
Args:
adata: :class:`~anndata.AnnData` object containing the summary table from the CoxPHFitter. This is stored in the `.uns` attribute, after fitting the model using :func:`~ehrapy.tl.cox_ph`.
uns_key: Key in `.uns` where :func:`~ehrapy.tools.cox_ph` function stored the summary table. See argument `uns_key` in :func:`~ehrapy.tools.cox_ph`.
labels: List of labels for each coefficient, default uses the index of the summary ta
fig_size: Width, height in inches.
t_adjuster: Adjust the table to the right.
ecolor: Color of the error bars.
size: Size of the markers.
marker: Marker style.
decimal: Number of decimal places to display.
text_size: Font size of the text.
color: Color of the markers.
show: Show the plot, do not return figure and axis.
title: Set the title of the plot.
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=False)
>>> adata_subset = adata[:, ["mort_day_censored", "censor_flg", "gender_num", "afib_flg", "day_icu_intime_num"]]
>>> coxph = ep.tl.cox_ph(adata_subset, event_col="censor_flg", duration_col="mort_day_censored")
>>> ep.pl.cox_ph_forestplot(adata_subset)
.. image:: /_static/docstring_previews/coxph_forestplot.png
"""
if uns_key not in adata.uns:
raise ValueError(f"Key {uns_key} not found in adata.uns. Please provide a valid key.")

coxph_fitting_summary = adata.uns[
uns_key
] # pd.Dataframe with columns: coef, exp(coef), se(coef), z, p, lower 0.95, upper 0.95
auc_col = "coef"

if labels is None:
labels = coxph_fitting_summary.index
tval = []
ytick = []
for row_index in range(len(coxph_fitting_summary)):
if not np.isnan(coxph_fitting_summary[auc_col][row_index]):
if (
(isinstance(coxph_fitting_summary[auc_col][row_index], float))
& (isinstance(coxph_fitting_summary["coef lower 95%"][row_index], float))
& (isinstance(coxph_fitting_summary["coef upper 95%"][row_index], float))
):
tval.append(
[
round(coxph_fitting_summary[auc_col][row_index], decimal),
(
"("
+ str(round(coxph_fitting_summary["coef lower 95%"][row_index], decimal))
+ ", "
+ str(round(coxph_fitting_summary["coef upper 95%"][row_index], decimal))
+ ")"
),
]
)
else:
tval.append(
[
coxph_fitting_summary[auc_col][row_index],
(
"("
+ str(coxph_fitting_summary["coef lower 95%"][row_index])
+ ", "
+ str(coxph_fitting_summary["coef upper 95%"][row_index])
+ ")"
),
]
)
ytick.append(row_index)
else:
tval.append([" ", " "])
ytick.append(row_index)

x_axis_upper_bound = round(((pd.to_numeric(coxph_fitting_summary["coef upper 95%"])).max() + 0.1), 2)

x_axis_lower_bound = round(((pd.to_numeric(coxph_fitting_summary["coef lower 95%"])).min() - 0.1), 1)

fig = plt.figure(figsize=fig_size)
gspec = gridspec.GridSpec(1, 6)
plot = plt.subplot(gspec[0, 0:4])
table = plt.subplot(gspec[0, 4:])
plot.set_ylim(-1, (len(coxph_fitting_summary))) # spacing out y-axis properly

plot.axvline(1, color="gray", zorder=1)
lower_diff = coxph_fitting_summary[auc_col] - coxph_fitting_summary["coef lower 95%"]
upper_diff = coxph_fitting_summary["coef upper 95%"] - coxph_fitting_summary[auc_col]
plot.errorbar(
coxph_fitting_summary[auc_col],
coxph_fitting_summary.index,
xerr=[lower_diff, upper_diff],
marker="None",
zorder=2,
ecolor=ecolor,
linewidth=0,
elinewidth=1,
)
# plot markers
plot.scatter(
coxph_fitting_summary[auc_col],
coxph_fitting_summary.index,
c=color,
s=(size * 25),
marker=marker,
zorder=3,
edgecolors="None",
)
# plot settings
plot.xaxis.set_ticks_position("bottom")
plot.yaxis.set_ticks_position("left")
plot.get_xaxis().set_major_formatter(ticker.ScalarFormatter())
plot.get_xaxis().set_minor_formatter(ticker.NullFormatter())
plot.set_yticks(ytick)
plot.set_xlim([x_axis_lower_bound, x_axis_upper_bound])
plot.set_xticks([x_axis_lower_bound, 1, x_axis_upper_bound])
plot.set_xticklabels([x_axis_lower_bound, 1, x_axis_upper_bound])
plot.set_yticklabels(labels)
plot.tick_params(axis="y", labelsize=text_size)
plot.yaxis.set_ticks_position("none")
plot.invert_yaxis() # invert y-axis to align values properly with table
tb = table.table(
cellText=tval, cellLoc="center", loc="right", colLabels=[auc_col, "95% CI"], bbox=[0, t_adjuster, 1, 1]
)
table.axis("off")
tb.auto_set_font_size(False)
tb.set_fontsize(text_size)
for _, cell in tb.get_celld().items():
cell.set_linewidth(0)

# remove spines
plot.spines["top"].set_visible(False)
plot.spines["right"].set_visible(False)
plot.spines["left"].set_visible(False)

if title:
plt.title(title)

if not show:
return fig, plot

else:
return None
86 changes: 86 additions & 0 deletions tests/_scripts/coxph_forestplot_create_expected.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"import ehrapy as ep"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"current_notebook_dir = %pwd\n",
"_TEST_IMAGE_PATH = f\"{current_notebook_dir}/../plot/_images\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"adata = ep.dt.mimic_2(encoded=False)\n",
"adata_subset = adata[:, [\"mort_day_censored\", \"censor_flg\", \"gender_num\", \"afib_flg\", \"day_icu_intime_num\"]]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"genderafib_coxph = ep.tl.cox_ph(adata_subset, duration_col=\"mort_day_censored\", event_col=\"censor_flg\")\n",
"\n",
"fig, ax = ep.pl.cox_ph_forestplot(genderafib_coxph, fig_size=(12, 3), t_adjuster=0.15, marker=\"o\", size=2, text_size=14)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig.savefig(f\"{_TEST_IMAGE_PATH}/coxph_forestplot_expected.png\", dpi=80)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ def rng():
return np.random.default_rng(seed=42)


@pytest.fixture
def mimic_2():
adata = ep.dt.mimic_2()
return adata


@pytest.fixture
def mimic_2_encoded():
adata = ep.dt.mimic_2(encoded=True)
Expand Down
Binary file added tests/plot/_images/coxph_forestplot_expected.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions tests/plot/test_catplot.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from pathlib import Path

from ehrapy.plot import catplot
import ehrapy as ep

CURRENT_DIR = Path(__file__).parent
_TEST_IMAGE_PATH = f"{CURRENT_DIR}/_images"


def test_catplot_vanilla(adata_mini, check_same_image):
fig = catplot(adata_mini, jitter=False)
fig = ep.pl.catplot(adata_mini, jitter=False)

check_same_image(
fig=fig,
Expand Down
18 changes: 18 additions & 0 deletions tests/plot/test_survival_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from pathlib import Path

import ehrapy as ep

CURRENT_DIR = Path(__file__).parent
_TEST_IMAGE_PATH = f"{CURRENT_DIR}/_images"


def test_coxph_forestplot(mimic_2, check_same_image):
adata_subset = mimic_2[:, ["mort_day_censored", "censor_flg", "gender_num", "afib_flg", "day_icu_intime_num"]]
ep.tl.cox_ph(adata_subset, duration_col="mort_day_censored", event_col="censor_flg")
fig, ax = ep.pl.cox_ph_forestplot(adata_subset, fig_size=(12, 3), t_adjuster=0.15, marker="o", size=2, text_size=14)

check_same_image(
fig=fig,
base_path=f"{_TEST_IMAGE_PATH}/coxph_forestplot",
tol=2e-1,
)

0 comments on commit ac2c531

Please sign in to comment.