diff --git a/ehrapy/anndata/_constants.py b/ehrapy/anndata/_constants.py new file mode 100644 index 00000000..1e855aad --- /dev/null +++ b/ehrapy/anndata/_constants.py @@ -0,0 +1,8 @@ +# Typing Column +# ----------------------- +# The column name and used values in adata.var for column types. + +EHRAPY_TYPE_KEY = "ehrapy_column_type" +NUMERIC_TAG = "numeric" +NON_NUMERIC_TAG = "non_numeric" +NON_NUMERIC_ENCODED_TAG = "non_numeric_encoded" diff --git a/ehrapy/anndata/anndata_ext.py b/ehrapy/anndata/anndata_ext.py index 4464fa4f..23984079 100644 --- a/ehrapy/anndata/anndata_ext.py +++ b/ehrapy/anndata/anndata_ext.py @@ -16,6 +16,7 @@ from scipy.sparse import issparse from ehrapy import logging as logg +from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_ENCODED_TAG, NON_NUMERIC_TAG, NUMERIC_TAG if TYPE_CHECKING: from collections.abc import Collection, Iterable, Sequence @@ -93,11 +94,15 @@ def df_to_anndata( # initializing an OrderedDict with a non-empty dict might not be intended, # see: https://stackoverflow.com/questions/25480089/right-way-to-initialize-an-ordereddict-using-its-constructor-such-that-it-retain/25480206 - uns = OrderedDict() + uns = OrderedDict() # type: ignore # store all numerical/non-numerical columns that are not obs only binary_columns = _detect_binary_columns(df, numerical_columns) - uns["numerical_columns"] = list(set(numerical_columns) | set(binary_columns)) - uns["non_numerical_columns"] = list(set(dataframes.df.columns) ^ set(uns["numerical_columns"])) + + var = pd.DataFrame(index=list(dataframes.df.columns)) + var[EHRAPY_TYPE_KEY] = NON_NUMERIC_TAG + var.loc[var.index.isin(list(set(numerical_columns) | set(binary_columns))), EHRAPY_TYPE_KEY] = NUMERIC_TAG + # in case of encoded columns by ehrapy, want to be able to read it back in + var.loc[var.index.str.contains("ehrapycat"), EHRAPY_TYPE_KEY] = NON_NUMERIC_ENCODED_TAG all_num = True if len(numerical_columns) == len(list(dataframes.df.columns)) else False X = X.astype(np.number) if all_num else X.astype(object) @@ -105,7 +110,7 @@ def df_to_anndata( adata = AnnData( X=X, obs=_cast_obs_columns(dataframes.obs), - var=pd.DataFrame(index=list(dataframes.df.columns)), + var=var, layers={"original": X.copy()}, uns=uns, ) @@ -202,37 +207,41 @@ def move_to_obs(adata: AnnData, to_obs: list[str] | str, copy_obs: bool = False) f"Columns `{[col for col in to_obs if col not in adata.var_names.values]}` are not in var_names." ) + cols_to_obs_indices = adata.var_names.isin(to_obs) + + num_set = _get_var_indices_for_type(adata, NUMERIC_TAG) + var_num = list(set(to_obs) & set(num_set)) + if copy_obs: - cols_to_obs_indices = adata.var_names.isin(to_obs) cols_to_obs = adata[:, cols_to_obs_indices].to_df() adata.obs = adata.obs.join(cols_to_obs) - num_set = set(adata.uns["numerical_columns"].copy()) - non_num_set = set(adata.uns["non_numerical_columns"].copy()) - var_num = [] - var_non_num = [] - for var in to_obs: - if var in num_set: - var_num.append(var) - elif var in non_num_set: - var_non_num.append(var) adata.obs[var_num] = adata.obs[var_num].apply(pd.to_numeric, errors="ignore", downcast="float") adata.obs = _cast_obs_columns(adata.obs) else: - cols_to_obs_indices = adata.var_names.isin(to_obs) df = adata[:, cols_to_obs_indices].to_df() adata._inplace_subset_var(~cols_to_obs_indices) adata.obs = adata.obs.join(df) - updated_num_uns, updated_non_num_uns, num_var = _update_uns(adata, to_obs) - adata.obs[num_var] = adata.obs[num_var].apply(pd.to_numeric, errors="ignore", downcast="float") + adata.obs[var_num] = adata.obs[var_num].apply(pd.to_numeric, errors="ignore", downcast="float") adata.obs = _cast_obs_columns(adata.obs) - adata.uns["numerical_columns"] = updated_num_uns - adata.uns["non_numerical_columns"] = updated_non_num_uns logg.info(f"Added `{to_obs}` to `obs`.") return adata +def _get_var_indices_for_type(adata: AnnData, tag: str) -> list[str]: + """Get indices of columns in var for a given tag. + + Args: + adata: The AnnData object + tag: The tag to search for, should be one of `NUMERIC_TAG`, `NON_NUMERIC_TAG` or `NON_NUMERIC_ENCODED_TAG` + + Returns: + List of numeric columns + """ + return adata.var_names[adata.var[EHRAPY_TYPE_KEY] == tag].tolist() + + def delete_from_obs(adata: AnnData, to_delete: list[str]) -> AnnData: """Delete features from obs. @@ -305,11 +314,11 @@ def move_to_x(adata: AnnData, to_x: list[str] | str) -> AnnData: if cols_not_in_x: new_adata = concat([adata, AnnData(adata.obs[cols_not_in_x])], axis=1) new_adata.obs = adata.obs[adata.obs.columns[~adata.obs.columns.isin(cols_not_in_x)]] - # update uns (copy maybe: could be a costly operation but reduces reference cycles) - # users might save those as separate AnnData object and this could be unexpected behaviour if we dont copy - num_columns_moved, non_num_columns_moved, _ = _update_uns(adata, cols_not_in_x, True) - new_adata.uns["numerical_columns"] = adata.uns["numerical_columns"] + num_columns_moved - new_adata.uns["non_numerical_columns"] = adata.uns["non_numerical_columns"] + non_num_columns_moved + + # AnnData's concat discards var if they dont match in their keys, so we need to create a new var + created_var = _create_new_var(adata, cols_not_in_x) + new_adata.var = pd.concat([adata.var, created_var], axis=0) + logg.info(f"Added `{cols_not_in_x}` features to `X`.") else: new_adata = adata @@ -486,10 +495,11 @@ def get_numeric_vars(adata: AnnData) -> list[str]: """ _assert_encoded(adata) - if "numerical_columns" not in adata.uns_keys(): + # This behaviour is consistent with the previous behaviour, allowing for a simple fully numeric X + if EHRAPY_TYPE_KEY not in adata.var.columns: return list(adata.var_names.values) else: - return adata.uns["numerical_columns"] + return _get_var_indices_for_type(adata, NUMERIC_TAG) def assert_numeric_vars(adata: AnnData, vars: Sequence[str]): @@ -579,6 +589,25 @@ def _update_uns( return all_moved_num_columns, list(all_moved_non_num_columns), None +def _create_new_var(adata: AnnData, cols_not_in_x: list[str]) -> pd.DataFrame: + """Create a new var DataFrame with the EHRAPY_TYPE_KEY column set for entries from .obs. + + Args: + adata: From where to get the .obs + cols_not_in_x: .obs columns to move to X + + Returns: + New var DataFrame with EHRAPY_TYPE_KEY column set for entries from .obs + """ + all_moved_num_columns = set(cols_not_in_x) & set(adata.obs.select_dtypes(include="number").columns) + + new_var = pd.DataFrame(index=cols_not_in_x) + new_var[EHRAPY_TYPE_KEY] = NON_NUMERIC_TAG + new_var.loc[list(all_moved_num_columns), EHRAPY_TYPE_KEY] = NUMERIC_TAG + + return new_var + + def _detect_binary_columns(df: pd.DataFrame, numerical_columns: list[str]) -> list[str]: """Detect all columns that contain only 0 and 1 (besides NaNs). diff --git a/ehrapy/preprocessing/_encode.py b/ehrapy/preprocessing/_encode.py index 7abb4380..5fd247b3 100644 --- a/ehrapy/preprocessing/_encode.py +++ b/ehrapy/preprocessing/_encode.py @@ -14,7 +14,8 @@ from sklearn.preprocessing import LabelEncoder, OneHotEncoder from ehrapy import logging as logg -from ehrapy.anndata.anndata_ext import _update_uns +from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_ENCODED_TAG, NON_NUMERIC_TAG, NUMERIC_TAG +from ehrapy.anndata.anndata_ext import _get_var_indices_for_type multi_encoding_modes = {"hash"} available_encodings = {"one-hot", "label", "count", *multi_encoding_modes} @@ -143,8 +144,8 @@ def _encode( "[bold yellow]The current AnnData object has been already encoded. Returning original AnnData object!" ) return adata + categoricals_names = _get_var_indices_for_type(adata, NON_NUMERIC_TAG) - categoricals_names = adata.uns["non_numerical_columns"] # no columns were detected, that would require an encoding (e.g. non numerical columns) if not categoricals_names: print("[bold yellow]Detected no columns that need to be encoded. Leaving passed AnnData object unchanged.") @@ -194,33 +195,30 @@ def _encode( ) progress.update(task, description=f"[bold blue]Finished {encodings} of autodetected columns.") + # copy non-encoded columns, and add new tag for encoded columns. This is needed to track encodings + new_var = pd.DataFrame(index=encoded_var_names) + new_var[EHRAPY_TYPE_KEY] = adata.var[EHRAPY_TYPE_KEY].copy() + new_var.loc[new_var.index.str.contains("ehrapycat")] = NON_NUMERIC_ENCODED_TAG + encoded_ann_data = AnnData( encoded_x, obs=adata.obs.copy(), - var={"var_names": encoded_var_names}, + var=new_var, uns=orig_uns_copy, layers={"original": updated_layer}, ) encoded_ann_data.uns["var_to_encoding"] = {categorical: encodings for categorical in categoricals_names} encoded_ann_data.uns["encoding_to_var"] = {encodings: categoricals_names} - encoded_ann_data.uns["numerical_columns"] = adata.uns["numerical_columns"].copy() - encoded_ann_data.uns["non_numerical_columns"] = [] - encoded_ann_data.uns["encoded_non_numerical_columns"] = [ - column for column in encoded_ann_data.var_names if column.startswith("ehrapycat_") - ] - _add_categoricals_to_obs(adata, encoded_ann_data, categoricals_names) # user passed categorical values with encoding mode for each of them else: - # Required since this would be deleted through side references - non_numericals = adata.uns["non_numerical_columns"].copy() # reencode data if "var_to_encoding" in adata.uns.keys(): encodings = _reorder_encodings(adata, encodings) # type: ignore adata = _undo_encoding(adata, "all") - adata.uns["non_numerical_columns"] = non_numericals + # are all specified encodings valid? for encoding in encodings.keys(): # type: ignore if encoding not in available_encodings: @@ -246,7 +244,7 @@ def _encode( "The categorical column names given contain at least one duplicate column. " "Check the column names to ensure that no column is encoded twice!" ) - elif any(cat in adata.uns["numerical_columns"] for cat in categoricals): + elif any(cat in adata.var_names[adata.var[EHRAPY_TYPE_KEY] == NUMERIC_TAG] for cat in categoricals): print( "[bold yellow]At least one of passed column names seems to have numerical dtype. In general it is not recommended " "to encode numerical columns!" @@ -298,11 +296,17 @@ def _encode( adata.var_names.to_list(), categoricals, ) + + # copy non-encoded columns, and add new tag for encoded columns. This is needed to track encodings + new_var = pd.DataFrame(index=encoded_var_names) + new_var[EHRAPY_TYPE_KEY] = adata.var[EHRAPY_TYPE_KEY].copy() + new_var.loc[new_var.index.str.contains("ehrapycat")] = NON_NUMERIC_ENCODED_TAG + try: encoded_ann_data = AnnData( X=encoded_x, obs=adata.obs.copy(), - var={"var_names": encoded_var_names}, + var=new_var, uns=orig_uns_copy, layers={"original": updated_layer}, ) @@ -315,12 +319,6 @@ def _encode( "Creation of AnnData object failed. Ensure that you passed all non numerical, " "categorical values for encoding!" ) from None - updated_num_uns, updated_non_num_uns, _ = _update_uns(adata, categoricals) - encoded_ann_data.uns["numerical_columns"] = updated_num_uns - encoded_ann_data.uns["non_numerical_columns"] = updated_non_num_uns - encoded_ann_data.uns["encoded_non_numerical_columns"] = [ - column for column in encoded_ann_data.var_names if column.startswith("ehrapycat_") - ] _add_categoricals_to_obs(adata, encoded_ann_data, categoricals) @@ -686,7 +684,8 @@ def _undo_encoding( new_obs = adata.obs[columns_obs_only] uns = OrderedDict() # reset uns and keep numerical/non-numerical columns - num_vars, non_num_vars = adata.uns["numerical_columns"], adata.uns["non_numerical_columns"] + num_vars = _get_var_indices_for_type(adata, NUMERIC_TAG) + non_num_vars = _get_var_indices_for_type(adata, NON_NUMERIC_TAG) for cat in categoricals: original_values = adata.uns["original_values_categoricals"][cat] type_first_nan = original_values[np.where(original_values != np.nan)][0] @@ -695,6 +694,11 @@ def _undo_encoding( else: non_num_vars.append(cat) + var = pd.DataFrame(index=new_var_names) + var[EHRAPY_TYPE_KEY] = NON_NUMERIC_TAG + # Notice previously encoded columns are now newly added, and will stay tagged as non numeric + var.loc[num_vars, EHRAPY_TYPE_KEY] = NUMERIC_TAG + uns["numerical_columns"] = num_vars uns["non_numerical_columns"] = non_num_vars @@ -703,7 +707,7 @@ def _undo_encoding( return AnnData( new_x, obs=new_obs, - var=pd.DataFrame(index=new_var_names), + var=var, uns=uns, layers={"original": new_x.copy()}, ) @@ -877,7 +881,7 @@ def _add_categoricals_to_uns(original: AnnData, new: AnnData, categorical_names: continue elif var_name in categorical_names: # keep numerical dtype when writing original values to uns - if var_name in original.uns["numerical_columns"]: + if var_name in original.var_names[original.var[EHRAPY_TYPE_KEY] == NUMERIC_TAG]: new["original_values_categoricals"][var_name] = original.X[::, idx : idx + 1].astype("float") else: new["original_values_categoricals"][var_name] = original.X[::, idx : idx + 1].astype("str") diff --git a/ehrapy/preprocessing/_imputation.py b/ehrapy/preprocessing/_imputation.py index 7e23d488..836778ef 100644 --- a/ehrapy/preprocessing/_imputation.py +++ b/ehrapy/preprocessing/_imputation.py @@ -13,6 +13,7 @@ from ehrapy import logging as logg from ehrapy import settings +from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_TAG from ehrapy.anndata.anndata_ext import _get_column_indices from ehrapy.core._tool_available import _check_module_importable @@ -252,7 +253,7 @@ def knn_impute( else: # ordinal encoding is used since non-numerical data can not be imputed using KNN Imputation enc = OrdinalEncoder() - column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"]) + column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices]) # impute the data using KNN imputation _knn_impute(adata, var_names, n_neighbours) @@ -513,7 +514,7 @@ def soft_impute( else: # ordinal encoding is used since non-numerical data can not be imputed using SoftImpute enc = OrdinalEncoder() - column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"]) + column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices]) # impute the data using SoftImpute _soft_impute( @@ -669,7 +670,7 @@ def iterative_svd_impute( else: # ordinal encoding is used since non-numerical data can not be imputed using IterativeSVD enc = OrdinalEncoder() - column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"]) + column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices]) # impute the data using IterativeSVD _iterative_svd_impute( @@ -809,7 +810,7 @@ def matrix_factorization_impute( else: # ordinal encoding is used since non-numerical data can not be imputed using MatrixFactorization enc = OrdinalEncoder() - column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"]) + column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices]) # impute the data using MatrixFactorization _matrix_factorization_impute( @@ -931,7 +932,7 @@ def nuclear_norm_minimization_impute( else: # ordinal encoding is used since non-numerical data can not be imputed using NuclearNormMinimization enc = OrdinalEncoder() - column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"]) + column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices]) # impute the data using NuclearNormMinimization _nuclear_norm_minimization_impute( @@ -1057,7 +1058,7 @@ def mice_forest_impute( else: # ordinal encoding is used since non-numerical data can not be imputed using miceforest enc = OrdinalEncoder() - column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"]) + column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices]) # impute the data using miceforest _miceforest_impute( diff --git a/ehrapy/tools/feature_ranking/_rank_features_groups.py b/ehrapy/tools/feature_ranking/_rank_features_groups.py index c35088ad..dac5f009 100644 --- a/ehrapy/tools/feature_ranking/_rank_features_groups.py +++ b/ehrapy/tools/feature_ranking/_rank_features_groups.py @@ -8,6 +8,7 @@ import scanpy as sc from ehrapy.anndata import move_to_x +from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_ENCODED_TAG, NUMERIC_TAG from ehrapy.preprocessing import encode if TYPE_CHECKING: @@ -201,8 +202,7 @@ def _evaluate_categorical_features( groups_order = _get_groups_order(groups_subset=groups, group_names=group_names, reference=reference) groups_values = adata.obs[groupby].to_numpy() - - for feature in adata.uns["encoded_non_numerical_columns"]: + for feature in adata.var_names[adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_ENCODED_TAG]: if feature == groupby or "ehrapycat_" + feature == groupby or feature == "ehrapycat_" + groupby: continue @@ -422,42 +422,20 @@ def rank_features_groups( else adata[:, columns_to_rank["var_names"]].layers[layer] ) var_to_keep = adata[:, columns_to_rank["var_names"]].var - uns_num_to_keep = _get_intersection( - adata_uns=adata.uns, key="numerical_columns", selection=columns_to_rank["var_names"] - ) - uns_non_num_to_keep = _get_intersection( - adata_uns=adata.uns, key="non_numerical_columns", selection=columns_to_rank["var_names"] - ) - uns_enc_to_keep = _get_intersection( - adata_uns=adata.uns, key="encoded_non_numerical_columns", selection=columns_to_rank["var_names"] - ) else: X_to_keep = adata.X if layer is None else adata.layers[layer] var_to_keep = adata.var - uns_num_to_keep = adata.uns["numerical_columns"] if "numerical_columns" in adata.uns else [] - uns_enc_to_keep = ( - adata.uns["encoded_non_numerical_columns"] if "encoded_non_numerical_columns" in adata.uns else [] - ) - uns_non_num_to_keep = adata.uns["non_numerical_columns"] if "non_numerical_columns" in adata.uns else [] else: # dummy 1-dimensional X to be used by move_to_x, and removed again afterwards X_to_keep = np.zeros((len(adata), 1)) var_to_keep = pd.DataFrame({"dummy": [0]}) - uns_num_to_keep = [] - uns_enc_to_keep = [] - uns_non_num_to_keep = [] adata_minimal = sc.AnnData( X=X_to_keep, obs=adata.obs, var=var_to_keep, - uns={ - "numerical_columns": uns_num_to_keep, - "encoded_non_numerical_columns": uns_enc_to_keep, - "non_numerical_columns": uns_non_num_to_keep, - }, ) if field_to_rank in ["obs", "layer_and_obs"]: @@ -501,12 +479,12 @@ def rank_features_groups( group_names = pd.Categorical(adata.obs[groupby].astype(str)).categories.tolist() - if adata.uns["numerical_columns"]: + if list(adata.var_names[adata.var[EHRAPY_TYPE_KEY] == NUMERIC_TAG]): # Rank numerical features # Without copying `numerical_adata` is a view, and code throws an error # because of "object" type of .X - numerical_adata = adata[:, adata.uns["numerical_columns"]].copy() + numerical_adata = adata[:, adata.var_names[adata.var[EHRAPY_TYPE_KEY] == NUMERIC_TAG]].copy() numerical_adata.X = numerical_adata.X.astype(float) sc.tl.rank_genes_groups( @@ -539,7 +517,7 @@ def rank_features_groups( groups_order=group_names, ) - if adata.uns["encoded_non_numerical_columns"]: + if list(adata.var_names[adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_ENCODED_TAG]): ( categorical_names, categorical_scores, diff --git a/tests/anndata/test_anndata_ext.py b/tests/anndata/test_anndata_ext.py index a848e377..3ce75f5c 100644 --- a/tests/anndata/test_anndata_ext.py +++ b/tests/anndata/test_anndata_ext.py @@ -11,6 +11,7 @@ from pandas.testing import assert_frame_equal import ehrapy as ep +from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_ENCODED_TAG, NON_NUMERIC_TAG, NUMERIC_TAG from ehrapy.anndata.anndata_ext import ( NotEncodedError, _assert_encoded, @@ -87,8 +88,20 @@ def test_move_to_x(self): assert set(new_adata_num.obs.columns) == {"name"} assert {str(col) for col in new_adata_num.obs.dtypes} == {"category"} assert {str(col) for col in new_adata_non_num.obs.dtypes} == {"float32", "category"} - assert len(sum(list(new_adata_num.uns.values()), [])) == len(list(new_adata_num.var_names)) - assert len(sum(list(new_adata_non_num.uns.values()), [])) == len(list(new_adata_non_num.var_names)) + assert_frame_equal( + new_adata_non_num.var, + DataFrame( + {EHRAPY_TYPE_KEY: [NUMERIC_TAG, NUMERIC_TAG, NON_NUMERIC_TAG]}, + index=["los_days", "b12_values", "name"], + ), + ) + assert_frame_equal( + new_adata_num.var, + DataFrame( + {EHRAPY_TYPE_KEY: [NUMERIC_TAG, NUMERIC_TAG, NON_NUMERIC_TAG, NUMERIC_TAG]}, + index=["los_days", "b12_values", "name", "clinic_id"], + ), + ) assert_frame_equal( new_adata_num.obs, DataFrame( @@ -121,7 +134,6 @@ def test_move_to_x_move_to_obs(self): adata = move_to_x(adata, ["name"]) assert {"name"}.issubset(set(adata.var_names)) assert adata.X.shape == adata_dim_old - assert "name" in [item for sublist in adata.uns.values() for item in sublist] delete_from_obs(adata, ["name"]) # case 2: move some column from obs to X and this col was previously moved inplace from X to obs @@ -130,9 +142,6 @@ def test_move_to_x_move_to_obs(self): assert not {"clinic_id"}.issubset(set(adata.obs.columns)) assert {"clinic_id"}.issubset(set(adata.var_names)) assert adata.X.shape == adata_dim_old - assert "clinic_id" in [ - item for sublist in adata.uns.values() for item in sublist - ] # check if the column in in uns # case 3: move multiple columns from obs to X and some of them were copied or moved inplace previously from X to obs move_to_obs(adata, ["los_days"], copy_obs=True) @@ -145,7 +154,6 @@ def test_move_to_x_move_to_obs(self): assert not {"b12_values"}.issubset(set(adata.obs.columns)) assert {"los_days", "b12_values"}.issubset(set(adata.var_names)) assert adata.X.shape == adata_dim_old - assert {"los_days", "b12_values"}.issubset({item for sublist in adata.uns.values() for item in sublist}) def test_delete_from_obs(self): adata = ep.io.read_csv(CUR_DIR / "../io/test_data_io/dataset_move_obs_mix.csv") @@ -153,7 +161,7 @@ def test_delete_from_obs(self): adata = delete_from_obs(adata, ["los_days"]) assert not {"los_days"}.issubset(set(adata.obs.columns)) assert {"los_days"}.issubset(set(adata.var_names)) - assert {"los_days"}.issubset({item for sublist in adata.uns.values() for item in sublist}) + assert EHRAPY_TYPE_KEY in adata.var.columns def test_df_to_anndata_simple(self): df, col1_val, col2_val, col3_val = TestAnndataExt._setup_df_to_anndata() @@ -281,27 +289,49 @@ def test_anndata_to_df_layers(self): def test_detect_binary_columns(self): binary_df = TestAnndataExt._setup_binary_df_to_anndata() adata = df_to_anndata(binary_df) - assert set(adata.uns["non_numerical_columns"]) == { - "col1", - "col2", - } - assert set(adata.uns["numerical_columns"]) == { - "col3", - "col4", - "col5", - "col6", - "col7_binary_int", - "col8_binary_float", - "col9_binary_missing_values", - } + + assert_frame_equal( + adata.var, + DataFrame( + { + EHRAPY_TYPE_KEY: [ + NON_NUMERIC_TAG, + NON_NUMERIC_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + ] + }, + index=[ + "col1", + "col2", + "col3", + "col4", + "col5", + "col6", + "col7_binary_int", + "col8_binary_float", + "col9_binary_missing_values", + ], + ), + ) def test_detect_mixed_binary_columns(self): df = pd.DataFrame( {"Col1": list(range(4)), "Col2": ["str" + str(i) for i in range(4)], "Col3": [1.0, 0.0, np.nan, 1.0]} ) adata = ep.ad.df_to_anndata(df) - assert set(adata.uns["non_numerical_columns"]) == {"Col2"} - assert set(adata.uns["numerical_columns"]) == {"Col1", "Col3"} + assert_frame_equal( + adata.var, + DataFrame( + {EHRAPY_TYPE_KEY: [NUMERIC_TAG, NON_NUMERIC_TAG, NUMERIC_TAG]}, + index=["Col1", "Col2", "Col3"], + ), + ) @staticmethod def _setup_df_to_anndata() -> tuple[DataFrame, list, list, list]: @@ -377,7 +407,7 @@ def setup_method(self): var=pd.DataFrame(data=var_numeric, index=var_numeric["Feature"]), uns=OrderedDict(), ) - + self.adata_numeric.var[EHRAPY_TYPE_KEY] = [NUMERIC_TAG, NUMERIC_TAG, NON_NUMERIC_TAG, NON_NUMERIC_TAG] self.adata_numeric.uns["numerical_columns"] = ["Numeric1", "Numeric2"] self.adata_numeric.uns["non_numerical_columns"] = ["String1", "String2"] self.adata_strings = AnnData( @@ -385,6 +415,7 @@ def setup_method(self): obs=pd.DataFrame(data=obs_data), var=pd.DataFrame(data=var_strings, index=var_strings["Feature"]), ) + self.adata_strings.var[EHRAPY_TYPE_KEY] = [NUMERIC_TAG, NUMERIC_TAG, NON_NUMERIC_TAG, NON_NUMERIC_TAG] self.adata_strings.uns["numerical_columns"] = ["Numeric1", "Numeric2"] self.adata_strings.uns["non_numerical_columns"] = ["String1", "String2"] self.adata_encoded = ep.pp.encode(self.adata_strings.copy(), autodetect=True, encodings="label") diff --git a/tests/preprocessing/test_encode.py b/tests/preprocessing/test_encode.py index c6912b99..3e678c88 100644 --- a/tests/preprocessing/test_encode.py +++ b/tests/preprocessing/test_encode.py @@ -2,8 +2,10 @@ import pandas as pd import pytest -from pandas import CategoricalDtype +from pandas import CategoricalDtype, DataFrame +from pandas.testing import assert_frame_equal +from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_ENCODED_TAG, NON_NUMERIC_TAG, NUMERIC_TAG from ehrapy.io._read import read_csv from ehrapy.preprocessing._encode import DuplicateColumnEncodingError, _reorder_encodings, encode @@ -55,20 +57,41 @@ def test_autodetect_encode(): assert id(encoded_ann_data.var) != id(adata.var) assert all(column in set(encoded_ann_data.obs.columns) for column in ["survival", "clinic_day"]) assert not any(column in set(adata.obs.columns) for column in ["survival", "clinic_day"]) - assert all(column in set(adata.uns["non_numerical_columns"]) for column in ["survival", "clinic_day"]) - assert not any( - column in set(encoded_ann_data.uns["non_numerical_columns"]) for column in ["survival", "clinic_day"] + assert_frame_equal( + adata.var, + DataFrame( + {EHRAPY_TYPE_KEY: [NUMERIC_TAG, NUMERIC_TAG, NUMERIC_TAG, NON_NUMERIC_TAG, NON_NUMERIC_TAG]}, + index=["patient_id", "los_days", "b12_values", "survival", "clinic_day"], + ), ) - assert all( - column in set(encoded_ann_data.uns["encoded_non_numerical_columns"]) - for column in [ - "ehrapycat_clinic_day_Friday", - "ehrapycat_clinic_day_Monday", - "ehrapycat_survival_False", - "ehrapycat_clinic_day_Saturday", - "ehrapycat_clinic_day_Sunday", - "ehrapycat_survival_True", - ] + assert_frame_equal( + encoded_ann_data.var, + DataFrame( + { + EHRAPY_TYPE_KEY: [ + NON_NUMERIC_ENCODED_TAG, + NON_NUMERIC_ENCODED_TAG, + NON_NUMERIC_ENCODED_TAG, + NON_NUMERIC_ENCODED_TAG, + NON_NUMERIC_ENCODED_TAG, + NON_NUMERIC_ENCODED_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + ] + }, + index=[ + "ehrapycat_survival_False", + "ehrapycat_survival_True", + "ehrapycat_clinic_day_Friday", + "ehrapycat_clinic_day_Monday", + "ehrapycat_clinic_day_Saturday", + "ehrapycat_clinic_day_Sunday", + "patient_id", + "los_days", + "b12_values", + ], + ), ) assert pd.api.types.is_bool_dtype(encoded_ann_data.obs["survival"].dtype) assert isinstance(encoded_ann_data.obs["clinic_day"].dtype, CategoricalDtype) @@ -105,13 +128,33 @@ def test_autodetect_custom_mode(): assert id(encoded_ann_data.var) != id(adata.var) assert all(column in set(encoded_ann_data.obs.columns) for column in ["survival", "clinic_day"]) assert not any(column in set(adata.obs.columns) for column in ["survival", "clinic_day"]) - assert all(column in set(adata.uns["non_numerical_columns"]) for column in ["survival", "clinic_day"]) - assert not any( - column in set(encoded_ann_data.uns["non_numerical_columns"]) for column in ["survival", "clinic_day"] + assert_frame_equal( + adata.var, + DataFrame( + {EHRAPY_TYPE_KEY: [NUMERIC_TAG, NUMERIC_TAG, NUMERIC_TAG, NON_NUMERIC_TAG, NON_NUMERIC_TAG]}, + index=["patient_id", "los_days", "b12_values", "survival", "clinic_day"], + ), ) - assert all( - column in set(encoded_ann_data.uns["encoded_non_numerical_columns"]) - for column in ["ehrapycat_survival", "ehrapycat_clinic_day"] + assert_frame_equal( + encoded_ann_data.var, + DataFrame( + { + EHRAPY_TYPE_KEY: [ + NON_NUMERIC_ENCODED_TAG, + NON_NUMERIC_ENCODED_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + ] + }, + index=[ + "ehrapycat_survival", + "ehrapycat_clinic_day", + "patient_id", + "los_days", + "b12_values", + ], + ), ) assert pd.api.types.is_bool_dtype(encoded_ann_data.obs["survival"].dtype) assert isinstance(encoded_ann_data.obs["clinic_day"].dtype, CategoricalDtype) @@ -155,19 +198,39 @@ def test_custom_encode(): assert id(encoded_ann_data.var) != id(adata.var) assert all(column in set(encoded_ann_data.obs.columns) for column in ["survival", "clinic_day"]) assert not any(column in set(adata.obs.columns) for column in ["survival", "clinic_day"]) - assert all(column in set(adata.uns["non_numerical_columns"]) for column in ["survival", "clinic_day"]) - assert not any( - column in set(encoded_ann_data.uns["non_numerical_columns"]) for column in ["survival", "clinic_day"] + assert_frame_equal( + adata.var, + DataFrame( + {EHRAPY_TYPE_KEY: [NUMERIC_TAG, NUMERIC_TAG, NUMERIC_TAG, NON_NUMERIC_TAG, NON_NUMERIC_TAG]}, + index=["patient_id", "los_days", "b12_values", "survival", "clinic_day"], + ), ) - assert all( - column in set(encoded_ann_data.uns["encoded_non_numerical_columns"]) - for column in [ - "ehrapycat_survival", - "ehrapycat_clinic_day_Friday", - "ehrapycat_clinic_day_Monday", - "ehrapycat_clinic_day_Saturday", - "ehrapycat_clinic_day_Sunday", - ] + assert_frame_equal( + encoded_ann_data.var, + DataFrame( + { + EHRAPY_TYPE_KEY: [ + NON_NUMERIC_ENCODED_TAG, + NON_NUMERIC_ENCODED_TAG, + NON_NUMERIC_ENCODED_TAG, + NON_NUMERIC_ENCODED_TAG, + NON_NUMERIC_ENCODED_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + ] + }, + index=[ + "ehrapycat_clinic_day_Friday", + "ehrapycat_clinic_day_Monday", + "ehrapycat_clinic_day_Saturday", + "ehrapycat_clinic_day_Sunday", + "ehrapycat_survival", + "patient_id", + "los_days", + "b12_values", + ], + ), ) assert pd.api.types.is_bool_dtype(encoded_ann_data.obs["survival"].dtype) assert isinstance(encoded_ann_data.obs["clinic_day"].dtype, CategoricalDtype) diff --git a/tests/preprocessing/test_normalization.py b/tests/preprocessing/test_normalization.py index db03a225..1bd23cb5 100644 --- a/tests/preprocessing/test_normalization.py +++ b/tests/preprocessing/test_normalization.py @@ -8,6 +8,7 @@ from anndata import AnnData import ehrapy as ep +from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_TAG, NUMERIC_TAG CURRENT_DIR = Path(__file__).parent _TEST_PATH = f"{CURRENT_DIR}/test_preprocessing" @@ -25,9 +26,12 @@ def adata_to_norm(): ], dtype=np.dtype(object), ) + # the "ignore" tag is used to make the column being ignored; the original test selecting a few + # columns induces a specific ordering which is kept for now var_data = { "Feature": ["Integer1", "Numeric1", "Numeric2", "Numeric3", "String1", "String2"], "Type": ["Integer", "Numeric", "Numeric", "Numeric", "String", "String"], + EHRAPY_TYPE_KEY: [NON_NUMERIC_TAG, NUMERIC_TAG, NUMERIC_TAG, "ignore", NON_NUMERIC_TAG, NON_NUMERIC_TAG], } adata = AnnData( X=X_data, @@ -35,8 +39,7 @@ def adata_to_norm(): var=pd.DataFrame(data=var_data, index=var_data["Feature"]), uns=OrderedDict(), ) - adata.uns["numerical_columns"] = ["Numeric1", "Numeric2"] - adata.uns["non_numerical_columns"] = ["String1", "String2"] + adata = ep.pp.encode(adata, autodetect=True, encodings="label") return adata diff --git a/tests/tools/test_features_ranking.py b/tests/tools/test_features_ranking.py index cc9c3c22..399197e9 100644 --- a/tests/tools/test_features_ranking.py +++ b/tests/tools/test_features_ranking.py @@ -352,7 +352,6 @@ def test_rank_features_groups_consistent_results(self): ) # to keep the same variables as in the datsets above, in order to make the comparison of consistency adata_features_in_x_and_obs = adata_features_in_x_and_obs[:, ["sys_bp_entry", "dia_bp_entry", "glucose"]] - adata_features_in_x_and_obs.uns["numerical_columns"] = ["sys_bp_entry", "dia_bp_entry", "glucose"] ep.tl.rank_features_groups(adata_features_in_x, groupby="disease") ep.tl.rank_features_groups(adata_features_in_obs, groupby="disease", field_to_rank="obs")