@@ -23,11 +23,13 @@ class _ball_query(Function):
2323    """ 
2424
2525    @staticmethod  
26-     def  forward (ctx , p1 , p2 , lengths1 , lengths2 , K , radius ):
26+     def  forward (ctx , p1 , p2 , lengths1 , lengths2 , K , radius ,  skip_points_outside_cube ):
2727        """ 
2828        Arguments defintions the same as in the ball_query function 
2929        """ 
30-         idx , dists  =  _C .ball_query (p1 , p2 , lengths1 , lengths2 , K , radius )
30+         idx , dists  =  _C .ball_query (
31+             p1 , p2 , lengths1 , lengths2 , K , radius , skip_points_outside_cube 
32+         )
3133        ctx .save_for_backward (p1 , p2 , lengths1 , lengths2 , idx )
3234        ctx .mark_non_differentiable (idx )
3335        return  dists , idx 
@@ -49,7 +51,7 @@ def backward(ctx, grad_dists, grad_idx):
4951        grad_p1 , grad_p2  =  _C .knn_points_backward (
5052            p1 , p2 , lengths1 , lengths2 , idx , 2 , grad_dists 
5153        )
52-         return  grad_p1 , grad_p2 , None , None , None , None 
54+         return  grad_p1 , grad_p2 , None , None , None , None ,  None 
5355
5456
5557def  ball_query (
@@ -60,6 +62,7 @@ def ball_query(
6062    K : int  =  500 ,
6163    radius : float  =  0.2 ,
6264    return_nn : bool  =  True ,
65+     skip_points_outside_cube : bool  =  False ,
6366):
6467    """ 
6568    Ball Query is an alternative to KNN. It can be 
@@ -98,6 +101,9 @@ def ball_query(
98101            within the radius 
99102        radius: the radius around each point within which the neighbors need to be located 
100103        return_nn: If set to True returns the K neighbor points in p2 for each point in p1. 
104+         skip_points_outside_cube: If set to True, reduce multiplications of float values 
105+             by not explicitly calculating distances to points that fall outside the 
106+             D-cube with side length (2*radius) centered at each point in p1. 
101107
102108    Returns: 
103109        dists: Tensor of shape (N, P1, K) giving the squared distances to 
@@ -134,7 +140,9 @@ def ball_query(
134140    if  lengths2  is  None :
135141        lengths2  =  torch .full ((N ,), P2 , dtype = torch .int64 , device = p1 .device )
136142
137-     dists , idx  =  _ball_query .apply (p1 , p2 , lengths1 , lengths2 , K , radius )
143+     dists , idx  =  _ball_query .apply (
144+         p1 , p2 , lengths1 , lengths2 , K , radius , skip_points_outside_cube 
145+     )
138146
139147    # Gather the neighbors if needed 
140148    points_nn  =  masked_gather (p2 , idx ) if  return_nn  else  None 
0 commit comments