Skip to content

Commit

Permalink
[Fix] Fix the support for nms_rotated in Ascend (#2931)
Browse files Browse the repository at this point in the history
  • Loading branch information
momo609 authored Sep 19, 2023
1 parent b361a81 commit ca99624
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 9 deletions.
15 changes: 10 additions & 5 deletions mmcv/ops/csrc/common/pytorch_npu_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,21 @@

#define NPU_NAME_SPACE at_npu::native

#if MMCV_WITH_XLA
#ifdef MMCV_WITH_XLA
#define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, XLA, value)
#else
#define REGISTER_NPU_IMPL(key, value) \
REGISTER_DEVICE_IMPL(key, PrivateUse1, value)
#endif

#define CHECK_NPU(x) \
TORCH_CHECK( \
x.device().type() == at::kXLA || x.device().type() == at::kPrivateUse1, \
#x " must be a NPU tensor")
#ifdef MMCV_WITH_XLA
#define CHECK_NPU(x) \
TORCH_CHECK(x.device().type() == at::kXLA, #x " must be a NPU tensor")
#else
#define CHECK_NPU(x) \
TORCH_CHECK(x.device().type() == at::kPrivateUse1, #x \
" must be a NPU " \
"tensor")

#endif
#endif // PYTORCH_NPU_HELPER_HPP_
8 changes: 5 additions & 3 deletions mmcv/ops/csrc/pytorch/nms_rotated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
#else
AT_ERROR("Not compiled with GPU support");
#endif
#ifdef MMCV_WITH_XLA
} else if (dets.device().type() == at::kXLA) {
#ifdef MMCV_WITH_NPU
return nms_rotated_npu(dets, scores, labels, iou_threshold);
#else
AT_ERROR("Not compiled with NPU support");
#endif
#ifdef MMCV_WITH_KPRIVATE
} else if (dets.device().type() == at::kPrivateUse1) {
return nms_rotated_npu(dets, scores, labels, iou_threshold);
#endif
#ifdef MMCV_WITH_MLU
} else if (dets.device().type() == at::kMLU) {
Expand Down
44 changes: 44 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,53 @@ void gather_points_forward_npu(int b, int c, int n, int npoints,
.Attr("batch_dims", batch_dims)
.Run();
}
void gather_points_backward_npu(int b, int c, int n, int npoints,
const Tensor grad_out, const Tensor idx,
Tensor grad_points) {
at::Tensor indices = idx;
if (idx.scalar_type() != at::ScalarType::Int) {
indices = idx.to(at::kInt);
}
if (idx.dim() == 0) {
indices.unsqueeze_(0);
}
int64_t dim = 0;
at::SmallVector<int64_t, N> pad_size = array_to_small_vector(idx.sizes());
at::Tensor trans_grad_points = grad_points.transpose(1, 2).contiguous();
at::Tensor grad_points_view = trans_grad_points.view(
{trans_grad_points.sizes()[0] * trans_grad_points.sizes()[1],
trans_grad_points.sizes()[2]});
at::Tensor trans_grad_out = grad_out.transpose(1, 2).contiguous();
trans_grad_out = trans_grad_out.view(
{trans_grad_out.sizes()[0] * trans_grad_out.sizes()[1],
trans_grad_out.sizes()[2]});
auto index = at::arange(0, b);
index = index.to(grad_out.device());
index = at::mul(index, n);
index = index.view({b, 1});
index = at::broadcast_to(index, pad_size);
indices = at::add(index, indices);
indices = indices.view({-1});
OpCommand cmd;
cmd.Name("InplaceIndexAdd")
.Input(grad_points_view)
.Input(indices)
.Input(trans_grad_out)
.Output(grad_points_view)
.Attr("axis", dim)
.Run();
at::Tensor grad_points_result =
grad_points_view.view(trans_grad_points.sizes());
grad_points_result = grad_points_result.transpose(1, 2);
grad_points.copy_(grad_points_result);
}

void gather_points_forward_impl(int b, int c, int n, int npoints,
const Tensor points, const Tensor idx,
Tensor out);
void gather_points_backward_impl(int b, int c, int n, int npoints,
const Tensor grad_out, const Tensor idx,
Tensor grad_points);

REGISTER_NPU_IMPL(gather_points_forward_impl, gather_points_forward_npu);
REGISTER_NPU_IMPL(gather_points_backward_impl, gather_points_backward_npu);
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,10 @@ def get_mluops_version(file_path):
from torch_npu.utils.cpp_extension import NpuExtension
define_macros += [('MMCV_WITH_NPU', None)]
extension = NpuExtension
if parse_version(torch.__version__) >= parse_version('2.0.0'):
if parse_version(torch.__version__) <= parse_version('2.0.0'):
define_macros += [('MMCV_WITH_XLA', None)]
if parse_version(torch.__version__) > parse_version('2.0.0'):
define_macros += [('MMCV_WITH_KPRIVATE', None)]
except Exception:
raise ImportError('can not find any torch_npu')
# src
Expand Down

0 comments on commit ca99624

Please sign in to comment.