diff --git a/ehrapy/plot/_survival_analysis.py b/ehrapy/plot/_survival_analysis.py index 948523fc..e3d2ae6b 100644 --- a/ehrapy/plot/_survival_analysis.py +++ b/ehrapy/plot/_survival_analysis.py @@ -186,7 +186,7 @@ def kmf( def kaplan_meier( kmfs: Sequence[KaplanMeierFitter], *, - diplay_table: bool = False, + display_table: bool = False, ci_alpha: list[float] | None = None, ci_force_lines: list[Boolean] | None = None, ci_show: list[Boolean] | None = None, @@ -208,7 +208,7 @@ def kaplan_meier( Args: kmfs: Iterables of fitted KaplanMeierFitter objects. - diplay_table: Display the survival probabilities in a table, below the plot. + display_table: Display the survival probabilities in a table, below the plot. ci_alpha: The transparency level of the confidence interval. If more than one kmfs, this should be a list. ci_force_lines: Force the confidence intervals to be line plots (versus default shaded areas). If more than one kmfs, this should be a list. @@ -302,7 +302,7 @@ def kaplan_meier( ax.set_title(title) # Prepare data for the table - if diplay_table: + if display_table: xticks = [x for x in ax.get_xticks() if x >= 0] xticks_space = xticks[1] - xticks[0] if xlabel is None: diff --git a/tests/_scripts/kaplain_meier_create_expected_plots.ipynb b/tests/_scripts/kaplain_meier_create_expected_plots.ipynb new file mode 100644 index 00000000..401bf639 --- /dev/null +++ b/tests/_scripts/kaplain_meier_create_expected_plots.ipynb @@ -0,0 +1,118 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "import ehrapy as ep\n", + "\n", + "current_notebook_dir = %pwd\n", + "_TEST_IMAGE_PATH = f\"{current_notebook_dir}/../plot/_images\"\n", + "mimic_2 = ep.dt.mimic_2(encoded=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "mimic_2[:, [\"censor_flg\"]].X = np.where(mimic_2[:, [\"censor_flg\"]].X == 0, 1, 0)\n", + "groups = mimic_2[:, [\"service_unit\"]].X\n", + "adata_ficu = mimic_2[groups == \"FICU\"]\n", + "adata_micu = mimic_2[groups == \"MICU\"]\n", + "kmf_1 = ep.tl.kaplan_meier(adata_ficu, duration_col=\"mort_day_censored\", event_col=\"censor_flg\", label=\"FICU\")\n", + "kmf_2 = ep.tl.kaplan_meier(adata_micu, duration_col=\"mort_day_censored\", event_col=\"censor_flg\", label=\"MICU\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = ep.pl.kaplan_meier(\n", + " [kmf_1, kmf_2],\n", + " ci_show=[False, False, False],\n", + " color=[\"k\", \"r\"],\n", + " xlim=[0, 750],\n", + " ylim=[0, 1],\n", + " xlabel=\"Days\",\n", + " ylabel=\"Proportion Survived\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "fig.savefig(f\"{_TEST_IMAGE_PATH}/kaplan_meier_expected.png\", dpi=80)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = ep.pl.kaplan_meier(\n", + " [kmf_1, kmf_2],\n", + " ci_show=[False, False, False],\n", + " color=[\"k\", \"r\"],\n", + " xlim=[0, 750],\n", + " ylim=[0, 1],\n", + " xlabel=\"Days\",\n", + " ylabel=\"Proportion Survived\",\n", + " display_table=True,\n", + " grid=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "fig.savefig(f\"{_TEST_IMAGE_PATH}/kaplan_meier_table_expected.png\", dpi=80)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/plot/_images/kaplan_meier_expected.png b/tests/plot/_images/kaplan_meier_expected.png new file mode 100644 index 00000000..cee1a8d0 Binary files /dev/null and b/tests/plot/_images/kaplan_meier_expected.png differ diff --git a/tests/plot/_images/kaplan_meier_table_expected.png b/tests/plot/_images/kaplan_meier_table_expected.png new file mode 100644 index 00000000..f14c58ac Binary files /dev/null and b/tests/plot/_images/kaplan_meier_table_expected.png differ diff --git a/tests/plot/test_survival_analysis.py b/tests/plot/test_survival_analysis.py index 2345102a..782f63b1 100644 --- a/tests/plot/test_survival_analysis.py +++ b/tests/plot/test_survival_analysis.py @@ -1,11 +1,55 @@ from pathlib import Path +import numpy as np + import ehrapy as ep CURRENT_DIR = Path(__file__).parent _TEST_IMAGE_PATH = f"{CURRENT_DIR}/_images" +def test_kaplan_meier(mimic_2, check_same_image): + mimic_2[:, ["censor_flg"]].X = np.where(mimic_2[:, ["censor_flg"]].X == 0, 1, 0) + groups = mimic_2[:, ["service_unit"]].X + adata_ficu = mimic_2[groups == "FICU"] + adata_micu = mimic_2[groups == "MICU"] + kmf_1 = ep.tl.kaplan_meier(adata_ficu, duration_col="mort_day_censored", event_col="censor_flg", label="FICU") + kmf_2 = ep.tl.kaplan_meier(adata_micu, duration_col="mort_day_censored", event_col="censor_flg", label="MICU") + fig, ax = ep.pl.kaplan_meier( + [kmf_1, kmf_2], + ci_show=[False, False, False], + color=["k", "r"], + xlim=[0, 750], + ylim=[0, 1], + xlabel="Days", + ylabel="Proportion Survived", + ) + + check_same_image( + fig=fig, + base_path=f"{_TEST_IMAGE_PATH}/kaplan_meier", + tol=2e-1, + ) + + fig, ax = ep.pl.kaplan_meier( + [kmf_1, kmf_2], + ci_show=[False, False, False], + color=["k", "r"], + xlim=[0, 750], + ylim=[0, 1], + xlabel="Days", + ylabel="Proportion Survived", + grid=True, + display_table=True, + ) + + check_same_image( + fig=fig, + base_path=f"{_TEST_IMAGE_PATH}/kaplan_meier_table", + tol=2e-1, + ) + + def test_coxph_forestplot(mimic_2, check_same_image): adata_subset = mimic_2[:, ["mort_day_censored", "censor_flg", "gender_num", "afib_flg", "day_icu_intime_num"]] ep.tl.cox_ph(adata_subset, duration_col="mort_day_censored", event_col="censor_flg")