Skip to content

Commit

Permalink
prepare type detection for alignment, added test
Browse files Browse the repository at this point in the history
  • Loading branch information
eroell committed Mar 5, 2024
1 parent 232c2b1 commit 4090f6f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 13 deletions.
18 changes: 5 additions & 13 deletions ehrapy/tools/cohort_tracking/_cohort_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,11 @@ def _check_columns_exist(df, columns) -> None:

# from tableone: https://github.com/tompollard/tableone/blob/bfd6fbaa4ed3e9f59e1a75191c6296a2a80ccc64/tableone/tableone.py#L555
def _detect_categorical_columns(data) -> list:
# assume all non-numerical and date columns are categorical
numeric_cols = set(data._get_numeric_data().columns.values)
date_cols = set(data.select_dtypes(include=[np.datetime64]).columns)
likely_cat = set(data.columns) - numeric_cols
# mypy absolutely looses it if likely_cat is overwritten to be a list
likely_cat_no_dates = list(likely_cat - date_cols)

# check proportion of unique values if numerical
for var in data._get_numeric_data().columns:
likely_flag = 1.0 * data[var].nunique() / data[var].count() < 0.005
if likely_flag:
likely_cat_no_dates.append(var)
return likely_cat_no_dates
# TODO grab this from ehrapy once https://github.com/theislab/ehrapy/issues/662 addressed
numeric_cols = set(data.select_dtypes("number").columns)
categorical_cols = set(data.columns) - numeric_cols

return list(categorical_cols)


class CohortTracker:
Expand Down
5 changes: 5 additions & 0 deletions tests/tools/cohort_tracking/test_cohort_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def test_CohortTracker_init_vanilla(columns, mini_adata):
assert ct._tracked_operations == []


def test_CohortTracker_type_detection(mini_adata):
ct = ep.tl.CohortTracker(mini_adata, ["glucose", "weight", "disease", "station"])
assert set(ct.categorical) == {"disease", "station"}


def test_CohortTracker_init_set_columns(mini_adata):
# limit columns
ep.tl.CohortTracker(mini_adata, columns=["glucose", "disease"])
Expand Down

0 comments on commit 4090f6f

Please sign in to comment.