Skip to content

Commit

Permalink
Added tree printing for inferred feature types
Browse files Browse the repository at this point in the history
  • Loading branch information
Lilly-May committed Apr 19, 2024
1 parent 39128fd commit 193286d
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 3 deletions.
2 changes: 1 addition & 1 deletion ehrapy/anndata/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@


FEATURE_TYPE_KEY = "feature_type"
CONTINUOUS_TAG = "continuous"
CONTINUOUS_TAG = "numeric" # TODO: Eventually rename to NUMERIC_TAG (as soon as the other NUMERIC_TAG is removed)
CATEGORICAL_TAG = "categorical"
DATE_TAG = "date"
36 changes: 34 additions & 2 deletions ehrapy/anndata/_feature_specifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import numpy as np
import pandas as pd
from rich import print
from rich.tree import Tree

from ehrapy import logging as logg
from ehrapy.anndata._constants import CATEGORICAL_TAG, CONTINUOUS_TAG, DATE_TAG, FEATURE_TYPE_KEY
Expand Down Expand Up @@ -41,11 +43,11 @@ def infer_feature_types(adata, layer: str | None = None, output: Literal["print"
adata.var[FEATURE_TYPE_KEY] = pd.Series(feature_types)[adata.var_names]

logg.info(
f"Feature types have been inferred and stored in adata.var[FEATURE_TYPE_KEY]. PLEASE CHECK and adjust if necessary using adata.var[{FEATURE_TYPE_KEY}]['feature1']='corrected_type'."
f"Feature types have been inferred and stored in adata.var[{FEATURE_TYPE_KEY}]. PLEASE CHECK and adjust if necessary using adata.var[{FEATURE_TYPE_KEY}]['feature1']='corrected_type'."
)

if output == "print":
print(adata.var[FEATURE_TYPE_KEY]) # TODO: Use ep.ad.type_overview
feature_type_overview(adata)
elif output == "dataframe":
return adata.var[FEATURE_TYPE_KEY]
elif output is not None:
Expand All @@ -60,3 +62,33 @@ def wrapper(adata, *args, **kwargs):
return func(adata, *args, **kwargs)

return wrapper


@check_feature_types
def feature_type_overview(adata):
"""
Print an overview of the feature types in the AnnData object.
Args:
adata: :class:`~anndata.A
"""
tree = Tree(
f"Detected feature types for AnnData object with {len(adata.obs_names)} obs and {len(adata.var_names)} vars",
guide_style="underline2 bright_blue",
)

branch = tree.add("📅 Date features", style="b green")
for date in sorted(adata.var_names[adata.var[FEATURE_TYPE_KEY] == DATE_TAG]):
branch.add(date)

branch = tree.add("📏 Numerical features", style="b green")
for numeric in sorted(adata.var_names[adata.var[FEATURE_TYPE_KEY] == CONTINUOUS_TAG]):
branch.add(numeric)

branch = tree.add("🗂️ Categorical features", style="b green")
cat_features = adata.var_names[adata.var[FEATURE_TYPE_KEY] == CATEGORICAL_TAG]
df = anndata_to_df(adata[:, cat_features])
for categorical in sorted(cat_features):
branch.add(f"{categorical} ({df.loc[:, categorical].nunique()} categories)")

print(tree)
1 change: 1 addition & 0 deletions ehrapy/anndata/anndata_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def _adata_type_overview(
f"[b green]Variable names for AnnData object with {len(adata.obs_names)} obs and {len(adata.var_names)} vars",
guide_style="underline2 bright_blue",
)

if "var_to_encoding" in adata.uns.keys():
original_values = adata.uns["original_values_categoricals"]
branch = tree.add("🔐 Encoded variables", style="b green")
Expand Down

0 comments on commit 193286d

Please sign in to comment.