Skip to content

Commit

Permalink
SYCL. Unify regression objective calculation (#11016)
Browse files Browse the repository at this point in the history
  • Loading branch information
razdoburdin authored Nov 28, 2024
1 parent 588198e commit 325fe41
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 206 deletions.
28 changes: 28 additions & 0 deletions plugin/sycl/common/linalg_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <vector>
#include <utility>

#include "../../../src/common/linalg_op.h"

#include "../data.h"
#include "../device_manager.h"

Expand Down Expand Up @@ -69,6 +71,32 @@ void ElementWiseKernel(TensorView<T, D> t, Fn&& fn) {
}).wait_and_throw();
}

template <typename T, int32_t D, typename Fn>
bool Validate(DeviceOrd device, TensorView<T, D> t, Fn&& fn) {
sycl::DeviceManager device_manager;
auto* qu = device_manager.GetQueue(t.Device());

int flag = 0;
{
::sycl::buffer<int, 1> flag_buf(&flag, 1);
qu->submit([&](::sycl::handler& cgh) {
auto flag_acc = flag_buf.get_access<::sycl::access::mode::write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(t.Size()),
[=](::sycl::id<1> pid) {
const size_t idx = pid[0];
const T& value = call(t, xgboost::linalg::UnravelIndex(idx, t.Shape()));
bool is_valid = const_cast<Fn&&>(fn)(value);
if (!is_valid) {
AtomicRef<int> flag_ref(flag_acc[0]);
flag_ref = 1;
}
});
});
}
qu->wait_and_throw();
return (flag == 0);
}

} // namespace linalg
} // namespace sycl

Expand Down
35 changes: 35 additions & 0 deletions plugin/sycl/common/transform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/**
* Copyright 2021-2024, XGBoost Contributors
* \file transform.h
*/
#ifndef PLUGIN_SYCL_COMMON_TRANSFORM_H_
#define PLUGIN_SYCL_COMMON_TRANSFORM_H_

#include "../device_manager.h"

#include <sycl/sycl.hpp>

namespace xgboost {
namespace sycl {
namespace common {

template <typename Functor, typename... SpanType>
void LaunchSyclKernel(DeviceOrd device, Functor&& _func, xgboost::common::Range _range,
SpanType... _spans) {
sycl::DeviceManager device_manager;
auto* qu = device_manager.GetQueue(device);

size_t size = *(_range.end());
qu->submit([&](::sycl::handler& cgh) {
cgh.parallel_for<>(::sycl::range<1>(size),
[=](::sycl::id<1> pid) {
const size_t idx = pid[0];
const_cast<Functor&&>(_func)(idx, _spans...);
});
}).wait();
}

} // namespace common
} // namespace sycl
} // namespace xgboost
#endif // PLUGIN_SYCL_COMMON_TRANSFORM_H_
200 changes: 0 additions & 200 deletions plugin/sycl/objective/regression_obj.cc

This file was deleted.

4 changes: 2 additions & 2 deletions src/common/stats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void Mean(Context const* ctx, linalg::Vector<float> const& v, linalg::Vector<flo
void SampleMean(Context const* ctx, bool is_column_split, linalg::Matrix<float> const& v,
linalg::Vector<float>* out) {
*out = linalg::Zeros<float>(ctx, std::max(v.Shape(1), decltype(v.Shape(1)){1}));
if (ctx->IsCPU()) {
if (!ctx->IsCUDA()) {
auto h_v = v.HostView();
CHECK(h_v.CContiguous());
std::int64_t n_samples = v.Shape(0);
Expand Down Expand Up @@ -94,7 +94,7 @@ void WeightedSampleMean(Context const* ctx, bool is_column_split, linalg::Matrix
HostDeviceVector<float> const& w, linalg::Vector<float>* out) {
*out = linalg::Zeros<float>(ctx, std::max(v.Shape(1), decltype(v.Shape(1)){1}));
CHECK_EQ(v.Shape(0), w.Size());
if (ctx->IsCPU()) {
if (!ctx->IsCUDA()) {
auto h_v = v.HostView();
auto h_w = w.ConstHostSpan();
auto sum_w = std::accumulate(h_w.data(), h_w.data() + h_w.size(), 0.0);
Expand Down
26 changes: 23 additions & 3 deletions src/common/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
#include "device_helpers.cuh"
#endif // defined (__CUDACC__)

#if defined (SYCL_LANGUAGE_VERSION)
#include "../plugin/sycl/common/transform.h"
#endif // defined (SYCL_LANGUAGE_VERSION)

namespace xgboost {
namespace common {

Expand Down Expand Up @@ -71,10 +75,10 @@ class Transform {
*/
template <typename... HDV>
void Eval(HDV... vectors) const {
bool on_device = device_.IsCUDA();

if (on_device) {
if (device_.IsCUDA()) {
LaunchCUDA(func_, vectors...);
} else if (device_.IsSycl()) {
LaunchSycl(func_, vectors...);
} else {
LaunchCPU(func_, vectors...);
}
Expand Down Expand Up @@ -160,6 +164,22 @@ class Transform {
}
#endif // defined(__CUDACC__)

#if defined (SYCL_LANGUAGE_VERSION)
template <typename... HDV>
void LaunchSycl(Functor _func, HDV*... _vectors) const {
UnpackShard(device_, _vectors...);

size_t range_size = *range_.end() - *range_.begin();
Range shard_range {0, static_cast<Range::DifferenceType>(range_size)};
sycl::common::LaunchSyclKernel(device_, _func, shard_range, UnpackHDVOnDevice(_vectors)...);
}
#else
template <typename... HDV>
void LaunchSycl(Functor _func, HDV *... _vectors) const {
LaunchCPU(_func, _vectors...);
}
#endif // defined(SYCL_LANGUAGE_VERSION)

template <typename... HDV>
void LaunchCPU(Functor func, HDV *...vectors) const {
omp_ulong end = static_cast<omp_ulong>(*(range_.end()));
Expand Down
11 changes: 10 additions & 1 deletion src/objective/regression_obj.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ class RegLossObj : public FitInterceptGlmLike {
common::AssertGPUSupport();
return false;
#endif // defined(XGBOOST_USE_CUDA)
},
[&] {
#if defined(XGBOOST_USE_SYCL)
return sycl::linalg::Validate(ctx_->Device(), label,
[](float y) -> bool { return Loss::CheckLabel(y); });
#else
common::AssertSYCLSupport();
return false;
#endif // defined(XGBOOST_USE_SYCL)
});
if (!valid) {
LOG(FATAL) << Loss::LabelErrorMsg();
Expand Down Expand Up @@ -123,7 +132,7 @@ class RegLossObj : public FitInterceptGlmLike {
additional_input_.HostVector().begin()[1] = is_null_weight;

const size_t nthreads = ctx_->Threads();
bool on_device = device.IsCUDA();
bool on_device = !device.IsCPU();
// On CPU we run the transformation each thread processing a contigious block of data
// for better performance.
const size_t n_data_blocks = std::max(static_cast<size_t>(1), (on_device ? ndata : nthreads));
Expand Down
Loading

0 comments on commit 325fe41

Please sign in to comment.