Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hist training with checkpointing is non-deterministic based on subsample #10324

Open
andrew-esteban-imc opened this issue May 24, 2024 · 3 comments

Comments

@andrew-esteban-imc
Copy link

Hi there,

We have found that despite setting a seed for our hist training, we get non-deterministic results when resuming training from a checkpoint. A reproducer can be seen below:

import numpy as np
import xgboost as xgb
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from xgboost.callback import TrainingCallback

# Generate random sample data
X, y = make_classification(n_samples=1000, n_features=20, n_informative=10, n_redundant=10, random_state=42)

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)


# Ensure model has correct epoch upon restart
class XgbCheckpointCallback(TrainingCallback):
    def __init__(
            self,
            start_epoch: int
    ):
        super().__init__()
        self.start_epoch = start_epoch

    def before_training(self, model: xgb.Booster):
        self._prev_update = model.update

        def update(
                dtrain, iteration: int, fobj=None
        ) -> None:
            return self._prev_update(dtrain, iteration + self.start_epoch, fobj)

        model.update = update
        return model

    def after_training(self, model):
        model.update = self._prev_update
        return model


# Set the parameters for XGBoost training
params = {
    'silent': True,
    "tree_method": "hist",
    "seed": 1,
    "base_score": -0.9,
    "max_depth": 3,
    "learning_rate": 0.02,
    "lambda": 500,
    "subsample": 0.9,
}

# Train the XGBoost model using the hist algorithm
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)

# Train the model for the first run
bst1 = xgb.train(params, dtrain, 10, evals=[(dtrain, 'train'), (dtest, 'test')], early_stopping_rounds=5, feval=None)

# Save the model after the first run
bst1.save_model('first_run_model')

# Train the model for the second run
bst2 = xgb.train(params, dtrain, 5, evals=[(dtrain, 'train'), (dtest, 'test')], early_stopping_rounds=5, feval=None)

# Save the model after the second run
bst2.save_model('second_run_model')

# Resume training from the second run and run for 5 more epochs
bst3 = xgb.train(params, dtrain, 5, evals=[(dtrain, 'train'), (dtest, 'test')], early_stopping_rounds=5, feval=None,
                 xgb_model=bst2, callbacks=[XgbCheckpointCallback(5)])

# Save the model after the third run
bst3.save_model('third_run_model')

# Make predictions on the test set for each run
preds1 = bst1.predict(dtest)
preds2 = bst2.predict(dtest)
preds3 = bst3.predict(dtest)

# Evaluate the model
print("Test set accuracy for run 1: {:.5f}".format(np.mean(preds1)))
print("Test set accuracy for run 3: {:.5f}".format(np.mean(preds3)))

We make use of XgbCheckpointCallback to fix a similar issue whereby restarting from a checkpoint ignores the epoch the checkpoint got up to. You can remove it, but then setting any of the colsample_* params to a value below 1.0 will produce the same issue.

When tree_method is set to exact, the uninterrupted model and the interrupted model are identical. When tree_method is set to hist and subsample is set to 1.0, they are also identical. When running with hist and subsample < 1.0 however, the results differ.

I've seen #6711, but that seems to be somewhat different in nature.

@trivialfis
Copy link
Member

trivialfis commented May 28, 2024

Thank you for raising the issue. We will revisit this after #6711 is resolved by removing the global random engine. Meanwhile, I will mark this one as a bug to increase the priority.

@andrew-esteban-imc
Copy link
Author

Hi again,
Just wondering if there's anything I can do to help out with resolving this? Like I mentioned earlier my C++ abilities are somewhat lacking, but if there is something I can help with I'd gladly have a crack.

@trivialfis
Copy link
Member

We are currently working on it among other items #10354 We want to use a booster-local random state instead of a global random state. The problem is that we can' reliably preserve the random state in the model file. As a result, after the PR is finished, the result would be deterministic. However, it's not necessarily true that training two smaller boosters would be the same as training a single large booster. We are still investigating.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants