Skip to content

Commit

Permalink
chore(pre-commit.ci): auto fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pre-commit-ci[bot] committed Jul 29, 2024
1 parent df89066 commit 9c774f8
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 56 deletions.
36 changes: 19 additions & 17 deletions src/sklearn_utilities/eval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,23 +131,25 @@ def __init__(
self,
estimator: TEstimator,
*,
tqdm_cls: Literal[
"auto",
"autonotebook",
"std",
"notebook",
"asyncio",
"keras",
"dask",
"tk",
"gui",
"rich",
"contrib.slack",
"contrib.discord",
"contrib.telegram",
"contrib.bells",
]
| type[tqdm.std.tqdm] = "auto",
tqdm_cls: (
Literal[
"auto",
"autonotebook",
"std",
"notebook",
"asyncio",
"keras",
"dask",
"tk",
"gui",
"rich",
"contrib.slack",
"contrib.discord",
"contrib.telegram",
"contrib.bells",
]
| type[tqdm.std.tqdm]
) = "auto",
tqdm_kwargs: dict[str, Any] | None = None,
verbose: bool = True,
) -> None:
Expand Down
22 changes: 13 additions & 9 deletions src/sklearn_utilities/pandas/dataframe_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,24 @@ def to_frame_or_series(
return Series(
array,
index=base_index if array.shape[0] == len(base_index) else None,
name=base_columns_or_name
if not isinstance(base_columns_or_name, Index)
else None,
name=(
base_columns_or_name
if not isinstance(base_columns_or_name, Index)
else None
),
)
if array.ndim == 2:
return DataFrame(
array,
index=base_index if array.shape[0] == len(base_index) else None,
columns=base_columns_or_name
if (
isinstance(base_columns_or_name, Index)
and array.shape[1] == len(base_columns_or_name)
)
else None,
columns=(
base_columns_or_name
if (
isinstance(base_columns_or_name, Index)
and array.shape[1] == len(base_columns_or_name)
)
else None
),
)
except Exception as e:
warnings.warn(f"Could not convert {array} to DataFrame or Series: {e}")
Expand Down
9 changes: 6 additions & 3 deletions src/sklearn_utilities/pandas/multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,12 @@ def predict(
check_is_fitted(self)
X = X[self.feature_names_in_]
preds = [est.predict(X, **predict_params) for est in self.estimators_]
preds_: DataFrame | Series | NDArray[Any] | tuple[
DataFrame | Series | NDArray[Any], ...
]
preds_: (
DataFrame
| Series
| NDArray[Any]
| tuple[DataFrame | Series | NDArray[Any], ...]
)
if any(isinstance(pred, tuple) for pred in preds):
# list of tuples of arrays to tuples of arrays
preds_ = tuple(np.array(pred).T for pred in zip(*preds))
Expand Down
6 changes: 2 additions & 4 deletions src/sklearn_utilities/proba/compose_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,12 @@ def fit(self, X: TX, y: TY, **fit_params: Any) -> Self:
@overload
def predict(
self, X: TX, return_std: Literal[False] = ..., **predict_params: Any
) -> TY:
...
) -> TY: ...

@overload
def predict(
self, X: TX, return_std: Literal[True], **predict_params: Any
) -> tuple[TY, TY]:
...
) -> tuple[TY, TY]: ...

def predict(
self, X: TX, return_std: bool = False, **predict_params: Any
Expand Down
5 changes: 3 additions & 2 deletions src/sklearn_utilities/reindex_missing_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ class ReindexMissingColumns(BaseEstimator, TransformerMixin):
def __init__(
self,
*,
if_missing: Literal["warn", "raise"]
| Callable[[Index[Any], Index[Any]], None] = "warn",
if_missing: (
Literal["warn", "raise"] | Callable[[Index[Any], Index[Any]], None]
) = "warn",
reindex_kwargs: dict[
Literal["method", "copy", "level", "fill_value", "limit", "tolerance"], Any
] = {},
Expand Down
5 changes: 3 additions & 2 deletions src/sklearn_utilities/report_non_finite.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ def __init__(
plot: bool = True,
calc_corr: bool = False,
callback: Callable[[dict[str, DataFrame | Series]], None] | None = None,
callback_figure: Callable[[Figure], None]
| None = lambda fig: Path("sklearn_utilities_info/ReportNonFinite").mkdir( # type: ignore
callback_figure: Callable[[Figure], None] | None = lambda fig: Path(
"sklearn_utilities_info/ReportNonFinite"
).mkdir( # type: ignore
parents=True, exist_ok=True
)
or fig.savefig(
Expand Down
6 changes: 3 additions & 3 deletions src/sklearn_utilities/torch/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def __init__(
*,
qr: bool = False,
svd_flip: bool | None = None,
device: torch.device | int | str = "cuda"
if torch.cuda.is_available()
else "cpu",
device: torch.device | int | str = (
"cuda" if torch.cuda.is_available() else "cpu"
),
dtype: torch.dtype = torch.float32,
**kwargs: Any,
) -> None:
Expand Down
34 changes: 18 additions & 16 deletions src/sklearn_utilities/torch/skorch/proba.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,22 +297,24 @@ def predict(
X: TX,
*,
return_std: bool = False,
type_: Literal[
"mean",
"median",
"nanmean",
"nanmedian",
"var",
"std",
"ptp",
"nanvar",
"nanstd",
]
| tuple[
Literal["mean", "median", "nanmean", "nanmedian"],
Literal["var", "std", "ptp", "nanvar", "nanstd"],
]
| None = None,
type_: (
Literal[
"mean",
"median",
"nanmean",
"nanmedian",
"var",
"std",
"ptp",
"nanvar",
"nanstd",
]
| tuple[
Literal["mean", "median", "nanmean", "nanmedian"],
Literal["var", "std", "ptp", "nanvar", "nanstd"],
]
| None
) = None,
**predict_params: Any,
) -> TY | tuple[TY, TY]:
ts_axis_ = self.estimator.criterion.ts_axis_
Expand Down

0 comments on commit 9c774f8

Please sign in to comment.