diff --git a/ehrapy/plot/_survival_analysis.py b/ehrapy/plot/_survival_analysis.py index c135d582..0eb7a55e 100644 --- a/ehrapy/plot/_survival_analysis.py +++ b/ehrapy/plot/_survival_analysis.py @@ -183,6 +183,8 @@ def kmf( def kaplan_meier( kmfs: Sequence[KaplanMeierFitter], + *, + diplay_table: bool = False, ci_alpha: list[float] | None = None, ci_force_lines: list[Boolean] | None = None, ci_show: list[Boolean] | None = None, @@ -204,6 +206,7 @@ def kaplan_meier( Args: kmfs: Iterables of fitted KaplanMeierFitter objects. + diplay_table: Display the survival probabilities in a table, below the plot. ci_alpha: The transparency level of the confidence interval. If more than one kmfs, this should be a list. ci_force_lines: Force the confidence intervals to be line plots (versus default shaded areas). If more than one kmfs, this should be a list. @@ -297,37 +300,38 @@ def kaplan_meier( ax.set_title(title) # Prepare data for the table - xticks = [x for x in ax.get_xticks() if x >= 0] - xticks_space = xticks[1] - xticks[0] - if xlabel is None: - xlabel = "Time" - - yticks = np.arange(len(kmfs)) - - ax_table = plt.subplot(spec[1, 0]) - ax_table.set_xticks(xticks) - ax_table.set_xlim(-xticks_space / 2, xticks[-1] + xticks_space / 2) - ax_table.set_ylim(-1, len(kmfs)) - ax_table.set_yticks(yticks) - ax_table.set_yticklabels([kmf.label if kmf.label else f"Group {i + 1}" for i, kmf in enumerate(kmfs[::-1])]) - - for i, kmf in enumerate(kmfs[::-1]): - survival_probs = kmf.survival_function_at_times(xticks).values - for j, prob in enumerate(survival_probs): - ax_table.text( - xticks[j], # x position - yticks[i], # y position - f"{prob:.2f}", # formatted survival probability - ha="center", - va="center", - bbox={"boxstyle": "round,pad=0.2", "edgecolor": "none", "facecolor": "lightgrey"}, - ) - - ax_table.grid(grid) - ax_table.spines["top"].set_visible(False) - ax_table.spines["right"].set_visible(False) - ax_table.spines["bottom"].set_visible(False) - ax_table.spines["left"].set_visible(False) + if diplay_table: + xticks = [x for x in ax.get_xticks() if x >= 0] + xticks_space = xticks[1] - xticks[0] + if xlabel is None: + xlabel = "Time" + + yticks = np.arange(len(kmfs)) + + ax_table = plt.subplot(spec[1, 0]) + ax_table.set_xticks(xticks) + ax_table.set_xlim(-xticks_space / 2, xticks[-1] + xticks_space / 2) + ax_table.set_ylim(-1, len(kmfs)) + ax_table.set_yticks(yticks) + ax_table.set_yticklabels([kmf.label if kmf.label else f"Group {i + 1}" for i, kmf in enumerate(kmfs[::-1])]) + + for i, kmf in enumerate(kmfs[::-1]): + survival_probs = kmf.survival_function_at_times(xticks).values + for j, prob in enumerate(survival_probs): + ax_table.text( + xticks[j], # x position + yticks[i], # y position + f"{prob:.2f}", # formatted survival probability + ha="center", + va="center", + bbox={"boxstyle": "round,pad=0.2", "edgecolor": "none", "facecolor": "lightgrey"}, + ) + + ax_table.grid(grid) + ax_table.spines["top"].set_visible(False) + ax_table.spines["right"].set_visible(False) + ax_table.spines["bottom"].set_visible(False) + ax_table.spines["left"].set_visible(False) if not show: return ax