Skip to content

Commit

Permalink
feat(lidar_centerpoint): optimize non-maximum suppression algorithm
Browse files Browse the repository at this point in the history
Refactor the `NonMaximumSuppression` class in the `lidar_centerpoint` module to optimize the non-maximum suppression algorithm. This includes removing the `isTargetLabel` function and modifying the `isTargetPairObject` function to handle pedestrian labels differently. Additionally, the `search_distance_2d_sq_` variable is now initialized in the `setParameters` function.
  • Loading branch information
technolojin committed Dec 11, 2024
1 parent adfecc2 commit e64c0f9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ class NonMaximumSuppression
std::vector<DetectedObject> apply(const std::vector<DetectedObject> &);

private:
bool isTargetLabel(const std::uint8_t);

bool isTargetPairObject(const DetectedObject &, const DetectedObject &);

Eigen::MatrixXd generateIoUMatrix(const std::vector<DetectedObject> &);

NMSParams params_{};
std::vector<bool> target_class_mask_{};

double search_distance_2d_sq_{};
};

} // namespace centerpoint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,38 +19,32 @@
#include "object_recognition_utils/object_recognition_utils.hpp"
namespace centerpoint
{
using Label = autoware_perception_msgs::msg::ObjectClassification;

void NonMaximumSuppression::setParameters(const NMSParams & params)
{
assert(params.search_distance_2d_ >= 0.0);
assert(params.iou_threshold_ >= 0.0 && params.iou_threshold_ <= 1.0);

params_ = params;
search_distance_2d_sq_ = params.search_distance_2d_ * params.search_distance_2d_;
target_class_mask_ = classNamesToBooleanMask(params.target_class_names_);
}

bool NonMaximumSuppression::isTargetLabel(const uint8_t label)
{
if (label >= target_class_mask_.size()) {
return false;
}
return target_class_mask_.at(label);
}

bool NonMaximumSuppression::isTargetPairObject(
const DetectedObject & object1, const DetectedObject & object2)
{
const auto label1 = object_recognition_utils::getHighestProbLabel(object1.classification);
const auto label2 = object_recognition_utils::getHighestProbLabel(object2.classification);

if (isTargetLabel(label1) && isTargetLabel(label2)) {
return true;
// if labels are not the same, and one of them is pedestrian, do not suppress
if (label1 != label2 && (label1 == Label::PEDESTRIAN || label2 == Label::PEDESTRIAN)) {
return false;
}

const auto search_sqr_dist_2d = params_.search_distance_2d_ * params_.search_distance_2d_;
const auto sqr_dist_2d = autoware::universe_utils::calcSquaredDistance2d(
object_recognition_utils::getPose(object1), object_recognition_utils::getPose(object2));
return sqr_dist_2d <= search_sqr_dist_2d;
return sqr_dist_2d <= search_distance_2d_sq_;
}

Eigen::MatrixXd NonMaximumSuppression::generateIoUMatrix(
Expand Down

0 comments on commit e64c0f9

Please sign in to comment.