diff --git a/cebra/integrations/sklearn/utils.py b/cebra/integrations/sklearn/utils.py index f7b03b8..c1671b9 100644 --- a/cebra/integrations/sklearn/utils.py +++ b/cebra/integrations/sklearn/utils.py @@ -29,6 +29,7 @@ from packaging import version from sklearn import __version__ as sklearn_version +sklearn_version = version.parse(sklearn_version) def update_old_param(old: dict, new: dict, kwargs: dict, default) -> tuple: