From 1780f5b9d9f85e6ef2f95dffef9e0303b28feb91 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Thu, 23 Jan 2025 03:37:59 -0800 Subject: [PATCH 01/18] add row-wise processing to PushRowPage --- src/common/quantile.h | 113 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 91 insertions(+), 22 deletions(-) diff --git a/src/common/quantile.h b/src/common/quantile.h index e189b259b159..f877e522ee1a 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -840,47 +840,116 @@ class SketchContainerImpl { template void PushRowPageImpl(Batch const &batch, size_t base_rowid, OptionalWeights weights, size_t nnz, size_t n_features, bool is_dense, IsValid is_valid) { - auto thread_columns_ptr = LoadBalance(batch, nnz, n_features, n_threads_, is_valid); - dmlc::OMPException exc; -#pragma omp parallel num_threads(n_threads_) - { - exc.Run([&]() { - auto tid = static_cast(omp_get_thread_num()); - auto const begin = thread_columns_ptr[tid]; - auto const end = thread_columns_ptr[tid + 1]; - - // do not iterate if no columns are assigned to the thread - if (begin < end && end <= n_features) { - for (size_t ridx = 0; ridx < batch.Size(); ++ridx) { + size_t ridx_block_size = batch.Size() / n_threads_ + (batch.Size() % n_threads_ > 0); + size_t min_ridx_block_size = 1024; + if ((n_features < n_threads_) && (ridx_block_size > min_ridx_block_size)) { + /* Row-wise parallelisation. + */ + std::vector> categories_buff(n_threads_ * n_features); + std::vector sketches_buff(n_threads_ * n_features); + + #pragma omp parallel num_threads(n_threads_) + { + exc.Run([&]() { + auto tid = static_cast(omp_get_thread_num()); + WQSketch* sketches_th = sketches_buff.data() + tid * n_features; + std::set* categories_th = categories_buff.data() + tid * n_features; + + for (size_t ii = 0; ii < n_features; ii++) { + auto n_bins = std::min(static_cast(max_bins_), columns_size_[ii]); + auto eps = 1.0 / (static_cast(n_bins) * WQSketch::kFactor); + sketches_th[ii].Init(columns_size_[ii], eps); + } + + size_t ridx_begin = tid * ridx_block_size; + size_t ridx_end = std::min(ridx_begin + ridx_block_size, batch.Size()); + for (size_t ridx = ridx_begin; ridx < ridx_end; ++ridx) { auto const &line = batch.GetLine(ridx); auto w = weights[ridx + base_rowid]; if (is_dense) { - for (size_t ii = begin; ii < end; ii++) { + for (size_t ii = 0; ii < n_features; ii++) { auto elem = line.GetElement(ii); if (is_valid(elem)) { if (IsCat(feature_types_, ii)) { - categories_[ii].emplace(elem.value); + categories_th[ii].emplace(elem.value); } else { - sketches_[ii].Push(elem.value, w); + sketches_th[ii].Push(elem.value, w); } } } } else { - for (size_t i = 0; i < line.Size(); ++i) { - auto const &elem = line.GetElement(i); - if (is_valid(elem) && elem.column_idx >= begin && elem.column_idx < end) { + for (size_t ii = 0; ii < line.Size(); ++ii) { + auto elem = line.GetElement(ii); + if (is_valid(elem)) { if (IsCat(feature_types_, elem.column_idx)) { - categories_[elem.column_idx].emplace(elem.value); + categories_th[elem.column_idx].emplace(elem.value); } else { - sketches_[elem.column_idx].Push(elem.value, w); + sketches_th[elem.column_idx].Push(elem.value, w); } } } } } - } - }); + #pragma omp barrier + + size_t fidx_block_size = n_features / n_threads_ + (n_features % n_threads_ > 0); + size_t fidx_begin = tid * fidx_block_size; + size_t fidx_end = std::min(fidx_begin + fidx_block_size, n_features); + for (size_t ii = fidx_begin; ii < fidx_end; ++ii) { + for (size_t th = 0; th < n_threads_; ++th) { + if (IsCat(feature_types_, ii)) { + categories_[ii].merge(categories_buff[th * n_features + ii]); + } else { + typename WQSketch::SummaryContainer summary; + sketches_buff[th * n_features + ii].GetSummary(&summary); + sketches_[ii].PushSummary(summary); + } + } + } + }); + } + } else { + auto thread_columns_ptr = LoadBalance(batch, nnz, n_features, n_threads_, is_valid); + #pragma omp parallel num_threads(n_threads_) + { + exc.Run([&]() { + auto tid = static_cast(omp_get_thread_num()); + auto const begin = thread_columns_ptr[tid]; + auto const end = thread_columns_ptr[tid + 1]; + + // do not iterate if no columns are assigned to the thread + if (begin < end && end <= n_features) { + for (size_t ridx = 0; ridx < batch.Size(); ++ridx) { + auto const &line = batch.GetLine(ridx); + auto w = weights[ridx + base_rowid]; + if (is_dense) { + for (size_t ii = begin; ii < end; ii++) { + auto elem = line.GetElement(ii); + if (is_valid(elem)) { + if (IsCat(feature_types_, ii)) { + categories_[ii].emplace(elem.value); + } else { + sketches_[ii].Push(elem.value, w); + } + } + } + } else { + for (size_t i = 0; i < line.Size(); ++i) { + auto const &elem = line.GetElement(i); + if (is_valid(elem) && elem.column_idx >= begin && elem.column_idx < end) { + if (IsCat(feature_types_, elem.column_idx)) { + categories_[elem.column_idx].emplace(elem.value); + } else { + sketches_[elem.column_idx].Push(elem.value, w); + } + } + } + } + } + } + }); + } } exc.Rethrow(); } From 922af516c7b41d09755eddcf70e35ac8d8dfa31d Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Thu, 23 Jan 2025 04:00:11 -0800 Subject: [PATCH 02/18] fix --- src/common/quantile.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/common/quantile.h b/src/common/quantile.h index f877e522ee1a..bbffa742a563 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -843,7 +843,7 @@ class SketchContainerImpl { dmlc::OMPException exc; size_t ridx_block_size = batch.Size() / n_threads_ + (batch.Size() % n_threads_ > 0); size_t min_ridx_block_size = 1024; - if ((n_features < n_threads_) && (ridx_block_size > min_ridx_block_size)) { + if ((n_features < size_t(n_threads_)) && (ridx_block_size > min_ridx_block_size)) { /* Row-wise parallelisation. */ std::vector> categories_buff(n_threads_ * n_features); @@ -897,7 +897,7 @@ class SketchContainerImpl { size_t fidx_begin = tid * fidx_block_size; size_t fidx_end = std::min(fidx_begin + fidx_block_size, n_features); for (size_t ii = fidx_begin; ii < fidx_end; ++ii) { - for (size_t th = 0; th < n_threads_; ++th) { + for (int th = 0; th < n_threads_; ++th) { if (IsCat(feature_types_, ii)) { categories_[ii].merge(categories_buff[th * n_features + ii]); } else { From 77923f22fb6e75c62c8b05875e23399be6a9f496 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Thu, 23 Jan 2025 05:00:17 -0800 Subject: [PATCH 03/18] clang tildy --- src/common/quantile.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/common/quantile.h b/src/common/quantile.h index bbffa742a563..f906d9ada90e 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -843,7 +843,8 @@ class SketchContainerImpl { dmlc::OMPException exc; size_t ridx_block_size = batch.Size() / n_threads_ + (batch.Size() % n_threads_ > 0); size_t min_ridx_block_size = 1024; - if ((n_features < size_t(n_threads_)) && (ridx_block_size > min_ridx_block_size)) { + if ((n_features < static_cast(n_threads_)) && + (ridx_block_size > min_ridx_block_size)) { /* Row-wise parallelisation. */ std::vector> categories_buff(n_threads_ * n_features); From ab313687e1728d0391797c3155b26240ae385e72 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Tue, 4 Feb 2025 05:21:31 -0800 Subject: [PATCH 04/18] reduce buffer size; add tests --- src/common/quantile.h | 268 ++++++++++++++++++------------ tests/cpp/common/test_quantile.cc | 36 +++- 2 files changed, 191 insertions(+), 113 deletions(-) diff --git a/src/common/quantile.h b/src/common/quantile.h index f906d9ada90e..3fec78636e13 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -748,37 +748,76 @@ std::vector CalcColumnSize(Batch const &batch, bst_feature_t const n_ return entries_per_columns; } +struct WLBalance { + WLBalance(size_t n_columns) : is_column_splited(n_columns) {} + + struct ThreadWorkLoad { + std::vector columns; + size_t split_idx; + size_t n_splits; + + ThreadWorkLoad() : columns(), split_idx(0), n_splits(1) {} + }; + + std::vector baskets; + std::vector is_column_splited; + bool has_splitted = false; +}; + + template -std::vector LoadBalance(Batch const &batch, size_t nnz, bst_feature_t n_columns, - size_t const nthreads, IsValid&& is_valid) { - /* Some sparse datasets have their mass concentrating on small number of features. To - * avoid waiting for a few threads running forever, we here distribute different number - * of columns to different threads according to number of entries. +WLBalance LoadBalance(Batch const &batch, size_t nnz, bst_feature_t n_columns, + size_t const nthreads, IsValid&& is_valid) { + /* Some datasets have long columns. It is beneficial to split such columns between threads and + * than collect the result if number of threads is high enourth. In this case, each thread being + * involved in processing of splitted columns works only with a single column. + * + * Columns that are too small for splitting are distributed between threads. In this case each thread + * can process multiple columns. The range of columns indexes for all the rthreads in this case don't + * overlap with each other. */ + WLBalance wl_balance(n_columns); + if (nnz == 0) return wl_balance; + auto& wl_baskets = wl_balance.baskets; + size_t const total_entries = nnz; size_t const entries_per_thread = DivRoundUp(total_entries, nthreads); // Need to calculate the size for each batch. std::vector entries_per_columns = CalcColumnSize(batch, n_columns, nthreads, is_valid); - std::vector cols_ptr(nthreads + 1, 0); - size_t count{0}; - size_t current_thread{1}; + + size_t count = 0; + for (size_t column_idx = 0; column_idx < n_columns; ++column_idx) { + size_t n_entries = entries_per_columns[column_idx]; + + size_t n_splits = n_entries / entries_per_thread; + if (n_splits > 1) { + // Split column between threads + wl_balance.has_splitted = true; + wl_balance.is_column_splited[column_idx] = true; + for (size_t split_idx = 0; split_idx < n_splits; split_idx++) { + wl_baskets.emplace_back(); + + auto& wl = wl_baskets.back(); + wl.columns.push_back(column_idx); + wl.split_idx = split_idx; + wl.n_splits = n_splits; + } + } else { + if (wl_baskets.empty() || count > entries_per_thread) { + wl_baskets.emplace_back(); + count = 0; + } + count += n_entries; - for (auto col : entries_per_columns) { - cols_ptr.at(current_thread)++; // add one column to thread - count += col; - CHECK_LE(count, total_entries); - if (count > entries_per_thread) { - current_thread++; - count = 0; - cols_ptr.at(current_thread) = cols_ptr[current_thread - 1]; + auto& wl = wl_baskets.back(); + wl.columns.push_back(column_idx); + wl_balance.is_column_splited[column_idx] = false; } } - // Idle threads. - for (; current_thread < cols_ptr.size() - 1; ++current_thread) { - cols_ptr[current_thread + 1] = cols_ptr[current_thread]; - } - return cols_ptr; + + CHECK_LE(wl_baskets.size(), nthreads); + return wl_balance; } /*! @@ -841,117 +880,116 @@ class SketchContainerImpl { void PushRowPageImpl(Batch const &batch, size_t base_rowid, OptionalWeights weights, size_t nnz, size_t n_features, bool is_dense, IsValid is_valid) { dmlc::OMPException exc; - size_t ridx_block_size = batch.Size() / n_threads_ + (batch.Size() % n_threads_ > 0); - size_t min_ridx_block_size = 1024; - if ((n_features < static_cast(n_threads_)) && - (ridx_block_size > min_ridx_block_size)) { - /* Row-wise parallelisation. - */ - std::vector> categories_buff(n_threads_ * n_features); - std::vector sketches_buff(n_threads_ * n_features); - #pragma omp parallel num_threads(n_threads_) - { - exc.Run([&]() { - auto tid = static_cast(omp_get_thread_num()); - WQSketch* sketches_th = sketches_buff.data() + tid * n_features; - std::set* categories_th = categories_buff.data() + tid * n_features; + auto threads_wl = LoadBalance(batch, nnz, n_features, n_threads_, is_valid); + if (threads_wl.baskets.empty()) return; + + std::vector> categories_buff; + std::vector sketches_buff; + if (threads_wl.has_splitted) { + sketches_buff.resize(threads_wl.baskets.size()); + categories_buff.resize(threads_wl.baskets.size()); + } - for (size_t ii = 0; ii < n_features; ii++) { - auto n_bins = std::min(static_cast(max_bins_), columns_size_[ii]); + #pragma omp parallel num_threads(threads_wl.baskets.size()) + { + exc.Run([&]() { + auto tid = static_cast(omp_get_thread_num()); + const auto& wl = threads_wl.baskets[tid]; + if (wl.n_splits > 1) { + // We process only a single column in this case + size_t column = wl.columns.front(); + + std::set* categories_out; + WQSketch* sketches_out; + if (wl.split_idx == 0) { + categories_out = &categories_[column]; + sketches_out = &sketches_[column]; + } else { + auto n_bins = std::min(static_cast(max_bins_), columns_size_[column]); auto eps = 1.0 / (static_cast(n_bins) * WQSketch::kFactor); - sketches_th[ii].Init(columns_size_[ii], eps); + sketches_buff[tid].Init(columns_size_[column], eps); + + categories_out = &categories_buff[tid]; + sketches_out = &sketches_buff[tid]; } - size_t ridx_begin = tid * ridx_block_size; - size_t ridx_end = std::min(ridx_begin + ridx_block_size, batch.Size()); - for (size_t ridx = ridx_begin; ridx < ridx_end; ++ridx) { + size_t split_size = batch.Size() / wl.n_splits; + size_t begin = wl.split_idx * split_size; + size_t end = std::min(begin + split_size, batch.Size()); + + for (size_t ridx = begin; ridx < end; ++ridx) { auto const &line = batch.GetLine(ridx); auto w = weights[ridx + base_rowid]; if (is_dense) { - for (size_t ii = 0; ii < n_features; ii++) { - auto elem = line.GetElement(ii); - if (is_valid(elem)) { - if (IsCat(feature_types_, ii)) { - categories_th[ii].emplace(elem.value); - } else { - sketches_th[ii].Push(elem.value, w); - } - } + auto const &elem = line.GetElement(column); + /* elem.column_idx == column */ + if (is_valid(elem)) { + PushElement(elem, categories_out, sketches_out, w); } } else { - for (size_t ii = 0; ii < line.Size(); ++ii) { - auto elem = line.GetElement(ii); - if (is_valid(elem)) { - if (IsCat(feature_types_, elem.column_idx)) { - categories_th[elem.column_idx].emplace(elem.value); - } else { - sketches_th[elem.column_idx].Push(elem.value, w); - } + size_t n_columns_with_high_idx = n_features - column; + size_t begin = line.Size() < n_columns_with_high_idx ? 0 : line.Size() - n_columns_with_high_idx; + size_t end = std::min(column + 1, line.Size()); + for (size_t i = begin; i < end; ++i) { + auto const &elem = line.GetElement(i); + if (is_valid(elem) && (elem.column_idx == column)) { + PushElement(elem, categories_out, sketches_out, w); } } } } - #pragma omp barrier - - size_t fidx_block_size = n_features / n_threads_ + (n_features % n_threads_ > 0); - size_t fidx_begin = tid * fidx_block_size; - size_t fidx_end = std::min(fidx_begin + fidx_block_size, n_features); - for (size_t ii = fidx_begin; ii < fidx_end; ++ii) { - for (int th = 0; th < n_threads_; ++th) { - if (IsCat(feature_types_, ii)) { - categories_[ii].merge(categories_buff[th * n_features + ii]); - } else { - typename WQSketch::SummaryContainer summary; - sketches_buff[th * n_features + ii].GetSummary(&summary); - sketches_[ii].PushSummary(summary); - } - } - } - }); - } - } else { - auto thread_columns_ptr = LoadBalance(batch, nnz, n_features, n_threads_, is_valid); - #pragma omp parallel num_threads(n_threads_) - { - exc.Run([&]() { - auto tid = static_cast(omp_get_thread_num()); - auto const begin = thread_columns_ptr[tid]; - auto const end = thread_columns_ptr[tid + 1]; - - // do not iterate if no columns are assigned to the thread - if (begin < end && end <= n_features) { - for (size_t ridx = 0; ridx < batch.Size(); ++ridx) { - auto const &line = batch.GetLine(ridx); - auto w = weights[ridx + base_rowid]; - if (is_dense) { - for (size_t ii = begin; ii < end; ii++) { - auto elem = line.GetElement(ii); + } else { + for (size_t ridx = 0; ridx < batch.Size(); ++ridx) { + auto const &line = batch.GetLine(ridx); + auto w = weights[ridx + base_rowid]; + if (is_dense) { + for (size_t ii = wl.columns.front(); ii <= wl.columns.back(); ++ii) { + if (!threads_wl.is_column_splited[ii]) { + auto const &elem = line.GetElement(ii); + /* elem.column_idx == ii */ if (is_valid(elem)) { - if (IsCat(feature_types_, ii)) { - categories_[ii].emplace(elem.value); - } else { - sketches_[ii].Push(elem.value, w); - } + PushElement(elem, &categories_[ii], &sketches_[ii], w); } } - } else { - for (size_t i = 0; i < line.Size(); ++i) { - auto const &elem = line.GetElement(i); - if (is_valid(elem) && elem.column_idx >= begin && elem.column_idx < end) { - if (IsCat(feature_types_, elem.column_idx)) { - categories_[elem.column_idx].emplace(elem.value); - } else { - sketches_[elem.column_idx].Push(elem.value, w); - } + } + } else { + // number of columns with idx >= wl.columns.front() + size_t n_columns_with_high_idx = n_features - wl.columns.front(); + size_t begin = line.Size() < n_columns_with_high_idx ? 0 : line.Size() - n_columns_with_high_idx; + size_t end = std::min(wl.columns.back() + 1, line.Size()); + for (size_t i = begin; i < end; ++i) { + auto const &elem = line.GetElement(i); + if (is_valid(elem)) { + if (!threads_wl.is_column_splited[elem.column_idx] && + (elem.column_idx >= wl.columns.front()) && + (elem.column_idx <= wl.columns.back())) { + PushElement(elem, &categories_[elem.column_idx], + &sketches_[elem.column_idx], w); } } } } } - }); - } + } + #pragma omp barrier + + if (wl.n_splits > 1 && wl.split_idx == 0) { + // The thread being responsible for the first block in split collect info from the other ones. + size_t column_idx = wl.columns.front(); + for (int th = tid + 1; th < tid + wl.n_splits; ++th) { + if (IsCat(feature_types_, column_idx)) { + categories_[column_idx].merge(categories_buff[th]); + } else { + typename WQSketch::SummaryContainer summary; + sketches_buff[th].GetSummary(&summary); + sketches_[column_idx].PushSummary(summary); + } + } + } + }); } + exc.Rethrow(); } @@ -963,6 +1001,18 @@ class SketchContainerImpl { private: // Merge all categories from other workers. void AllreduceCategories(Context const* ctx, MetaInfo const& info); + + template + void PushElement(const ElemType& elem, + std::set* categorie, + WQSketch* sketch, + float w) { + if (IsCat(feature_types_, elem.column_idx)) { + categorie->emplace(elem.value); + } else { + sketch->Push(elem.value, w); + } + } }; class HostSketchContainer : public SketchContainerImpl> { diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index 1ef6572599fc..6864e4d984e6 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -17,15 +17,19 @@ namespace xgboost::common { TEST(Quantile, LoadBalance) { size_t constexpr kRows = 1000, kCols = 100; auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(); - std::vector cols_ptr; + WLBalance threads_wl(kCols); Context ctx; for (auto const& page : m->GetBatches(&ctx)) { data::SparsePageAdapterBatch adapter{page.GetView()}; - cols_ptr = LoadBalance(adapter, page.data.Size(), kCols, 13, [](auto) { return true; }); + threads_wl = LoadBalance(adapter, page.data.Size(), kCols, 13, [](auto) { return true; }); + CHECK_LE(threads_wl.baskets.size(), common::OmpGetNumThreads(0)); } size_t n_cols = 0; - for (size_t i = 1; i < cols_ptr.size(); ++i) { - n_cols += cols_ptr[i] - cols_ptr[i - 1]; + for (const auto& basket : threads_wl.baskets) { + n_cols += basket.columns.size(); + for (size_t column : basket.columns) { + CHECK_LT(column, kCols); + } } CHECK_EQ(n_cols, kCols); } @@ -160,6 +164,12 @@ TEST(Quantile, DistributedBasic) { TestDistributedQuantile(kRows, kCols); } +TEST(Quantile, DistributedRowWise) { + size_t kRows = 2 * common::OmpGetNumThreads(0); + size_t kCols = 2; + TestDistributedQuantile(kRows, kCols); +} + TEST(Quantile, Distributed) { constexpr size_t kRows = 4000, kCols = 200; TestDistributedQuantile(kRows, kCols); @@ -170,6 +180,12 @@ TEST(Quantile, SortedDistributedBasic) { TestDistributedQuantile(kRows, kCols); } +TEST(Quantile, SortedDistributedRowWise) { + size_t kRows = 2 * common::OmpGetNumThreads(0); + size_t kCols = 2; + TestDistributedQuantile(kRows, kCols); +} + TEST(Quantile, SortedDistributed) { constexpr size_t kRows = 4000, kCols = 200; TestDistributedQuantile(kRows, kCols); @@ -288,6 +304,12 @@ TEST(Quantile, ColumnSplitBasic) { TestColSplitQuantile(kRows, kCols); } +TEST(Quantile, ColumnSplitRowWise) { + size_t kRows = 2 * common::OmpGetNumThreads(0); + size_t kCols = 2; + TestColSplitQuantile(kRows, kCols); +} + TEST(Quantile, ColumnSplit) { constexpr size_t kRows = 4000, kCols = 200; TestColSplitQuantile(kRows, kCols); @@ -298,6 +320,12 @@ TEST(Quantile, ColumnSplitSortedBasic) { TestColSplitQuantile(kRows, kCols); } +TEST(Quantile, ColumnSplitSortedRowWise) { + size_t kRows = 2 * common::OmpGetNumThreads(0); + size_t kCols = 2; + TestColSplitQuantile(kRows, kCols); +} + TEST(Quantile, ColumnSplitSorted) { constexpr size_t kRows = 4000, kCols = 200; TestColSplitQuantile(kRows, kCols); From 836c768c01d24f41ce39e332321893f0b6c577f6 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Tue, 4 Feb 2025 05:27:18 -0800 Subject: [PATCH 05/18] linting --- src/common/quantile.h | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/common/quantile.h b/src/common/quantile.h index 3fec78636e13..ceaed368260e 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -749,7 +749,7 @@ std::vector CalcColumnSize(Batch const &batch, bst_feature_t const n_ } struct WLBalance { - WLBalance(size_t n_columns) : is_column_splited(n_columns) {} + explicit WLBalance(size_t n_columns) : is_column_splited(n_columns) {} struct ThreadWorkLoad { std::vector columns; @@ -785,7 +785,7 @@ WLBalance LoadBalance(Batch const &batch, size_t nnz, bst_feature_t n_columns, // Need to calculate the size for each batch. std::vector entries_per_columns = CalcColumnSize(batch, n_columns, nthreads, is_valid); - + size_t count = 0; for (size_t column_idx = 0; column_idx < n_columns; ++column_idx) { size_t n_entries = entries_per_columns[column_idx]; @@ -929,7 +929,8 @@ class SketchContainerImpl { } } else { size_t n_columns_with_high_idx = n_features - column; - size_t begin = line.Size() < n_columns_with_high_idx ? 0 : line.Size() - n_columns_with_high_idx; + size_t begin = line.Size() < n_columns_with_high_idx ? 0 + : line.Size() - n_columns_with_high_idx; size_t end = std::min(column + 1, line.Size()); for (size_t i = begin; i < end; ++i) { auto const &elem = line.GetElement(i); @@ -956,13 +957,14 @@ class SketchContainerImpl { } else { // number of columns with idx >= wl.columns.front() size_t n_columns_with_high_idx = n_features - wl.columns.front(); - size_t begin = line.Size() < n_columns_with_high_idx ? 0 : line.Size() - n_columns_with_high_idx; + size_t begin = line.Size() < n_columns_with_high_idx ? 0 + : line.Size() - n_columns_with_high_idx; size_t end = std::min(wl.columns.back() + 1, line.Size()); for (size_t i = begin; i < end; ++i) { auto const &elem = line.GetElement(i); if (is_valid(elem)) { if (!threads_wl.is_column_splited[elem.column_idx] && - (elem.column_idx >= wl.columns.front()) && + (elem.column_idx >= wl.columns.front()) && (elem.column_idx <= wl.columns.back())) { PushElement(elem, &categories_[elem.column_idx], &sketches_[elem.column_idx], w); @@ -975,7 +977,9 @@ class SketchContainerImpl { #pragma omp barrier if (wl.n_splits > 1 && wl.split_idx == 0) { - // The thread being responsible for the first block in split collect info from the other ones. + /* The thread being responsible for the first block in split + * collect info from the other ones. + */ size_t column_idx = wl.columns.front(); for (int th = tid + 1; th < tid + wl.n_splits; ++th) { if (IsCat(feature_types_, column_idx)) { From 49e309c5867ca7ee302b99b4ea9a1b1e734f3ccd Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Tue, 4 Feb 2025 05:45:01 -0800 Subject: [PATCH 06/18] fix build --- src/common/quantile.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/quantile.h b/src/common/quantile.h index ceaed368260e..629061d8f42e 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -981,7 +981,7 @@ class SketchContainerImpl { * collect info from the other ones. */ size_t column_idx = wl.columns.front(); - for (int th = tid + 1; th < tid + wl.n_splits; ++th) { + for (size_t th = tid + 1; th < tid + wl.n_splits; ++th) { if (IsCat(feature_types_, column_idx)) { categories_[column_idx].merge(categories_buff[th]); } else { From d30e6575ab6429305d1c5ac3d1b814b18f3f466b Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Tue, 4 Feb 2025 06:08:42 -0800 Subject: [PATCH 07/18] fix tests --- src/common/quantile.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/quantile.h b/src/common/quantile.h index 629061d8f42e..4b89fbaaf928 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -914,7 +914,7 @@ class SketchContainerImpl { sketches_out = &sketches_buff[tid]; } - size_t split_size = batch.Size() / wl.n_splits; + size_t split_size = DivRoundUp(batch.Size(), wl.n_splits); size_t begin = wl.split_idx * split_size; size_t end = std::min(begin + split_size, batch.Size()); From e00bf4ae49c661be0b93849b59e72df01987f246 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Tue, 4 Feb 2025 06:34:32 -0800 Subject: [PATCH 08/18] fix tests with empty columns --- src/common/quantile.h | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/src/common/quantile.h b/src/common/quantile.h index 4b89fbaaf928..bb2245c04bc7 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -790,29 +790,31 @@ WLBalance LoadBalance(Batch const &batch, size_t nnz, bst_feature_t n_columns, for (size_t column_idx = 0; column_idx < n_columns; ++column_idx) { size_t n_entries = entries_per_columns[column_idx]; - size_t n_splits = n_entries / entries_per_thread; - if (n_splits > 1) { - // Split column between threads - wl_balance.has_splitted = true; - wl_balance.is_column_splited[column_idx] = true; - for (size_t split_idx = 0; split_idx < n_splits; split_idx++) { - wl_baskets.emplace_back(); + if (n_entries > 0) { + size_t n_splits = n_entries / entries_per_thread; + if (n_splits > 1) { + // Split column between threads + wl_balance.has_splitted = true; + wl_balance.is_column_splited[column_idx] = true; + for (size_t split_idx = 0; split_idx < n_splits; split_idx++) { + wl_baskets.emplace_back(); + + auto& wl = wl_baskets.back(); + wl.columns.push_back(column_idx); + wl.split_idx = split_idx; + wl.n_splits = n_splits; + } + } else { + if (wl_baskets.empty() || count > entries_per_thread) { + wl_baskets.emplace_back(); + count = 0; + } + count += n_entries; auto& wl = wl_baskets.back(); wl.columns.push_back(column_idx); - wl.split_idx = split_idx; - wl.n_splits = n_splits; + wl_balance.is_column_splited[column_idx] = false; } - } else { - if (wl_baskets.empty() || count > entries_per_thread) { - wl_baskets.emplace_back(); - count = 0; - } - count += n_entries; - - auto& wl = wl_baskets.back(); - wl.columns.push_back(column_idx); - wl_balance.is_column_splited[column_idx] = false; } } From 2c0e5ddada309983bce08bf0f79eebfd29df1792 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Tue, 4 Feb 2025 06:51:38 -0800 Subject: [PATCH 09/18] remove incorrect check --- tests/cpp/common/test_quantile.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index 6864e4d984e6..3af1ef62e05a 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -22,7 +22,6 @@ TEST(Quantile, LoadBalance) { for (auto const& page : m->GetBatches(&ctx)) { data::SparsePageAdapterBatch adapter{page.GetView()}; threads_wl = LoadBalance(adapter, page.data.Size(), kCols, 13, [](auto) { return true; }); - CHECK_LE(threads_wl.baskets.size(), common::OmpGetNumThreads(0)); } size_t n_cols = 0; for (const auto& basket : threads_wl.baskets) { From 79824155a66e8b9d61d022de24caebeb71fda5fe Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Wed, 5 Feb 2025 04:03:49 -0800 Subject: [PATCH 10/18] fix tilda warning and non-omp tests --- src/common/quantile.h | 6 +++--- tests/cpp/common/test_quantile.cc | 16 ++++++++-------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/common/quantile.h b/src/common/quantile.h index bb2245c04bc7..850491f0ec37 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -753,10 +753,10 @@ struct WLBalance { struct ThreadWorkLoad { std::vector columns; - size_t split_idx; - size_t n_splits; + size_t split_idx = 0; + size_t n_splits = 1; - ThreadWorkLoad() : columns(), split_idx(0), n_splits(1) {} + ThreadWorkLoad() : columns() {} }; std::vector baskets; diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index 3af1ef62e05a..891238fd6418 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -164,8 +164,8 @@ TEST(Quantile, DistributedBasic) { } TEST(Quantile, DistributedRowWise) { - size_t kRows = 2 * common::OmpGetNumThreads(0); - size_t kCols = 2; + size_t kRows = 10 * common::OmpGetNumThreads(0); + size_t kCols = 10; TestDistributedQuantile(kRows, kCols); } @@ -180,8 +180,8 @@ TEST(Quantile, SortedDistributedBasic) { } TEST(Quantile, SortedDistributedRowWise) { - size_t kRows = 2 * common::OmpGetNumThreads(0); - size_t kCols = 2; + size_t kRows = 10 * common::OmpGetNumThreads(0); + size_t kCols = 10; TestDistributedQuantile(kRows, kCols); } @@ -304,8 +304,8 @@ TEST(Quantile, ColumnSplitBasic) { } TEST(Quantile, ColumnSplitRowWise) { - size_t kRows = 2 * common::OmpGetNumThreads(0); - size_t kCols = 2; + size_t kRows = 10 * common::OmpGetNumThreads(0); + size_t kCols = 10; TestColSplitQuantile(kRows, kCols); } @@ -320,8 +320,8 @@ TEST(Quantile, ColumnSplitSortedBasic) { } TEST(Quantile, ColumnSplitSortedRowWise) { - size_t kRows = 2 * common::OmpGetNumThreads(0); - size_t kCols = 2; + size_t kRows = 10 * common::OmpGetNumThreads(0); + size_t kCols = 10; TestColSplitQuantile(kRows, kCols); } From a0d5bd6bf44e872f78735980feb58979dd5d4fe8 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Thu, 6 Feb 2025 02:18:51 -0800 Subject: [PATCH 11/18] fixes for tests --- src/common/quantile.h | 4 ++-- tests/cpp/common/test_quantile.cc | 6 ------ 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/common/quantile.h b/src/common/quantile.h index 850491f0ec37..b4aa091fd6b2 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -959,8 +959,8 @@ class SketchContainerImpl { } else { // number of columns with idx >= wl.columns.front() size_t n_columns_with_high_idx = n_features - wl.columns.front(); - size_t begin = line.Size() < n_columns_with_high_idx ? 0 - : line.Size() - n_columns_with_high_idx; + size_t begin = line.Size() < n_columns_with_high_idx + ? 0 : line.Size() - n_columns_with_high_idx; size_t end = std::min(wl.columns.back() + 1, line.Size()); for (size_t i = begin; i < end; ++i) { auto const &elem = line.GetElement(i); diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index 891238fd6418..3f925b339e69 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -319,12 +319,6 @@ TEST(Quantile, ColumnSplitSortedBasic) { TestColSplitQuantile(kRows, kCols); } -TEST(Quantile, ColumnSplitSortedRowWise) { - size_t kRows = 10 * common::OmpGetNumThreads(0); - size_t kCols = 10; - TestColSplitQuantile(kRows, kCols); -} - TEST(Quantile, ColumnSplitSorted) { constexpr size_t kRows = 4000, kCols = 200; TestColSplitQuantile(kRows, kCols); From 9280b5f5c0b00a336113c8f395bbc61ffe8a4de7 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Thu, 6 Feb 2025 02:27:30 -0800 Subject: [PATCH 12/18] remove unpropper added tests --- tests/cpp/common/test_quantile.cc | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index 3f925b339e69..f09e892d221f 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -179,12 +179,6 @@ TEST(Quantile, SortedDistributedBasic) { TestDistributedQuantile(kRows, kCols); } -TEST(Quantile, SortedDistributedRowWise) { - size_t kRows = 10 * common::OmpGetNumThreads(0); - size_t kCols = 10; - TestDistributedQuantile(kRows, kCols); -} - TEST(Quantile, SortedDistributed) { constexpr size_t kRows = 4000, kCols = 200; TestDistributedQuantile(kRows, kCols); From 8130d9443a0d93aa107f459b1f4e462dccc4c4ec Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Tue, 11 Feb 2025 07:19:36 -0800 Subject: [PATCH 13/18] refactor the changes --- src/common/quantile.h | 74 +++++++++++++++++------------- tests/python/test_data_iterator.py | 2 +- 2 files changed, 42 insertions(+), 34 deletions(-) diff --git a/src/common/quantile.h b/src/common/quantile.h index b4aa091fd6b2..34beeb2596c7 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -791,7 +791,7 @@ WLBalance LoadBalance(Batch const &batch, size_t nnz, bst_feature_t n_columns, size_t n_entries = entries_per_columns[column_idx]; if (n_entries > 0) { - size_t n_splits = n_entries / entries_per_thread; + size_t n_splits = std::min(nthreads * n_entries / total_entries, n_entries); if (n_splits > 1) { // Split column between threads wl_balance.has_splitted = true; @@ -881,18 +881,20 @@ class SketchContainerImpl { template void PushRowPageImpl(Batch const &batch, size_t base_rowid, OptionalWeights weights, size_t nnz, size_t n_features, bool is_dense, IsValid is_valid) { - dmlc::OMPException exc; - auto threads_wl = LoadBalance(batch, nnz, n_features, n_threads_, is_valid); if (threads_wl.baskets.empty()) return; std::vector> categories_buff; std::vector sketches_buff; + std::vector buff_was_used; + if (threads_wl.has_splitted) { sketches_buff.resize(threads_wl.baskets.size()); categories_buff.resize(threads_wl.baskets.size()); + buff_was_used.resize(threads_wl.baskets.size(), 0); } + dmlc::OMPException exc; #pragma omp parallel num_threads(threads_wl.baskets.size()) { exc.Run([&]() { @@ -902,19 +904,9 @@ class SketchContainerImpl { // We process only a single column in this case size_t column = wl.columns.front(); - std::set* categories_out; - WQSketch* sketches_out; - if (wl.split_idx == 0) { - categories_out = &categories_[column]; - sketches_out = &sketches_[column]; - } else { - auto n_bins = std::min(static_cast(max_bins_), columns_size_[column]); - auto eps = 1.0 / (static_cast(n_bins) * WQSketch::kFactor); - sketches_buff[tid].Init(columns_size_[column], eps); - - categories_out = &categories_buff[tid]; - sketches_out = &sketches_buff[tid]; - } + auto n_bins = std::min(static_cast(max_bins_), columns_size_[column]); + auto eps = 1.0 / (static_cast(n_bins) * WQSketch::kFactor); + sketches_buff[tid].Init(columns_size_[column], eps); size_t split_size = DivRoundUp(batch.Size(), wl.n_splits); size_t begin = wl.split_idx * split_size; @@ -927,17 +919,19 @@ class SketchContainerImpl { auto const &elem = line.GetElement(column); /* elem.column_idx == column */ if (is_valid(elem)) { - PushElement(elem, categories_out, sketches_out, w); + buff_was_used[tid] = 1; + PushElement(elem, &categories_buff[tid], &sketches_buff[tid], w); } } else { size_t n_columns_with_high_idx = n_features - column; - size_t begin = line.Size() < n_columns_with_high_idx ? 0 + size_t col_begin = line.Size() < n_columns_with_high_idx ? 0 : line.Size() - n_columns_with_high_idx; - size_t end = std::min(column + 1, line.Size()); - for (size_t i = begin; i < end; ++i) { + size_t col_end = std::min(column + 1, line.Size()); + for (size_t i = col_begin; i < col_end; ++i) { auto const &elem = line.GetElement(i); if (is_valid(elem) && (elem.column_idx == column)) { - PushElement(elem, categories_out, sketches_out, w); + buff_was_used[tid] = 1; + PushElement(elem, &categories_buff[tid], &sketches_buff[tid], w); } } } @@ -959,10 +953,10 @@ class SketchContainerImpl { } else { // number of columns with idx >= wl.columns.front() size_t n_columns_with_high_idx = n_features - wl.columns.front(); - size_t begin = line.Size() < n_columns_with_high_idx + size_t col_begin = line.Size() < n_columns_with_high_idx ? 0 : line.Size() - n_columns_with_high_idx; - size_t end = std::min(wl.columns.back() + 1, line.Size()); - for (size_t i = begin; i < end; ++i) { + size_t col_end = std::min(wl.columns.back() + 1, line.Size()); + for (size_t i = col_begin; i < col_end; ++i) { auto const &elem = line.GetElement(i); if (is_valid(elem)) { if (!threads_wl.is_column_splited[elem.column_idx] && @@ -976,26 +970,40 @@ class SketchContainerImpl { } } } - #pragma omp barrier + // #pragma omp flush(buff_was_used, categories_buff, sketches_buff) + #pragma omp barrier if (wl.n_splits > 1 && wl.split_idx == 0) { /* The thread being responsible for the first block in split * collect info from the other ones. */ size_t column_idx = wl.columns.front(); - for (size_t th = tid + 1; th < tid + wl.n_splits; ++th) { - if (IsCat(feature_types_, column_idx)) { - categories_[column_idx].merge(categories_buff[th]); - } else { - typename WQSketch::SummaryContainer summary; - sketches_buff[th].GetSummary(&summary); - sketches_[column_idx].PushSummary(summary); + + typename WQSketch::SummaryContainer main_summary; + main_summary.Reserve(sketches_[column_idx].limit_size); + typename WQSketch::SummaryContainer split_summary; + split_summary.Reserve(2 * sketches_[column_idx].limit_size); + typename WQSketch::SummaryContainer comb_summary; + comb_summary.Reserve(3 * sketches_[column_idx].limit_size); + + for (size_t th = tid + 0; th < tid + wl.n_splits; ++th) { + CHECK_LT(th, threads_wl.baskets.size()); + // Make shure some work was done by thread + if (buff_was_used[th] > 0) { + if (IsCat(feature_types_, column_idx)) { + categories_[column_idx].merge(categories_buff[th]); + } else { + sketches_buff[th].GetSummary(&split_summary); + + comb_summary.SetCombine(main_summary, split_summary); + main_summary.SetPrune(comb_summary, sketches_[column_idx].limit_size); + } } } + sketches_[column_idx].PushSummary(main_summary); } }); } - exc.Rethrow(); } diff --git a/tests/python/test_data_iterator.py b/tests/python/test_data_iterator.py index 6a980f967a97..2a875070016d 100644 --- a/tests/python/test_data_iterator.py +++ b/tests/python/test_data_iterator.py @@ -169,7 +169,7 @@ def run_data_iterator( @given( strategies.integers(0, 1024), strategies.integers(1, 7), - strategies.integers(0, 13), + strategies.integers(0, 2), strategies.booleans(), ) @settings(deadline=None, max_examples=10, print_blob=True) From c09dc9ef6e62cc312a35aff4c3e5a68567e12d33 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Tue, 11 Feb 2025 07:21:53 -0800 Subject: [PATCH 14/18] remove commented code --- src/common/quantile.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/common/quantile.h b/src/common/quantile.h index 34beeb2596c7..0ed22f7de2ad 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -971,7 +971,6 @@ class SketchContainerImpl { } } - // #pragma omp flush(buff_was_used, categories_buff, sketches_buff) #pragma omp barrier if (wl.n_splits > 1 && wl.split_idx == 0) { /* The thread being responsible for the first block in split From ab2462a2d4c9b67605e4c0eaa020fa1fca435c37 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Mon, 17 Feb 2025 06:30:06 -0800 Subject: [PATCH 15/18] return the minimal block size --- src/common/quantile.h | 5 +++-- tests/cpp/common/test_quantile.cc | 4 ++-- tests/python/test_data_iterator.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/common/quantile.h b/src/common/quantile.h index 0ed22f7de2ad..a7d828cc856a 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -791,8 +791,9 @@ WLBalance LoadBalance(Batch const &batch, size_t nnz, bst_feature_t n_columns, size_t n_entries = entries_per_columns[column_idx]; if (n_entries > 0) { - size_t n_splits = std::min(nthreads * n_entries / total_entries, n_entries); - if (n_splits > 1) { + size_t n_splits = std::min(nthreads * n_entries / total_entries, n_entries); + constexpr size_t kMinBlockSize = 1024; + if ((n_splits > 1) && (kMinBlockSize * n_splits < n_entries)) { // Split column between threads wl_balance.has_splitted = true; wl_balance.is_column_splited[column_idx] = true; diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index f09e892d221f..48577d21ef88 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -164,7 +164,7 @@ TEST(Quantile, DistributedBasic) { } TEST(Quantile, DistributedRowWise) { - size_t kRows = 10 * common::OmpGetNumThreads(0); + size_t kRows = 1024 * common::OmpGetNumThreads(0); size_t kCols = 10; TestDistributedQuantile(kRows, kCols); } @@ -298,7 +298,7 @@ TEST(Quantile, ColumnSplitBasic) { } TEST(Quantile, ColumnSplitRowWise) { - size_t kRows = 10 * common::OmpGetNumThreads(0); + size_t kRows = 1024 * common::OmpGetNumThreads(0); size_t kCols = 10; TestColSplitQuantile(kRows, kCols); } diff --git a/tests/python/test_data_iterator.py b/tests/python/test_data_iterator.py index 2a875070016d..6a980f967a97 100644 --- a/tests/python/test_data_iterator.py +++ b/tests/python/test_data_iterator.py @@ -169,7 +169,7 @@ def run_data_iterator( @given( strategies.integers(0, 1024), strategies.integers(1, 7), - strategies.integers(0, 2), + strategies.integers(0, 13), strategies.booleans(), ) @settings(deadline=None, max_examples=10, print_blob=True) From f2b33b890c120d2e0e3ac226efcc6c4bec6f649a Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Mon, 17 Feb 2025 07:08:20 -0800 Subject: [PATCH 16/18] param tuning --- src/common/quantile.h | 2 +- tests/cpp/common/test_quantile.cc | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/common/quantile.h b/src/common/quantile.h index a7d828cc856a..97eeb593b869 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -792,7 +792,7 @@ WLBalance LoadBalance(Batch const &batch, size_t nnz, bst_feature_t n_columns, if (n_entries > 0) { size_t n_splits = std::min(nthreads * n_entries / total_entries, n_entries); - constexpr size_t kMinBlockSize = 1024; + constexpr size_t kMinBlockSize = (1u << 16); if ((n_splits > 1) && (kMinBlockSize * n_splits < n_entries)) { // Split column between threads wl_balance.has_splitted = true; diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index 48577d21ef88..5b0bf69055dd 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -164,7 +164,7 @@ TEST(Quantile, DistributedBasic) { } TEST(Quantile, DistributedRowWise) { - size_t kRows = 1024 * common::OmpGetNumThreads(0); + size_t kRows = (1u << 16) * common::OmpGetNumThreads(0); size_t kCols = 10; TestDistributedQuantile(kRows, kCols); } @@ -298,7 +298,7 @@ TEST(Quantile, ColumnSplitBasic) { } TEST(Quantile, ColumnSplitRowWise) { - size_t kRows = 1024 * common::OmpGetNumThreads(0); + size_t kRows = (1u << 16) * common::OmpGetNumThreads(0); size_t kCols = 10; TestColSplitQuantile(kRows, kCols); } From 79d9cb620f0dd8ebb2a5042e0676bed5bc106185 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Mon, 17 Feb 2025 07:20:29 -0800 Subject: [PATCH 17/18] reduce number of data in tests --- tests/cpp/common/test_quantile.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index 5b0bf69055dd..b54c2a96391b 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -165,7 +165,7 @@ TEST(Quantile, DistributedBasic) { TEST(Quantile, DistributedRowWise) { size_t kRows = (1u << 16) * common::OmpGetNumThreads(0); - size_t kCols = 10; + size_t kCols = 2; TestDistributedQuantile(kRows, kCols); } @@ -299,7 +299,7 @@ TEST(Quantile, ColumnSplitBasic) { TEST(Quantile, ColumnSplitRowWise) { size_t kRows = (1u << 16) * common::OmpGetNumThreads(0); - size_t kCols = 10; + size_t kCols = 2; TestColSplitQuantile(kRows, kCols); } From 66049396e540dbc217e5e64b46943475111a8566 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Tue, 25 Feb 2025 12:51:55 +0100 Subject: [PATCH 18/18] Update src/common/quantile.h Co-authored-by: Jiaming Yuan --- src/common/quantile.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/quantile.h b/src/common/quantile.h index 97eeb593b869..0e7857ec2852 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -769,7 +769,7 @@ template WLBalance LoadBalance(Batch const &batch, size_t nnz, bst_feature_t n_columns, size_t const nthreads, IsValid&& is_valid) { /* Some datasets have long columns. It is beneficial to split such columns between threads and - * than collect the result if number of threads is high enourth. In this case, each thread being + * than collect the result if the number of threads is high enough. In this case, each thread being * involved in processing of splitted columns works only with a single column. * * Columns that are too small for splitting are distributed between threads. In this case each thread