Skip to content

Commit

Permalink
[Feature] Add the support of three_interpolate op for Ascend device (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
lihao7212148 authored Oct 17, 2023
1 parent c0774b5 commit 94dff26
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 11 deletions.
29 changes: 29 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

void three_interpolate_forward_npu(int b, int c, int m, int n,
const Tensor points, const Tensor idx,
const Tensor weight, Tensor out) {
auto point_c_trans = points.transpose(1, 2);

OpCommand cmd;
cmd.Name("ThreeInterpolate")
.Input(point_c_trans)
.Input(idx)
.Input(weight)
.Output(out)
.Run();

auto output = out.view({b, n, c}).transpose(1, 2);
auto res = NpuUtils::format_contiguous(output);
out.copy_(res);
}

void three_interpolate_forward_impl(int b, int c, int m, int n,
const Tensor points, const Tensor idx,
const Tensor weight, Tensor out);

REGISTER_NPU_IMPL(three_interpolate_forward_impl,
three_interpolate_forward_npu);
41 changes: 30 additions & 11 deletions tests/test_ops/test_three_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,28 @@
import torch

from mmcv.ops import three_interpolate
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE


@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
@pytest.mark.parametrize('dtype', [torch.half, torch.float, torch.double])
def test_three_interpolate(dtype):
@pytest.mark.parametrize('dtype', [
torch.half, torch.float,
pytest.param(
torch.double,
marks=pytest.mark.skipif(
IS_NPU_AVAILABLE,
reason='NPU does not support for 64-bit floating point'))
])
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_three_interpolate(dtype, device):
features = torch.tensor(
[[[2.4350, 4.7516, 4.4995, 2.4350, 2.4350, 2.4350],
[3.1236, 2.6278, 3.0447, 3.1236, 3.1236, 3.1236],
Expand All @@ -20,12 +36,13 @@ def test_three_interpolate(dtype):
[0.0000, 0.2744, 2.0842, 0.0000, 0.0000, 0.0000],
[0.3414, 1.5063, 1.6209, 0.3414, 0.3414, 0.3414],
[0.5814, 0.0103, 0.0000, 0.5814, 0.5814, 0.5814]]],
dtype=dtype).cuda()
dtype=dtype,
device=device)

idx = torch.tensor([[[0, 1, 2], [2, 3, 4], [2, 3, 4], [0, 1, 2], [0, 1, 2],
[0, 1, 3]],
[[0, 2, 3], [1, 3, 4], [2, 1, 4], [0, 2, 4], [0, 2, 4],
[0, 1, 2]]]).int().cuda()
idx = torch.tensor(
[[[0, 1, 2], [2, 3, 4], [2, 3, 4], [0, 1, 2], [0, 1, 2], [0, 1, 3]],
[[0, 2, 3], [1, 3, 4], [2, 1, 4], [0, 2, 4], [0, 2, 4], [0, 1, 2]]],
device=device).int()

weight = torch.tensor([[[3.3333e-01, 3.3333e-01, 3.3333e-01],
[1.0000e+00, 5.8155e-08, 2.2373e-08],
Expand All @@ -39,7 +56,8 @@ def test_three_interpolate(dtype):
[3.3333e-01, 3.3333e-01, 3.3333e-01],
[3.3333e-01, 3.3333e-01, 3.3333e-01],
[3.3333e-01, 3.3333e-01, 3.3333e-01]]],
dtype=dtype).cuda()
dtype=dtype,
device=device)

output = three_interpolate(features, idx, weight)
expected_output = torch.tensor([[[
Expand Down Expand Up @@ -73,6 +91,7 @@ def test_three_interpolate(dtype):
3.8760e-01, 1.0300e-02, 8.3569e-09,
3.8760e-01, 3.8760e-01, 1.9723e-01
]]],
dtype=dtype).cuda()
dtype=dtype,
device=device)

assert torch.allclose(output, expected_output, 1e-3, 1e-4)

0 comments on commit 94dff26

Please sign in to comment.