Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XGBoost Booster Ignores device=cpu When Loading Model from GPU Training #11199

Open
Flash1709 opened this issue Feb 3, 2025 · 12 comments
Open

Comments

@Flash1709
Copy link

Flash1709 commented Feb 3, 2025

Description:
When training a model on GPU and then loading it on CPU using xgb.Booster, the device parameter appears to be set correctly in save_config(), but inference still utilizes the GPU unexpectedly. This results in high GPU memory usage and slow predictions.

I am using a Train Loop Fn for Out of Sample predictions, but the problem allready occurs in the first itteration of the loop.


Environment:

  • XGBoost Version: 2.1.3
  • Python Version: 12.8
  • OS: Windows 11

Reproduction Code:

import xgboost as xgb
import numpy as np
import polars as pl
import gc
from typing import Optional 

# Placeholder values
FEATURES = ["feature1", "feature2", "feature3"]  # Example feature columns
TARGETS = "target"  # Example target column
WEIGHTS = "weight"  # Example weight column

train_df = pl.DataFrame()
valid_df = pl.DataFrame()
eval_df = pl.DataFrame()


def create_xgb_matricies(train_df: pl.DataFrame, valid_df: pl.DataFrame, eval_df: pl.DataFrame):

    X_train = train_df[FEATURES].to_numpy()
    y_train = train_df[TARGETS].to_numpy()
    w_train = train_df[WEIGHTS].to_numpy()
    dtrain = xgb.DMatrix(X_train, label=y_train, weight=w_train)

    del train_df, X_train, y_train, w_train

    X_valid = valid_df[FEATURES].to_numpy()
    y_valid = valid_df[TARGETS].to_numpy()
    w_valid = valid_df[WEIGHTS].to_numpy()
    dvalid = xgb.DMatrix(X_valid, label=y_valid, weight=w_valid)

    del valid_df, X_valid, y_valid, w_valid

    X_eval = eval_df[FEATURES].to_numpy()
    y_eval = eval_df[TARGETS].to_numpy()
    w_eval = eval_df[WEIGHTS].to_numpy()
    deval = xgb.DMatrix(X_eval, label=y_eval, weight=w_eval)

    del eval_df, X_eval, y_eval, w_eval

    gc.collect()
    
    return dtrain, dvalid, deval


def get_pred(train: xgb.DMatrix, valid: xgb.DMatrix, eval: xgb.DMatrix, params):
    gpu_model = xgb.train(params, train, num_boost_round=1000,
                          evals=[(valid, 'val')], early_stopping_rounds=10, verbose_eval=False)
                
    print("training, done!")
    # 🔥 Save the model
    gpu_model.save_model("gpu_model.json")  
    print(gpu_model.save_config())  # Should show 'device'='cpu'
    del gpu_model
    
    cpu_model = xgb.Booster(params={"device": "cpu"})
    cpu_model.load_model('gpu_model.json')        
    print(cpu_model.save_config())  # Should show 'device'='cpu'
    
    print("loading done!")
    pred = cpu_model.predict(eval)
    
    del cpu_model
    
    return pred
    


def get_oos_probas(mats: list[xgb.DMatrix], params=None):
    params = params or {'objective': 'binary:logistic', 'eval_metric': 'logloss', 'seed': 42, 'device': 'cuda', 'learning_rate': 0.05}
    oos_preds: list[Optional[np.ndarray]] = [None] * len(mats)
    
    for i in range(len(mats)):
        train_idx, val_idx, test_idx = i, (i+1)%3, (i+2)%3
        oos_preds[test_idx] = get_pred(mats[train_idx], mats[val_idx], mats[test_idx], params)
        
    return oos_preds


dtrain, dvalid, deval = create_xgb_matricies(train_df, valid_df, eval_df)
oos_preds = get_oos_probas([dtrain, dvalid, deval]) # , PARAMS

Expected Behavior:

  • After loading the model with device=cpu, inference should be performed entirely on the CPU.
  • Task Manager should show CPU utilization during inference.

Observed Behavior:

  • Despite save_config() showing device=cpu, inference still uses GPU resources.
  • GPU utilization is 0% compute but maxed-out GPU memory in Task Manager.
  • This results in slow predictions.

Additional Notes:

  • I was not able to do predictions on CPU unless training on the CPU.
  • All workarounds tried where not effective.

Would appreciate any guidance or fixes for this issue!

@trivialfis
Copy link
Member

Hi, thank you for rasing the issue. Could you please share:

  • The XGBoost version.
  • Is the mats[test_idx constructed from GPU data?

@Flash1709
Copy link
Author

Flash1709 commented Feb 3, 2025

the xgboost version: 2.1.3

The DMatrix are custructed like this, iam using polars DataFrames, but the data is copied/cloned to independent numpy arrays that live on CPU.

P.S:
I updated the code of the original Post.

@Flash1709
Copy link
Author

Might be a Memory leak.

Train, Test and Valid DMatrix are all about 4gb,

Unsing the exact same code:

CPU:

  • Trains and predicts without issue
  • takes about 10m to run

GPU:

  • first Train iteration is WAY faster than CPU
  • SLOWS down when encountering first prediction step
    • gpu mem is maxed out even tho i send model to cpu
    • 3% cpu util and 100% gpu util even model is on "cpu"
  • gets progressivly slower the longer training goes on (related to the number of times i call cpu_model.predict ???)
  • after a while training a single tree takes multiple seconds
  • takes >200min to run (terminated the process)

Also translated my code to Catboost using cb.pool and cbclassifier.

  • runs without issues
  • finishes in 2min

@trivialfis
Copy link
Member

Do you delete the DMatrix object after running inference? XGBoost by default caches the prediction result, this is a problem if the users keep the DMatrix objects alive.

@trivialfis
Copy link
Member

I will try to reproduce and see where's the bottleneck.

@trivialfis
Copy link
Member

Hi, based on your description I modified your reproducer to:

  • synthesize data using sklearn.datasets.
  • Run get_grad multiple times (2000).

But I haven't been able to observe any slow down yet, each run takes about 0.7 seconds. Could you please share more details on how to reproduce?

import gc
import time
from typing import Optional

import numpy as np
import polars as pl
import xgboost as xgb
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

# Placeholder values
FEATURES = ["feature1", "feature2", "feature3"]  # Example feature columns
TARGETS = "target"  # Example target column
WEIGHTS = "weight"  # Example weight column


X, y = make_classification(
    random_state=2025, n_samples=int(2**16), n_features=3, n_classes=2, n_redundant=0
)

X_train, X_test, y_train, y_test = train_test_split(X, y)
X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train)


def to_pl(X: np.ndarray, y: np.ndarray, seed: int) -> pl.DataFrame:
    rng = np.random.default_rng(seed)
    weight = rng.uniform(size=(y.shape[0]), low=0.0, high=1.0)
    return pl.DataFrame(
        {
            "feature1": X[:, 0],
            "feature2": X[:, 1],
            "feature3": X[:, 2],
            TARGETS: y,
            WEIGHTS: weight,
        }
    )


train_df = to_pl(X_train, y_train, 0)
valid_df = to_pl(X_valid, y_valid, 1)
eval_df = to_pl(X_test, y_test, 2)


def create_xgb_matricies(
    train_df: pl.DataFrame, valid_df: pl.DataFrame, eval_df: pl.DataFrame
):

    X_train = train_df[FEATURES].to_numpy()
    y_train = train_df[TARGETS].to_numpy()
    w_train = train_df[WEIGHTS].to_numpy()
    dtrain = xgb.DMatrix(X_train, label=y_train, weight=w_train)

    del train_df, X_train, y_train, w_train

    X_valid = valid_df[FEATURES].to_numpy()
    y_valid = valid_df[TARGETS].to_numpy()
    w_valid = valid_df[WEIGHTS].to_numpy()
    dvalid = xgb.DMatrix(X_valid, label=y_valid, weight=w_valid)

    del valid_df, X_valid, y_valid, w_valid

    X_eval = eval_df[FEATURES].to_numpy()
    y_eval = eval_df[TARGETS].to_numpy()
    w_eval = eval_df[WEIGHTS].to_numpy()
    deval = xgb.DMatrix(X_eval, label=y_eval, weight=w_eval)

    del eval_df, X_eval, y_eval, w_eval

    gc.collect()

    return dtrain, dvalid, deval


def get_pred(train: xgb.DMatrix, valid: xgb.DMatrix, eval: xgb.DMatrix, params):
    gpu_model = xgb.train(
        params,
        train,
        num_boost_round=1000,
        evals=[(valid, "val")],
        early_stopping_rounds=10,
        verbose_eval=False,
    )

    print("training, done!")
    # 🔥 Save the model
    gpu_model.save_model("gpu_model.json")
    # print(gpu_model.save_config())  # Should show 'device'='cpu'
    del gpu_model

    cpu_model = xgb.Booster(params={"device": "cpu"})
    cpu_model.load_model("gpu_model.json")
    # print(cpu_model.save_config())  # Should show 'device'='cpu'

    print("loading done!")
    pred = cpu_model.predict(eval)

    del cpu_model

    return pred


def get_oos_probas(mats: list[xgb.DMatrix], params=None):
    params = params or {
        "objective": "binary:logistic",
        "eval_metric": "logloss",
        "seed": 42,
        "device": "cuda",
        "learning_rate": 0.05,
    }
    oos_preds: list[Optional[np.ndarray]] = [None] * len(mats)

    for k in range(2000):
        start = time.time()
        for i in range(len(mats)):
            train_idx, val_idx, test_idx = i, (i + 1) % 3, (i + 2) % 3
            oos_preds[test_idx] = get_pred(
                mats[train_idx], mats[val_idx], mats[test_idx], params
            )
        end = time.time()
        print("dur:", end - start)

    return oos_preds


dtrain, dvalid, deval = create_xgb_matricies(train_df, valid_df, eval_df)
oos_preds = get_oos_probas([dtrain, dvalid, deval])  # , PARAMS

@Flash1709
Copy link
Author

Flash1709 commented Feb 6, 2025

"Do you delete the DMatrix object after running inference? XGBoost by default caches the prediction result, this is a problem if the users keep the DMatrix objects alive."

This seems to be the case, but I need to keep the DMatrix alive, because i want to use it for recursive hyperparm optimization.

It looks like I’m running out of GPU memory on the second iteration. CPU-based prediction works fine.

I initially thought the model was stuck during iteration 1 Predictions, but I now realize that after completing training and prediction in iteration 1, it was actually setting up training for iteration 2. Since the GPU was already out of memory, training did not start immediately, and GPU utilization remained at 100%. This made me think that the prediction step from iteration 1 was still running on GPU, but in reality, the system was just stuck in setup for iteration 2 for several minutes.

Maybe in cases like this there should just be an OOM error, because traning will never finish after iteration 1.

Training ...
Train dur: 94.88388347625732
Predict dur: 2.184433698654175
Train dur: 664.393104600102555
crtl+c after 20min of not finishing

import gc
import time
from typing import Optional

import numpy as np
import polars as pl
import xgboost as xgb
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split


N_FEATS = 200

# Placeholder values
FEATURES = [f"feature{i}" for i in range(N_FEATS)]  # Example feature columns
TARGETS = "target"  # Example target column
WEIGHTS = "weight"  # Example weight column


def create_xgb_matricies(
    train_df: pl.DataFrame, valid_df: pl.DataFrame, eval_df: pl.DataFrame
):

    X_train = train_df[FEATURES].to_numpy()
    y_train = train_df[TARGETS].to_numpy()
    w_train = train_df[WEIGHTS].to_numpy()
    dtrain = xgb.DMatrix(X_train, label=y_train, weight=w_train)

    del train_df, X_train, y_train, w_train

    X_valid = valid_df[FEATURES].to_numpy()
    y_valid = valid_df[TARGETS].to_numpy()
    w_valid = valid_df[WEIGHTS].to_numpy()
    dvalid = xgb.DMatrix(X_valid, label=y_valid, weight=w_valid)

    del valid_df, X_valid, y_valid, w_valid

    X_eval = eval_df[FEATURES].to_numpy()
    y_eval = eval_df[TARGETS].to_numpy()
    w_eval = eval_df[WEIGHTS].to_numpy()
    deval = xgb.DMatrix(X_eval, label=y_eval, weight=w_eval)

    del eval_df, X_eval, y_eval, w_eval

    gc.collect()

    return dtrain, dvalid, deval


def get_pred(train: xgb.DMatrix, valid: xgb.DMatrix, eval: xgb.DMatrix, params):
    
    start = time.time()
    
    gpu_model = xgb.train(
        params,
        train,
        num_boost_round=1000,
        evals=[(valid, "val")],
        early_stopping_rounds=10,
        verbose_eval=False,
    )
    
    end = time.time()
    print("Train dur:", end - start)

    # 🔥 Save the model
    gpu_model.save_model("gpu_model.json")
    # print(gpu_model.save_config())  # Should show 'device'='cpu'
    del gpu_model

    cpu_model = xgb.Booster(params={"device": "cpu"})
    cpu_model.load_model("gpu_model.json")
    # print(cpu_model.save_config())  # Should show 'device'='cpu'

    start = time.time()
    
    pred = cpu_model.predict(eval)
    
    end = time.time()
    print("Predict dur:", end - start)

    del cpu_model

    return pred


def get_oos_probas(mats: list[xgb.DMatrix], params=None):
    params = params or {
        "objective": "binary:logistic",
        "eval_metric": "logloss",
        "seed": 42,
        "device": "cuda",
        "learning_rate": 0.05,
    }
    oos_preds: list[Optional[np.ndarray]] = [None] * len(mats)

    for k in range(2000):
        
        for i in range(len(mats)):
            train_idx, val_idx, test_idx = i, (i + 1) % 3, (i + 2) % 3
            oos_preds[test_idx] = get_pred(
                mats[train_idx], mats[val_idx], mats[test_idx], params
            )
        

    return oos_preds



def to_pl(X: np.ndarray, y: np.ndarray, seed: int) -> pl.DataFrame:
    data = {f"feature{i}": X[:, i] for i in range(N_FEATS)}
    
    rng = np.random.default_rng(seed)
    data.update({
        "target": y,
        "weight": rng.uniform(size=(y.shape[0]), low=0.0, high=1.0),
    })
    
    return pl.DataFrame(data)


print("Creating synthetic Dataset ...")

X, y = make_classification(
    random_state=2025, n_samples=int(20e6), n_features=N_FEATS, n_classes=2, n_redundant=0
)

print("Splitting ...")

X_train, X_test, y_train, y_test = train_test_split(X, y)
X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train)

del X, y

train_df = to_pl(X_train, y_train, 0)
valid_df = to_pl(X_valid, y_valid, 1)
eval_df = to_pl(X_test, y_test, 2)

del X_train, X_test, y_train, y_test, X_valid, y_valid


print("Training ...")

dtrain, dvalid, deval = create_xgb_matricies(train_df, valid_df, eval_df)
oos_preds = get_oos_probas([dtrain, dvalid, deval])  # , PARAMS

@trivialfis
Copy link
Member

Maybe in cases like this there should just be an OOM error, because traning will never finish after iteration 1.

I don't know which process is running into memory issues. For XGBoost, it should emit an OOM error if allocation fails, there's no "waiting for memory" in XGBoost.

@trivialfis
Copy link
Member

trivialfis commented Feb 12, 2025

This seems to be the case, but I need to keep the DMatrix alive, because i want to use it for recursive hyperparm optimization.

I don't know what type of HPO you are doing. If the number of DMatrix alive is constant, say 5 matrices for 5-fold validation, then it's fine. However, if you keep creating new matrices without deleting previous ones, then it will really stress the GPU memory and the XGBoost cache.

Lastly, consider using the QuantileDMatrix over DMatrix for more efficient memory usage:

Xy_train = xgboost.QuantileDMatrix(X_train, y_train)
Xy_valid = xgboost.QuantileDMatrix(X_valid, y_valid, ref=Xy_train)
xgboost.train({"device": "cuda"}, dtrain=Xy_train, evals=[(Xy_valid, "Validation")])

The QuantileDMatrix is used as default in the sklearn interface when the tree method is hist.

@trivialfis
Copy link
Member

Feel free to reopen if the issue persists after using QuantileDMatrix.

@Flash1709
Copy link
Author

Hey, thanks for getting back to me and for the helpful insights.

  • I don’t have time to dig into this right now, but I’ll run some tests later to better understand what’s happening.
    As I mentioned, simply swapping in catboost.Pool and CatBoostClassifier works flawlessly with the exact same code, so this seems specific to XGBoost.
  • I’m keeping the number of DMatrix objects constant at 3 and always reusing the same Python object for training and testing, so there shouldn’t be any unexpected memory buildup.
  • I’m not sure how XGBoost manages DMatrix creation on the CPU combined with later training on the GPU, but I’ll take a closer look when I have more time.

@trivialfis trivialfis reopened this Feb 13, 2025
@trivialfis
Copy link
Member

No worries, keeping the issue open for now.

You can use the XGBClassifier, which chooses the right DMatrix for you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants