Skip to content


Browse files Browse the repository at this point in the history
  • Loading branch information
kentechx committed Jul 19, 2023
1 parent 9748350 commit 5c13cd1
Show file tree
Hide file tree
Showing 5 changed files with 421 additions and 0 deletions.
1 change: 1 addition & 0 deletions pointnext/
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .pointnext import PointNextEncoder, PointNextDecoder, PointNext
300 changes: 300 additions & 0 deletions pointnext/
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
from collections import namedtuple
from typing import Union
from pykeops.torch import LazyTensor
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import repeat, rearrange
from .utils import farthest_point_sampling, ball_query_pytorch

__TAICHI__ = False
__KEOPS__ = True

def enable_taichi():
import taichi as ti
global __TAICHI__
__TAICHI__ = True

def disable_keops():
global __KEOPS__
__KEOPS__ = False

def exists(val):
return val is not None

def default(*vals):
for val in vals:
if exists(val):
return val

SampleResult = namedtuple('SampleResult', ['x', 'xyz', 'sample_idx', 'neighbor_idx'])

def downsample_fps(xyz, n_sample):
# xyz: (b, 3, n)
if n_sample == xyz.shape[-1]:
sample_idx = torch.arange(n_sample, device=xyz.device)
sample_idx = repeat(sample_idx, 'n -> b n', b=xyz.shape[0])
return SampleResult(None, xyz.clone(), sample_idx, None)
_xyz = rearrange(xyz, 'b d n -> b n d')
sample_idx = farthest_point_sampling(_xyz, n_sample, start_idx=0) # (b, k)
sample_xyz = xyz.gather(-1, repeat(sample_idx, 'b k -> b d k', d=xyz.shape[1])) # (b, 3, k)
return SampleResult(None, sample_xyz, sample_idx, None)

def _ball_query(src, query, radius, k):
# conduct ball query on dim 1
src = rearrange(src, 'b d n -> b n d')
query = rearrange(query, 'b d m -> b m d')
if __TAICHI__:
from .taichi import ball_query
return ball_query(src, query, radius, k)
return ball_query_pytorch(src, query, radius, k)

def cdist(x, y=None):
# perform cdist in dimension 1
# x: (b, d, n)
# y: (b, d, m)
if exists(y):
x = rearrange(x, 'b d n -> b n d')
y = rearrange(y, 'b d m -> b m d')
return torch.cdist(x, y)
x = rearrange(x, 'b d n -> b n d')
return torch.cdist(x, x)

def knn(src, query, k):
dists = cdist(query, src) # (b, m, n)
idx = dists.topk(k, dim=-1, largest=False, sorted=False)[1] # (b, m, k)
dists = dists.gather(-1, idx) # (b, m, k)
return idx, dists

def gather(x, idx):
# x: (b, d, n)
# idx: (b, m, k)
# output: (b, d, m, k)
m = idx.shape[1]
ind = repeat(idx, 'b m k -> b d (m k)', d=x.shape[1])
out = x.gather(-1, ind) # (b, d, (m k))
out = rearrange(out, 'b d (m k) -> b d m k', m=m)
return out

class SABlock(nn.Module):
Set abstraction block without downsampling.

def __init__(self, in_dim, out_dim, stride=1, layers=1, radius=0.1, k=16):
self.stride = stride
self.radius = radius
self.layers = layers
self.k = k

dims = [in_dim + 3] + [out_dim] * layers

if layers == 1:
self.convs = nn.Conv2d(dims[0], dims[1], 1, bias=False)
self.norm = nn.BatchNorm1d(out_dim)
self.act = nn.ReLU()
self.skip_conv = nn.Conv1d(in_dim, out_dim, 1, bias=False) if in_dim != out_dim else nn.Identity()
self.convs = nn.Sequential(*[
nn.Sequential(nn.Conv2d(in_d, out_d, 1, bias=False),
for in_d, out_d in zip(dims[:-2], dims[1:-1])
self.convs.append(nn.Conv2d(dims[-2], dims[-1], 1, bias=False))
self.norm = nn.BatchNorm1d(out_dim)
self.act = nn.ReLU()

def route(self, src_x, src_xyz, xyz, radius, k, neighbor_idx=None):
# src_x: (b, d, n)
# src_xyz: (b, 3, n)
# xyz: (b, 3, m)
if not exists(neighbor_idx):
neighbor_idx = _ball_query(src_xyz, xyz, radius, k)[0] # (b, m, k)
neighbor_xyz = gather(src_xyz, neighbor_idx) # (b, 3, m, k)
neighbor_xyz -= xyz[..., None]
neighbor_xyz /= radius
x = gather(src_x, neighbor_idx) # (b, d, m, k)
x =[x, neighbor_xyz], dim=1) # (b, d+3, m, k)
return SampleResult(x, xyz, None, neighbor_idx)

def forward(self, x, xyz):
# x: (b, d, n)
# xyz: (b, 3, n)
# out: (b, d', n // stride)
sample = downsample_fps(xyz, n_sample=xyz.shape[-1] // self.stride)
inputs = x.gather(-1, repeat(sample.sample_idx, 'b k -> b d k', d=x.shape[1]))
sample = self.route(x, xyz,, self.radius, self.k)
x = self.convs(sample.x)
x = x.max(dim=-1)[0]
if hasattr(self, 'skip_conv'):
x = self.skip_conv(inputs) + x
x = self.act(self.norm(x))
return SampleResult(x,, sample.sample_idx, sample.neighbor_idx)

class InvResMLP(nn.Module):

def __init__(self, in_dim, expansion=4, radius=0.1, k=16):
self.sa_conv = SABlock(in_dim, in_dim, stride=1, layers=1, radius=radius, k=k)

dims = [in_dim, in_dim * expansion, in_dim]
self.conv = nn.Sequential(
nn.Conv1d(dims[0], dims[1], 1, bias=False),
nn.Conv1d(dims[1], dims[2], 1, bias=False),
self.act = nn.ReLU()

def forward(self, x, xyz):
inputs = x
x = self.sa_conv(x, xyz).x
x = self.conv(x)
x = self.act(inputs + x)
return x

class UpBlock(nn.Module):

def __init__(self, in_dim, out_dim, k=3, eps=1e-5):
self.k = k
self.eps = eps
dims = [in_dim, out_dim, out_dim]
self.conv = nn.Sequential(*[
nn.Sequential(nn.Conv1d(in_d, out_d, 1, bias=False),
for in_d, out_d in zip(dims[:-1], dims[1:])

def route(self, src_x, src_xyz, dst_x, dst_xyz, neighbor_idx=None, dists=None):
# use knn and weighted average to get the features
if not exists(neighbor_idx):
neighbor_idx, dists = knn(src_xyz, dst_xyz, self.k) # (b, m, k)

weights = 1. / (dists + self.eps) # (b, m, k)
weights = weights / weights.sum(dim=-1, keepdim=True) # (b, m, k)

neighbor_x = gather(src_x, neighbor_idx) # (b, d, m, k)
neighbor_x = (weights[:, None] * neighbor_x).sum(dim=-1) # (b, d, m)

dst_x =[dst_x, neighbor_x], dim=1) # (b, d+d', m)
return dst_x

def forward(self, x, xyz, sub_x, sub_xyz):
x = self.route(sub_x, sub_xyz, x, xyz)
x = self.conv(x)
return x

class PointNextEncoder(nn.Module):

def __init__(
dims=[32, 64, 128, 256, 512], # dims[0] is the dim of the stem output
blocks=[4, 7, 4, 4], # blocks: sa + invres
strides=[4, 4, 4, 4],

self.stem = nn.Sequential(
nn.Conv1d(in_dim, dims[0], 1, bias=False),

radius_scaling = 2
radii = [radius * (radius_scaling ** i) for i in range(len(blocks))]
self.encoder = nn.ModuleList()
for i in range(len(blocks)):
layers = nn.Sequential(
SABlock(dims[i], dims[i + 1], stride=strides[i], layers=sa_layers, radius=radii[i], k=k),
*[InvResMLP(dims[i + 1], radius=radii[i] * radius_scaling, k=k) for _ in range(blocks[i] - 1)]

self.out_dim = dims[-1]

def forward_features(self, x, xyz):
x = self.stem(x)
features = [(x, xyz)]
for block in self.encoder:
sample = block[0](x, xyz)
x, xyz = sample.x,
for layer in block[1:]:
x = layer(x, xyz)
features.append((x, xyz))
return features

def forward(self, x, xyz):
return self.forward_features(x, xyz)

class PointNextDecoder(nn.Module):

def __init__(self, encoder_dims=[32, 64, 128, 256, 512]):
self.decoder = nn.ModuleList()

decoder_dims = encoder_dims[::-1]
for i in range(len(decoder_dims) - 1):
self.decoder.append(UpBlock(decoder_dims[i] + decoder_dims[i + 1], decoder_dims[i + 1]))

self.out_dim = decoder_dims[-1]

def forward(self, feats):
sub_x, sub_xyz = feats.pop()
for i, block in enumerate(self.decoder):
x, xyz = feats.pop()
x = block(x, xyz, sub_x, sub_xyz)
sub_x, sub_xyz = x, xyz
return x

class PointNext(nn.Module):

def __init__(self, out_dim, encoder: PointNextEncoder, decoder: PointNextDecoder = None, n_category=0):
self.encoder = encoder
self.decoder = decoder
feat_dim = decoder.out_dim if exists(decoder) else encoder.out_dim

if n_category > 0:
self.category_emb = nn.Embedding(n_category, feat_dim)

self.head = nn.Conv1d(feat_dim, out_dim, 1)

def forward(self, x, xyz, category=None):
feats = self.encoder(x, xyz)
if exists(self.decoder):
out = self.decoder(feats)
out = feats[-1][0]
if exists(category):
out = out + self.category_emb(category)[:, :, None]
out = self.head(out)
return out
42 changes: 42 additions & 0 deletions pointnext/
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import taichi as ti
import torch

def _ball_query_kernel(
src: ti.types.ndarray(ndim=3),
query: ti.types.ndarray(ndim=3),
out: ti.types.ndarray(ndim=3),
dists: ti.types.ndarray(ndim=3),
radius: ti.float32,
K: ti.int32
B, M, D = query.shape
N = src.shape[1]

for b, m in ti.ndrange(B, M):
query_pt = ti.math.vec3(query[b, m, 0], query[b, m, 1], query[b, m, 2])

count = 0
for i in range(N):
if count >= K:
src_pt = ti.math.vec3(src[b, i, 0], src[b, i, 1], src[b, i, 2])
dist = (query_pt - src_pt).norm()
if dist <= radius:
out[b, m, count] = i
dists[b, m, count] = dist
count += 1
if count == K:

def ball_query(src: torch.Tensor, query: torch.Tensor, radius, k):
assert src.shape[-1] == 3, "src shape should be (B, N, 3)"
idx = torch.full((*query.shape[:2], k), fill_value=-1, dtype=torch.long, device='cuda')
dists = torch.full((*query.shape[:2], k), fill_value=-1, dtype=src.dtype, device='cuda')
_ball_query_kernel(src.contiguous(), query.contiguous(), idx, dists, radius, k)
mask = idx >= 0
idx = torch.where(mask, idx, idx[:, :, [0]])
dists = torch.where(mask, dists, dists[:, :, [0]])
return idx, dists
49 changes: 49 additions & 0 deletions pointnext/
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
from einops import rearrange, repeat

def exists(val):
return val is not None

def farthest_point_sampling(x: torch.Tensor, n_sample: int, start_idx: int = None):
# x: (b, n, 3)
b, n = x.shape[:2]
assert n_sample <= n, "not enough points to sample"

if n_sample == n:
return repeat(torch.arange(n_sample, dtype=torch.long, device=x.device), 'm -> b m', b=b)

# start index
if exists(start_idx):
sel_idx = torch.full((b, n_sample), start_idx, dtype=torch.long, device=x.device)
sel_idx = torch.randint(n, (b, n_sample), dtype=torch.long, device=x.device)

cur_x = rearrange(x[torch.arange(b), sel_idx[:, 0]], 'b c -> b 1 c')
min_dists = torch.full((b, n), dtype=x.dtype, device=x.device, fill_value=float('inf'))
for i in range(1, n_sample):
# update distance
dists = torch.linalg.norm(x - cur_x, dim=-1)
min_dists = torch.minimum(dists, min_dists)

# take the farthest
idx_farthest = torch.max(min_dists, dim=-1).indices
sel_idx[:, i] = idx_farthest
cur_x[:, 0, :] = x[torch.arange(b), idx_farthest]

return sel_idx

def ball_query_pytorch(src, query, radius, k):
# src: (b, n, 3)
# query: (b, m, 3)
b, n = src.shape[:2]
m = query.shape[1]
dists = torch.cdist(query, src) # (b, m, n)
idx = repeat(torch.arange(n, device=src.device), 'n -> b m n', b=b, m=m)
idx = torch.where(dists > radius, n, idx)
idx = idx.sort(dim=-1).values[:, :, :k] # (b, m, k)
idx = torch.where(idx == n, idx[:, :, [0]], idx)
_dists = dists.gather(-1, idx) # (b, m, k)
return idx, _dists

0 comments on commit 5c13cd1

Please sign in to comment.