Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Remove global random engine. #10354

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions include/xgboost/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,6 @@
#define XGBOOST_LOG_WITH_TIME 1
#endif // XGBOOST_LOG_WITH_TIME

/*!
* \brief Whether to customize global PRNG.
*/
#ifndef XGBOOST_CUSTOMIZE_GLOBAL_PRNG
#define XGBOOST_CUSTOMIZE_GLOBAL_PRNG 0
#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG

/*!
* \brief Check if alignas(*) keyword is supported. (g++ 4.8 or higher)
*/
Expand Down
16 changes: 12 additions & 4 deletions include/xgboost/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
#include <type_traits> // for invoke_result_t, is_same_v, underlying_type_t

namespace xgboost {

class Json;
struct CUDAContext;
namespace common {
class RandomEngine;
} // namespace common

// symbolic names
struct DeviceSym {
Expand Down Expand Up @@ -46,9 +49,7 @@ struct DeviceOrd {
[[nodiscard]] bool IsSyclDefault() const { return device == kSyclDefault; }
[[nodiscard]] bool IsSyclCPU() const { return device == kSyclCPU; }
[[nodiscard]] bool IsSyclGPU() const { return device == kSyclGPU; }
[[nodiscard]] bool IsSycl() const { return (IsSyclDefault() ||
IsSyclCPU() ||
IsSyclGPU()); }
[[nodiscard]] bool IsSycl() const { return (IsSyclDefault() || IsSyclCPU() || IsSyclGPU()); }

constexpr DeviceOrd() = default;
constexpr DeviceOrd(Type type, bst_d_ordinal_t ord) : device{type}, ordinal{ord} {}
Expand Down Expand Up @@ -296,6 +297,11 @@ struct Context : public XGBoostParameter<Context> {
.describe("Enable checking whether parameters are used or not.");
}

[[nodiscard]] auto& Rng() const { return *rng_; }

void SaveConfig(Json* out) const;
void LoadConfig(Json const& in);

private:
void SetDeviceOrdinal(Args const& kwargs);
Context& SetDevice(DeviceOrd d) {
Expand All @@ -307,6 +313,8 @@ struct Context : public XGBoostParameter<Context> {
// shared_ptr is used instead of unique_ptr as with unique_ptr it's difficult to define
// p_impl while trying to hide CUDA code from the host compiler.
mutable std::shared_ptr<CUDAContext> cuctx_;
// mutable for random engine. The rng is shared by child contexts, if there's any.
mutable std::shared_ptr<common::RandomEngine> rng_;
// cached value for CFS CPU limit. (used in containerized env)
std::int32_t cfs_cpu_count_; // NOLINT
};
Expand Down
14 changes: 1 addition & 13 deletions src/common/common.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2015-2023 by Contributors
* Copyright 2015-2024 by Contributors
*/
#include "common.h"

Expand All @@ -9,19 +9,7 @@
#include <cstdio> // for snprintf, size_t
#include <string> // for string

#include "./random.h" // for GlobalRandomEngine, GlobalRandom

namespace xgboost::common {
/*! \brief thread local entry for random. */
struct RandomThreadLocalEntry {
/*! \brief the random engine instance. */
GlobalRandomEngine engine;
};

using RandomThreadLocalStore = dmlc::ThreadLocalStore<RandomThreadLocalEntry>;

GlobalRandomEngine &GlobalRandom() { return RandomThreadLocalStore::Get()->engine; }

void EscapeU8(std::string const &string, std::string *p_buffer) {
auto &buffer = *p_buffer;
for (size_t i = 0; i < string.length(); i++) {
Expand Down
6 changes: 3 additions & 3 deletions src/common/random.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include <thrust/shuffle.h> // for shuffle

Expand All @@ -19,7 +19,7 @@ void WeightedSamplingWithoutReplacement(Context const *ctx, common::Span<bst_fea
common::Span<float const> weights,
common::Span<bst_feature_t> results,
HostDeviceVector<bst_feature_t> *sorted_idx,
GlobalRandomEngine *grng) {
RandomEngine *grng) {
CUDAContext const *cuctx = ctx->CUDACtx();
CHECK_EQ(array.size(), weights.size());
// Sampling keys
Expand Down Expand Up @@ -61,7 +61,7 @@ void SampleFeature(Context const *ctx, bst_feature_t n_features,
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_new_features,
HostDeviceVector<float> const &feature_weights,
HostDeviceVector<float> *weight_buffer,
HostDeviceVector<bst_feature_t> *idx_buffer, GlobalRandomEngine *grng) {
HostDeviceVector<bst_feature_t> *idx_buffer, RandomEngine *grng) {
CUDAContext const *cuctx = ctx->CUDACtx();
auto &new_features = *p_new_features;
new_features.SetDevice(ctx->Device());
Expand Down
78 changes: 11 additions & 67 deletions src/common/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,80 +7,24 @@
#ifndef XGBOOST_COMMON_RANDOM_H_
#define XGBOOST_COMMON_RANDOM_H_

#include <xgboost/logging.h>

#include <algorithm>
#include <functional>
#include <limits>
#include <map>
#include <memory>
#include <numeric>
#include <random>
#include <utility>
#include <vector>

#include "../collective/broadcast.h" // for Broadcast
#include "../collective/communicator-inl.h"
#include "algorithm.h" // ArgSort
#include "common.h"
#include "xgboost/context.h" // Context
#include "xgboost/host_device_vector.h"
#include "xgboost/linalg.h"
#include "../collective/broadcast.h" // for Broadcast
#include "algorithm.h" // ArgSort
#include "xgboost/context.h" // Context
#include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/linalg.h" // for MakeVec
#include "xgboost/logging.h"

namespace xgboost::common {
/*!
* \brief Define mt19937 as default type Random Engine.
*/
using RandomEngine = std::mt19937;

#if defined(XGBOOST_CUSTOMIZE_GLOBAL_PRNG) && XGBOOST_CUSTOMIZE_GLOBAL_PRNG == 1
/*!
* \brief An customized random engine, used to be plugged in PRNG from other systems.
* The implementation of this library is not provided by xgboost core library.
* Instead the other library can implement this class, which will be used as GlobalRandomEngine
* If XGBOOST_RANDOM_CUSTOMIZE = 1, by default this is switched off.
*/
class CustomGlobalRandomEngine {
public:
/*! \brief The result type */
using result_type = uint32_t;
/*! \brief The minimum of random numbers generated */
inline static constexpr result_type min() {
return 0;
}
/*! \brief The maximum random numbers generated */
inline static constexpr result_type max() {
return std::numeric_limits<result_type>::max();
}
/*!
* \brief seed function, to be implemented
* \param val The value of the seed.
*/
void seed(result_type val);
/*!
* \return next random number.
*/
result_type operator()();
};

/*!
* \brief global random engine
*/
typedef CustomGlobalRandomEngine GlobalRandomEngine;

#else
/*!
* \brief global random engine
*/
using GlobalRandomEngine = RandomEngine;
#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG

/*!
* \brief global singleton of a random engine.
* This random engine is thread-local and
* only visible to current thread.
*/
GlobalRandomEngine& GlobalRandom(); // NOLINT(*)
class RandomEngine : public std::mt19937 {};

/*
* Original paper:
Expand All @@ -96,7 +40,7 @@ std::vector<T> WeightedSamplingWithoutReplacement(Context const* ctx, std::vecto
CHECK_EQ(array.size(), weights.size());
std::vector<float> keys(weights.size());
std::uniform_real_distribution<float> dist;
auto& rng = GlobalRandom();
auto& rng = ctx->Rng();
for (size_t i = 0; i < array.size(); ++i) {
auto w = std::max(weights.at(i), kRtEps);
auto u = dist(rng);
Expand All @@ -120,7 +64,7 @@ void SampleFeature(Context const* ctx, bst_feature_t n_features,
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_new_features,
HostDeviceVector<float> const& feature_weights,
HostDeviceVector<float>* weight_buffer,
HostDeviceVector<bst_feature_t>* idx_buffer, GlobalRandomEngine* grng);
HostDeviceVector<bst_feature_t>* idx_buffer, RandomEngine* grng);

void InitFeatureSet(Context const* ctx,
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features);
Expand All @@ -140,7 +84,7 @@ class ColumnSampler {
float colsample_bylevel_{1.0f};
float colsample_bytree_{1.0f};
float colsample_bynode_{1.0f};
GlobalRandomEngine rng_;
RandomEngine rng_;
Context const* ctx_;

// Used for weighted sampling.
Expand Down Expand Up @@ -230,7 +174,7 @@ class ColumnSampler {
};

inline auto MakeColumnSampler(Context const* ctx) {
std::uint32_t seed = common::GlobalRandomEngine()();
std::uint32_t seed = ctx->Rng()();
auto rc = collective::Broadcast(ctx, linalg::MakeVec(&seed, 1), 0);
collective::SafeColl(rc);
auto cs = std::make_shared<common::ColumnSampler>(seed);
Expand Down
30 changes: 27 additions & 3 deletions src/context.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2014-2023 by XGBoost Contributors
* Copyright 2014-2024, XGBoost Contributors
*
* \brief Context object used for controlling runtime parameters.
*/
Expand All @@ -8,21 +8,27 @@
#include <algorithm> // for find_if
#include <charconv> // for from_chars
#include <iterator> // for distance
#include <locale> // for locale
#include <optional> // for optional
#include <regex> // for regex_replace, regex_match
#include <sstream> // for stringstream

#include "common/common.h" // AssertGPUSupport
#include "common/error_msg.h" // WarnDeprecatedGPUId
#include "common/threading_utils.h"
#include "xgboost/json.h" // for Json
#include "xgboost/string_view.h"
#include "common/random.h" // for RandomEngin

namespace xgboost {

DMLC_REGISTER_PARAMETER(Context);

std::int64_t constexpr Context::kDefaultSeed;

Context::Context() : cfs_cpu_count_{common::GetCfsCPUCount()} {}
Context::Context()
: rng_{std::make_shared<common::RandomEngine>()}, cfs_cpu_count_{common::GetCfsCPUCount()} {
rng_->seed(kDefaultSeed);
}

namespace {
inline constexpr char const* kDevice = "device";
Expand Down Expand Up @@ -219,6 +225,24 @@ void Context::Init(Args const& kwargs) {
}
}

void Context::SaveConfig(Json* out) const {
(*out) = ToJson(*this);
std::stringstream ss;
ss.imbue(std::locale{"en_US.UTF8"});
ss << this->Rng();
(*out)["rng"] = ss.str();
}

void Context::LoadConfig(Json const& in) {
FromJson(in, this);
std::stringstream ss;
ss.imbue(std::locale{"en_US.UTF8"});
ss << get<String const>(in["rng"]);
ss >> this->Rng();
// make sure the GPU ID is valid in new environment before start running configure.
this->ConfigureGpuId(false);
}

void Context::ConfigureGpuId(bool require_gpu) {
if (this->IsCPU() && require_gpu) {
this->UpdateAllowUnknown(Args{{kDevice, DeviceSym::CUDA()}});
Expand Down
5 changes: 2 additions & 3 deletions src/gbm/gbtree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
#include <dmlc/parameter.h>

#include <algorithm> // for equal
#include <cinttypes> // for uint32_t
#include <limits>
#include <cstdint> // for uint32_t
#include <memory>
#include <string>
#include <utility>
Expand Down Expand Up @@ -928,7 +927,7 @@ class Dart : public GBTree {
idx_drop_.clear();

std::uniform_real_distribution<> runif(0.0, 1.0);
auto& rnd = common::GlobalRandom();
auto& rnd = ctx_->Rng();
bool skip = false;
if (dparam_.skip_drop > 0.0) skip = (runif(rnd) < dparam_.skip_drop);
// sample some trees to drop
Expand Down
17 changes: 7 additions & 10 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@
#include <limits> // for numeric_limits
#include <memory> // for allocator, unique_ptr, shared_ptr, operator==
#include <mutex> // for mutex, lock_guard
#include <set> // for set
#include <sstream> // for operator<<, basic_ostream, basic_ostream::opera...
#include <stack> // for stack
#include <string> // for basic_string, char_traits, operator<, string
#include <system_error> // for errc
#include <tuple> // for get
#include <unordered_map> // for operator!=, unordered_map
#include <utility> // for pair, as_const, move, swap
#include <vector> // for vector

Expand All @@ -41,7 +39,7 @@
#include "common/error_msg.h" // for MaxFeatureSize, WarnOldSerialization, ...
#include "common/io.h" // for PeekableInStream, ReadAll, FixedSizeStream, Mem...
#include "common/observer.h" // for TrainingObserver
#include "common/random.h" // for GlobalRandom
#include "common/random.h" // for RandomEngine
#include "common/timer.h" // for Monitor
#include "common/version.h" // for Version
#include "dmlc/endian.h" // for ByteSwap, DMLC_IO_NO_ENDIAN_SWAP
Expand Down Expand Up @@ -476,7 +474,7 @@ class LearnerConfiguration : public Learner {

// set seed only before the model is initialized
if (!initialized || ctx_.seed != old_seed) {
common::GlobalRandom().seed(ctx_.seed);
ctx_.Rng().seed(ctx_.seed);
}

// must precede configure gbm since num_features is required for gbm
Expand Down Expand Up @@ -556,9 +554,7 @@ class LearnerConfiguration : public Learner {
}
}

FromJson(learner_parameters.at("generic_param"), &ctx_);
// make sure the GPU ID is valid in new environment before start running configure.
ctx_.ConfigureGpuId(false);
ctx_.LoadConfig(learner_parameters.at("generic_param"));

this->need_configuration_ = true;
}
Expand Down Expand Up @@ -588,7 +584,8 @@ class LearnerConfiguration : public Learner {
}
learner_parameters["metrics"] = Array(std::move(metrics));

learner_parameters["generic_param"] = ToJson(ctx_);
learner_parameters["generic_param"] = Object{};
ctx_.SaveConfig(&learner_parameters["generic_param"]);
}

void SetParam(const std::string& key, const std::string& value) override {
Expand Down Expand Up @@ -1271,7 +1268,7 @@ class LearnerImpl : public LearnerIO {
this->InitBaseScore(train.get());

if (ctx_.seed_per_iteration) {
common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter);
ctx_.Rng().seed(ctx_.seed * kRandSeedMagic + iter);
}

this->ValidateDMatrix(train.get(), true);
Expand All @@ -1298,7 +1295,7 @@ class LearnerImpl : public LearnerIO {
this->Configure();

if (ctx_.seed_per_iteration) {
common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter);
ctx_.Rng().seed(ctx_.seed * kRandSeedMagic + iter);
}

this->ValidateDMatrix(train.get(), true);
Expand Down
Loading
Loading