Skip to content

Commit

Permalink
Add support for new __sklearn_tags__ (#205)
Browse files Browse the repository at this point in the history
* Add support for new __sklearn_tags__

* fix inheritance order

* Add more tests

* fix added test
  • Loading branch information
stes authored Dec 16, 2024
1 parent 36a91c7 commit 5f46c32
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
13 changes: 12 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,16 @@ jobs:
# as well as selected previous versions on
# https://pytorch.org/get-started/previous-versions/
torch-version: ["2.2.2", "2.4.0"]
sklearn-version: ["latest"]
include:
- os: windows-latest
torch-version: 2.4.0
python-version: "3.10"
sklearn-version: "latest"
- os: ubuntu-latest
torch-version: 2.4.0
python-version: "3.10"
sklearn-version: "legacy"

runs-on: ${{ matrix.os }}

Expand All @@ -32,7 +38,7 @@ jobs:
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }}
key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }}-sklearn_${{ matrix.sklearn-version }}

- name: Checkout code
uses: actions/checkout@v2
Expand All @@ -48,6 +54,11 @@ jobs:
python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu
pip install '.[dev,datasets,integrations]'
- name: Check sklearn legacy version
if: matrix.sklearn-version == 'legacy'
run: |
pip install scikit-learn==1.4.2 '.[dev,datasets,integrations]'
- name: Run the formatter
run: |
make format
Expand Down
18 changes: 17 additions & 1 deletion cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
import pkg_resources
import sklearn.utils.validation as sklearn_utils_validation
import torch
import sklearn
from sklearn.base import BaseEstimator
from sklearn.base import TransformerMixin
from sklearn.utils.metaestimators import available_if
from torch import nn

import cebra.data
Expand All @@ -41,6 +43,11 @@
import cebra.models
import cebra.solver

def check_version(estimator):
# NOTE(stes): required as a check for the old way of specifying tags
# https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
from packaging import version
return version.parse(sklearn.__version__) < version.parse("1.6.dev")

def _init_loader(
is_cont: bool,
Expand Down Expand Up @@ -364,7 +371,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA":
return cebra_


class CEBRA(BaseEstimator, TransformerMixin):
class CEBRA(TransformerMixin, BaseEstimator):
"""CEBRA model defined as part of a ``scikit-learn``-like API.
Attributes:
Expand Down Expand Up @@ -1294,6 +1301,15 @@ def fit_transform(
callback_frequency=callback_frequency)
return self.transform(X)

def __sklearn_tags__(self):
# NOTE(stes): from 1.6.dev, this is the new way to specify tags
# https://scikit-learn.org/dev/developers/develop.html
# https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
tags = super().__sklearn_tags__()
tags.non_deterministic = True
return tags

@available_if(check_version)
def _more_tags(self):
# NOTE(stes): This tag is needed as seeding is not fully implemented in the
# current version of CEBRA.
Expand Down

0 comments on commit 5f46c32

Please sign in to comment.