Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pymomentum/geometry/geometry_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3159,14 +3159,16 @@ Using the normal is a good way to avoid certain kinds of bad matches, such as ma
:param points_source: [nBatch x nPoints x 3] tensor of source points.
:param vertices_target: [nBatch x nPoints x 3] tensor of target vertices.
:param faces_target: [nBatch x nPoints x 3] tensor of target faces.
:param max_distance: Maximum search distance, allows the search to end early if no points are found within this bound.
:return: A tuple of three tensors, (valid, points, face_index, bary). The first is [nBatch x nPoints] and specifies if the closest point result is valid.
The second is [nBatch x nPoints x 3] and contains the actual closest point (or 0, 0, 0 if invalid).
The third is [nBatch x nPoints] and contains the index of the closest face (or -1 if invalid).
The fourth is [nBatch x nPoints x 3] and contains the barycentric coordinates of the closest point on the face (or 0, 0, 0 if invalid).
)",
py::arg("points_source"),
py::arg("vertices_target"),
py::arg("faces_target"));
py::arg("faces_target"),
py::arg("max_distance") = std::numeric_limits<float>::max());

m.def(
"replace_rest_mesh",
Expand Down
15 changes: 10 additions & 5 deletions pymomentum/tensor_momentum/tensor_kd_tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,8 @@ void findClosestPointsOnMesh_imp(
at::Tensor faces_target,
at::Tensor result_points,
at::Tensor result_face_index,
at::Tensor result_barycentric) {
at::Tensor result_barycentric,
float maxDist) {
using TriBvh = typename axel::TriBvh<S>;

const int64_t nSrcPts = points_source.size(0);
Expand Down Expand Up @@ -374,7 +375,8 @@ void findClosestPointsOnMesh_imp(
for (int64_t k = srcStart; k < srcEnd; ++k) {
const Eigen::Vector3<S> p_src =
pts_src_map.template segment<3>(3 * k);
const auto queryResult = targetTree.closestSurfacePoint(p_src);
const auto queryResult =
targetTree.closestSurfacePoint(p_src, maxDist);
if (queryResult.triangleIdx == axel::kInvalidTriangleIdx) {
result_face_indices_map(k) = -1;
} else {
Expand All @@ -393,7 +395,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
findClosestPointsOnMesh(
at::Tensor points_source,
at::Tensor vertices_target,
at::Tensor faces_target) {
at::Tensor faces_target,
float maxDist) {
TensorChecker checker("find_closest_points_on_mesh");

bool squeeze_src = false;
Expand Down Expand Up @@ -455,7 +458,8 @@ findClosestPointsOnMesh(
faces_target.select(0, iBatch),
result_closest_points.select(0, iBatch),
result_face_index.select(0, iBatch),
result_barycentric.select(0, iBatch));
result_barycentric.select(0, iBatch),
maxDist);
}
} else {
for (int64_t iBatch = 0; iBatch < nBatch; ++iBatch) {
Expand All @@ -465,7 +469,8 @@ findClosestPointsOnMesh(
faces_target.select(0, iBatch),
result_closest_points.select(0, iBatch),
result_face_index.select(0, iBatch),
result_barycentric.select(0, iBatch));
result_barycentric.select(0, iBatch),
maxDist);
}
}

Expand Down
3 changes: 2 additions & 1 deletion pymomentum/tensor_momentum/tensor_kd_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
findClosestPointsOnMesh(
at::Tensor points_source,
at::Tensor vertices_target,
at::Tensor faces_target);
at::Tensor faces_target,
float maxDist);

} // namespace pymomentum