Skip to content

Commit 128257b

Browse files
icarosaderostes
andauthored
Update cebra/integrations/sklearn/utils.py
Co-authored-by: Steffen Schneider <[email protected]>
1 parent 3ba6bc6 commit 128257b

File tree

1 file changed

+12
-29
lines changed

1 file changed

+12
-29
lines changed

cebra/integrations/sklearn/utils.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -84,35 +84,18 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray:
8484
Returns:
8585
The converted and validated array.
8686
"""
87-
88-
if sklearn_version < version.parse("1.8"):
89-
return sklearn_utils_validation.check_array(
90-
X,
91-
accept_sparse=False,
92-
accept_large_sparse=False,
93-
dtype=("float16", "float32", "float64"),
94-
order=None,
95-
copy=False,
96-
force_all_finite=True,
97-
ensure_2d=True,
98-
allow_nd=False,
99-
ensure_min_samples=min_samples,
100-
ensure_min_features=1,
101-
)
102-
else:
103-
return sklearn_utils_validation.check_array(
104-
X,
105-
accept_sparse=False,
106-
accept_large_sparse=False,
107-
dtype=("float16", "float32", "float64"),
108-
order=None,
109-
copy=False,
110-
ensure_all_finite=True,
111-
ensure_2d=True,
112-
allow_nd=False,
113-
ensure_min_samples=min_samples,
114-
ensure_min_features=1,
115-
)
87+
return _check_array_ensure_all_finite(
88+
X,
89+
accept_sparse=False,
90+
accept_large_sparse=False,
91+
dtype=("float16", "float32", "float64"),
92+
order=None,
93+
copy=False,
94+
ensure_2d=True,
95+
allow_nd=False,
96+
ensure_min_samples=min_samples,
97+
ensure_min_features=1,
98+
)
11699

117100

118101
def check_label_array(y: npt.NDArray, *, min_samples: int):

0 commit comments

Comments
 (0)