Skip to content

Commit

Permalink
Fixed types
Browse files Browse the repository at this point in the history
  • Loading branch information
Zethson committed Nov 13, 2023
1 parent 54aa671 commit 402b6b2
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 23 deletions.
2 changes: 1 addition & 1 deletion ehrapy/anndata/anndata_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def move_to_x(adata: AnnData, to_x: list[str] | str) -> AnnData:
return new_adata


def get_column_indices(adata: AnnData, col_names: str | Sequence[str]) -> list[int]:
def get_column_indices(adata: AnnData, col_names: str | Iterable[str]) -> list[int]:
"""Fetches the column indices in X for a given list of column names
Args:
Expand Down
8 changes: 5 additions & 3 deletions ehrapy/io/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,13 @@ def _read_multiple_csv(
file_path: File path to the directory containing multiple .csv/.tsv files.
sep: Either , or \t to determine which files to read.
index_column: Column names of the index columns for obs
columns_obs_only: List of columns per file (AnnData object) which should only be stored in .obs, but not in X. Useful for free text annotations.
columns_x_only: List of columns per file (AnnData object) which should only be stored in .X, but not in obs. Datetime columns will be added to .obs regardless.
columns_obs_only: List of columns per file (AnnData object) which should only be stored in .obs, but not in X.
Useful for free text annotations.
columns_x_only: List of columns per file (AnnData object) which should only be stored in .X, but not in obs.
Datetime columns will be added to .obs regardless.
return_dfs: When set to True, return a dictionary of Pandas DataFrames.
cache: Whether to cache results or not
kwargs: Keyword arguments for Pandas read_csv
kwargs: Keyword arguments for Pandas `read_csv`
Returns:
A Dict mapping the filename (object name) to the corresponding :class:`~anndata.AnnData` object and the columns
Expand Down
9 changes: 5 additions & 4 deletions ehrapy/io/_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ def write(
compression: Literal["gzip", "lzf"] | None = "gzip",
compression_opts: int | None = None,
) -> None:
"""Write :class:`~anndata.AnnData` objects to file. It is possbile to either write an :class:`~anndata.AnnData` object to
a .csv file or a .h5ad file.
The .h5ad file can be used as a cache to save the current state of the object and to retrieve it faster once needed. This preserves
the object state at the time of writing. It is possible to write both, encoded and unencoded objects.
"""Write :class:`~anndata.AnnData` objects to file.
It is possbile to either write an :class:`~anndata.AnnData` object to a .csv file or a .h5ad file.
The .h5ad file can be used as a cache to save the current state of the object and to retrieve it faster once needed.
This preserves the object state at the time of writing. It is possible to write both, encoded and unencoded objects.
Args:
filename: File name or path to write the file to
Expand Down
6 changes: 3 additions & 3 deletions ehrapy/preprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from ehrapy.preprocessing._data_imputation import (
from ehrapy.preprocessing._encode import encode, undo_encoding
from ehrapy.preprocessing._highly_variable_features import highly_variable_features
from ehrapy.preprocessing._imputation import (
explicit_impute,
iterative_svd_impute,
knn_impute,
Expand All @@ -9,8 +11,6 @@
simple_impute,
soft_impute,
)
from ehrapy.preprocessing._encode import encode, undo_encoding
from ehrapy.preprocessing._highly_variable_features import highly_variable_features
from ehrapy.preprocessing._normalization import (
log_norm,
maxabs_norm,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections.abc import Iterable
from typing import TYPE_CHECKING, Literal

import numpy as np
Expand All @@ -16,8 +17,6 @@
from ehrapy.core._tool_available import _check_module_importable

if TYPE_CHECKING:
from collections.abc import Iterable

from anndata import AnnData


Expand Down Expand Up @@ -187,11 +186,9 @@ def simple_impute(

def _simple_impute(adata: AnnData, var_names: Iterable[str] | None, strategy: str) -> None:
imputer = SimpleImputer(strategy=strategy)
# impute a subset of columns
if isinstance(var_names, list):
if isinstance(var_names, Iterable):
column_indices = get_column_indices(adata, var_names)
adata.X[::, column_indices] = imputer.fit_transform(adata.X[::, column_indices])
# impute all columns if None passed
else:
adata.X = imputer.fit_transform(adata.X)

Expand Down Expand Up @@ -289,7 +286,7 @@ def _knn_impute(adata: AnnData, var_names: Iterable[str] | None, n_neighbours: i

imputer = KNNImputer(n_neighbors=n_neighbours)

if isinstance(var_names, list):
if isinstance(var_names, Iterable):
column_indices = get_column_indices(adata, var_names)
adata.X[::, column_indices] = imputer.fit_transform(adata.X[::, column_indices])
# this is required since X dtype has to be numerical in order to correctly round floats
Expand Down Expand Up @@ -568,7 +565,7 @@ def _soft_impute(
verbose,
)

if isinstance(var_names, list):
if isinstance(var_names, Iterable):
column_indices = get_column_indices(adata, var_names)
adata.X[::, column_indices] = imputer.fit_transform(adata.X[::, column_indices])
else:
Expand Down Expand Up @@ -717,7 +714,7 @@ def _iterative_svd_impute(
verbose,
)

if isinstance(var_names, list):
if isinstance(var_names, Iterable):
column_indices = get_column_indices(adata, var_names)
adata.X[::, column_indices] = imputer.fit_transform(adata.X[::, column_indices])
else:
Expand Down Expand Up @@ -816,7 +813,6 @@ def matrix_factorization_impute(
verbose,
)
adata.X = adata.X.astype("object")
# decode ordinal encoding to obtain imputed original data
adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices])

if var_names:
Expand Down Expand Up @@ -855,7 +851,7 @@ def _matrix_factorization_impute(
verbose,
)

if isinstance(var_names, list):
if isinstance(var_names, Iterable):
column_indices = get_column_indices(adata, var_names)
adata.X[::, column_indices] = imputer.fit_transform(adata.X[::, column_indices])
else:
Expand Down Expand Up @@ -1071,7 +1067,7 @@ def _miceforest_impute(
"""Utility function to impute data using miceforest"""
import miceforest as mf

if isinstance(var_names, list):
if isinstance(var_names, Iterable):
column_indices = get_column_indices(adata, var_names)

# Create kernel.
Expand Down
2 changes: 1 addition & 1 deletion tests/preprocessing/test_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sklearn.exceptions import ConvergenceWarning

from ehrapy.io._read import read_csv
from ehrapy.preprocessing._data_imputation import (
from ehrapy.preprocessing._imputation import (
_warn_imputation_threshold,
explicit_impute,
iterative_svd_impute,
Expand Down

0 comments on commit 402b6b2

Please sign in to comment.