From 5c13cd133056e1f6c6ece1036efbc91194ee645a Mon Sep 17 00:00:00 2001 From: Kentechx Date: Thu, 20 Jul 2023 01:45:48 +0800 Subject: [PATCH] init --- pointnext/__init__.py | 1 + pointnext/pointnext.py | 300 +++++++++++++++++++++++++++++++++++++++++ pointnext/taichi.py | 42 ++++++ pointnext/utils.py | 49 +++++++ setup.py | 29 ++++ 5 files changed, 421 insertions(+) create mode 100644 pointnext/__init__.py create mode 100644 pointnext/pointnext.py create mode 100644 pointnext/taichi.py create mode 100644 pointnext/utils.py create mode 100644 setup.py diff --git a/pointnext/__init__.py b/pointnext/__init__.py new file mode 100644 index 0000000..ee3e43d --- /dev/null +++ b/pointnext/__init__.py @@ -0,0 +1 @@ +from .pointnext import PointNextEncoder, PointNextDecoder, PointNext \ No newline at end of file diff --git a/pointnext/pointnext.py b/pointnext/pointnext.py new file mode 100644 index 0000000..7d8ecaf --- /dev/null +++ b/pointnext/pointnext.py @@ -0,0 +1,300 @@ +from collections import namedtuple +from typing import Union +from pykeops.torch import LazyTensor +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import repeat, rearrange +from .utils import farthest_point_sampling, ball_query_pytorch + +__TAICHI__ = False +__KEOPS__ = True + + +def enable_taichi(): + import taichi as ti + global __TAICHI__ + __TAICHI__ = True + ti.init(ti.cuda) + + +def disable_keops(): + global __KEOPS__ + __KEOPS__ = False + + +def exists(val): + return val is not None + + +def default(*vals): + for val in vals: + if exists(val): + return val + + +SampleResult = namedtuple('SampleResult', ['x', 'xyz', 'sample_idx', 'neighbor_idx']) + + +def downsample_fps(xyz, n_sample): + # xyz: (b, 3, n) + if n_sample == xyz.shape[-1]: + sample_idx = torch.arange(n_sample, device=xyz.device) + sample_idx = repeat(sample_idx, 'n -> b n', b=xyz.shape[0]) + return SampleResult(None, xyz.clone(), sample_idx, None) + _xyz = rearrange(xyz, 'b d n -> b n d') + sample_idx = farthest_point_sampling(_xyz, n_sample, start_idx=0) # (b, k) + sample_xyz = xyz.gather(-1, repeat(sample_idx, 'b k -> b d k', d=xyz.shape[1])) # (b, 3, k) + return SampleResult(None, sample_xyz, sample_idx, None) + + +def _ball_query(src, query, radius, k): + # conduct ball query on dim 1 + src = rearrange(src, 'b d n -> b n d') + query = rearrange(query, 'b d m -> b m d') + if __TAICHI__: + from .taichi import ball_query + return ball_query(src, query, radius, k) + else: + return ball_query_pytorch(src, query, radius, k) + + +def cdist(x, y=None): + # perform cdist in dimension 1 + # x: (b, d, n) + # y: (b, d, m) + if exists(y): + x = rearrange(x, 'b d n -> b n d') + y = rearrange(y, 'b d m -> b m d') + return torch.cdist(x, y) + else: + x = rearrange(x, 'b d n -> b n d') + return torch.cdist(x, x) + + +def knn(src, query, k): + dists = cdist(query, src) # (b, m, n) + idx = dists.topk(k, dim=-1, largest=False, sorted=False)[1] # (b, m, k) + dists = dists.gather(-1, idx) # (b, m, k) + return idx, dists + + +def gather(x, idx): + # x: (b, d, n) + # idx: (b, m, k) + # output: (b, d, m, k) + m = idx.shape[1] + ind = repeat(idx, 'b m k -> b d (m k)', d=x.shape[1]) + out = x.gather(-1, ind) # (b, d, (m k)) + out = rearrange(out, 'b d (m k) -> b d m k', m=m) + return out + + +class SABlock(nn.Module): + """ + Set abstraction block without downsampling. + """ + + def __init__(self, in_dim, out_dim, stride=1, layers=1, radius=0.1, k=16): + super().__init__() + self.stride = stride + self.radius = radius + self.layers = layers + self.k = k + + dims = [in_dim + 3] + [out_dim] * layers + + if layers == 1: + self.convs = nn.Conv2d(dims[0], dims[1], 1, bias=False) + self.norm = nn.BatchNorm1d(out_dim) + self.act = nn.ReLU() + else: + self.skip_conv = nn.Conv1d(in_dim, out_dim, 1, bias=False) if in_dim != out_dim else nn.Identity() + self.convs = nn.Sequential(*[ + nn.Sequential(nn.Conv2d(in_d, out_d, 1, bias=False), + nn.BatchNorm2d(out_d), + nn.ReLU()) + for in_d, out_d in zip(dims[:-2], dims[1:-1]) + ]) + self.convs.append(nn.Conv2d(dims[-2], dims[-1], 1, bias=False)) + self.norm = nn.BatchNorm1d(out_dim) + self.act = nn.ReLU() + + def route(self, src_x, src_xyz, xyz, radius, k, neighbor_idx=None): + # src_x: (b, d, n) + # src_xyz: (b, 3, n) + # xyz: (b, 3, m) + if not exists(neighbor_idx): + neighbor_idx = _ball_query(src_xyz, xyz, radius, k)[0] # (b, m, k) + neighbor_xyz = gather(src_xyz, neighbor_idx) # (b, 3, m, k) + neighbor_xyz -= xyz[..., None] + neighbor_xyz /= radius + x = gather(src_x, neighbor_idx) # (b, d, m, k) + x = torch.cat([x, neighbor_xyz], dim=1) # (b, d+3, m, k) + return SampleResult(x, xyz, None, neighbor_idx) + + def forward(self, x, xyz): + # x: (b, d, n) + # xyz: (b, 3, n) + # out: (b, d', n // stride) + sample = downsample_fps(xyz, n_sample=xyz.shape[-1] // self.stride) + inputs = x.gather(-1, repeat(sample.sample_idx, 'b k -> b d k', d=x.shape[1])) + sample = self.route(x, xyz, sample.xyz, self.radius, self.k) + x = self.convs(sample.x) + x = x.max(dim=-1)[0] + if hasattr(self, 'skip_conv'): + x = self.skip_conv(inputs) + x + x = self.act(self.norm(x)) + return SampleResult(x, sample.xyz, sample.sample_idx, sample.neighbor_idx) + + +class InvResMLP(nn.Module): + + def __init__(self, in_dim, expansion=4, radius=0.1, k=16): + super().__init__() + self.sa_conv = SABlock(in_dim, in_dim, stride=1, layers=1, radius=radius, k=k) + + dims = [in_dim, in_dim * expansion, in_dim] + self.conv = nn.Sequential( + nn.Conv1d(dims[0], dims[1], 1, bias=False), + nn.BatchNorm1d(dims[1]), + nn.ReLU(), + nn.Conv1d(dims[1], dims[2], 1, bias=False), + nn.BatchNorm1d(dims[2]) + ) + self.act = nn.ReLU() + + def forward(self, x, xyz): + inputs = x + x = self.sa_conv(x, xyz).x + x = self.conv(x) + x = self.act(inputs + x) + return x + + +class UpBlock(nn.Module): + + def __init__(self, in_dim, out_dim, k=3, eps=1e-5): + super().__init__() + self.k = k + self.eps = eps + dims = [in_dim, out_dim, out_dim] + self.conv = nn.Sequential(*[ + nn.Sequential(nn.Conv1d(in_d, out_d, 1, bias=False), + nn.BatchNorm1d(out_d), + nn.ReLU()) + for in_d, out_d in zip(dims[:-1], dims[1:]) + ]) + + def route(self, src_x, src_xyz, dst_x, dst_xyz, neighbor_idx=None, dists=None): + # use knn and weighted average to get the features + if not exists(neighbor_idx): + neighbor_idx, dists = knn(src_xyz, dst_xyz, self.k) # (b, m, k) + + weights = 1. / (dists + self.eps) # (b, m, k) + weights = weights / weights.sum(dim=-1, keepdim=True) # (b, m, k) + + neighbor_x = gather(src_x, neighbor_idx) # (b, d, m, k) + neighbor_x = (weights[:, None] * neighbor_x).sum(dim=-1) # (b, d, m) + + dst_x = torch.cat([dst_x, neighbor_x], dim=1) # (b, d+d', m) + return dst_x + + def forward(self, x, xyz, sub_x, sub_xyz): + x = self.route(sub_x, sub_xyz, x, xyz) + x = self.conv(x) + return x + + +class PointNextEncoder(nn.Module): + + def __init__( + self, + in_dim=3, + dims=[32, 64, 128, 256, 512], # dims[0] is the dim of the stem output + blocks=[4, 7, 4, 4], # blocks: sa + invres + strides=[4, 4, 4, 4], + radius=0.1, + k=32, + sa_layers=1, + ): + super().__init__() + + self.stem = nn.Sequential( + nn.Conv1d(in_dim, dims[0], 1, bias=False), + nn.BatchNorm1d(dims[0]), + nn.ReLU() + ) + + radius_scaling = 2 + radii = [radius * (radius_scaling ** i) for i in range(len(blocks))] + self.encoder = nn.ModuleList() + for i in range(len(blocks)): + layers = nn.Sequential( + SABlock(dims[i], dims[i + 1], stride=strides[i], layers=sa_layers, radius=radii[i], k=k), + *[InvResMLP(dims[i + 1], radius=radii[i] * radius_scaling, k=k) for _ in range(blocks[i] - 1)] + ) + self.encoder.append(layers) + + self.out_dim = dims[-1] + + def forward_features(self, x, xyz): + x = self.stem(x) + features = [(x, xyz)] + for block in self.encoder: + sample = block[0](x, xyz) + x, xyz = sample.x, sample.xyz + for layer in block[1:]: + x = layer(x, xyz) + features.append((x, xyz)) + return features + + def forward(self, x, xyz): + return self.forward_features(x, xyz) + + +class PointNextDecoder(nn.Module): + + def __init__(self, encoder_dims=[32, 64, 128, 256, 512]): + super().__init__() + self.decoder = nn.ModuleList() + + decoder_dims = encoder_dims[::-1] + for i in range(len(decoder_dims) - 1): + self.decoder.append(UpBlock(decoder_dims[i] + decoder_dims[i + 1], decoder_dims[i + 1])) + + self.out_dim = decoder_dims[-1] + + def forward(self, feats): + sub_x, sub_xyz = feats.pop() + for i, block in enumerate(self.decoder): + x, xyz = feats.pop() + x = block(x, xyz, sub_x, sub_xyz) + sub_x, sub_xyz = x, xyz + return x + + +class PointNext(nn.Module): + + def __init__(self, out_dim, encoder: PointNextEncoder, decoder: PointNextDecoder = None, n_category=0): + super().__init__() + self.encoder = encoder + self.decoder = decoder + feat_dim = decoder.out_dim if exists(decoder) else encoder.out_dim + + if n_category > 0: + self.category_emb = nn.Embedding(n_category, feat_dim) + + self.head = nn.Conv1d(feat_dim, out_dim, 1) + + def forward(self, x, xyz, category=None): + feats = self.encoder(x, xyz) + if exists(self.decoder): + out = self.decoder(feats) + else: + out = feats[-1][0] + if exists(category): + out = out + self.category_emb(category)[:, :, None] + out = self.head(out) + return out diff --git a/pointnext/taichi.py b/pointnext/taichi.py new file mode 100644 index 0000000..3b55793 --- /dev/null +++ b/pointnext/taichi.py @@ -0,0 +1,42 @@ +import taichi as ti +import torch + + +@ti.kernel +def _ball_query_kernel( + src: ti.types.ndarray(ndim=3), + query: ti.types.ndarray(ndim=3), + out: ti.types.ndarray(ndim=3), + dists: ti.types.ndarray(ndim=3), + radius: ti.float32, + K: ti.int32 +): + B, M, D = query.shape + N = src.shape[1] + + for b, m in ti.ndrange(B, M): + query_pt = ti.math.vec3(query[b, m, 0], query[b, m, 1], query[b, m, 2]) + + count = 0 + for i in range(N): + if count >= K: + break + src_pt = ti.math.vec3(src[b, i, 0], src[b, i, 1], src[b, i, 2]) + dist = (query_pt - src_pt).norm() + if dist <= radius: + out[b, m, count] = i + dists[b, m, count] = dist + count += 1 + if count == K: + break + + +def ball_query(src: torch.Tensor, query: torch.Tensor, radius, k): + assert src.shape[-1] == 3, "src shape should be (B, N, 3)" + idx = torch.full((*query.shape[:2], k), fill_value=-1, dtype=torch.long, device='cuda') + dists = torch.full((*query.shape[:2], k), fill_value=-1, dtype=src.dtype, device='cuda') + _ball_query_kernel(src.contiguous(), query.contiguous(), idx, dists, radius, k) + mask = idx >= 0 + idx = torch.where(mask, idx, idx[:, :, [0]]) + dists = torch.where(mask, dists, dists[:, :, [0]]) + return idx, dists diff --git a/pointnext/utils.py b/pointnext/utils.py new file mode 100644 index 0000000..a0ad059 --- /dev/null +++ b/pointnext/utils.py @@ -0,0 +1,49 @@ +import torch +from einops import rearrange, repeat + + +def exists(val): + return val is not None + + +def farthest_point_sampling(x: torch.Tensor, n_sample: int, start_idx: int = None): + # x: (b, n, 3) + b, n = x.shape[:2] + assert n_sample <= n, "not enough points to sample" + + if n_sample == n: + return repeat(torch.arange(n_sample, dtype=torch.long, device=x.device), 'm -> b m', b=b) + + # start index + if exists(start_idx): + sel_idx = torch.full((b, n_sample), start_idx, dtype=torch.long, device=x.device) + else: + sel_idx = torch.randint(n, (b, n_sample), dtype=torch.long, device=x.device) + + cur_x = rearrange(x[torch.arange(b), sel_idx[:, 0]], 'b c -> b 1 c') + min_dists = torch.full((b, n), dtype=x.dtype, device=x.device, fill_value=float('inf')) + for i in range(1, n_sample): + # update distance + dists = torch.linalg.norm(x - cur_x, dim=-1) + min_dists = torch.minimum(dists, min_dists) + + # take the farthest + idx_farthest = torch.max(min_dists, dim=-1).indices + sel_idx[:, i] = idx_farthest + cur_x[:, 0, :] = x[torch.arange(b), idx_farthest] + + return sel_idx + + +def ball_query_pytorch(src, query, radius, k): + # src: (b, n, 3) + # query: (b, m, 3) + b, n = src.shape[:2] + m = query.shape[1] + dists = torch.cdist(query, src) # (b, m, n) + idx = repeat(torch.arange(n, device=src.device), 'n -> b m n', b=b, m=m) + idx = torch.where(dists > radius, n, idx) + idx = idx.sort(dim=-1).values[:, :, :k] # (b, m, k) + idx = torch.where(idx == n, idx[:, :, [0]], idx) + _dists = dists.gather(-1, idx) # (b, m, k) + return idx, _dists diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..7dec03f --- /dev/null +++ b/setup.py @@ -0,0 +1,29 @@ +from setuptools import setup, find_packages + +setup( + name='pointnext', + packages=find_packages(), + version='0.0.1', + license='MIT', + description='PointNext - Pytorch', + author='Kaidi Shen', + url='https://github.com/kentechx/pointnext', + long_description_content_type='text/markdown', + keywords=[ + '3D segmentation', + '3D classification', + 'point cloud understanding', + ], + install_requires=[ + 'torch>=1.10', + 'einops>=0.6.1', + 'pykeops>=2.1.2', + ], + classifiers=[ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 3.6', + ], +)