Skip to content

Commit 2d4d345

Browse files
jsukparkmeta-codesync[bot]
authored andcommitted
Improve ball_query() runtime for large-scale cases (#2006)
Summary: ### Overview The current C++ code for `pytorch3d.ops.ball_query()` performs floating point multiplication for every coordinate of every pair of points (up until the maximum number of neighbor points is reached). This PR modifies the code (for both CPU and CUDA versions) to implement idea presented [here](https://stackoverflow.com/a/3939525): a `D`-cube around the `D`-ball is first constructed, and any point pairs falling outside the cube are skipped, without explicitly computing the squared distances. This change is especially useful for when the dimension `D` and the number of points `P2` are large and the radius is much smaller than the overall volume of space occupied by the point clouds; as much as **~2.5x speedup** (CPU case; ~1.8x speedup in CUDA case) is observed when `D = 10` and `radius = 0.01`. In all benchmark cases, points were uniform randomly distributed inside a unit `D`-cube. The benchmark code used was different from `tests/benchmarks/bm_ball_query.py` (only the forward part is benchmarked, larger input sizes were used) and is stored in `tests/benchmarks/bm_ball_query_large.py`. ### Average time comparisons <img width="360" height="270" alt="cpu-03-0 01-avg" src="https://github.com/user-attachments/assets/6cc79893-7921-44af-9366-1766c3caf142" /> <img width="360" height="270" alt="cuda-03-0 01-avg" src="https://github.com/user-attachments/assets/5151647d-0273-40a3-aac6-8b9399ede18a" /> <img width="360" height="270" alt="cpu-03-0 10-avg" src="https://github.com/user-attachments/assets/a87bc150-a5eb-47cd-a4ba-83c2ec81edaf" /> <img width="360" height="270" alt="cuda-03-0 10-avg" src="https://github.com/user-attachments/assets/e3699a9f-dfd3-4dd3-b3c9-619296186d43" /> <img width="360" height="270" alt="cpu-10-0 01-avg" src="https://github.com/user-attachments/assets/5ec8c32d-8e4d-4ced-a94e-1b816b1cb0f8" /> <img width="360" height="270" alt="cuda-10-0 01-avg" src="https://github.com/user-attachments/assets/168a3dfc-777a-4fb3-8023-1ac8c13985b8" /> <img width="360" height="270" alt="cpu-10-0 10-avg" src="https://github.com/user-attachments/assets/43a57fd6-1e01-4c5e-87a9-8ef604ef5fa0" /> <img width="360" height="270" alt="cuda-10-0 10-avg" src="https://github.com/user-attachments/assets/a7c7cc69-f273-493e-95b8-3ba2bb2e32da" /> ### Peak time comparisons <img width="360" height="270" alt="cpu-03-0 01-peak" src="https://github.com/user-attachments/assets/5bbbea3f-ef9b-490d-ab0d-ce551711d74f" /> <img width="360" height="270" alt="cuda-03-0 01-peak" src="https://github.com/user-attachments/assets/30b5ab9b-45cb-4057-b69f-bda6e76bd1dc" /> <img width="360" height="270" alt="cpu-03-0 10-peak" src="https://github.com/user-attachments/assets/db69c333-e5ac-4305-8a86-a26a8a9fe80d" /> <img width="360" height="270" alt="cuda-03-0 10-peak" src="https://github.com/user-attachments/assets/82549656-1f12-409e-8160-dd4c4c9d14f7" /> <img width="360" height="270" alt="cpu-10-0 01-peak" src="https://github.com/user-attachments/assets/d0be8ef1-535e-47bc-b773-b87fad625bf0" /> <img width="360" height="270" alt="cuda-10-0 01-peak" src="https://github.com/user-attachments/assets/e308e66e-ae30-400f-8ad2-015517f6e1af" /> <img width="360" height="270" alt="cpu-10-0 10-peak" src="https://github.com/user-attachments/assets/c9b5bf59-9cc2-465c-ad5d-d4e23bdd138a" /> <img width="360" height="270" alt="cuda-10-0 10-peak" src="https://github.com/user-attachments/assets/311354d4-b488-400c-a1dc-c85a21917aa9" /> ### Full benchmark logs [benchmark-before-change.txt](https://github.com/user-attachments/files/22978300/benchmark-before-change.txt) [benchmark-after-change.txt](https://github.com/user-attachments/files/22978299/benchmark-after-change.txt) Pull Request resolved: #2006 Reviewed By: shapovalov Differential Revision: D85356394 Pulled By: bottler fbshipit-source-id: 9b3ce5fc87bb73d4323cc5b4190fc38ae42f41b2
1 parent 45df20e commit 2d4d345

File tree

5 files changed

+115
-14
lines changed

5 files changed

+115
-14
lines changed

pytorch3d/csrc/ball_query/ball_query.cu

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ __global__ void BallQueryKernel(
3232
at::PackedTensorAccessor64<int64_t, 3, at::RestrictPtrTraits> idxs,
3333
at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> dists,
3434
const int64_t K,
35-
const float radius2) {
35+
const float radius,
36+
const float radius2,
37+
const bool skip_points_outside_cube) {
3638
const int64_t N = p1.size(0);
3739
const int64_t chunks_per_cloud = (1 + (p1.size(1) - 1) / blockDim.x);
3840
const int64_t chunks_to_do = N * chunks_per_cloud;
@@ -51,7 +53,19 @@ __global__ void BallQueryKernel(
5153
// Iterate over points in p2 until desired count is reached or
5254
// all points have been considered
5355
for (int64_t j = 0, count = 0; j < lengths2[n] && count < K; ++j) {
54-
// Calculate the distance between the points
56+
if (skip_points_outside_cube) {
57+
bool is_within_radius = true;
58+
// Filter when any one coordinate is already outside the radius
59+
for (int d = 0; is_within_radius && d < D; ++d) {
60+
scalar_t abs_diff = fabs(p1[n][i][d] - p2[n][j][d]);
61+
is_within_radius = (abs_diff <= radius);
62+
}
63+
if (!is_within_radius) {
64+
continue;
65+
}
66+
}
67+
68+
// Else, calculate the distance between the points and compare
5569
scalar_t dist2 = 0.0;
5670
for (int d = 0; d < D; ++d) {
5771
scalar_t diff = p1[n][i][d] - p2[n][j][d];
@@ -77,7 +91,8 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
7791
const at::Tensor& lengths1, // (N,)
7892
const at::Tensor& lengths2, // (N,)
7993
int K,
80-
float radius) {
94+
float radius,
95+
bool skip_points_outside_cube) {
8196
// Check inputs are on the same device
8297
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
8398
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4};
@@ -120,7 +135,9 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
120135
idxs.packed_accessor64<int64_t, 3, at::RestrictPtrTraits>(),
121136
dists.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
122137
K_64,
123-
radius2);
138+
radius,
139+
radius2,
140+
skip_points_outside_cube);
124141
}));
125142

126143
AT_CUDA_CHECK(cudaGetLastError());

pytorch3d/csrc/ball_query/ball_query.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
// within the radius
2626
// radius: the radius around each point within which the neighbors need to be
2727
// located
28+
// skip_points_outside_cube: If true, reduce multiplications of float values
29+
// by not explicitly calculating distances to points that fall outside the
30+
// D-cube with side length (2*radius) centered at each point in p1.
2831
//
2932
// Returns:
3033
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
@@ -46,7 +49,8 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
4649
const at::Tensor& lengths1,
4750
const at::Tensor& lengths2,
4851
const int K,
49-
const float radius);
52+
const float radius,
53+
const bool skip_points_outside_cube);
5054

5155
// CUDA implementation
5256
std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
@@ -55,7 +59,8 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
5559
const at::Tensor& lengths1,
5660
const at::Tensor& lengths2,
5761
const int K,
58-
const float radius);
62+
const float radius,
63+
const bool skip_points_outside_cube);
5964

6065
// Implementation which is exposed
6166
// Note: the backward pass reuses the KNearestNeighborBackward kernel
@@ -65,7 +70,8 @@ inline std::tuple<at::Tensor, at::Tensor> BallQuery(
6570
const at::Tensor& lengths1,
6671
const at::Tensor& lengths2,
6772
int K,
68-
float radius) {
73+
float radius,
74+
bool skip_points_outside_cube) {
6975
if (p1.is_cuda() || p2.is_cuda()) {
7076
#ifdef WITH_CUDA
7177
CHECK_CUDA(p1);
@@ -76,7 +82,8 @@ inline std::tuple<at::Tensor, at::Tensor> BallQuery(
7682
lengths1.contiguous(),
7783
lengths2.contiguous(),
7884
K,
79-
radius);
85+
radius,
86+
skip_points_outside_cube);
8087
#else
8188
AT_ERROR("Not compiled with GPU support.");
8289
#endif
@@ -89,5 +96,6 @@ inline std::tuple<at::Tensor, at::Tensor> BallQuery(
8996
lengths1.contiguous(),
9097
lengths2.contiguous(),
9198
K,
92-
radius);
99+
radius,
100+
skip_points_outside_cube);
93101
}

pytorch3d/csrc/ball_query/ball_query_cpu.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <math.h>
910
#include <torch/extension.h>
1011
#include <tuple>
1112

@@ -15,7 +16,8 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
1516
const at::Tensor& lengths1,
1617
const at::Tensor& lengths2,
1718
int K,
18-
float radius) {
19+
float radius,
20+
bool skip_points_outside_cube) {
1921
const int N = p1.size(0);
2022
const int P1 = p1.size(1);
2123
const int D = p1.size(2);
@@ -37,6 +39,16 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
3739
const int64_t length2 = lengths2_a[n];
3840
for (int64_t i = 0; i < length1; ++i) {
3941
for (int64_t j = 0, count = 0; j < length2 && count < K; ++j) {
42+
if (skip_points_outside_cube) {
43+
bool is_within_radius = true;
44+
for (int d = 0; is_within_radius && d < D; ++d) {
45+
float abs_diff = fabs(p1_a[n][i][d] - p2_a[n][j][d]);
46+
is_within_radius = (abs_diff <= radius);
47+
}
48+
if (!is_within_radius) {
49+
continue;
50+
}
51+
}
4052
float dist2 = 0;
4153
for (int d = 0; d < D; ++d) {
4254
float diff = p1_a[n][i][d] - p2_a[n][j][d];

pytorch3d/ops/ball_query.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@ class _ball_query(Function):
2323
"""
2424

2525
@staticmethod
26-
def forward(ctx, p1, p2, lengths1, lengths2, K, radius):
26+
def forward(ctx, p1, p2, lengths1, lengths2, K, radius, skip_points_outside_cube):
2727
"""
2828
Arguments defintions the same as in the ball_query function
2929
"""
30-
idx, dists = _C.ball_query(p1, p2, lengths1, lengths2, K, radius)
30+
idx, dists = _C.ball_query(
31+
p1, p2, lengths1, lengths2, K, radius, skip_points_outside_cube
32+
)
3133
ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
3234
ctx.mark_non_differentiable(idx)
3335
return dists, idx
@@ -49,7 +51,7 @@ def backward(ctx, grad_dists, grad_idx):
4951
grad_p1, grad_p2 = _C.knn_points_backward(
5052
p1, p2, lengths1, lengths2, idx, 2, grad_dists
5153
)
52-
return grad_p1, grad_p2, None, None, None, None
54+
return grad_p1, grad_p2, None, None, None, None, None
5355

5456

5557
def ball_query(
@@ -60,6 +62,7 @@ def ball_query(
6062
K: int = 500,
6163
radius: float = 0.2,
6264
return_nn: bool = True,
65+
skip_points_outside_cube: bool = False,
6366
):
6467
"""
6568
Ball Query is an alternative to KNN. It can be
@@ -98,6 +101,9 @@ def ball_query(
98101
within the radius
99102
radius: the radius around each point within which the neighbors need to be located
100103
return_nn: If set to True returns the K neighbor points in p2 for each point in p1.
104+
skip_points_outside_cube: If set to True, reduce multiplications of float values
105+
by not explicitly calculating distances to points that fall outside the
106+
D-cube with side length (2*radius) centered at each point in p1.
101107
102108
Returns:
103109
dists: Tensor of shape (N, P1, K) giving the squared distances to
@@ -134,7 +140,9 @@ def ball_query(
134140
if lengths2 is None:
135141
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device)
136142

137-
dists, idx = _ball_query.apply(p1, p2, lengths1, lengths2, K, radius)
143+
dists, idx = _ball_query.apply(
144+
p1, p2, lengths1, lengths2, K, radius, skip_points_outside_cube
145+
)
138146

139147
# Gather the neighbors if needed
140148
points_nn = masked_gather(p2, idx) if return_nn else None
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from itertools import product
8+
9+
import torch
10+
from fvcore.common.benchmark import benchmark
11+
12+
from pytorch3d.ops.ball_query import ball_query
13+
14+
15+
def ball_query_square(
16+
N: int, P1: int, P2: int, D: int, K: int, radius: float, device: str
17+
):
18+
device = torch.device(device)
19+
pts1 = torch.rand(N, P1, D, device=device)
20+
pts2 = torch.rand(N, P2, D, device=device)
21+
torch.cuda.synchronize()
22+
23+
def output():
24+
ball_query(pts1, pts2, K=K, radius=radius, skip_points_outside_cube=True)
25+
torch.cuda.synchronize()
26+
27+
return output
28+
29+
30+
def bm_ball_query() -> None:
31+
backends = ["cpu", "cuda:0"]
32+
33+
kwargs_list = []
34+
Ns = [32]
35+
P1s = [256]
36+
P2s = [2**p for p in range(9, 20, 2)]
37+
Ds = [3, 10]
38+
Ks = [500]
39+
Rs = [0.01, 0.1]
40+
test_cases = product(Ns, P1s, P2s, Ds, Ks, Rs, backends)
41+
for case in test_cases:
42+
N, P1, P2, D, K, R, b = case
43+
kwargs_list.append(
44+
{"N": N, "P1": P1, "P2": P2, "D": D, "K": K, "radius": R, "device": b}
45+
)
46+
benchmark(
47+
ball_query_square,
48+
"BALLQUERY_SQUARE",
49+
kwargs_list,
50+
num_iters=30,
51+
warmup_iters=1,
52+
)
53+
54+
55+
if __name__ == "__main__":
56+
bm_ball_query()

0 commit comments

Comments
 (0)