88#include " ../device_manager.h"
99
1010namespace xgboost ::common::sycl_impl {
11- double SumOptionalWeights (Context const * ctx, OptionalWeights const & weights) {
12- sycl::DeviceManager device_manager;
13- auto * qu = device_manager.GetQueue (ctx->Device ());
1411
12+ template <typename T>
13+ T ElementWiseSum (::sycl::queue* qu, OptionalWeights const & weights) {
1514 const auto * data = weights.Data ();
16- double result = 0 ;
15+ T result = 0 ;
1716 {
18- ::sycl::buffer<double > buff (&result, 1 );
17+ ::sycl::buffer<T > buff (&result, 1 );
1918 qu->submit ([&](::sycl::handler& cgh) {
2019 auto reduction = ::sycl::reduction (buff, cgh, ::sycl::plus<>());
2120 cgh.parallel_for <>(::sycl::range<1 >(weights.Size ()), reduction,
@@ -28,4 +27,16 @@ double SumOptionalWeights(Context const* ctx, OptionalWeights const& weights) {
2827
2928 return result;
3029}
30+
31+ double SumOptionalWeights (Context const * ctx, OptionalWeights const & weights) {
32+ sycl::DeviceManager device_manager;
33+ auto * qu = device_manager.GetQueue (ctx->Device ());
34+
35+ bool has_fp64_support = qu->get_device ().has (::sycl::aspect::fp64);
36+ if (has_fp64_support) {
37+ return ElementWiseSum<double >(qu, weights);
38+ } else {
39+ return ElementWiseSum<float >(qu, weights);
40+ }
41+ }
3142} // namespace xgboost::common::sycl_impl
0 commit comments