Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(lidar_centerpoint): optimize non-maximum suppression algorithm #1689

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ class NonMaximumSuppression
std::vector<DetectedObject> apply(const std::vector<DetectedObject> &);

private:
bool isTargetLabel(const std::uint8_t);

bool isTargetPairObject(const DetectedObject &, const DetectedObject &);

Eigen::MatrixXd generateIoUMatrix(const std::vector<DetectedObject> &);

NMSParams params_{};
std::vector<bool> target_class_mask_{};

double search_distance_2d_sq_{};
};

} // namespace centerpoint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,38 +19,32 @@
#include "object_recognition_utils/object_recognition_utils.hpp"
namespace centerpoint
{
using Label = autoware_perception_msgs::msg::ObjectClassification;

void NonMaximumSuppression::setParameters(const NMSParams & params)
{
assert(params.search_distance_2d_ >= 0.0);
assert(params.iou_threshold_ >= 0.0 && params.iou_threshold_ <= 1.0);

params_ = params;
search_distance_2d_sq_ = params.search_distance_2d_ * params.search_distance_2d_;
target_class_mask_ = classNamesToBooleanMask(params.target_class_names_);
}

bool NonMaximumSuppression::isTargetLabel(const uint8_t label)
{
if (label >= target_class_mask_.size()) {
return false;
}
return target_class_mask_.at(label);
}

bool NonMaximumSuppression::isTargetPairObject(
const DetectedObject & object1, const DetectedObject & object2)
{
const auto label1 = object_recognition_utils::getHighestProbLabel(object1.classification);
const auto label2 = object_recognition_utils::getHighestProbLabel(object2.classification);

if (isTargetLabel(label1) && isTargetLabel(label2)) {
return true;
// if labels are not the same, and one of them is pedestrian, do not suppress
if (label1 != label2 && (label1 == Label::PEDESTRIAN || label2 == Label::PEDESTRIAN)) {
return false;
}

const auto search_sqr_dist_2d = params_.search_distance_2d_ * params_.search_distance_2d_;
const auto sqr_dist_2d = autoware::universe_utils::calcSquaredDistance2d(
object_recognition_utils::getPose(object1), object_recognition_utils::getPose(object2));
return sqr_dist_2d <= search_sqr_dist_2d;
return sqr_dist_2d <= search_distance_2d_sq_;
}

Eigen::MatrixXd NonMaximumSuppression::generateIoUMatrix(
Expand Down
Loading