Skip to content

Commit

Permalink
Enhance Kaplan-Meier plot with grid layout and survival probability t…
Browse files Browse the repository at this point in the history
…able
  • Loading branch information
aGuyLearning committed Jan 15, 2025
1 parent c437f27 commit 2b8718c
Showing 1 changed file with 44 additions and 6 deletions.
50 changes: 44 additions & 6 deletions ehrapy/plot/_survival_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
from typing import TYPE_CHECKING

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
from numpy import ndarray
Expand Down Expand Up @@ -261,7 +262,10 @@ def kaplan_meier(
at_risk_counts = [False] * len(kmfs)
if color is None:
color = [None] * len(kmfs)
plt.figure(figsize=figsize)

fig = plt.figure(constrained_layout=True, figsize=figsize)
spec = fig.add_gridspec(2, 1)
ax = plt.subplot(spec[0, 0])

for i, kmf in enumerate(kmfs):
if i == 0:
Expand All @@ -283,13 +287,47 @@ def kaplan_meier(
at_risk_counts=at_risk_counts[i],
color=color[i],
)
# Configure plot appearance
ax.grid(grid)
plt.xlim(xlim)
plt.ylim(ylim)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
if title:
plt.title(title)
ax.set_title(title)

# Prepare data for the table
xticks = [x for x in ax.get_xticks() if x >= 0]
xticks_space = xticks[1] - xticks[0]
if xlabel is None:
xlabel = "Time"

yticks = np.arange(len(kmfs))

ax_table = plt.subplot(spec[1, 0])
ax_table.set_xticks(xticks)
ax_table.set_xlim(-xticks_space / 2, xticks[-1] + xticks_space / 2)
ax_table.set_ylim(-1, len(kmfs))
ax_table.set_yticks(yticks)
ax_table.set_yticklabels([kmf.label if kmf.label else f"Group {i + 1}" for i, kmf in enumerate(kmfs[::-1])])

for i, kmf in enumerate(kmfs[::-1]):
survival_probs = kmf.survival_function_at_times(xticks).values
for j, prob in enumerate(survival_probs):
ax_table.text(
xticks[j], # x position
yticks[i], # y position
f"{prob:.2f}", # formatted survival probability
ha="center",
va="center",
bbox={"boxstyle": "round,pad=0.2", "edgecolor": "none", "facecolor": "lightgrey"},
)

ax_table.grid(grid)
ax_table.spines["top"].set_visible(False)
ax_table.spines["right"].set_visible(False)
ax_table.spines["bottom"].set_visible(False)
ax_table.spines["left"].set_visible(False)

if not show:
return ax
Expand Down

0 comments on commit 2b8718c

Please sign in to comment.