Skip to content

Commit

Permalink
[python-package] simplify scikit-learn 1.6+ tags support (#6735)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored Dec 4, 2024
1 parent ea04c66 commit 6e0b0a8
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 20 deletions.
10 changes: 0 additions & 10 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,6 @@
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import assert_all_finite, check_array, check_X_y

# sklearn.utils Tags types can be imported unconditionally once
# lightgbm's minimum scikit-learn version is 1.6 or higher
try:
from sklearn.utils import ClassifierTags as _sklearn_ClassifierTags
from sklearn.utils import RegressorTags as _sklearn_RegressorTags
except ImportError:
_sklearn_ClassifierTags = None
_sklearn_RegressorTags = None
try:
from sklearn.exceptions import NotFittedError
from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold
Expand Down Expand Up @@ -148,8 +140,6 @@ class _LGBMRegressorBase: # type: ignore
_LGBMCheckClassificationTargets = None
_LGBMComputeSampleWeight = None
_LGBMValidateData = None
_sklearn_ClassifierTags = None
_sklearn_RegressorTags = None
_sklearn_version = None

# additional scikit-learn imports only for type hints
Expand Down
15 changes: 5 additions & 10 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
_LGBMModelBase,
_LGBMRegressorBase,
_LGBMValidateData,
_sklearn_ClassifierTags,
_sklearn_RegressorTags,
_sklearn_version,
dt_DataTable,
pd_DataFrame,
Expand Down Expand Up @@ -726,7 +724,7 @@ def __sklearn_tags__(self) -> Optional["_sklearn_Tags"]:
# take whatever tags are provided by BaseEstimator, then modify
# them with LightGBM-specific values
return self._update_sklearn_tags_from_dict(
tags=_LGBMModelBase.__sklearn_tags__(self),
tags=super().__sklearn_tags__(),
tags_dict=self._more_tags(),
)

Expand Down Expand Up @@ -1298,10 +1296,7 @@ def _more_tags(self) -> Dict[str, Any]:
return tags

def __sklearn_tags__(self) -> "_sklearn_Tags":
tags = LGBMModel.__sklearn_tags__(self)
tags.estimator_type = "regressor"
tags.regressor_tags = _sklearn_RegressorTags(multi_label=False)
return tags
return super().__sklearn_tags__()

def fit( # type: ignore[override]
self,
Expand Down Expand Up @@ -1360,9 +1355,9 @@ def _more_tags(self) -> Dict[str, Any]:
return tags

def __sklearn_tags__(self) -> "_sklearn_Tags":
tags = LGBMModel.__sklearn_tags__(self)
tags.estimator_type = "classifier"
tags.classifier_tags = _sklearn_ClassifierTags(multi_class=True, multi_label=False)
tags = super().__sklearn_tags__()
tags.classifier_tags.multi_class = True
tags.classifier_tags.multi_label = False
return tags

def fit( # type: ignore[override]
Expand Down
6 changes: 6 additions & 0 deletions tests/python_package_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,6 +1488,12 @@ def test_sklearn_tags_should_correctly_reflect_lightgbm_specific_values(estimato
assert sklearn_tags.input_tags.allow_nan is True
assert sklearn_tags.input_tags.sparse is True
assert sklearn_tags.target_tags.one_d_labels is True
if estimator_class is lgb.LGBMClassifier:
assert sklearn_tags.estimator_type == "classifier"
assert sklearn_tags.classifier_tags.multi_class is True
assert sklearn_tags.classifier_tags.multi_label is False
elif estimator_class is lgb.LGBMRegressor:
assert sklearn_tags.estimator_type == "regressor"


@pytest.mark.parametrize("task", all_tasks)
Expand Down

0 comments on commit 6e0b0a8

Please sign in to comment.