-
-
Notifications
You must be signed in to change notification settings - Fork 25.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT allow metadata to be transformed in a Pipeline #28901
Conversation
So for simple cases where metadata is only used in Specifically, in this test: @pytest.mark.usefixtures("enable_slep006")
@pytest.mark.parametrize("method", ["fit", "fit_transform"])
def test_transform_input_pipeline(method):
"""Test that with transform_input, data is correctly transformed for each step."""
def get_transformer(registry, sample_weight, metadata):
"""Get a transformer with requests set."""
return (
ConsumingTransformer(registry=registry)
.set_fit_request(sample_weight=sample_weight, metadata=metadata)
.set_transform_request(sample_weight=sample_weight, metadata=metadata)
)
def get_pipeline():
"""Get a pipeline and corresponding registries.
The pipeline has 4 steps, with different request values set to test different
cases. One is aliased.
"""
registry_1, registry_2, registry_3, registry_4 = (
_Registry(),
_Registry(),
_Registry(),
_Registry(),
)
pipe = make_pipeline(
get_transformer(registry_1, sample_weight=True, metadata=True),
get_transformer(registry_2, sample_weight=False, metadata=False),
get_transformer(registry_3, sample_weight=True, metadata=True),
get_transformer(registry_4, sample_weight="other_weights", metadata=True),
transform_input=["sample_weight"],
)
return pipe, registry_1, registry_2, registry_3, registry_4
def check_metadata(registry, methods, **metadata):
"""Check that the right metadata was recorded for the given methods."""
assert registry
for estimator in registry:
for method in methods:
check_recorded_metadata(
estimator,
method=method,
**metadata,
)
X = np.array([[1, 2], [3, 4]])
y = np.array([0, 1])
sample_weight = np.array([[1, 2]])
other_weights = np.array([[30, 40]])
metadata = np.array([[100, 200]])
pipe, registry_1, registry_2, registry_3, registry_4 = get_pipeline()
pipe.fit(
X,
y,
sample_weight=sample_weight,
other_weights=other_weights,
metadata=metadata,
)
check_metadata(
registry_1, ["fit", "transform"], sample_weight=sample_weight, metadata=metadata
)
check_metadata(registry_2, ["fit", "transform"])
check_metadata(
registry_3,
["fit", "transform"],
sample_weight=sample_weight + 2,
metadata=metadata,
)
check_metadata(
registry_4,
method.split("_"), # ["fit", "transform"] if "fit_transform", ["fit"] otherwise
sample_weight=other_weights + 3,
metadata=metadata,
) Step 3 receives transformed data in its The question is, what should be the expected behavior? Do we want |
Actually, in if _routing_enabled():
transform_params = self.get_metadata_routing().consumes(
method="transform", params=fit_params.keys()
)
if transform_params:
warnings.warn(
(
f"This object ({self.__class__.__name__}) has a `transform`"
" method which consumes metadata, but `fit_transform` does not"
" forward metadata to `transform`. Please implement a custom"
" `fit_transform` method to forward metadata to `transform` as"
" well. Alternatively, you can explicitly do"
" `set_transform_request`and set all values to `False` to"
" disable metadata routed to `transform`, if that's an option."
),
UserWarning,
) and we never send anything to However, for third party transformers where they can have their own |
Another question is, do we want to have this syntactic sugar? pipe = make_pipeline(
StandardScaler(),
HistGradientBoostingClassifier(..., early_stopping=True)
).fit(X, y, X_val, y_val) The above code would:
It wouldn't change what we have now implemented in For that to happen, HGBC need to have: class HistGradientBoostingClassifier(...):
...
def get_metadata_routing(self):
routing = super().get_metadata_routing()
if self.early_stopping:
routing.fit.add(X_val=True, y_val=True)
def __sklearn_get_transforming_data__(self):
return ["X_val"] And cc @glemaitre It goes towards the direction of having more default routing info as @ogrisel really likes. (ref #26179 ) Note that this could come later separately as an enhancement to this PR. |
There is an issue with testing metadata routing in more complex situations (which has come up in this PR) which requires some fixes (adding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A partial review.
@adrinjalali Just that you see that at least someone cares.
sklearn/pipeline.py
Outdated
them. Requirement is defined via :ref:`metadata routing <metadata_routing>`. | ||
This can be used to pass a validation set through the pipeline for instance. | ||
|
||
See the example TBD for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm very keen to see that example, maybe the HGBT early stopping case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went and checked if I could use lighgbm, but there the validation set is passed as a list of tuples. No way to process that in a pipeline.
As for HGBT, it would look like this:
from sklearn.datasets import make_regression
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.model_selection import GridSearchCV, ShuffleSplit
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
X, y = make_regression(n_samples=200, n_features=500, n_informative=5, random_state=0)
X[:2,] = X[:2,] + 20
# Validation set chosen before looking at the data.
X_val, y_val = X[:50,], y[:50,]
X, y = X[50:,], y[50:,]
est_gs = GridSearchCV(
Pipeline(
(
StandardScaler(),
HistGradientBoostingRegressor(
early_stopping=True,
).set_fit_request(X_val=True, y_val=True),
),
# telling pipeline to transform these inputs up to the step which is
# requesting them.
transform_input=["X_val", "y_val"],
),
param_grid={"histgradientboostingregressor__max_depth": list(range(5))},
cv=5,
).fit(X, y, X_val, y_val)
# this passes X_val, y_val to Pipeline, and Pipeline knows how to deal with
# them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To confirm my understanding:
est_gs = GridSearchCV(
Pipeline(
(
StandardScaler(),
HistGradientBoostingRegressor(
early_stopping=True,
).set_fit_request(X_val=True, y_val=True),
),
),
param_grid={"histgradientboostingregressor__max_depth": list(range(5))},
cv=5,
).fit(X, y, X_val, y_val)
This would only transform X
, and y
, whereas your example would now explicitly mark X_val, y_val
(and X
and y
) to be transformed as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only side remark: Our HistGradientBoostingRegressor
does not YET support X_val
and y_val
in fit
.
sklearn/pipeline.py
Outdated
will be transformed. | ||
|
||
`all_params` are the metadata passed by the user. Used to call `transform` | ||
on the pipeline itself. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding the Parameters section in the docstring might help to better understand this method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this now helps.
sklearn/pipeline.py
Outdated
will be transformed. | ||
|
||
`all_params` are the metadata passed by the user. Used to call `transform` | ||
on the pipeline itself. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this now helps.
sklearn/pipeline.py
Outdated
them. Requirement is defined via :ref:`metadata routing <metadata_routing>`. | ||
This can be used to pass a validation set through the pipeline for instance. | ||
|
||
See the example TBD for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went and checked if I could use lighgbm, but there the validation set is passed as a list of tuples. No way to process that in a pipeline.
As for HGBT, it would look like this:
from sklearn.datasets import make_regression
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.model_selection import GridSearchCV, ShuffleSplit
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
X, y = make_regression(n_samples=200, n_features=500, n_informative=5, random_state=0)
X[:2,] = X[:2,] + 20
# Validation set chosen before looking at the data.
X_val, y_val = X[:50,], y[:50,]
X, y = X[50:,], y[50:,]
est_gs = GridSearchCV(
Pipeline(
(
StandardScaler(),
HistGradientBoostingRegressor(
early_stopping=True,
).set_fit_request(X_val=True, y_val=True),
),
# telling pipeline to transform these inputs up to the step which is
# requesting them.
transform_input=["X_val", "y_val"],
),
param_grid={"histgradientboostingregressor__max_depth": list(range(5))},
cv=5,
).fit(X, y, X_val, y_val)
# this passes X_val, y_val to Pipeline, and Pipeline knows how to deal with
# them.
I checked |
Let's ask them directly: @jameslamb @StrikerRUS @shiyu1994 @trivialfis @hcho3 your opinion would be very appreciated. We are trying to transform metadata on the way to the step of a pipeline where it is needed, e.g. validation data for early stopping in GBTs, see #28901 (comment) (StandardScaler is just for demonstration purposes). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As for the example, @lorentzenchr I'm not sure what you want me to add, since I we don't have any estimator in scikit-learn which can use this right now. Do you want a fake, mini estimator using X_val in the example? This example makes more sense once we merge the validation set work in HGBT, doesn't it?
sklearn/tests/test_pipeline.py
Outdated
def test_transform_tuple_input(): | ||
"""Test that if metadata is a tuple of arrays, both arrays are transformed.""" | ||
|
||
class Estimator(ClassifierMixin, BaseEstimator): | ||
def fit(self, X, y, X_val=None, y_val=None): | ||
assert isinstance(X_val, tuple) | ||
assert isinstance(y_val, tuple) | ||
# Here we make sure that each X_val is transformed by the transformer | ||
assert_array_equal(X_val[0], np.array([[2, 3]])) | ||
assert_array_equal(y_val[0], np.array([0, 1])) | ||
assert_array_equal(X_val[1], np.array([[11, 12]])) | ||
assert_array_equal(y_val[1], np.array([1, 2])) | ||
return self | ||
|
||
class Transformer(TransformerMixin, BaseEstimator): | ||
def fit(self, X, y): | ||
return self | ||
|
||
def transform(self, X): | ||
return X + 1 | ||
|
||
X = np.array([[1, 2]]) | ||
y = np.array([0, 1]) | ||
X_val0 = np.array([[1, 2]]) | ||
y_val0 = np.array([0, 1]) | ||
X_val1 = np.array([[10, 11]]) | ||
y_val1 = np.array([1, 2]) | ||
pipe = Pipeline( | ||
[ | ||
("transformer", Transformer()), | ||
("estimator", Estimator().set_fit_request(X_val=True, y_val=True)), | ||
], | ||
transform_input=["X_val"], | ||
) | ||
pipe.fit(X, y, X_val=(X_val0, X_val1), y_val=(y_val0, y_val1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the branch is updated @jameslamb
sklearn/pipeline.py
Outdated
them. Requirement is defined via :ref:`metadata routing <metadata_routing>`. | ||
For instance, this can be used to pass a validation set through the pipeline. | ||
|
||
See the example TBD for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the example TBD for more details. |
Let's do that later in a different PR.
sklearn/pipeline.py
Outdated
them. Requirement is defined via :ref:`metadata routing <metadata_routing>`. | ||
This can be used to pass a validation set through the pipeline for instance. | ||
|
||
See the example TBD for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the example TBD for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, one typo in the comment. Will try to test it with XGBoost.
Co-authored-by: Jiaming Yuan <[email protected]>
@@ -0,0 +1,3 @@ | |||
- :class:`pipeline.Pipeline` can now transform metadata up to the step requiring the | |||
metadata, which can be set using the `transform_input` parameter. | |||
By `Adrin Jalali`_. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By `Adrin Jalali`_. | |
By `Adrin Jalali`_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Only nitpicks and small typos.
Co-authored-by: Guillaume Lemaitre <[email protected]>
ping @jeremiedbb for backport in the branch |
Initial proposal: #28440 (comment)
xref: #28440 (comment)
This adds
transform_input
as a constructor argument toPipeline
, as:It simply allows to transform metadata with fitted estimators up to the step which needs the metadata.
How does this look?
cc @lorentzenchr @ogrisel @amueller @betatim