Skip to content

Commit

Permalink
coxph_forestplot
Browse files Browse the repository at this point in the history
  • Loading branch information
aGuyLearning committed Dec 6, 2024
1 parent 5336a2e commit 74bab7c
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 3 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion ehrapy/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from ehrapy.plot._colormaps import * # noqa: F403
from ehrapy.plot._missingno_pl_api import * # noqa: F403
from ehrapy.plot._scanpy_pl_api import * # noqa: F403
from ehrapy.plot._survival_analysis import kmf, ols
from ehrapy.plot._survival_analysis import kmf, ols, coxph_forestplot
from ehrapy.plot.causal_inference._dowhy import causal_effect
from ehrapy.plot.feature_ranking._feature_importances import rank_features_supervised
102 changes: 102 additions & 0 deletions ehrapy/plot/_survival_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@

from typing import TYPE_CHECKING

from lifelines import CoxPHFitter
from matplotlib import gridspec
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
from numpy import ndarray
import pandas as pd

from ehrapy.plot import scatter

Expand Down Expand Up @@ -251,3 +255,101 @@ def kmf(

if not show:
return ax


def coxph_forestplot(coxph: CoxPHFitter,
labels: list[str] | None = None,
fig_size: tuple = (10, 10),
t_adjuster: float = 0.1,
ecolor: str = 'dimgray',
size: int = 3,
marker: str = 'o',
decimal: int = 2,
text_size: int = 12,
color: str = 'k'):
"""Plots a forest plot of the Cox Proportional Hazard model.
Inspired by the forest plot in the zEpid package in Python.
Link: https://zepid.readthedocs.io/en/latest/Graphics.html#effect-measure-plots
Args:
coxph: Fitted CoxPHFitter object.
labels: List of labels for each coefficient, default uses the index of the coxph.summary.
fig_size: Width, height in inches.
t_adjuster: Adjust the table to the right.
ecolor: Color of the error bars.
size: Size of the markers.
marker: Marker style.
decimal: Number of decimal places to display.
text_size: Font size of the text.
color: Color of the markers.
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=False)
>>> adata_subset = adata[:, ["mort_day_censored", "censor_flg", "gender_num", "afib_flg", "day_icu_intime_num"]]
>>> coxph = ep.tl.coxph(adata_subset, event_col="censor_flg", duration_col="mort_day_censored")
>>> ep.pl.coxph_forestplot(coxph)
.. image:: /_static/docstring_previews/coxph_forestplot.png
"""

data = coxph.summary
auc_col = 'coef'

if labels is None:
labels = data.index
tval = []
ytick = []
for i in range(len(data)):
if not np.isnan(data[auc_col][i]):
if ((isinstance(data[auc_col][i], float)) & (isinstance(data['coef lower 95%'][i], float)) &
(isinstance(data['coef upper 95%'][i], float))):
tval.append([round(data[auc_col][i], decimal), (
'(' + str(round(data['coef lower 95%'][i], decimal)) + ', ' +
str(round(data['coef upper 95%'][i], decimal)) + ')')])
else:
tval.append([data[auc_col][i], ('(' + str(data['coef lower 95%'][i]) + ', ' + str(data['coef upper 95%'][i]) + ')')])
ytick.append(i)
else:
tval.append([' ', ' '])
ytick.append(i)

maxi = round(((pd.to_numeric(data['coef upper 95%'])).max() + 0.1),2) # setting x-axis maximum

mini = round(((pd.to_numeric(data['coef lower 95%'])).min() - 0.1), 1) # setting x-axis minimum

fig = plt.figure(figsize=fig_size)
gspec = gridspec.GridSpec(1, 6) # sets up grid
plot = plt.subplot(gspec[0, 0:4]) # plot of data
tabl = plt.subplot(gspec[0, 4:]) # table
plot.set_ylim(-1, (len(data))) # spacing out y-axis properly

plot.axvline(1, color='gray', zorder=1)
lower_diff = data[auc_col] - data['coef lower 95%']
upper_diff = data['coef upper 95%'] - data[auc_col]
plot.errorbar(data[auc_col], data.index, xerr=[lower_diff, upper_diff], marker='None', zorder=2, ecolor=ecolor, linewidth=0, elinewidth=1)
plot.scatter(data[auc_col], data.index, c=color, s=(size * 25), marker=marker, zorder=3, edgecolors='None')
plot.xaxis.set_ticks_position('bottom')
plot.yaxis.set_ticks_position('left')
plot.get_xaxis().set_major_formatter(ticker.ScalarFormatter())
plot.get_xaxis().set_minor_formatter(ticker.NullFormatter())
plot.set_yticks(ytick)
plot.set_xlim([mini, maxi])
plot.set_xticks([mini, 1, maxi])
plot.set_xticklabels([mini, 1, maxi])
plot.set_yticklabels(labels)
plot.tick_params(axis='y', labelsize=text_size)
plot.yaxis.set_ticks_position('none')
plot.invert_yaxis() # invert y-axis to align values properly with table
tb = tabl.table(cellText=tval, cellLoc='center', loc='right', colLabels=[auc_col, '95% CI'], bbox=[0, t_adjuster, 1, 1])
tabl.axis('off')
tb.auto_set_font_size(False)
tb.set_fontsize(text_size)
for _ , cell in tb.get_celld().items():
cell.set_linewidth(0)
plot.spines["top"].set_visible(False)
plot.spines["right"].set_visible(False)
plot.spines["left"].set_visible(False)
return fig, plot

86 changes: 86 additions & 0 deletions tests/_scripts/coxph_forestplot_create_expected.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"import ehrapy as ep\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"current_notebook_dir = %pwd\n",
"_TEST_IMAGE_PATH = f\"{current_notebook_dir}/../plot/_images\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"adata = ep.dt.mimic_2(encoded=False)\n",
"adata_subset = adata[:, [\"mort_day_censored\", \"censor_flg\", \"gender_num\", \"afib_flg\", \"day_icu_intime_num\"]]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"genderafib_coxph = ep.tl.cox_ph(adata_subset, duration_col=\"mort_day_censored\", event_col=\"censor_flg\")\n",
"\n",
"fig, ax = ep.pl.coxph_forestplot(genderafib_coxph, fig_size=(12,3), t_adjuster=0.15, marker=\"o\", size=2, text_size=14)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig.savefig(f\"{_TEST_IMAGE_PATH}/coxph_forestplot_expected.png\", dpi=80)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def root_dir():
def rng():
return np.random.default_rng(seed=42)

@pytest.fixture
def mimic_2():
adata = ep.dt.mimic_2()
return adata

@pytest.fixture
def mimic_2_encoded():
Expand Down
Binary file added tests/plot/_images/coxph_forestplot_expected.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
15 changes: 13 additions & 2 deletions tests/plot/test_catplot.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
from pathlib import Path

from ehrapy.plot import catplot
import ehrapy as ep

CURRENT_DIR = Path(__file__).parent
_TEST_IMAGE_PATH = f"{CURRENT_DIR}/_images"


def test_catplot_vanilla(adata_mini, check_same_image):
fig = catplot(adata_mini, jitter=False)
fig = ep.pl.catplot(adata_mini, jitter=False)

check_same_image(
fig=fig,
base_path=f"{_TEST_IMAGE_PATH}/catplot_vanilla",
tol=2e-1,
)

def test_coxph_forestplot(mimic_2, check_same_image):
adata_subset = mimic_2[:, ["mort_day_censored", "censor_flg", "gender_num", "afib_flg", "day_icu_intime_num"]]
coxph = ep.tl.cox_ph(adata_subset, duration_col="mort_day_censored", event_col="censor_flg")
fig, ax = ep.pl.coxph_forestplot(coxph, fig_size=(12,3), t_adjuster=0.15, marker="o", size=2, text_size=14)

check_same_image(
fig=fig,
base_path=f"{_TEST_IMAGE_PATH}/coxph_forestplot",
tol=2e-1,
)

0 comments on commit 74bab7c

Please sign in to comment.