-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adapt to scikit-learn 1.6 estimator tag changes #11021
base: master
Are you sure you want to change the base?
Changes from 11 commits
79ed32c
3106cf1
3af44be
a9e30b4
6a12576
52e6d83
816667a
abfc6a6
ef725c1
d845922
b7564a1
8364e92
a511848
27a1bcc
18c602f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,6 +63,8 @@ disable = [ | |
"import-error", | ||
"attribute-defined-outside-init", | ||
"import-outside-toplevel", | ||
"too-few-public-methods", | ||
"too-many-ancestors", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Switching the placeholder classes in the scikit-learn-is-not-available branch of
It seems that there are already many other places in this codebase where those warnings are being suppressed with
I don't feel that strongly... if you'd prefer to keep suppressing individual cases of these, please let me know and I'll happily switch back to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks good to me. @RAMitchell find the pylint checks helpful. I myself prefer mypy checks and think the pylint is not particularly suitable for ML libraries like XGBoost. In general, I don't have a strong opinion about these "structural" or naming warnings and care mostly about warnings like unused imports or use before definition. |
||
"too-many-nested-blocks", | ||
"unsubscriptable-object", | ||
"useless-object-inheritance" | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -43,6 +43,7 @@ def lazy_isinstance(instance: Any, module: str, name: str) -> bool: | |||||||||||||
|
||||||||||||||
# sklearn | ||||||||||||||
try: | ||||||||||||||
from sklearn import __version__ as _sklearn_version | ||||||||||||||
from sklearn.base import BaseEstimator as XGBModelBase | ||||||||||||||
from sklearn.base import ClassifierMixin as XGBClassifierBase | ||||||||||||||
from sklearn.base import RegressorMixin as XGBRegressorBase | ||||||||||||||
|
@@ -55,20 +56,43 @@ def lazy_isinstance(instance: Any, module: str, name: str) -> bool: | |||||||||||||
from sklearn.cross_validation import KFold as XGBKFold | ||||||||||||||
from sklearn.cross_validation import StratifiedKFold as XGBStratifiedKFold | ||||||||||||||
|
||||||||||||||
# sklearn.utils Tags types can be imported unconditionally once | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can do that once the next sklearn is published. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we should. That'd effectively raise Because it would result in xgboost/python-package/xgboost/compat.py Lines 60 to 61 in 5826b02
Which would make all the estimators unusable on those versions. xgboost/python-package/xgboost/sklearn.py Lines 754 to 757 in 5826b02
|
||||||||||||||
# xgboost's minimum scikit-learn version is 1.6 or higher | ||||||||||||||
try: | ||||||||||||||
from sklearn.utils import ClassifierTags as _sklearn_ClassifierTags | ||||||||||||||
from sklearn.utils import RegressorTags as _sklearn_RegressorTags | ||||||||||||||
from sklearn.utils import Tags as _sklearn_Tags | ||||||||||||||
except ImportError: | ||||||||||||||
_sklearn_ClassifierTags = object | ||||||||||||||
_sklearn_RegressorTags = object | ||||||||||||||
_sklearn_Tags = object | ||||||||||||||
|
||||||||||||||
SKLEARN_INSTALLED = True | ||||||||||||||
|
||||||||||||||
except ImportError: | ||||||||||||||
SKLEARN_INSTALLED = False | ||||||||||||||
|
||||||||||||||
# used for compatibility without sklearn | ||||||||||||||
XGBModelBase = object | ||||||||||||||
XGBClassifierBase = object | ||||||||||||||
XGBRegressorBase = object | ||||||||||||||
LabelEncoder = object | ||||||||||||||
class XGBModelBase: # type: ignore[no-redef] | ||||||||||||||
"""Dummy class for sklearn.base.BaseEstimator.""" | ||||||||||||||
|
||||||||||||||
class XGBClassifierBase: # type: ignore[no-redef] | ||||||||||||||
"""Dummy class for sklearn.base.ClassifierMixin.""" | ||||||||||||||
|
||||||||||||||
class XGBRegressorBase: # type: ignore[no-redef] | ||||||||||||||
"""Dummy class for sklearn.base.RegressorMixin.""" | ||||||||||||||
|
||||||||||||||
class LabelEncoder: # type: ignore[no-redef] | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can remove the label encoder for now. It's not used. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh great! I just removed that in a511848 Noticed that |
||||||||||||||
"""Dummy class for sklearn.preprocessing.LabelEncoder.""" | ||||||||||||||
|
||||||||||||||
XGBKFold = None | ||||||||||||||
XGBStratifiedKFold = None | ||||||||||||||
|
||||||||||||||
_sklearn_ClassifierTags = object | ||||||||||||||
_sklearn_RegressorTags = object | ||||||||||||||
_sklearn_Tags = object | ||||||||||||||
_sklearn_version = object | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
_logger = logging.getLogger(__name__) | ||||||||||||||
|
||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -35,6 +35,10 @@ | |||
XGBClassifierBase, | ||||
XGBModelBase, | ||||
XGBRegressorBase, | ||||
_sklearn_ClassifierTags, | ||||
_sklearn_RegressorTags, | ||||
_sklearn_Tags, | ||||
_sklearn_version, | ||||
import_cupy, | ||||
) | ||||
from .config import config_context | ||||
|
@@ -54,7 +58,7 @@ | |||
from .training import train | ||||
|
||||
|
||||
class XGBRankerMixIn: # pylint: disable=too-few-public-methods | ||||
class XGBRankerMixIn: | ||||
"""MixIn for ranking, defines the _estimator_type usually defined in scikit-learn | ||||
base classes. | ||||
|
||||
|
@@ -79,7 +83,7 @@ def _can_use_qdm(tree_method: Optional[str], device: Optional[str]) -> bool: | |||
return tree_method in ("hist", "gpu_hist", None, "auto") and not_sycl | ||||
|
||||
|
||||
class _SklObjWProto(Protocol): # pylint: disable=too-few-public-methods | ||||
class _SklObjWProto(Protocol): | ||||
def __call__( | ||||
self, | ||||
y_true: ArrayLike, | ||||
|
@@ -805,6 +809,41 @@ def _more_tags(self) -> Dict[str, bool]: | |||
tags["non_deterministic"] = True | ||||
return tags | ||||
|
||||
@staticmethod | ||||
def _update_sklearn_tags_from_dict( | ||||
*, | ||||
tags: _sklearn_Tags, | ||||
tags_dict: Dict[str, bool], | ||||
) -> _sklearn_Tags: | ||||
"""Update ``sklearn.utils.Tags`` inherited from ``scikit-learn`` base classes. | ||||
|
||||
``scikit-learn`` 1.6 introduced a dataclass-based interface for estimator tags. | ||||
ref: https://github.com/scikit-learn/scikit-learn/pull/29677 | ||||
|
||||
This method handles updating that instance based on the values in ``self._more_tags()``. | ||||
""" | ||||
tags.non_deterministic = tags_dict.get("non_deterministic", False) | ||||
tags.no_validation = tags_dict["no_validation"] | ||||
tags.input_tags.allow_nan = tags_dict["allow_nan"] | ||||
return tags | ||||
|
||||
def __sklearn_tags__(self) -> _sklearn_Tags: | ||||
# XGBModelBase.__sklearn_tags__() cannot be called unconditionally, | ||||
# because that method isn't defined for scikit-learn<1.6 | ||||
if not hasattr(XGBModelBase, "__sklearn_tags__"): | ||||
err_msg = ( | ||||
"__sklearn_tags__() should not be called when using scikit-learn<1.6. " | ||||
f"Detected version: {_sklearn_version}" | ||||
) | ||||
raise AttributeError(err_msg) | ||||
|
||||
# take whatever tags are provided by BaseEstimator, then modify | ||||
# them with XGBoost-specific values | ||||
return self._update_sklearn_tags_from_dict( | ||||
tags=XGBModelBase.__sklearn_tags__(self), # pylint: disable=no-member | ||||
tags_dict=self._more_tags(), | ||||
) | ||||
|
||||
def __sklearn_is_fitted__(self) -> bool: | ||||
return hasattr(self, "_Booster") | ||||
|
||||
|
@@ -898,13 +937,30 @@ def get_params(self, deep: bool = True) -> Dict[str, Any]: | |||
"""Get parameters.""" | ||||
# Based on: https://stackoverflow.com/questions/59248211 | ||||
# The basic flow in `get_params` is: | ||||
# 0. Return parameters in subclass first, by using inspect. | ||||
# 1. Return parameters in `XGBModel` (the base class). | ||||
# 0. Return parameters in subclass (self.__class__) first, by using inspect. | ||||
# 1. Return parameters in all parent classes (especially `XGBModel`). | ||||
# 2. Return whatever in `**kwargs`. | ||||
# 3. Merge them. | ||||
# | ||||
# This needs to accommodate being called recursively in the following | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please help add a test for this? The hierarchy and the Python introspection are getting a bit confusing now. ;-( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure! I just added one in a511848, let me know if there are other conditions you'd like to see tested. Between that and the existing test: xgboost/tests/python/test_with_sklearn.py Line 758 in 5826b02
I think this behavior should be well-covered. |
||||
# inheritance graphs (and similar for classification and ranking): | ||||
# | ||||
# XGBRFRegressor -> XGBRegressor -> XGBModel -> BaseEstimator | ||||
# XGBRegressor -> XGBModel -> BaseEstimator | ||||
# XGBModel -> BaseEstimator | ||||
# | ||||
params = super().get_params(deep) | ||||
cp = copy.copy(self) | ||||
cp.__class__ = cp.__class__.__bases__[0] | ||||
# if the immediate parent is a mixin, skip it (mixins don't define get_params()) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think it's more general to check for the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes that's a good idea! Just pushed 18c602f implementing that. |
||||
if cp.__class__.__bases__[0] in ( | ||||
XGBClassifierBase, | ||||
XGBRankerMixIn, | ||||
XGBRegressorBase, | ||||
): | ||||
cp.__class__ = cp.__class__.__bases__[1] | ||||
# otherwise, run get_params() from the immediate parent class | ||||
else: | ||||
cp.__class__ = cp.__class__.__bases__[0] | ||||
params.update(cp.__class__.get_params(cp, deep)) | ||||
# if kwargs is a dict, update params accordingly | ||||
if hasattr(self, "kwargs") and isinstance(self.kwargs, dict): | ||||
|
@@ -1481,7 +1537,7 @@ def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) -> | |||
Number of boosting rounds. | ||||
""", | ||||
) | ||||
class XGBClassifier(XGBModel, XGBClassifierBase): | ||||
class XGBClassifier(XGBClassifierBase, XGBModel): | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As of scikit-learn/scikit-learn#30234 (which will be in
That check is new, but it enforced behavior that's been documented in
That new check error led to these inheritance-order changes, which led to the |
||||
# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes | ||||
@_deprecate_positional_args | ||||
def __init__( | ||||
|
@@ -1497,6 +1553,15 @@ def _more_tags(self) -> Dict[str, bool]: | |||
tags["multilabel"] = True | ||||
return tags | ||||
|
||||
def __sklearn_tags__(self) -> _sklearn_Tags: | ||||
tags = XGBModel.__sklearn_tags__(self) | ||||
tags.estimator_type = "classifier" | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this if we inherit the classifier mixin? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah you are totally right, I don't think we do: Removed this, the corresponding |
||||
tags_dict = self._more_tags() | ||||
tags.classifier_tags = _sklearn_ClassifierTags( | ||||
multi_label=tags_dict["multilabel"] | ||||
) | ||||
return tags | ||||
|
||||
@_deprecate_positional_args | ||||
def fit( | ||||
self, | ||||
|
@@ -1769,7 +1834,7 @@ def fit( | |||
"Implementation of the scikit-learn API for XGBoost regression.", | ||||
["estimators", "model", "objective"], | ||||
) | ||||
class XGBRegressor(XGBModel, XGBRegressorBase): | ||||
class XGBRegressor(XGBRegressorBase, XGBModel): | ||||
# pylint: disable=missing-docstring | ||||
@_deprecate_positional_args | ||||
def __init__( | ||||
|
@@ -1783,6 +1848,15 @@ def _more_tags(self) -> Dict[str, bool]: | |||
tags["multioutput_only"] = False | ||||
return tags | ||||
|
||||
def __sklearn_tags__(self) -> _sklearn_Tags: | ||||
tags = XGBModel.__sklearn_tags__(self) | ||||
tags.estimator_type = "regressor" | ||||
tags_dict = self._more_tags() | ||||
tags.regressor_tags = _sklearn_RegressorTags() | ||||
tags.target_tags.multi_output = tags_dict["multioutput"] | ||||
tags.target_tags.single_output = not tags_dict["multioutput_only"] | ||||
return tags | ||||
|
||||
|
||||
@xgboost_model_doc( | ||||
"scikit-learn API for XGBoost random forest regression.", | ||||
|
@@ -1910,7 +1984,7 @@ def _get_qid( | |||
`qid` can be a special column of input `X` instead of a separated parameter, see | ||||
:py:meth:`fit` for more info.""", | ||||
) | ||||
class XGBRanker(XGBModel, XGBRankerMixIn): | ||||
class XGBRanker(XGBRankerMixIn, XGBModel): | ||||
# pylint: disable=missing-docstring,too-many-arguments,invalid-name | ||||
@_deprecate_positional_args | ||||
def __init__(self, *, objective: str = "rank:ndcg", **kwargs: Any): | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
import os | ||
import pickle | ||
import random | ||
import re | ||
import tempfile | ||
import warnings | ||
from typing import Callable, Optional | ||
|
@@ -1517,7 +1518,7 @@ def test_tags() -> None: | |
assert tags["multioutput"] is True | ||
assert tags["multioutput_only"] is False | ||
|
||
for clf in [xgb.XGBClassifier()]: | ||
for clf in [xgb.XGBClassifier(), xgb.XGBRFClassifier()]: | ||
tags = clf._more_tags() | ||
assert "multioutput" not in tags | ||
assert tags["multilabel"] is True | ||
|
@@ -1526,6 +1527,58 @@ def test_tags() -> None: | |
assert "multioutput" not in tags | ||
|
||
|
||
# the try-excepts in this test should be removed once xgboost's | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking that it's useful to check that the exact, expected If you read that and still think |
||
# minimum supported scikit-learn version is at least 1.6 | ||
def test_sklearn_tags(): | ||
|
||
def _assert_has_xgbmodel_tags(tags): | ||
# values set by XGBModel.__sklearn_tags__() | ||
assert tags.non_deterministic is False | ||
assert tags.no_validation is True | ||
assert tags.input_tags.allow_nan is True | ||
|
||
for reg in [xgb.XGBRegressor(), xgb.XGBRFRegressor()]: | ||
try: | ||
# if no AttributeError was thrown, we must be using scikit-learn>=1.6, | ||
# and so the actual effects of __sklearn_tags__() should be tested | ||
tags = reg.__sklearn_tags__() | ||
_assert_has_xgbmodel_tags(tags) | ||
# regressor-specific values | ||
assert tags.estimator_type == "regressor" | ||
assert tags.regressor_tags is not None | ||
assert tags.classifier_tags is None | ||
assert tags.target_tags.multi_output is True | ||
assert tags.target_tags.single_output is True | ||
except AttributeError as err: | ||
# only the exact error we expected to be raised should be raised | ||
assert bool(re.search(r"__sklearn_tags__.* should not be called", str(err))) | ||
|
||
for clf in [xgb.XGBClassifier(), xgb.XGBRFClassifier()]: | ||
try: | ||
# if no AttributeError was thrown, we must be using scikit-learn>=1.6, | ||
# and so the actual effects of __sklearn_tags__() should be tested | ||
tags = clf.__sklearn_tags__() | ||
_assert_has_xgbmodel_tags(tags) | ||
# classifier-specific values | ||
assert tags.estimator_type == "classifier" | ||
assert tags.regressor_tags is None | ||
assert tags.classifier_tags is not None | ||
assert tags.classifier_tags.multi_label is True | ||
except AttributeError as err: | ||
# only the exact error we expected to be raised should be raised | ||
assert bool(re.search(r"__sklearn_tags__.* should not be called", str(err))) | ||
|
||
for rnk in [xgb.XGBRanker(),]: | ||
try: | ||
# if no AttributeError was thrown, we must be using scikit-learn>=1.6, | ||
# and so the actual effects of __sklearn_tags__() should be tested | ||
tags = rnk.__sklearn_tags__() | ||
_assert_has_xgbmodel_tags(tags) | ||
except AttributeError as err: | ||
# only the exact error we expected to be raised should be raised | ||
assert bool(re.search(r"__sklearn_tags__.* should not be called", str(err))) | ||
|
||
|
||
def test_doc_link() -> None: | ||
for est in [ | ||
xgb.XGBRegressor(), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noticed some model files left behind from running all the Python tests locally while developing this. These
.gitignore
rules prevent checking them into source control.