Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deprecation warning force_all_finite -> ensure_all_finite from 1.7 to 1.8 #206

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
84 changes: 59 additions & 25 deletions cebra/integrations/sklearn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@

import cebra.helper

from packaging import version
from sklearn import __version__ as sklearn_version
sklearn_version = version.parse(sklearn_version)

icarosadero marked this conversation as resolved.
Show resolved Hide resolved

def update_old_param(old: dict, new: dict, kwargs: dict, default) -> tuple:
"""Handle deprecated arguments of a function until they are replaced.
Expand Down Expand Up @@ -74,19 +78,35 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray:
Returns:
The converted and validated array.
"""
return sklearn_utils_validation.check_array(
X,
accept_sparse=False,
accept_large_sparse=False,
dtype=("float16", "float32", "float64"),
order=None,
copy=False,
force_all_finite=True,
ensure_2d=True,
allow_nd=False,
ensure_min_samples=min_samples,
ensure_min_features=1,
)

if sklearn_version < version.parse("1.8"):
return sklearn_utils_validation.check_array(
X,
accept_sparse=False,
accept_large_sparse=False,
dtype=("float16", "float32", "float64"),
order=None,
copy=False,
force_all_finite=True,
ensure_2d=True,
allow_nd=False,
ensure_min_samples=min_samples,
ensure_min_features=1,
)
else:
return sklearn_utils_validation.check_array(
X,
accept_sparse=False,
accept_large_sparse=False,
dtype=("float16", "float32", "float64"),
order=None,
copy=False,
ensure_all_finite=True,
ensure_2d=True,
allow_nd=False,
ensure_min_samples=min_samples,
ensure_min_features=1,
)
icarosadero marked this conversation as resolved.
Show resolved Hide resolved


def check_label_array(y: npt.NDArray, *, min_samples: int):
Expand All @@ -105,18 +125,32 @@ def check_label_array(y: npt.NDArray, *, min_samples: int):
Returns:
The converted and validated labels.
"""
return sklearn_utils_validation.check_array(
y,
accept_sparse=False,
accept_large_sparse=False,
dtype="numeric",
order=None,
copy=False,
force_all_finite=True,
ensure_2d=False,
allow_nd=False,
ensure_min_samples=min_samples,
)
if sklearn_version < version.parse("1.8"):
return sklearn_utils_validation.check_array(
y,
accept_sparse=False,
accept_large_sparse=False,
dtype="numeric",
order=None,
copy=False,
force_all_finite=True,
ensure_2d=False,
allow_nd=False,
ensure_min_samples=min_samples,
)
else:
return sklearn_utils_validation.check_array(
y,
accept_sparse=False,
accept_large_sparse=False,
dtype="numeric",
order=None,
copy=False,
ensure_all_finite=True,
ensure_2d=False,
allow_nd=False,
ensure_min_samples=min_samples,
)
icarosadero marked this conversation as resolved.
Show resolved Hide resolved


def check_device(device: str) -> str:
Expand Down
Loading