Skip to content

Commit

Permalink
add more sa models
Browse files Browse the repository at this point in the history
  • Loading branch information
fatisati committed Mar 1, 2024
1 parent 20ca992 commit e781df0
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 18 deletions.
2 changes: 1 addition & 1 deletion ehrapy/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ehrapy.tools._sa import anova_glm, cox_ph, glm, kmf, ols, test_kmf_logrank, test_nested_f_statistic
from ehrapy.tools._sa import anova_glm, cox_ph, glm, kmf, ols, test_kmf_logrank, test_nested_f_statistic, nelson_alen, weibull, weibull_aft, log_rogistic_aft
from ehrapy.tools._scanpy_tl_api import * # noqa: F403
from ehrapy.tools.causal._dowhy import causal_inference
from ehrapy.tools.feature_ranking._rank_features_groups import filter_rank_features_groups, rank_features_groups
Expand Down
40 changes: 27 additions & 13 deletions ehrapy/tools/_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
from lifelines import CoxPHFitter, KaplanMeierFitter, NelsonAalenFitter
from lifelines import CoxPHFitter, KaplanMeierFitter, NelsonAalenFitter, WeibullFitter, WeibullAFTFitter, LogLogisticAFTFitter
from lifelines.statistics import StatisticalResult, logrank_test
from scipy import stats

Expand Down Expand Up @@ -266,6 +266,16 @@ def anova_glm(result_1: GLMResultsWrapper, result_2: GLMResultsWrapper, formula_
dataframe = pd.DataFrame(data=table)
return dataframe

def regression_model(model_class, adata: AnnData, duration_col: str, event_col: str, entry_col: str = None):
df = anndata_to_df(adata)
keys = [duration_col, event_col]
if entry_col:
keys.append(entry_col)
df = df[keys]
model = model_class()
model.fit(df, duration_col, event_col, entry_col=entry_col)

return model

def cox_ph(adata: AnnData, duration_col: str, event_col: str, entry_col: str = None) -> CoxPHFitter:
"""Fit the Cox’s proportional hazard for the survival function.
Expand All @@ -291,22 +301,26 @@ def cox_ph(adata: AnnData, duration_col: str, event_col: str, entry_col: str = N
>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
>>> cph = ep.tl.cox_ph(adata, "mort_day_censored", "censor_flg")
"""
df = anndata_to_df(adata)
keys = [duration_col, event_col]
if entry_col:
keys.append(entry_col)
df = df[keys]
cph = CoxPHFitter()
cph.fit(df, duration_col, event_col, entry_col=entry_col)
return regression_model(CoxPHFitter, adata, duration_col, event_col, entry_col)

return cph
def weibull_aft(adata: AnnData, duration_col: str, event_col: str, entry_col: str = None) -> WeibullAFTFitter:
return regression_model(WeibullAFTFitter, adata, duration_col, event_col, entry_col)

def nelson_alen(adata: AnnData, duration_col: str, event_col: str):
def log_rogistic_aft(adata: AnnData, duration_col: str, event_col: str, entry_col: str = None) -> LogLogisticAFTFitter:
return regression_model(LogLogisticAFTFitter, adata, duration_col, event_col, entry_col)

def univariate_model(adata: AnnData, duration_col: str, event_col: str, model_class):
df = anndata_to_df(adata)
T = df[duration_col]
E = df[event_col]

naf = NelsonAalenFitter()
model = model_class()

model.fit(T,event_observed=E)
return model

def nelson_alen(adata: AnnData, duration_col: str, event_col: str):
return univariate_model(adata, duration_col, event_col, NelsonAalenFitter)

naf.fit(T,event_observed=E)
return naf
def weibull(adata: AnnData, duration_col: str, event_col: str):
return univariate_model(adata, duration_col, event_col, WeibullFitter)
17 changes: 13 additions & 4 deletions tests/tools/test_sa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest
import statsmodels
from lifelines import CoxPHFitter, KaplanMeierFitter, NelsonAalenFitter
from lifelines import CoxPHFitter, KaplanMeierFitter, NelsonAalenFitter, WeibullFitter, WeibullAFTFitter, LogLogisticAFTFitter

import ehrapy as ep

Expand Down Expand Up @@ -82,10 +82,19 @@ def sa_func_test(self, sa_function, sa_class):
assert sum(sa.event_observed) == 497

def test_kmf(self):
self.sa_func_test(self.ep.tl.kmf, KaplanMeierFitter)
self.sa_func_test(ep.tl.kmf, KaplanMeierFitter)

def test_cox_ph(self):
self.sa_func_test(self.ep.tl.cox_ph, CoxPHFitter)
self.sa_func_test(ep.tl.cox_ph, CoxPHFitter)

def test_nelson_alen(self):
self.sa_func_test(self.ep.tl.nelson_alen, NelsonAalenFitter)
self.sa_func_test(ep.tl.nelson_alen, NelsonAalenFitter)

def test_weibull(self):
self.sa_func_test(ep.tl.weibull, WeibullFitter)

def test_weibull_aft(self):
self.sa_func_test(ep.tl.weibull_aft, WeibullAFTFitter)

def test_log_logistic(self):
self.sa_func_test(ep.tl.log_rogistic_aft, LogLogisticAFTFitter)

0 comments on commit e781df0

Please sign in to comment.