Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions src/ATen/native/xpu/LossNLL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,9 @@ TORCH_IMPL_FUNC(nll_loss_forward_out_xpu)
int64_t ignore_index,
const Tensor& output,
const Tensor& total_weight) {
const Tensor& weight = weight_opt.getTensorRef();
xpu::nll_loss_forward_kernel(
self,
target,
((weight_opt.has_value() && (*weight_opt).defined())
? at::OptionalTensorRef(*weight_opt)
: at::OptionalTensorRef()),
reduction,
ignore_index,
output,
total_weight);
output, total_weight, self, target, weight, reduction, ignore_index);
}

TORCH_IMPL_FUNC(nll_loss_backward_out_xpu)
Expand All @@ -39,19 +32,18 @@ TORCH_IMPL_FUNC(nll_loss_backward_out_xpu)
int64_t ignore_index,
const Tensor& total_weight,
const Tensor& grad_input) {
const Tensor& weight = weight_opt.getTensorRef();
grad_input.zero_();
xpu::nll_loss_backward_kernel(
grad_input,
grad_output,
self,
target,
((weight_opt.has_value() && (*weight_opt).defined())
? at::OptionalTensorRef(*weight_opt)
: at::OptionalTensorRef()),
reduction,
ignore_index,
total_weight,
grad_input);
weight,
reduction,
ignore_index);
}

} // namespace native
} // namespace at
} // namespace at
19 changes: 19 additions & 0 deletions src/ATen/native/xpu/sycl/KernelUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,22 @@
i = _i_n_d_e_x)

#define XPU_KERNEL_LOOP(item, i, n) XPU_KERNEL_LOOP_TYPE(item, i, n, int)

// Use 1024 threads per block, which requires cuda sm_2x or above
constexpr int SYCL_NUM_THREADS = 1024;

// CUDA: number of blocks for threads.
inline int GET_GROUPS(
const int64_t N,
const int64_t max_threads_per_group = SYCL_NUM_THREADS) {
TORCH_INTERNAL_ASSERT(
N > 0, "XPU kernel launch blocks must be positive, but got N=", N);
constexpr int64_t max_int = std::numeric_limits<int>::max();

// Round up division for positive number that cannot cause integer overflow
auto group_num = (N - 1) / max_threads_per_group + 1;
TORCH_INTERNAL_ASSERT(
group_num <= max_int, "Can't schedule too many blocks on XPU device");

return static_cast<int>(group_num);
}
Loading
Loading