Skip to content

Commit

Permalink
replace pytorch function with cuda ops
Browse files Browse the repository at this point in the history
  • Loading branch information
kentechx committed Jul 24, 2023
1 parent 8514896 commit 26256f8
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 70 deletions.
98 changes: 98 additions & 0 deletions pointnext/ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# This is adapted from https://github.com/guochengqian/openpoints/blob/2bc0bf9cb2aee0fcd61f6cdc3abca1207e5e809e/models/layers/subsample.py
from typing import Tuple

import torch
from torch.autograd import Function
from pointnext import _C
Expand Down Expand Up @@ -62,3 +64,99 @@ def backward(ctx, a=None):


ball_query = BallQuery.apply


class ThreeNN(Function):

@staticmethod
def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Find the three nearest neighbors of unknown in known
:param ctx:
:param unknown: (B, N, 3)
:param known: (B, M, 3)
:return:
dist: (B, N, 3) l2 distance to the three nearest neighbors
idx: (B, N, 3) index of 3 nearest neighbors
"""
assert unknown.is_contiguous()
assert known.is_contiguous()

B, N, _ = unknown.size()
m = known.size(1)
dist2 = torch.cuda.FloatTensor(B, N, 3)
idx = torch.cuda.IntTensor(B, N, 3)

_C.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
return torch.sqrt(dist2), idx

@staticmethod
def backward(ctx, a=None, b=None):
return None, None


three_nn = ThreeNN.apply


class ThreeInterpolate(Function):

@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""
Performs weight linear interpolation on 3 features
:param ctx:
:param features: (B, C, M) Features descriptors to be interpolated from
:param idx: (B, n, 3) three nearest neighbors of the target features in features
:param weight: (B, n, 3) weights
:return:
output: (B, C, N) tensor of the interpolated features
"""
assert features.is_contiguous()
assert idx.is_contiguous()
assert weight.is_contiguous()

B, c, m = features.size()
n = idx.size(1)
ctx.three_interpolate_for_backward = (idx, weight, m)
output = torch.cuda.FloatTensor(B, c, n)

_C.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output)
return output

@staticmethod
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
:param ctx:
:param grad_out: (B, C, N) tensor with gradients of outputs
:return:
grad_features: (B, C, M) tensor with gradients of features
None:
None:
"""
idx, weight, m = ctx.three_interpolate_for_backward
B, c, n = grad_out.size()

grad_features = torch.zeros([B, c, m], device='cuda', requires_grad=True)
grad_out_data = grad_out.data.contiguous()

_C.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data)
return grad_features, None, None


three_interpolate = ThreeInterpolate.apply


def three_interpolation(known_xyz, know_feat, unknown_xyz):
"""
:param known_xyz: (b, m, 3)
:param know_feat: (b, c, m)
:param unknown_xyz: (b, n, 3)
output: (b, n, c)
"""
dist, idx = three_nn(unknown_xyz, known_xyz)
dist_recip = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
interpolated_feats = three_interpolate(know_feat, idx, weight)
return interpolated_feats
41 changes: 14 additions & 27 deletions pointnext/pointnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import torch.nn.functional as F

from einops import repeat, rearrange
from .utils import farthest_point_sampling, ball_query_pytorch
# from .utils import farthest_point_sampling, ball_query_pytorch
from .ops import ball_query, furthest_point_sample, three_interpolation

__TAICHI__ = False
__KEOPS__ = True
Expand Down Expand Up @@ -43,21 +44,19 @@ def downsample_fps(xyz, n_sample):
sample_idx = torch.arange(n_sample, device=xyz.device)
sample_idx = repeat(sample_idx, 'n -> b n', b=xyz.shape[0])
return SampleResult(None, xyz.clone(), sample_idx, None)
_xyz = rearrange(xyz, 'b d n -> b n d')
sample_idx = farthest_point_sampling(_xyz, n_sample, start_idx=0) # (b, k)
_xyz = rearrange(xyz, 'b d n -> b n d').contiguous()
sample_idx = furthest_point_sample(_xyz, n_sample).long() # (b, k)
sample_xyz = xyz.gather(-1, repeat(sample_idx, 'b k -> b d k', d=xyz.shape[1])) # (b, 3, k)
return SampleResult(None, sample_xyz, sample_idx, None)


def _ball_query(src, query, radius, k):
# conduct ball query on dim 1
src = rearrange(src, 'b d n -> b n d')
query = rearrange(query, 'b d m -> b m d')
if __TAICHI__:
from .taichi import ball_query
return ball_query(src, query, radius, k)
else:
return ball_query_pytorch(src, query, radius, k)
src = rearrange(src, 'b d n -> b n d').contiguous()
query = rearrange(query, 'b d m -> b m d').contiguous()
idx = ball_query(src, query, radius, k).long()
dists = None
return idx, dists


def cdist(x, y=None):
Expand All @@ -73,13 +72,6 @@ def cdist(x, y=None):
return torch.cdist(x, x)


def knn(src, query, k):
dists = cdist(query, src) # (b, m, n)
idx = dists.topk(k, dim=-1, largest=False, sorted=False)[1] # (b, m, k)
dists = dists.gather(-1, idx) # (b, m, k)
return idx, dists


def gather(x, idx):
# x: (b, d, n)
# idx: (b, m, k)
Expand Down Expand Up @@ -178,6 +170,7 @@ class UpBlock(nn.Module):
def __init__(self, in_dim, out_dim, k=3, eps=1e-5):
super().__init__()
self.k = k
assert k == 3, "only support k=3"
self.eps = eps
dims = [in_dim, out_dim, out_dim]
self.conv = nn.Sequential(*[
Expand All @@ -189,16 +182,10 @@ def __init__(self, in_dim, out_dim, k=3, eps=1e-5):

def route(self, src_x, src_xyz, dst_x, dst_xyz, neighbor_idx=None, dists=None):
# use knn and weighted average to get the features
if not exists(neighbor_idx):
neighbor_idx, dists = knn(src_xyz, dst_xyz, self.k) # (b, m, k)

weights = 1. / (dists + self.eps) # (b, m, k)
weights = weights / weights.sum(dim=-1, keepdim=True) # (b, m, k)

neighbor_x = gather(src_x, neighbor_idx) # (b, d, m, k)
neighbor_x = (weights[:, None] * neighbor_x).sum(dim=-1) # (b, d, m)

dst_x = torch.cat([dst_x, neighbor_x], dim=1) # (b, d+d', m)
src_xyz = rearrange(src_xyz, 'b d n -> b n d').contiguous()
dst_xyz = rearrange(dst_x, 'b d m -> b m d').contiguous()
lerp_x = three_interpolation(src_xyz, src_x, dst_xyz)
dst_x = torch.cat([dst_x, lerp_x], dim=1) # (b, d+d', m)
return dst_x

def forward(self, x, xyz, sub_x, sub_xyz):
Expand Down
42 changes: 0 additions & 42 deletions pointnext/taichi.py

This file was deleted.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
setup(
name='pointnext',
packages=find_packages(exclude=("csrc")),
version='0.0.3',
version='0.0.4',
license='MIT',
description='PointNext - Pytorch',
author='Kaidi Shen',
Expand Down

0 comments on commit 26256f8

Please sign in to comment.