-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
915 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
recursive-include csrc *.h |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
/* | ||
batch version of ball query, modified from the original implementation of official PointNet++ codes. | ||
Written by Shaoshuai Shi | ||
All Rights Reserved 2018. | ||
*/ | ||
|
||
|
||
#include <torch/serialize/tensor.h> | ||
#include <vector> | ||
#include <cuda.h> | ||
#include <cuda_runtime_api.h> | ||
#include "ball_query_gpu.h" | ||
|
||
#define CHECK_CUDA(x) do { \ | ||
if (!x.type().is_cuda()) { \ | ||
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \ | ||
exit(-1); \ | ||
} \ | ||
} while (0) | ||
#define CHECK_CONTIGUOUS(x) do { \ | ||
if (!x.is_contiguous()) { \ | ||
fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \ | ||
exit(-1); \ | ||
} \ | ||
} while (0) | ||
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) | ||
|
||
|
||
int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, | ||
at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) { | ||
CHECK_INPUT(new_xyz_tensor); | ||
CHECK_INPUT(xyz_tensor); | ||
const float *new_xyz = new_xyz_tensor.data<float>(); | ||
const float *xyz = xyz_tensor.data<float>(); | ||
int *idx = idx_tensor.data<int>(); | ||
|
||
ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx); | ||
return 1; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
/* | ||
batch version of ball query, modified from the original implementation of official PointNet++ codes. | ||
Written by Shaoshuai Shi | ||
All Rights Reserved 2018. | ||
*/ | ||
|
||
#include <math.h> | ||
#include <stdio.h> | ||
#include <stdlib.h> | ||
|
||
#include "ball_query_gpu.h" | ||
#include "cuda_utils.h" | ||
|
||
|
||
__global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample, | ||
const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) { | ||
// new_xyz: (B, M, 3) | ||
// xyz: (B, N, 3) | ||
// output: | ||
// idx: (B, M, nsample) | ||
int bs_idx = blockIdx.y; | ||
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (bs_idx >= b || pt_idx >= m) return; | ||
|
||
new_xyz += bs_idx * m * 3 + pt_idx * 3; | ||
xyz += bs_idx * n * 3; | ||
idx += bs_idx * m * nsample + pt_idx * nsample; | ||
|
||
float radius2 = radius * radius; | ||
float new_x = new_xyz[0]; | ||
float new_y = new_xyz[1]; | ||
float new_z = new_xyz[2]; | ||
|
||
int cnt = 0; | ||
for (int k = 0; k < n; ++k) { | ||
float x = xyz[k * 3 + 0]; | ||
float y = xyz[k * 3 + 1]; | ||
float z = xyz[k * 3 + 2]; | ||
float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); | ||
if (d2 < radius2){ | ||
if (cnt == 0){ | ||
for (int l = 0; l < nsample; ++l) { | ||
idx[l] = k; | ||
} | ||
} | ||
idx[cnt] = k; | ||
++cnt; | ||
if (cnt >= nsample) break; | ||
} | ||
} | ||
} | ||
|
||
|
||
void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \ | ||
const float *new_xyz, const float *xyz, int *idx) { | ||
// new_xyz: (B, M, 3) | ||
// xyz: (B, N, 3) | ||
// output: | ||
// idx: (B, M, nsample) | ||
|
||
cudaError_t err; | ||
|
||
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) | ||
dim3 threads(THREADS_PER_BLOCK); | ||
|
||
ball_query_kernel_fast<<<blocks, threads>>>(b, n, m, radius, nsample, new_xyz, xyz, idx); | ||
// cudaDeviceSynchronize(); // for using printf in kernel function | ||
err = cudaGetLastError(); | ||
if (cudaSuccess != err) { | ||
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); | ||
exit(-1); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#ifndef _BALL_QUERY_GPU_H | ||
#define _BALL_QUERY_GPU_H | ||
|
||
#include <torch/serialize/tensor.h> | ||
#include <vector> | ||
#include <cuda.h> | ||
#include <cuda_runtime_api.h> | ||
|
||
int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, | ||
at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor); | ||
|
||
void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, | ||
const float *xyz, const float *new_xyz, int *idx); | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#ifndef _CUDA_UTILS_H | ||
#define _CUDA_UTILS_H | ||
|
||
#include <cmath> | ||
|
||
#define TOTAL_THREADS 1024 | ||
#define THREADS_PER_BLOCK 256 | ||
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) | ||
|
||
inline int opt_n_threads(int work_size) { | ||
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0); | ||
|
||
return max(min(1 << pow_2, TOTAL_THREADS), 1); | ||
} | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
/* | ||
batch version of point grouping, modified from the original implementation of official PointNet++ codes. | ||
Written by Shaoshuai Shi | ||
All Rights Reserved 2018. | ||
*/ | ||
|
||
#include <torch/serialize/tensor.h> | ||
#include <cuda.h> | ||
#include <cuda_runtime_api.h> | ||
#include <vector> | ||
#include "group_points_gpu.h" | ||
|
||
int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, | ||
at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) | ||
{ | ||
|
||
float *grad_points = grad_points_tensor.data<float>(); | ||
const int *idx = idx_tensor.data<int>(); | ||
const float *grad_out = grad_out_tensor.data<float>(); | ||
|
||
group_points_grad_kernel_launcher_fast(b, c, n, npoints, nsample, grad_out, idx, grad_points); | ||
return 1; | ||
} | ||
|
||
int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, | ||
at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) | ||
{ | ||
|
||
const float *points = points_tensor.data<float>(); | ||
const int *idx = idx_tensor.data<int>(); | ||
float *out = out_tensor.data<float>(); | ||
|
||
group_points_kernel_launcher_fast(b, c, n, npoints, nsample, points, idx, out); | ||
return 1; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
/* | ||
batch version of point grouping, modified from the original implementation of official PointNet++ codes. | ||
Written by Shaoshuai Shi | ||
All Rights Reserved 2018. | ||
*/ | ||
|
||
#include <stdio.h> | ||
#include <stdlib.h> | ||
|
||
#include "cuda_utils.h" | ||
#include "group_points_gpu.h" | ||
|
||
|
||
__global__ void group_points_grad_kernel_fast(int b, int c, int n, int npoints, int nsample, | ||
const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) { | ||
// grad_out: (B, C, npoints, nsample) | ||
// idx: (B, npoints, nsample) | ||
// output: | ||
// grad_points: (B, C, N) | ||
int bs_idx = blockIdx.z; | ||
int c_idx = blockIdx.y; | ||
int index = blockIdx.x * blockDim.x + threadIdx.x; | ||
int pt_idx = index / nsample; | ||
if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; | ||
|
||
int sample_idx = index % nsample; | ||
grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; | ||
idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; | ||
|
||
atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]); | ||
} | ||
|
||
void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, | ||
const float *grad_out, const int *idx, float *grad_points) { | ||
// grad_out: (B, C, npoints, nsample) | ||
// idx: (B, npoints, nsample) | ||
// output: | ||
// grad_points: (B, C, N) | ||
cudaError_t err; | ||
dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) | ||
dim3 threads(THREADS_PER_BLOCK); | ||
|
||
group_points_grad_kernel_fast<<<blocks, threads>>>(b, c, n, npoints, nsample, grad_out, idx, grad_points); | ||
|
||
err = cudaGetLastError(); | ||
if (cudaSuccess != err) { | ||
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); | ||
exit(-1); | ||
} | ||
} | ||
|
||
|
||
__global__ void group_points_kernel_fast(int b, int c, int n, int npoints, int nsample, | ||
const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { | ||
// points: (B, C, N) | ||
// idx: (B, npoints, nsample) | ||
// output: | ||
// out: (B, C, npoints, nsample) | ||
int bs_idx = blockIdx.z; | ||
int c_idx = blockIdx.y; | ||
int index = blockIdx.x * blockDim.x + threadIdx.x; | ||
int pt_idx = index / nsample; | ||
if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; | ||
|
||
int sample_idx = index % nsample; | ||
|
||
idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; | ||
int in_idx = bs_idx * c * n + c_idx * n + idx[0]; | ||
int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; | ||
|
||
out[out_idx] = points[in_idx]; | ||
} | ||
|
||
|
||
void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, | ||
const float *points, const int *idx, float *out) { | ||
// points: (B, C, N) | ||
// idx: (B, npoints, nsample) | ||
// output: | ||
// out: (B, C, npoints, nsample) | ||
cudaError_t err; | ||
dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) | ||
dim3 threads(THREADS_PER_BLOCK); | ||
|
||
group_points_kernel_fast<<<blocks, threads>>>(b, c, n, npoints, nsample, points, idx, out); | ||
// cudaDeviceSynchronize(); // for using printf in kernel function | ||
err = cudaGetLastError(); | ||
if (cudaSuccess != err) { | ||
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); | ||
exit(-1); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
#ifndef _GROUP_POINTS_GPU_H | ||
#define _GROUP_POINTS_GPU_H | ||
|
||
#include <torch/serialize/tensor.h> | ||
#include <cuda.h> | ||
#include <cuda_runtime_api.h> | ||
#include <vector> | ||
|
||
|
||
int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, | ||
at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); | ||
|
||
void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, | ||
const float *points, const int *idx, float *out); | ||
|
||
int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, | ||
at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); | ||
|
||
void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, | ||
const float *grad_out, const int *idx, float *grad_points); | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
/* | ||
batch version of point interpolation, modified from the original implementation of official PointNet++ codes. | ||
Written by Shaoshuai Shi | ||
All Rights Reserved 2018. | ||
*/ | ||
|
||
|
||
#include <torch/serialize/tensor.h> | ||
#include <vector> | ||
#include <math.h> | ||
#include <stdio.h> | ||
#include <stdlib.h> | ||
#include <cuda.h> | ||
#include <cuda_runtime_api.h> | ||
#include "interpolate_gpu.h" | ||
|
||
|
||
|
||
|
||
void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, | ||
at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) { | ||
const float *unknown = unknown_tensor.data<float>(); | ||
const float *known = known_tensor.data<float>(); | ||
float *dist2 = dist2_tensor.data<float>(); | ||
int *idx = idx_tensor.data<int>(); | ||
|
||
three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx); | ||
} | ||
|
||
|
||
void three_interpolate_wrapper_fast(int b, int c, int m, int n, | ||
at::Tensor points_tensor, | ||
at::Tensor idx_tensor, | ||
at::Tensor weight_tensor, | ||
at::Tensor out_tensor) { | ||
|
||
const float *points = points_tensor.data<float>(); | ||
const float *weight = weight_tensor.data<float>(); | ||
float *out = out_tensor.data<float>(); | ||
const int *idx = idx_tensor.data<int>(); | ||
|
||
three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out); | ||
} | ||
|
||
void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, | ||
at::Tensor grad_out_tensor, | ||
at::Tensor idx_tensor, | ||
at::Tensor weight_tensor, | ||
at::Tensor grad_points_tensor) { | ||
|
||
const float *grad_out = grad_out_tensor.data<float>(); | ||
const float *weight = weight_tensor.data<float>(); | ||
float *grad_points = grad_points_tensor.data<float>(); | ||
const int *idx = idx_tensor.data<int>(); | ||
|
||
three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points); | ||
} |
Oops, something went wrong.