You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We have found that despite setting a seed for our hist training, we get non-deterministic results when resuming training from a checkpoint. A reproducer can be seen below:
importnumpyasnpimportxgboostasxgbfromsklearn.datasetsimportmake_classificationfromsklearn.model_selectionimporttrain_test_splitfromxgboost.callbackimportTrainingCallback# Generate random sample dataX, y=make_classification(n_samples=1000, n_features=20, n_informative=10, n_redundant=10, random_state=42)
# Split the data into training and testing setsX_train, X_test, y_train, y_test=train_test_split(X, y, test_size=0.2, random_state=42)
# Ensure model has correct epoch upon restartclassXgbCheckpointCallback(TrainingCallback):
def__init__(
self,
start_epoch: int
):
super().__init__()
self.start_epoch=start_epochdefbefore_training(self, model: xgb.Booster):
self._prev_update=model.updatedefupdate(
dtrain, iteration: int, fobj=None
) ->None:
returnself._prev_update(dtrain, iteration+self.start_epoch, fobj)
model.update=updatereturnmodeldefafter_training(self, model):
model.update=self._prev_updatereturnmodel# Set the parameters for XGBoost trainingparams= {
'silent': True,
"tree_method": "hist",
"seed": 1,
"base_score": -0.9,
"max_depth": 3,
"learning_rate": 0.02,
"lambda": 500,
"subsample": 0.9,
}
# Train the XGBoost model using the hist algorithmdtrain=xgb.DMatrix(X_train, label=y_train)
dtest=xgb.DMatrix(X_test, label=y_test)
# Train the model for the first runbst1=xgb.train(params, dtrain, 10, evals=[(dtrain, 'train'), (dtest, 'test')], early_stopping_rounds=5, feval=None)
# Save the model after the first runbst1.save_model('first_run_model')
# Train the model for the second runbst2=xgb.train(params, dtrain, 5, evals=[(dtrain, 'train'), (dtest, 'test')], early_stopping_rounds=5, feval=None)
# Save the model after the second runbst2.save_model('second_run_model')
# Resume training from the second run and run for 5 more epochsbst3=xgb.train(params, dtrain, 5, evals=[(dtrain, 'train'), (dtest, 'test')], early_stopping_rounds=5, feval=None,
xgb_model=bst2, callbacks=[XgbCheckpointCallback(5)])
# Save the model after the third runbst3.save_model('third_run_model')
# Make predictions on the test set for each runpreds1=bst1.predict(dtest)
preds2=bst2.predict(dtest)
preds3=bst3.predict(dtest)
# Evaluate the modelprint("Test set accuracy for run 1: {:.5f}".format(np.mean(preds1)))
print("Test set accuracy for run 3: {:.5f}".format(np.mean(preds3)))
We make use of XgbCheckpointCallback to fix a similar issue whereby restarting from a checkpoint ignores the epoch the checkpoint got up to. You can remove it, but then setting any of the colsample_* params to a value below 1.0 will produce the same issue.
When tree_method is set to exact, the uninterrupted model and the interrupted model are identical. When tree_method is set to hist and subsample is set to 1.0, they are also identical. When running with hist and subsample < 1.0 however, the results differ.
I've seen #6711, but that seems to be somewhat different in nature.
The text was updated successfully, but these errors were encountered:
Thank you for raising the issue. We will revisit this after #6711 is resolved by removing the global random engine. Meanwhile, I will mark this one as a bug to increase the priority.
Hi again,
Just wondering if there's anything I can do to help out with resolving this? Like I mentioned earlier my C++ abilities are somewhat lacking, but if there is something I can help with I'd gladly have a crack.
We are currently working on it among other items #10354 We want to use a booster-local random state instead of a global random state. The problem is that we can' reliably preserve the random state in the model file. As a result, after the PR is finished, the result would be deterministic. However, it's not necessarily true that training two smaller boosters would be the same as training a single large booster. We are still investigating.
Hi there,
We have found that despite setting a seed for our
hist
training, we get non-deterministic results when resuming training from a checkpoint. A reproducer can be seen below:We make use of
XgbCheckpointCallback
to fix a similar issue whereby restarting from a checkpoint ignores the epoch the checkpoint got up to. You can remove it, but then setting any of thecolsample_*
params to a value below1.0
will produce the same issue.When
tree_method
is set toexact
, the uninterrupted model and the interrupted model are identical. Whentree_method
is set tohist
andsubsample
is set to1.0
, they are also identical. When running withhist
andsubsample < 1.0
however, the results differ.I've seen #6711, but that seems to be somewhat different in nature.
The text was updated successfully, but these errors were encountered: