Skip to content

Commit 3ba6bc6

Browse files
icarosaderostes
andauthored
Update cebra/integrations/sklearn/utils.py
Co-authored-by: Steffen Schneider <[email protected]>
1 parent 3e4ee86 commit 3ba6bc6

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

cebra/integrations/sklearn/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,15 @@
2828
import cebra.helper
2929

3030
from packaging import version
31-
from sklearn import __version__ as sklearn_version
32-
sklearn_version = version.parse(sklearn_version)
31+
import sklearn
3332

33+
def _check_array_ensure_all_finite(array, **kwargs):
34+
if version.parse(sklearn.__version__) < version.parse("1.8"):
35+
key = "force_all_finite"
36+
else:
37+
key = "ensure_all_finite"
38+
kwargs[key] = True
39+
return sklearn_utils_validation.check_array(array, **kwargs)
3440

3541
def update_old_param(old: dict, new: dict, kwargs: dict, default) -> tuple:
3642
"""Handle deprecated arguments of a function until they are replaced.

0 commit comments

Comments
 (0)