Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Early stopping with gblinear doesn't save the best model for subsequent prediction #10893

Open
sktin opened this issue Oct 15, 2024 · 0 comments
Labels
gblinear Everything related the linear model.

Comments

@sktin
Copy link

sktin commented Oct 15, 2024

Code to replicate:

import xgboost as xgb
print(F'{xgb.__version__=}')

from sklearn.datasets import make_classification
from xgboost import XGBClassifier
from sklearn.metrics import roc_auc_score

X, y = make_classification(50000, random_state=0)
model = XGBClassifier(
    booster='gblinear', updater='coord_descent',
    eval_metric='auc', eta=0.01, 
    early_stopping_rounds=10, n_estimators=1000000, 
    random_state=0, n_jobs=4
)
model.fit(X, y, eval_set=[(X, y)], verbose=1)
print(F'{model.best_iteration=}, {model.best_score=}')
print(roc_auc_score(y, model.predict_proba(X)[:,1]))

Output:

xgb.__version__='2.1.1'
[0]	validation_0-auc:0.93851
[1]	validation_0-auc:0.93851
[2]	validation_0-auc:0.93851
[3]	validation_0-auc:0.93851
[4]	validation_0-auc:0.93851
[5]	validation_0-auc:0.93851
[6]	validation_0-auc:0.93851
[7]	validation_0-auc:0.93851
[8]	validation_0-auc:0.93851
[9]	validation_0-auc:0.93851
[10]	validation_0-auc:0.93851
[11]	validation_0-auc:0.93851
[12]	validation_0-auc:0.93851
[13]	validation_0-auc:0.93851
[14]	validation_0-auc:0.93851
[15]	validation_0-auc:0.93850
[16]	validation_0-auc:0.93850
model.best_iteration=6, model.best_score=0.9385135544606076
0.9385032096436587

iteration_range has no effect.

print(roc_auc_score(y, model.predict_proba(X, iteration_range=(0,7))[:,1]))
print(roc_auc_score(y, model.predict_proba(X, iteration_range=(0,1000000))[:,1]))
print(roc_auc_score(y, model.predict_proba(X, iteration_range=(0,1))[:,1]))

Output:

0.9385032096436587
0.9385032096436587
0.9385032096436587

The only workaround I can think of is re-fitting with the best_iteration found.

model = XGBClassifier(
    booster='gblinear', updater='coord_descent',
    eval_metric='auc', eta=0.01, 
    n_estimators=model.best_iteration+1, 
    random_state=0, n_jobs=4
)
model.fit(X, y, eval_set=[(X, y)], verbose=1)
print(roc_auc_score(y, model.predict_proba(X)[:,1]))

Output:

[0]	validation_0-auc:0.93851
[1]	validation_0-auc:0.93851
[2]	validation_0-auc:0.93851
[3]	validation_0-auc:0.93851
[4]	validation_0-auc:0.93851
[5]	validation_0-auc:0.93851
[6]	validation_0-auc:0.93851
0.9385135544606077
@trivialfis trivialfis added the gblinear Everything related the linear model. label Oct 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
gblinear Everything related the linear model.
Projects
None yet
Development

No branches or pull requests

2 participants