Skip to content

Commit 5d5329c

Browse files
authored
FIX raise error in OneHotEncoder.inverse_transform (scikit-learn#14982)
1 parent 3ba09fa commit 5d5329c

File tree

3 files changed

+50
-6
lines changed

3 files changed

+50
-6
lines changed

doc/whats_new/v0.24.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ Changelog
393393
curve classification metric.
394394
:pr:`10591` by :user:`Jeremy Karnowski <jkarnows>` and
395395
:user:`Daniel Mohns <dmohns>`.
396-
396+
397397
- |Feature| Added :func:`metrics.plot_det_curve` and
398398
:class:`metrics.DetCurveDisplay` to ease the plot of DET curves.
399399
:pr:`18176` by :user:`Guillaume Lemaitre <glemaitre>`.
@@ -645,6 +645,12 @@ Changelog
645645
:class:`preprocessing.KBinsDiscretizer`.
646646
:pr:`16335` by :user:`Arthur Imbert <Henley13>`.
647647

648+
- |Fix| Raise error on
649+
:meth:`sklearn.preprocessing.OneHotEncoder.inverse_transform`
650+
when `handle_unknown='error'` and `drop=None` for samples
651+
encoded as all zeros. :pr:`14982` by
652+
:user:`Kevin Winata <kwinata>`.
653+
648654
:mod:`sklearn.svm`
649655
..................
650656

sklearn/preprocessing/_encoders.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -571,13 +571,20 @@ def inverse_transform(self, X):
571571
# ignored unknown categories: we have a row of all zero
572572
if unknown.any():
573573
found_unknown[i] = unknown
574-
# drop will either be None or handle_unknown will be error. If
575-
# self.drop_idx_ is not None, then we can safely assume that all of
576-
# the nulls in each column are the dropped value
577-
elif self.drop_idx_ is not None:
574+
else:
578575
dropped = np.asarray(sub.sum(axis=1) == 0).flatten()
579576
if dropped.any():
580-
X_tr[dropped, i] = self.categories_[i][self.drop_idx_[i]]
577+
if self.drop_idx_ is None:
578+
all_zero_samples = np.flatnonzero(dropped)
579+
raise ValueError(
580+
f"Samples {all_zero_samples} can not be inverted "
581+
"when drop=None and handle_unknown='error' "
582+
"because they contain all zeros")
583+
# we can safely assume that all of the nulls in each column
584+
# are the dropped value
585+
X_tr[dropped, i] = self.categories_[i][
586+
self.drop_idx_[i]
587+
]
581588

582589
j += n_categories
583590

sklearn/preprocessing/tests/test_encoders.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sklearn.exceptions import NotFittedError
1010
from sklearn.utils._testing import assert_array_equal
1111
from sklearn.utils._testing import assert_allclose
12+
from sklearn.utils._testing import _convert_container
1213
from sklearn.utils import is_scalar_nan
1314

1415
from sklearn.preprocessing import OneHotEncoder
@@ -266,6 +267,36 @@ def test_one_hot_encoder_inverse(sparse_, drop):
266267
enc.inverse_transform(X_tr)
267268

268269

270+
@pytest.mark.parametrize('sparse_', [False, True])
271+
@pytest.mark.parametrize(
272+
"X, X_trans",
273+
[
274+
([[2, 55], [1, 55], [2, 55]], [[0, 1, 1], [0, 0, 0], [0, 1, 1]]),
275+
([['one', 'a'], ['two', 'a'], ['three', 'b'], ['two', 'a']],
276+
[[0, 0, 0, 0, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0]]),
277+
]
278+
)
279+
def test_one_hot_encoder_inverse_transform_raise_error_with_unknown(
280+
X, X_trans, sparse_
281+
):
282+
"""Check that `inverse_transform` raise an error with unknown samples, no
283+
dropped feature, and `handle_unknow="error`.
284+
Non-regression test for:
285+
https://github.com/scikit-learn/scikit-learn/issues/14934
286+
"""
287+
enc = OneHotEncoder(sparse=sparse_).fit(X)
288+
msg = (
289+
r"Samples \[(\d )*\d\] can not be inverted when drop=None and "
290+
r"handle_unknown='error' because they contain all zeros"
291+
)
292+
293+
if sparse_:
294+
# emulate sparse data transform by a one-hot encoder sparse.
295+
X_trans = _convert_container(X_trans, "sparse")
296+
with pytest.raises(ValueError, match=msg):
297+
enc.inverse_transform(X_trans)
298+
299+
269300
def test_one_hot_encoder_inverse_if_binary():
270301
X = np.array([['Male', 1],
271302
['Female', 3],

0 commit comments

Comments
 (0)