From 44bdb48301e9f7cd3152d3142be2e882b61b6230 Mon Sep 17 00:00:00 2001 From: Taekjin LEE Date: Wed, 11 Dec 2024 16:55:21 +0900 Subject: [PATCH] feat(lidar_centerpoint): optimize non-maximum suppression algorithm (#1689) Refactor the `NonMaximumSuppression` class in the `lidar_centerpoint` module to optimize the non-maximum suppression algorithm. This includes removing the `isTargetLabel` function and modifying the `isTargetPairObject` function to handle pedestrian labels differently. Additionally, the `search_distance_2d_sq_` variable is now initialized in the `setParameters` function. --- .../postprocess/non_maximum_suppression.hpp | 4 ++-- .../postprocess/non_maximum_suppression.cpp | 18 ++++++------------ 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/perception/lidar_centerpoint/include/lidar_centerpoint/postprocess/non_maximum_suppression.hpp b/perception/lidar_centerpoint/include/lidar_centerpoint/postprocess/non_maximum_suppression.hpp index 3cefe7ddf3335..0fddbfc4edc56 100644 --- a/perception/lidar_centerpoint/include/lidar_centerpoint/postprocess/non_maximum_suppression.hpp +++ b/perception/lidar_centerpoint/include/lidar_centerpoint/postprocess/non_maximum_suppression.hpp @@ -66,14 +66,14 @@ class NonMaximumSuppression std::vector apply(const std::vector &); private: - bool isTargetLabel(const std::uint8_t); - bool isTargetPairObject(const DetectedObject &, const DetectedObject &); Eigen::MatrixXd generateIoUMatrix(const std::vector &); NMSParams params_{}; std::vector target_class_mask_{}; + + double search_distance_2d_sq_{}; }; } // namespace centerpoint diff --git a/perception/lidar_centerpoint/lib/postprocess/non_maximum_suppression.cpp b/perception/lidar_centerpoint/lib/postprocess/non_maximum_suppression.cpp index 66e53310be263..d750e4adab0e1 100644 --- a/perception/lidar_centerpoint/lib/postprocess/non_maximum_suppression.cpp +++ b/perception/lidar_centerpoint/lib/postprocess/non_maximum_suppression.cpp @@ -19,6 +19,7 @@ #include "object_recognition_utils/object_recognition_utils.hpp" namespace centerpoint { +using Label = autoware_perception_msgs::msg::ObjectClassification; void NonMaximumSuppression::setParameters(const NMSParams & params) { @@ -26,31 +27,24 @@ void NonMaximumSuppression::setParameters(const NMSParams & params) assert(params.iou_threshold_ >= 0.0 && params.iou_threshold_ <= 1.0); params_ = params; + search_distance_2d_sq_ = params.search_distance_2d_ * params.search_distance_2d_; target_class_mask_ = classNamesToBooleanMask(params.target_class_names_); } -bool NonMaximumSuppression::isTargetLabel(const uint8_t label) -{ - if (label >= target_class_mask_.size()) { - return false; - } - return target_class_mask_.at(label); -} - bool NonMaximumSuppression::isTargetPairObject( const DetectedObject & object1, const DetectedObject & object2) { const auto label1 = object_recognition_utils::getHighestProbLabel(object1.classification); const auto label2 = object_recognition_utils::getHighestProbLabel(object2.classification); - if (isTargetLabel(label1) && isTargetLabel(label2)) { - return true; + // if labels are not the same, and one of them is pedestrian, do not suppress + if (label1 != label2 && (label1 == Label::PEDESTRIAN || label2 == Label::PEDESTRIAN)) { + return false; } - const auto search_sqr_dist_2d = params_.search_distance_2d_ * params_.search_distance_2d_; const auto sqr_dist_2d = autoware::universe_utils::calcSquaredDistance2d( object_recognition_utils::getPose(object1), object_recognition_utils::getPose(object2)); - return sqr_dist_2d <= search_sqr_dist_2d; + return sqr_dist_2d <= search_distance_2d_sq_; } Eigen::MatrixXd NonMaximumSuppression::generateIoUMatrix(