diff --git a/pymomentum/geometry/geometry_pybind.cpp b/pymomentum/geometry/geometry_pybind.cpp index a0e1e0b371..73da8fc899 100644 --- a/pymomentum/geometry/geometry_pybind.cpp +++ b/pymomentum/geometry/geometry_pybind.cpp @@ -3159,6 +3159,7 @@ Using the normal is a good way to avoid certain kinds of bad matches, such as ma :param points_source: [nBatch x nPoints x 3] tensor of source points. :param vertices_target: [nBatch x nPoints x 3] tensor of target vertices. :param faces_target: [nBatch x nPoints x 3] tensor of target faces. + :param max_distance: Maximum search distance, allows the search to end early if no points are found within this bound. :return: A tuple of three tensors, (valid, points, face_index, bary). The first is [nBatch x nPoints] and specifies if the closest point result is valid. The second is [nBatch x nPoints x 3] and contains the actual closest point (or 0, 0, 0 if invalid). The third is [nBatch x nPoints] and contains the index of the closest face (or -1 if invalid). @@ -3166,7 +3167,8 @@ Using the normal is a good way to avoid certain kinds of bad matches, such as ma )", py::arg("points_source"), py::arg("vertices_target"), - py::arg("faces_target")); + py::arg("faces_target"), + py::arg("max_distance") = std::numeric_limits::max()); m.def( "replace_rest_mesh", diff --git a/pymomentum/tensor_momentum/tensor_kd_tree.cpp b/pymomentum/tensor_momentum/tensor_kd_tree.cpp index ac368e7d13..e42c1a138c 100644 --- a/pymomentum/tensor_momentum/tensor_kd_tree.cpp +++ b/pymomentum/tensor_momentum/tensor_kd_tree.cpp @@ -327,7 +327,8 @@ void findClosestPointsOnMesh_imp( at::Tensor faces_target, at::Tensor result_points, at::Tensor result_face_index, - at::Tensor result_barycentric) { + at::Tensor result_barycentric, + float maxDist) { using TriBvh = typename axel::TriBvh; const int64_t nSrcPts = points_source.size(0); @@ -374,7 +375,8 @@ void findClosestPointsOnMesh_imp( for (int64_t k = srcStart; k < srcEnd; ++k) { const Eigen::Vector3 p_src = pts_src_map.template segment<3>(3 * k); - const auto queryResult = targetTree.closestSurfacePoint(p_src); + const auto queryResult = + targetTree.closestSurfacePoint(p_src, maxDist); if (queryResult.triangleIdx == axel::kInvalidTriangleIdx) { result_face_indices_map(k) = -1; } else { @@ -393,7 +395,8 @@ std::tuple findClosestPointsOnMesh( at::Tensor points_source, at::Tensor vertices_target, - at::Tensor faces_target) { + at::Tensor faces_target, + float maxDist) { TensorChecker checker("find_closest_points_on_mesh"); bool squeeze_src = false; @@ -455,7 +458,8 @@ findClosestPointsOnMesh( faces_target.select(0, iBatch), result_closest_points.select(0, iBatch), result_face_index.select(0, iBatch), - result_barycentric.select(0, iBatch)); + result_barycentric.select(0, iBatch), + maxDist); } } else { for (int64_t iBatch = 0; iBatch < nBatch; ++iBatch) { @@ -465,7 +469,8 @@ findClosestPointsOnMesh( faces_target.select(0, iBatch), result_closest_points.select(0, iBatch), result_face_index.select(0, iBatch), - result_barycentric.select(0, iBatch)); + result_barycentric.select(0, iBatch), + maxDist); } } diff --git a/pymomentum/tensor_momentum/tensor_kd_tree.h b/pymomentum/tensor_momentum/tensor_kd_tree.h index 4528881acd..fe619405f1 100644 --- a/pymomentum/tensor_momentum/tensor_kd_tree.h +++ b/pymomentum/tensor_momentum/tensor_kd_tree.h @@ -32,6 +32,7 @@ std::tuple findClosestPointsOnMesh( at::Tensor points_source, at::Tensor vertices_target, - at::Tensor faces_target); + at::Tensor faces_target, + float maxDist); } // namespace pymomentum