From 26256f88edba1686df48278a0be0877ac03c8517 Mon Sep 17 00:00:00 2001 From: Kentechx Date: Tue, 25 Jul 2023 04:26:23 +0800 Subject: [PATCH] replace pytorch function with cuda ops --- pointnext/ops.py | 98 ++++++++++++++++++++++++++++++++++++++++++ pointnext/pointnext.py | 41 ++++++------------ pointnext/taichi.py | 42 ------------------ setup.py | 2 +- 4 files changed, 113 insertions(+), 70 deletions(-) delete mode 100644 pointnext/taichi.py diff --git a/pointnext/ops.py b/pointnext/ops.py index 25fa92c..fda10c4 100644 --- a/pointnext/ops.py +++ b/pointnext/ops.py @@ -1,4 +1,6 @@ # This is adapted from https://github.com/guochengqian/openpoints/blob/2bc0bf9cb2aee0fcd61f6cdc3abca1207e5e809e/models/layers/subsample.py +from typing import Tuple + import torch from torch.autograd import Function from pointnext import _C @@ -62,3 +64,99 @@ def backward(ctx, a=None): ball_query = BallQuery.apply + + +class ThreeNN(Function): + + @staticmethod + def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Find the three nearest neighbors of unknown in known + :param ctx: + :param unknown: (B, N, 3) + :param known: (B, M, 3) + :return: + dist: (B, N, 3) l2 distance to the three nearest neighbors + idx: (B, N, 3) index of 3 nearest neighbors + """ + assert unknown.is_contiguous() + assert known.is_contiguous() + + B, N, _ = unknown.size() + m = known.size(1) + dist2 = torch.cuda.FloatTensor(B, N, 3) + idx = torch.cuda.IntTensor(B, N, 3) + + _C.three_nn_wrapper(B, N, m, unknown, known, dist2, idx) + return torch.sqrt(dist2), idx + + @staticmethod + def backward(ctx, a=None, b=None): + return None, None + + +three_nn = ThreeNN.apply + + +class ThreeInterpolate(Function): + + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """ + Performs weight linear interpolation on 3 features + :param ctx: + :param features: (B, C, M) Features descriptors to be interpolated from + :param idx: (B, n, 3) three nearest neighbors of the target features in features + :param weight: (B, n, 3) weights + :return: + output: (B, C, N) tensor of the interpolated features + """ + assert features.is_contiguous() + assert idx.is_contiguous() + assert weight.is_contiguous() + + B, c, m = features.size() + n = idx.size(1) + ctx.three_interpolate_for_backward = (idx, weight, m) + output = torch.cuda.FloatTensor(B, c, n) + + _C.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output) + return output + + @staticmethod + def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + :param ctx: + :param grad_out: (B, C, N) tensor with gradients of outputs + :return: + grad_features: (B, C, M) tensor with gradients of features + None: + None: + """ + idx, weight, m = ctx.three_interpolate_for_backward + B, c, n = grad_out.size() + + grad_features = torch.zeros([B, c, m], device='cuda', requires_grad=True) + grad_out_data = grad_out.data.contiguous() + + _C.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data) + return grad_features, None, None + + +three_interpolate = ThreeInterpolate.apply + + +def three_interpolation(known_xyz, know_feat, unknown_xyz): + """ + :param known_xyz: (b, m, 3) + :param know_feat: (b, c, m) + :param unknown_xyz: (b, n, 3) + output: (b, n, c) + """ + dist, idx = three_nn(unknown_xyz, known_xyz) + dist_recip = 1.0 / (dist + 1e-8) + norm = torch.sum(dist_recip, dim=2, keepdim=True) + weight = dist_recip / norm + interpolated_feats = three_interpolate(know_feat, idx, weight) + return interpolated_feats \ No newline at end of file diff --git a/pointnext/pointnext.py b/pointnext/pointnext.py index 7d8ecaf..12311a4 100644 --- a/pointnext/pointnext.py +++ b/pointnext/pointnext.py @@ -6,7 +6,8 @@ import torch.nn.functional as F from einops import repeat, rearrange -from .utils import farthest_point_sampling, ball_query_pytorch +# from .utils import farthest_point_sampling, ball_query_pytorch +from .ops import ball_query, furthest_point_sample, three_interpolation __TAICHI__ = False __KEOPS__ = True @@ -43,21 +44,19 @@ def downsample_fps(xyz, n_sample): sample_idx = torch.arange(n_sample, device=xyz.device) sample_idx = repeat(sample_idx, 'n -> b n', b=xyz.shape[0]) return SampleResult(None, xyz.clone(), sample_idx, None) - _xyz = rearrange(xyz, 'b d n -> b n d') - sample_idx = farthest_point_sampling(_xyz, n_sample, start_idx=0) # (b, k) + _xyz = rearrange(xyz, 'b d n -> b n d').contiguous() + sample_idx = furthest_point_sample(_xyz, n_sample).long() # (b, k) sample_xyz = xyz.gather(-1, repeat(sample_idx, 'b k -> b d k', d=xyz.shape[1])) # (b, 3, k) return SampleResult(None, sample_xyz, sample_idx, None) def _ball_query(src, query, radius, k): # conduct ball query on dim 1 - src = rearrange(src, 'b d n -> b n d') - query = rearrange(query, 'b d m -> b m d') - if __TAICHI__: - from .taichi import ball_query - return ball_query(src, query, radius, k) - else: - return ball_query_pytorch(src, query, radius, k) + src = rearrange(src, 'b d n -> b n d').contiguous() + query = rearrange(query, 'b d m -> b m d').contiguous() + idx = ball_query(src, query, radius, k).long() + dists = None + return idx, dists def cdist(x, y=None): @@ -73,13 +72,6 @@ def cdist(x, y=None): return torch.cdist(x, x) -def knn(src, query, k): - dists = cdist(query, src) # (b, m, n) - idx = dists.topk(k, dim=-1, largest=False, sorted=False)[1] # (b, m, k) - dists = dists.gather(-1, idx) # (b, m, k) - return idx, dists - - def gather(x, idx): # x: (b, d, n) # idx: (b, m, k) @@ -178,6 +170,7 @@ class UpBlock(nn.Module): def __init__(self, in_dim, out_dim, k=3, eps=1e-5): super().__init__() self.k = k + assert k == 3, "only support k=3" self.eps = eps dims = [in_dim, out_dim, out_dim] self.conv = nn.Sequential(*[ @@ -189,16 +182,10 @@ def __init__(self, in_dim, out_dim, k=3, eps=1e-5): def route(self, src_x, src_xyz, dst_x, dst_xyz, neighbor_idx=None, dists=None): # use knn and weighted average to get the features - if not exists(neighbor_idx): - neighbor_idx, dists = knn(src_xyz, dst_xyz, self.k) # (b, m, k) - - weights = 1. / (dists + self.eps) # (b, m, k) - weights = weights / weights.sum(dim=-1, keepdim=True) # (b, m, k) - - neighbor_x = gather(src_x, neighbor_idx) # (b, d, m, k) - neighbor_x = (weights[:, None] * neighbor_x).sum(dim=-1) # (b, d, m) - - dst_x = torch.cat([dst_x, neighbor_x], dim=1) # (b, d+d', m) + src_xyz = rearrange(src_xyz, 'b d n -> b n d').contiguous() + dst_xyz = rearrange(dst_x, 'b d m -> b m d').contiguous() + lerp_x = three_interpolation(src_xyz, src_x, dst_xyz) + dst_x = torch.cat([dst_x, lerp_x], dim=1) # (b, d+d', m) return dst_x def forward(self, x, xyz, sub_x, sub_xyz): diff --git a/pointnext/taichi.py b/pointnext/taichi.py deleted file mode 100644 index 3b55793..0000000 --- a/pointnext/taichi.py +++ /dev/null @@ -1,42 +0,0 @@ -import taichi as ti -import torch - - -@ti.kernel -def _ball_query_kernel( - src: ti.types.ndarray(ndim=3), - query: ti.types.ndarray(ndim=3), - out: ti.types.ndarray(ndim=3), - dists: ti.types.ndarray(ndim=3), - radius: ti.float32, - K: ti.int32 -): - B, M, D = query.shape - N = src.shape[1] - - for b, m in ti.ndrange(B, M): - query_pt = ti.math.vec3(query[b, m, 0], query[b, m, 1], query[b, m, 2]) - - count = 0 - for i in range(N): - if count >= K: - break - src_pt = ti.math.vec3(src[b, i, 0], src[b, i, 1], src[b, i, 2]) - dist = (query_pt - src_pt).norm() - if dist <= radius: - out[b, m, count] = i - dists[b, m, count] = dist - count += 1 - if count == K: - break - - -def ball_query(src: torch.Tensor, query: torch.Tensor, radius, k): - assert src.shape[-1] == 3, "src shape should be (B, N, 3)" - idx = torch.full((*query.shape[:2], k), fill_value=-1, dtype=torch.long, device='cuda') - dists = torch.full((*query.shape[:2], k), fill_value=-1, dtype=src.dtype, device='cuda') - _ball_query_kernel(src.contiguous(), query.contiguous(), idx, dists, radius, k) - mask = idx >= 0 - idx = torch.where(mask, idx, idx[:, :, [0]]) - dists = torch.where(mask, dists, dists[:, :, [0]]) - return idx, dists diff --git a/setup.py b/setup.py index bbe8914..59af67b 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ setup( name='pointnext', packages=find_packages(exclude=("csrc")), - version='0.0.3', + version='0.0.4', license='MIT', description='PointNext - Pytorch', author='Kaidi Shen',