Skip to content

Commit

Permalink
added toggle for table
Browse files Browse the repository at this point in the history
  • Loading branch information
aGuyLearning committed Jan 16, 2025
1 parent 51aac2e commit cfbd412
Showing 1 changed file with 35 additions and 31 deletions.
66 changes: 35 additions & 31 deletions ehrapy/plot/_survival_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ def kmf(

def kaplan_meier(
kmfs: Sequence[KaplanMeierFitter],
*,
diplay_table: bool = False,
ci_alpha: list[float] | None = None,
ci_force_lines: list[Boolean] | None = None,
ci_show: list[Boolean] | None = None,
Expand All @@ -204,6 +206,7 @@ def kaplan_meier(
Args:
kmfs: Iterables of fitted KaplanMeierFitter objects.
diplay_table: Display the survival probabilities in a table, below the plot.
ci_alpha: The transparency level of the confidence interval. If more than one kmfs, this should be a list.
ci_force_lines: Force the confidence intervals to be line plots (versus default shaded areas).
If more than one kmfs, this should be a list.
Expand Down Expand Up @@ -297,37 +300,38 @@ def kaplan_meier(
ax.set_title(title)

# Prepare data for the table
xticks = [x for x in ax.get_xticks() if x >= 0]
xticks_space = xticks[1] - xticks[0]
if xlabel is None:
xlabel = "Time"

yticks = np.arange(len(kmfs))

ax_table = plt.subplot(spec[1, 0])
ax_table.set_xticks(xticks)
ax_table.set_xlim(-xticks_space / 2, xticks[-1] + xticks_space / 2)
ax_table.set_ylim(-1, len(kmfs))
ax_table.set_yticks(yticks)
ax_table.set_yticklabels([kmf.label if kmf.label else f"Group {i + 1}" for i, kmf in enumerate(kmfs[::-1])])

for i, kmf in enumerate(kmfs[::-1]):
survival_probs = kmf.survival_function_at_times(xticks).values
for j, prob in enumerate(survival_probs):
ax_table.text(
xticks[j], # x position
yticks[i], # y position
f"{prob:.2f}", # formatted survival probability
ha="center",
va="center",
bbox={"boxstyle": "round,pad=0.2", "edgecolor": "none", "facecolor": "lightgrey"},
)

ax_table.grid(grid)
ax_table.spines["top"].set_visible(False)
ax_table.spines["right"].set_visible(False)
ax_table.spines["bottom"].set_visible(False)
ax_table.spines["left"].set_visible(False)
if diplay_table:
xticks = [x for x in ax.get_xticks() if x >= 0]
xticks_space = xticks[1] - xticks[0]
if xlabel is None:
xlabel = "Time"

yticks = np.arange(len(kmfs))

ax_table = plt.subplot(spec[1, 0])
ax_table.set_xticks(xticks)
ax_table.set_xlim(-xticks_space / 2, xticks[-1] + xticks_space / 2)
ax_table.set_ylim(-1, len(kmfs))
ax_table.set_yticks(yticks)
ax_table.set_yticklabels([kmf.label if kmf.label else f"Group {i + 1}" for i, kmf in enumerate(kmfs[::-1])])

for i, kmf in enumerate(kmfs[::-1]):
survival_probs = kmf.survival_function_at_times(xticks).values
for j, prob in enumerate(survival_probs):
ax_table.text(
xticks[j], # x position
yticks[i], # y position
f"{prob:.2f}", # formatted survival probability
ha="center",
va="center",
bbox={"boxstyle": "round,pad=0.2", "edgecolor": "none", "facecolor": "lightgrey"},
)

ax_table.grid(grid)
ax_table.spines["top"].set_visible(False)
ax_table.spines["right"].set_visible(False)
ax_table.spines["bottom"].set_visible(False)
ax_table.spines["left"].set_visible(False)

if not show:
return ax
Expand Down

0 comments on commit cfbd412

Please sign in to comment.