Skip to content

Commit

Permalink
Fix filter issue
Browse files Browse the repository at this point in the history
  • Loading branch information
bryant1410 committed Apr 30, 2023
1 parent b07b8da commit 473e566
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
13 changes: 11 additions & 2 deletions features.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,14 @@ def _transform_features_to_numbers(
remainder="passthrough",
**common_column_transformer_kwargs,
)),
("filter", SelectMinNonMostFrequentValues(min_non_most_frequent_values)),
("filter", make_column_transformer(
# Pass a separate copy of the same type of transformer to bools because the output bool dtype is otherwise
# lost. See https://github.com/scikit-learn/scikit-learn/issues/25560
(SelectMinNonMostFrequentValues(min_non_most_frequent_values), make_column_selector(dtype_include=bool)),
# `set_output` is not passed to the `remainder` transformer.
# See https://github.com/scikit-learn/scikit-learn/issues/26306
remainder=SelectMinNonMostFrequentValues(min_non_most_frequent_values).set_output(transform="pandas"),
**common_column_transformer_kwargs)),
("scaler", make_column_transformer(
(StandardScaler(), make_column_selector(dtype_exclude=None if standardize_binary_features else bool)),
remainder="passthrough",
Expand All @@ -516,7 +523,9 @@ def _transform_features_to_numbers(
("imputer", make_column_transformer(
(SimpleImputer(strategy="mean"), make_column_selector(dtype_include=np.number)),
(BoolImputer(strategy="most_frequent"), make_column_selector(dtype_include=bool)),
remainder=SimpleImputer(strategy="most_frequent"),
# `set_output` is not passed to the `remainder` transformer.
# See https://github.com/scikit-learn/scikit-learn/issues/26306
remainder=SimpleImputer(strategy="most_frequent").set_output(transform="pandas"),
**common_column_transformer_kwargs,
)),
], verbose=verbose).set_output(transform="pandas").fit_transform(df)
Expand Down
21 changes: 13 additions & 8 deletions sklearn_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from collections import Counter
from typing import Any, Callable, Sequence
from numbers import Real
from typing import Any, Callable, Iterable, Mapping, Sequence

import numpy as np
import numpy.typing as npt
Expand All @@ -13,17 +14,24 @@
from sklearn.impute._base import _check_inputs_dtype
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.utils import _is_pandas_na, is_scalar_nan
from sklearn.utils._param_validation import Interval
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted


class SelectMinNonMostFrequentValues(SelectorMixin, BaseEstimator):
"""Select features with at least `min_non_most_frequent_values` values different from the most frequent value."""

_parameter_constraints: Mapping[str, Iterable[Any]] = {
"min_non_most_frequent_values": [Interval(Real, 0, None, closed="left")],
"leave_at_least_one": ["boolean"],
}

def __init__(self, min_non_most_frequent_values: int = 50, leave_at_least_one: bool = True) -> None:
self.min_non_most_frequent_values = min_non_most_frequent_values
self.leave_at_least_one = leave_at_least_one

def fit(self, X: np.ndarray, y: np.ndarray | None = None) -> SelectMinNonMostFrequentValues: # noqa
self._validate_params()
X = self._validate_data(X, ensure_2d=True, force_all_finite="allow-nan")

self.non_most_frequent_counts_ = np.empty(X.shape[1], dtype=np.int64) # noqa
Expand All @@ -41,13 +49,6 @@ def fit(self, X: np.ndarray, y: np.ndarray | None = None) -> SelectMinNonMostFre

return self

@overrides
def transform(self, X: np.ndarray) -> np.ndarray:
# We don't change `X` because the data types may be changed.
_ = self._validate_data(X, ensure_2d=True, dtype=None, accept_sparse="csr", force_all_finite="allow-nan",
reset=False)
return self._transform(X)

@overrides
def _get_support_mask(self) -> np.ndarray:
check_is_fitted(self)
Expand All @@ -58,6 +59,10 @@ def _get_support_mask(self) -> np.ndarray:

return mask

@overrides
def _more_tags(self) -> Mapping[str, Any]:
return {"allow_nan": True}


# See https://github.com/scikit-learn/scikit-learn/issues/11309#issuecomment-1528042914
class MultiHotEncoder(BaseEstimator, TransformerMixin):
Expand Down

0 comments on commit 473e566

Please sign in to comment.