We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3e4ee86 commit 3ba6bc6Copy full SHA for 3ba6bc6
cebra/integrations/sklearn/utils.py
@@ -28,9 +28,15 @@
28
import cebra.helper
29
30
from packaging import version
31
-from sklearn import __version__ as sklearn_version
32
-sklearn_version = version.parse(sklearn_version)
+import sklearn
33
+def _check_array_ensure_all_finite(array, **kwargs):
34
+ if version.parse(sklearn.__version__) < version.parse("1.8"):
35
+ key = "force_all_finite"
36
+ else:
37
+ key = "ensure_all_finite"
38
+ kwargs[key] = True
39
+ return sklearn_utils_validation.check_array(array, **kwargs)
40
41
def update_old_param(old: dict, new: dict, kwargs: dict, default) -> tuple:
42
"""Handle deprecated arguments of a function until they are replaced.
0 commit comments