diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 7e8ed2f29568..6ae1dea8d3ce 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -876,31 +876,48 @@ XGB_DLL int XGDMatrixGetQuantileCut(DMatrixHandle const handle, char const *conf * @defgroup Booster Booster * * @brief The `Booster` class is the gradient-boosted model for XGBoost. + * + * During training, the booster object has many caches for improved performance. In + * addition to gradient and prediction, it also includes runtime buffers like leaf + * partitions. These buffers persist with the Booster object until either XGBoosterReset() + * is called or the booster is deleted by the XGBoosterFree(). + * * @{ */ -/*! - * \brief create xgboost learner - * \param dmats matrices that are set to be cached - * \param len length of dmats - * \param out handle to the result booster - * \return 0 when success, -1 when failure happens +/** + * @brief Create a XGBoost learner (booster) + * + * @param dmats matrices that are set to be cached by the booster. + * @param len length of dmats + * @param out handle to the result booster + * + * @return 0 when success, -1 when failure happens */ XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[], bst_ulong len, BoosterHandle *out); /** * @example c-api-demo.c */ -/*! - * \brief free obj in handle - * \param handle handle to be freed - * \return 0 when success, -1 when failure happens +/** + * @brief Delete the booster. + * + * @param handle The handle to be freed. + * + * @return 0 when success, -1 when failure happens */ XGB_DLL int XGBoosterFree(BoosterHandle handle); /** * @example c-api-demo.c inference.c external_memory.c */ +/** + * @brief Reset the booster object to release data caches used for training. + * + * @since 3.0.0 + */ +XGB_DLL int XGBoosterReset(BoosterHandle handle); + /*! * \brief Slice a model using boosting index. The slice m:n indicates taking all trees * that were fit during the boosting rounds m, (m+1), (m+2), ..., (n-1). diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index 939324e4a6c4..1499804c8592 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -249,6 +249,10 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { std::string format) = 0; virtual XGBAPIThreadLocalEntry& GetThreadLocal() const = 0; + /** + * @brief Reset the booster object to release data caches used for training. + */ + virtual void Reset() = 0; /*! * \brief Create a new instance of learner. * \param cache_data The matrix to cache the prediction. diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 05c0cc30fa82..0db6afabcf40 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2008,7 +2008,8 @@ def __setstate__(self, state: Dict) -> None: self.__dict__.update(state) def __getitem__(self, val: Union[Integer, tuple, slice, EllipsisType]) -> "Booster": - """Get a slice of the tree-based model. + """Get a slice of the tree-based model. Attributes like `best_iteration` and + `best_score` are removed in the resulting booster. .. versionadded:: 1.3.0 @@ -2107,6 +2108,15 @@ def copy(self) -> "Booster": """ return copy.copy(self) + def reset(self) -> "Booster": + """Reset the booster object to release data caches used for training. + + .. versionadded:: 3.0.0 + + """ + _check_call(_LIB.XGBoosterReset(self.handle)) + return self + def attr(self, key: str) -> Optional[str]: """Get attribute string from the Booster. diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index bb4ebe44e1ed..86370469a400 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -187,9 +187,7 @@ def train( if evals_result is not None: evals_result.update(cb_container.history) - # Copy to serialise and unserialise booster to reset state and free - # training memory - return bst.copy() + return bst.reset() class CVPack: diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 90407fcf58ac..ee99922cdd1c 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -980,6 +980,13 @@ XGB_DLL int XGBoosterFree(BoosterHandle handle) { API_END(); } +XGB_DLL int XGBoosterReset(BoosterHandle handle) { + API_BEGIN(); + CHECK_HANDLE(); + static_cast(handle)->Reset(); + API_END(); +} + XGB_DLL int XGBoosterSetParam(BoosterHandle handle, const char *name, const char *value) { diff --git a/src/learner.cc b/src/learner.cc index e6642b0874ac..1dcd0fcfc7eb 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -860,6 +860,7 @@ class LearnerIO : public LearnerConfiguration { // Will be removed once JSON takes over. Right now we still loads some RDS files from R. std::string const serialisation_header_ { u8"CONFIG-offset:" }; + protected: void ClearCaches() { this->prediction_container_ = PredictionContainer{}; } public: @@ -1264,6 +1265,28 @@ class LearnerImpl : public LearnerIO { return out_impl; } + void Reset() override { + this->Configure(); + this->CheckModelInitialized(); + // Global data + auto local_map = LearnerAPIThreadLocalStore::Get(); + if (local_map->find(this) != local_map->cend()) { + local_map->erase(this); + } + + // Model + std::string buf; + common::MemoryBufferStream fo(&buf); + this->Save(&fo); + + common::MemoryFixSizeBuffer fs(buf.data(), buf.size()); + this->Load(&fs); + + // Learner self cache. Prediction is cleared in the load method + CHECK(this->prediction_container_.Container().empty()); + this->gpair_ = decltype(this->gpair_){}; + } + void UpdateOneIter(int iter, std::shared_ptr train) override { monitor_.Start("UpdateOneIter"); TrainingObserver::Instance().Update(iter); diff --git a/tests/cpp/test_learner.cu b/tests/cpp/test_learner.cu new file mode 100644 index 000000000000..2fde49ca0fdb --- /dev/null +++ b/tests/cpp/test_learner.cu @@ -0,0 +1,39 @@ +/** + * Copyright 2024, XGBoost contributors + */ +#include +#include // for DeviceSym +#include // for GlobalConfigThreadLocalStore +#include + +#include // for int32_t +#include // for unique_ptr + +#include "../../src/common/device_vector.cuh" // for GlobalMemoryLogger +#include "helpers.h" // for RandomDataGenerator + +namespace xgboost { +TEST(Learner, Reset) { + dh::GlobalMemoryLogger().Clear(); + + auto verbosity = GlobalConfigThreadLocalStore::Get()->verbosity; + ConsoleLogger::Configure({{"verbosity", "3"}}); + auto p_fmat = RandomDataGenerator{1024, 32, 0.0}.GenerateDMatrix(true); + std::unique_ptr learner{Learner::Create({p_fmat})}; + learner->SetParam("device", DeviceSym::CUDA()); + learner->Configure(); + for (std::int32_t i = 0; i < 2; ++i) { + learner->UpdateOneIter(i, p_fmat); + } + + auto cur = dh::GlobalMemoryLogger().CurrentlyAllocatedBytes(); + p_fmat.reset(); + auto after_p_fmat_reset = dh::GlobalMemoryLogger().CurrentlyAllocatedBytes(); + ASSERT_LT(after_p_fmat_reset, cur); + learner->Reset(); + auto after_learner_reset = dh::GlobalMemoryLogger().CurrentlyAllocatedBytes(); + ASSERT_LT(after_learner_reset, after_p_fmat_reset); + ASSERT_LE(after_learner_reset, 64); + ConsoleLogger::Configure({{"verbosity", std::to_string(verbosity)}}); +} +} // namespace xgboost