Skip to content

Commit a15b93c

Browse files
authored
fix fp32 compatibility (#78)
* fix fp32 compatibility * fix build --------- Co-authored-by: Dmitry Razdoburdin <>
1 parent 95fd456 commit a15b93c

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

plugin/sycl/common/optional_weight.cc

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88
#include "../device_manager.h"
99

1010
namespace xgboost::common::sycl_impl {
11-
double SumOptionalWeights(Context const* ctx, OptionalWeights const& weights) {
12-
sycl::DeviceManager device_manager;
13-
auto* qu = device_manager.GetQueue(ctx->Device());
1411

12+
template <typename T>
13+
T ElementWiseSum(::sycl::queue* qu, OptionalWeights const& weights) {
1514
const auto* data = weights.Data();
16-
double result = 0;
15+
T result = 0;
1716
{
18-
::sycl::buffer<double> buff(&result, 1);
17+
::sycl::buffer<T> buff(&result, 1);
1918
qu->submit([&](::sycl::handler& cgh) {
2019
auto reduction = ::sycl::reduction(buff, cgh, ::sycl::plus<>());
2120
cgh.parallel_for<>(::sycl::range<1>(weights.Size()), reduction,
@@ -28,4 +27,16 @@ double SumOptionalWeights(Context const* ctx, OptionalWeights const& weights) {
2827

2928
return result;
3029
}
30+
31+
double SumOptionalWeights(Context const* ctx, OptionalWeights const& weights) {
32+
sycl::DeviceManager device_manager;
33+
auto* qu = device_manager.GetQueue(ctx->Device());
34+
35+
bool has_fp64_support = qu->get_device().has(::sycl::aspect::fp64);
36+
if (has_fp64_support) {
37+
return ElementWiseSum<double>(qu, weights);
38+
} else {
39+
return ElementWiseSum<float>(qu, weights);
40+
}
41+
}
3142
} // namespace xgboost::common::sycl_impl

0 commit comments

Comments
 (0)