Skip to content

Commit

Permalink
Final Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrhm committed Sep 18, 2024
1 parent 60be043 commit 755f093
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion perception/object_detector/object_detector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace mrover {

RCLCPP_INFO_STREAM(get_logger(), "Opening Model " << mModelName);

mLearning = Learning{mModelName, packagePath};
mTensorRT = TensortRT{mModelName, packagePath};

mDebugImgPub = create_publisher<sensor_msgs::msg::Image>("object_detector/debug_img", 1);

Expand Down
2 changes: 1 addition & 1 deletion perception/object_detector/object_detector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace mrover {

LoopProfiler mLoopProfiler;

Learning mLearning;
TensortRT mTensorRT;

cv::Mat mRgbImage, mImageBlob;
sensor_msgs::msg::Image mDetectionsImageMessage;
Expand Down
2 changes: 1 addition & 1 deletion perception/object_detector/object_detector.processing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace mrover {

// Run the blob through the model
std::vector<Detection> detections{};
mLearning.modelForwardPass(mImageBlob, detections, mModelScoreThreshold, mModelNmsThreshold);
mTensorRT.modelForwardPass(mImageBlob, detections, mModelScoreThreshold, mModelNmsThreshold);

mLoopProfiler.measureEvent("Execution");

Expand Down
2 changes: 1 addition & 1 deletion perception/object_detector/pch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@

#include "parameter.hpp"
#include "point.hpp"
#include <learning.hpp>
#include <tensorrt.hpp>
12 changes: 6 additions & 6 deletions tensorrt/learning.cpp → tensorrt/tensorrt.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#include "learning.hpp"
#include "tensorrt.hpp"

using namespace std;

Learning::Learning() = default;
TensortRT::TensortRT() = default;

Learning::Learning(string modelName, string& packagePathString) : mModelName{std::move(modelName)} {
TensortRT::TensortRT(string modelName, string& packagePathString) : mModelName{std::move(modelName)} {

std::filesystem::path packagePath = packagePathString;
std::filesystem::path modelFileName = mModelName.append(".onnx");
Expand All @@ -13,15 +13,15 @@ Learning::Learning(string modelName, string& packagePathString) : mModelName{std
mInferenceWrapper = InferenceWrapper{modelPath, mModelName, packagePath};
}

Learning::~Learning() = default;
TensortRT::~TensortRT() = default;

auto Learning::modelForwardPass(cv::Mat const& blob, std::vector<Detection>& detections, float modelScoreThreshold, float modelNMSThreshold) const -> void {
auto TensortRT::modelForwardPass(cv::Mat const& blob, std::vector<Detection>& detections, float modelScoreThreshold, float modelNMSThreshold) const -> void {
mInferenceWrapper.doDetections(blob);
cv::Mat output = mInferenceWrapper.getOutputTensor();
parseModelOutput(output, detections, modelScoreThreshold, modelNMSThreshold);
}

auto Learning::parseModelOutput(cv::Mat& output, std::vector<Detection>& detections, float modelScoreThreshold, float modelNMSThreshold) const -> void {
auto TensortRT::parseModelOutput(cv::Mat& output, std::vector<Detection>& detections, float modelScoreThreshold, float modelNMSThreshold) const -> void {
// Parse model specific dimensioning from the output

// The input to this function is expecting a YOLOv8 style model, thus the dimensions should be > rows
Expand Down
8 changes: 4 additions & 4 deletions tensorrt/learning.hpp → tensorrt/tensorrt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ struct Detection {
cv::Rect box;
};

class Learning {
class TensortRT {
std::string mModelName;

std::vector<std::string> classes{"bottle", "hammer"};
Expand All @@ -22,11 +22,11 @@ class Learning {
float modelNMSThreshold = 0.5) const -> void;

public:
Learning();
TensortRT();

explicit Learning(std::string modelName, std::string& packagePathString);
explicit TensortRT(std::string modelName, std::string& packagePathString);

~Learning();
~TensortRT();

auto modelForwardPass(cv::Mat const& blob,
std::vector<Detection>& detections,
Expand Down

0 comments on commit 755f093

Please sign in to comment.