Skip to content

Commit

Permalink
added kmf function legacy support in tests and added new kaplan_meier…
Browse files Browse the repository at this point in the history
… function in line with new signature
  • Loading branch information
aGuyLearning committed Nov 27, 2024
1 parent 6085e96 commit 16b7d5f
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 60 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
1 change: 1 addition & 0 deletions ehrapy/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
cox_ph,
glm,
kmf,
kaplan_meier,
log_logistic_aft,
nelson_aalen,
ols,
Expand Down
147 changes: 91 additions & 56 deletions ehrapy/tools/_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,80 @@ def glm(


def kmf(
durations: Iterable,
event_observed: Iterable | None = None,
timeline: Iterable = None,
entry: Iterable | None = None,
label: str | None = None,
alpha: float | None = None,
ci_labels: tuple[str, str] = None,
weights: Iterable | None = None,
censoring: Literal["right", "left"] = None,
) -> KaplanMeierFitter:
"""Fit the Kaplan-Meier estimate for the survival function.
The Kaplan–Meier estimator, also known as the product limit estimator, is a non-parametric statistic used to estimate the survival function from lifetime data.
In medical research, it is often used to measure the fraction of patients living for a certain amount of time after treatment.
See https://en.wikipedia.org/wiki/Kaplan%E2%80%93Meier_estimator
https://lifelines.readthedocs.io/en/latest/fitters/univariate/KaplanMeierFitter.html#module-lifelines.fitters.kaplan_meier_fitter
Args:
durations: length n -- duration (relative to subject's birth) the subject was alive for.
event_observed: True if the death was observed, False if the event was lost (right-censored). Defaults to all True if event_observed is equal to `None`.
timeline: return the best estimate at the values in timelines (positively increasing)
entry: Relative time when a subject entered the study. This is useful for left-truncated (not left-censored) observations.
If None, all members of the population entered study when they were "born".
label: A string to name the column of the estimate.
alpha: The alpha value in the confidence intervals. Overrides the initializing alpha for this call to fit only.
ci_labels: Add custom column names to the generated confidence intervals as a length-2 list: [<lower-bound name>, <upper-bound name>] (default: <label>_lower_<1-alpha/2>).
weights: If providing a weighted dataset. For example, instead of providing every subject
as a single element of `durations` and `event_observed`, one could weigh subject differently.
censoring: 'right' for fitting the model to a right-censored dataset.
'left' for fitting the model to a left-censored dataset (default: fit the model to a right-censored dataset).
Returns:
Fitted KaplanMeierFitter.
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=False)
>>> # Flip 'censor_fl' because 0 = death and 1 = censored
>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
>>> kmf = ep.tl.kmf(adata[:, ["mort_day_censored"]].X, adata[:, ["censor_flg"]].X)
"""
# deprecated
warnings.warn(
"This function is deprecated and will be removed in the next release. Use `kaplan_meier` instead.",
DeprecationWarning,
)
kmf = KaplanMeierFitter()
if censoring == "None" or "right":
kmf.fit(
durations=durations,
event_observed=event_observed,
timeline=timeline,
entry=entry,
label=label,
alpha=alpha,
ci_labels=ci_labels,
weights=weights,
)
elif censoring == "left":
kmf.fit_left_censoring(
durations=durations,
event_observed=event_observed,
timeline=timeline,
entry=entry,
label=label,
alpha=alpha,
ci_labels=ci_labels,
weights=weights,
)

return kmf

def kaplan_meier(
adata: AnnData,
duration_col: str,
event_col: str | None = None,
Expand All @@ -128,8 +202,6 @@ def kmf(
weights: list[float] | None = None,
fit_options: dict | None = None,
censoring: Literal["right", "left"] = "right",
durations: Iterable | None = None,
event_observed: Iterable | None = None,
) -> KaplanMeierFitter:
"""Fit the Kaplan-Meier estimate for the survival function.
Expand All @@ -154,9 +226,6 @@ def kmf(
fit_options: Additional keyword arguments to pass into the estimator.
censoring: 'right' for fitting the model to a right-censored dataset. (default, calls fit).
'left' for fitting the model to a left-censored dataset (calls fit_left_censoring).
durations: length n -- duration (relative to subject's birth) the subject was alive for. (legacy argument, use duration_col instead)
event_observed: True if the death was observed, False if the event was lost (right-censored). Defaults to all True if event_observed is equal to `None`.(this is a legacy argument, use event_col instead)
Returns:
Fitted KaplanMeierFitter.
Expand All @@ -166,57 +235,23 @@ def kmf(
>>> adata = ep.dt.mimic_2(encoded=False)
>>> # Flip 'censor_fl' because 0 = death and 1 = censored
>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
>>> kmf = ep.tl.kmf(adata, "mort_day_censored", "censor_flg", label="Mortality")
>>> kmf = ep.tl.kaplan_meier(adata, "mort_day_censored", "censor_flg", label="Mortality")
"""
# legacy support
if durations is not None:
# legacy warning
warnings.warn(
"The `durations` and `event_observed` arguments are deprecated, please use `duration_col` and `event_col` instead.",
DeprecationWarning,
stacklevel=2,
)
kmf = KaplanMeierFitter()
if censoring == "None" or "right":
kmf.fit(
durations=durations,
event_observed=event_observed,
timeline=timeline,
entry=entry,
label=label,
alpha=alpha,
ci_labels=ci_labels,
weights=weights,
)
elif censoring == "left":
kmf.fit_left_censoring(
durations=durations,
event_observed=event_observed,
timeline=timeline,
entry=entry,
label=label,
alpha=alpha,
ci_labels=ci_labels,
weights=weights,
)

return kmf
else:
return _univariate_model(
adata,
duration_col,
event_col,
KaplanMeierFitter,
True,
timeline,
entry,
label,
alpha,
ci_labels,
weights,
fit_options,
censoring,
)
return _univariate_model(
adata,
duration_col,
event_col,
KaplanMeierFitter,
True,
timeline,
entry,
label,
alpha,
ci_labels,
weights,
fit_options,
censoring,
)


def test_kmf_logrank(
Expand Down Expand Up @@ -456,7 +491,7 @@ def _univariate_model(
def nelson_aalen(
adata: AnnData,
duration_col: str,
event_col: str,
event_col: str | None = None,
timeline: list[float] | None = None,
entry: str | None = None,
label: str | None = None,
Expand Down
11 changes: 8 additions & 3 deletions tests/tools/test_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,14 @@ def _sa_func_test(self, sa_function, sa_class, mimic_2_sa):
self._sa_function_assert(sa, sa_class)

def test_kmf(self, mimic_2_sa):
adata, _, _ = mimic_2_sa
kmf = ep.tl.kmf(adata, "mort_day_censored", "censor_flg")
self._sa_function_assert(kmf, KaplanMeierFitter)
# check for deprecation warning
with pytest.warns(DeprecationWarning):
adata, _, _ = mimic_2_sa
kmf = ep.tl.kmf(adata[:, ["mort_day_censored"]].X, adata[:, ["censor_flg"]].X)
self._sa_function_assert(kmf, KaplanMeierFitter)

def test_kaplan_meyer(self, mimic_2_sa):
self._sa_func_test(ep.tl.kaplan_meier, KaplanMeierFitter, mimic_2_sa)

def test_cox_ph(self, mimic_2_sa):
self._sa_func_test(ep.tl.cox_ph, CoxPHFitter, mimic_2_sa)
Expand Down

0 comments on commit 16b7d5f

Please sign in to comment.