From 21f23ffcead3e58a058829821903262dbef8ec07 Mon Sep 17 00:00:00 2001 From: lindsayshuo Date: Tue, 22 Oct 2024 11:37:34 +0800 Subject: [PATCH] Distinguish the numclass of pose and obb in config. h and add the Oriented Bounding Boxes (OBB) Estimation algorithm (#1593) * Add the generation of multi-class pose engines * Change grids in forwardGpu to one-dimensional arrays * Update README.md * Update types.h keypoints array with dynamic size based on kNumberOfPoints * yolov8_5u_det(YOLOv5u with the anchor-free, objectness-free split head structure based on YOLOv8 features) model * update * fix code style * yolov8_5u_det model download link * yolov8_5u_det model download link * Distinguish the numclass of pose and obb in config. h and add the Oriented Bounding Boxes (OBB) Estimation algorithm * fix code style --------- Co-authored-by: lindsayshuo --- yolov8/CMakeLists.txt | 7 +- yolov8/README.md | 27 +- yolov8/gen_wts.py | 4 +- yolov8/include/block.h | 2 +- yolov8/include/config.h | 8 +- yolov8/include/model.h | 4 + yolov8/include/postprocess.h | 33 +- yolov8/include/types.h | 1 + yolov8/plugin/yololayer.cu | 66 ++-- yolov8/plugin/yololayer.h | 3 +- yolov8/src/block.cpp | 8 +- yolov8/src/model.cpp | 326 ++++++++++++++++++-- yolov8/src/postprocess.cpp | 240 ++++++++++++++- yolov8/src/postprocess.cu | 155 ++++++++-- yolov8/src/preprocess.cu | 60 ++-- yolov8/yolov8_5u_det_trt.py | 9 +- yolov8/yolov8_det_trt.py | 9 +- yolov8/yolov8_obb.cpp | 276 +++++++++++++++++ yolov8/yolov8_obb_trt.py | 571 +++++++++++++++++++++++++++++++++++ yolov8/yolov8_pose.cpp | 2 +- yolov8/yolov8_pose_trt.py | 13 +- yolov8/yolov8_seg.cpp | 2 +- yolov8/yolov8_seg_trt.py | 7 +- 23 files changed, 1677 insertions(+), 156 deletions(-) mode change 100755 => 100644 yolov8/README.md create mode 100644 yolov8/yolov8_obb.cpp create mode 100644 yolov8/yolov8_obb_trt.py diff --git a/yolov8/CMakeLists.txt b/yolov8/CMakeLists.txt index e0c3f8ea..4aa58eb0 100644 --- a/yolov8/CMakeLists.txt +++ b/yolov8/CMakeLists.txt @@ -25,8 +25,8 @@ else() link_directories(/usr/local/cuda/lib64) # tensorrt - include_directories(/home/lindsay/TensorRT-8.4.1.5/include) - link_directories(/home/lindsay/TensorRT-8.4.1.5/lib) + include_directories(/home/lindsay/TensorRT-8.6.1.6/include) + link_directories(/home/lindsay/TensorRT-8.6.1.6/lib) # include_directories(/home/lindsay/TensorRT-7.2.3.4/include) # link_directories(/home/lindsay/TensorRT-7.2.3.4/lib) @@ -60,3 +60,6 @@ target_link_libraries(yolov8_cls nvinfer cudart myplugins ${OpenCV_LIBS}) add_executable(yolov8_5u_det ${PROJECT_SOURCE_DIR}/yolov8_5u_det.cpp ${SRCS}) target_link_libraries(yolov8_5u_det nvinfer cudart myplugins ${OpenCV_LIBS}) + +add_executable(yolov8_obb ${PROJECT_SOURCE_DIR}/yolov8_obb.cpp ${SRCS}) +target_link_libraries(yolov8_obb nvinfer cudart myplugins ${OpenCV_LIBS}) diff --git a/yolov8/README.md b/yolov8/README.md old mode 100755 new mode 100644 index ef9a6a8c..51243e62 --- a/yolov8/README.md +++ b/yolov8/README.md @@ -106,7 +106,7 @@ wget -O coco.txt https://raw.githubusercontent.com/amikelive/coco-labels/master/ ``` cd {tensorrtx}/yolov8/ // Download inference images -wget https://github.com/lindsayshuo/infer_pic/blob/main/1709970363.6990473rescls.jpg +wget https://github.com/lindsayshuo/infer_pic/releases/download/pics/1709970363.6990473rescls.jpg mkdir samples cp -r 1709970363.6990473rescls.jpg samples // Download ImageNet labels @@ -130,7 +130,7 @@ sudo ./yolov8_cls -d yolov8n-cls.engine ../samples ### Pose Estimation ``` cd {tensorrtx}/yolov8/ -// update "kNumClass = 1" in config.h +// update "kPoseNumClass = 1" in config.h mkdir build cd build cp {ultralytics}/ultralytics/yolov8-pose.wts {tensorrtx}/yolov8/build @@ -146,6 +146,28 @@ sudo ./yolov8_pose -d yolov8n-pose.engine ../images g //gpu postprocess ``` +### Oriented Bounding Boxes (OBB) Estimation +``` +cd {tensorrtx}/yolov8/ +// update "kObbNumClass = 15" "kInputH = 1024" "kInputW = 1024" in config.h +wget https://github.com/lindsayshuo/infer_pic/releases/download/pics/obb.png +mkdir images +mv obb.png ./images +mkdir build +cd build +cp {ultralytics}/ultralytics/yolov8-obb.wts {tensorrtx}/yolov8/build +cmake .. +make +sudo ./yolov8_obb -s [.wts] [.engine] [n/s/m/l/x/n2/s2/m2/l2/x2/n6/s6/m6/l6/x6] // serialize model to plan file +sudo ./yolov8_obb -d [.engine] [image folder] [c/g] // deserialize and run inference, the images in [image folder] will be processed. + +// For example yolov8-obb +sudo ./yolov8_obb -s yolov8n-obb.wts yolov8n-obb.engine n +sudo ./yolov8_obb -d yolov8n-obb.engine ../images c //cpu postprocess +sudo ./yolov8_obb -d yolov8n-obb.engine ../images g //gpu postprocess +``` + + 4. optional, load and run the tensorrt model in python ``` @@ -156,6 +178,7 @@ python yolov8_seg_trt.py # Segmentation python yolov8_cls_trt.py # Classification python yolov8_pose_trt.py # Pose Estimation python yolov8_5u_det_trt.py # yolov8_5u_det(YOLOv5u with the anchor-free, objectness-free split head structure based on YOLOv8 features) model +python yolov8_obb_trt.py # Oriented Bounding Boxes (OBB) Estimation ``` # INT8 Quantization diff --git a/yolov8/gen_wts.py b/yolov8/gen_wts.py index 5f037db2..ea3147b7 100644 --- a/yolov8/gen_wts.py +++ b/yolov8/gen_wts.py @@ -12,7 +12,7 @@ def parse_args(): parser.add_argument( '-o', '--output', help='Output (.wts) file path (optional)') parser.add_argument( - '-t', '--type', type=str, default='detect', choices=['detect', 'cls', 'seg', 'pose'], + '-t', '--type', type=str, default='detect', choices=['detect', 'cls', 'seg', 'pose', 'obb'], help='determines the model is detection/classification') args = parser.parse_args() if not os.path.isfile(args.weights): @@ -39,7 +39,7 @@ def parse_args(): # Load model model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32 -if m_type in ['detect', 'seg', 'pose']: +if m_type in ['detect', 'seg', 'pose', 'obb']: anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None] delattr(model.model[-1], 'anchors') diff --git a/yolov8/include/block.h b/yolov8/include/block.h index ae8ec993..64149164 100644 --- a/yolov8/include/block.h +++ b/yolov8/include/block.h @@ -33,4 +33,4 @@ nvinfer1::IShuffleLayer* DFL(nvinfer1::INetworkDefinition* network, std::map dets, const int* px_arry, - int px_arry_num, bool is_segmentation, bool is_pose); + int px_arry_num, int num_class, bool is_segmentation, bool is_pose, bool is_obb); diff --git a/yolov8/include/config.h b/yolov8/include/config.h index 44773a31..31b9481c 100644 --- a/yolov8/include/config.h +++ b/yolov8/include/config.h @@ -5,7 +5,6 @@ const static char* kInputTensorName = "images"; const static char* kOutputTensorName = "output"; const static int kNumClass = 80; -const static int kNumberOfPoints = 17; // number of keypoints total const static int kBatchSize = 1; const static int kGpuId = 0; const static int kInputH = 640; @@ -23,3 +22,10 @@ constexpr static int kClsNumClass = 1000; // Classfication model's input shape constexpr static int kClsInputH = 224; constexpr static int kClsInputW = 224; + +// pose model's number of classes +constexpr static int kPoseNumClass = 1; +const static int kNumberOfPoints = 17; // number of keypoints total + +// obb model's number of classes +constexpr static int kObbNumClass = 15; diff --git a/yolov8/include/model.h b/yolov8/include/model.h index 8f30e029..f58cb9ef 100644 --- a/yolov8/include/model.h +++ b/yolov8/include/model.h @@ -37,3 +37,7 @@ nvinfer1::IHostMemory* buildEngineYolov8_5uDet(nvinfer1::IBuilder* builder, nvin nvinfer1::IHostMemory* buildEngineYolov8_5uDetP6(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config, nvinfer1::DataType dt, const std::string& wts_path, float& gd, float& gw, int& max_channels); + +nvinfer1::IHostMemory* buildEngineYolov8Obb(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config, + nvinfer1::DataType dt, const std::string& wts_path, float& gd, float& gw, + int& max_channels); diff --git a/yolov8/include/postprocess.h b/yolov8/include/postprocess.h index eb18d542..863f687b 100644 --- a/yolov8/include/postprocess.h +++ b/yolov8/include/postprocess.h @@ -4,27 +4,38 @@ #include "NvInfer.h" #include "types.h" +// Preprocessing functions cv::Rect get_rect(cv::Mat& img, float bbox[4]); -void nms(std::vector& res, float* output, float conf_thresh, float nms_thresh = 0.5); - -void batch_nms(std::vector>& batch_res, float* output, int batch_size, int output_size, - float conf_thresh, float nms_thresh = 0.5); - -void draw_bbox(std::vector& img_batch, std::vector>& res_batch); - -void draw_bbox_keypoints_line(std::vector& img_batch, std::vector>& res_batch); - +// Processing functions void batch_process(std::vector>& res_batch, const float* decode_ptr_host, int batch_size, int bbox_element, const std::vector& img_batch); - +void batch_process_obb(std::vector>& res_batch, const float* decode_ptr_host, int batch_size, + int bbox_element, const std::vector& img_batch); void process_decode_ptr_host(std::vector& res, const float* decode_ptr_host, int bbox_element, cv::Mat& img, int count); +void process_decode_ptr_host_obb(std::vector& res, const float* decode_ptr_host, int bbox_element, + cv::Mat& img, int count); + +// NMS functions +void nms(std::vector& res, float* output, float conf_thresh, float nms_thresh = 0.5); +void batch_nms(std::vector>& batch_res, float* output, int batch_size, int output_size, + float conf_thresh, float nms_thresh = 0.5); +void nms_obb(std::vector& res, float* output, float conf_thresh, float nms_thresh = 0.5); +void batch_nms_obb(std::vector>& batch_res, float* output, int batch_size, int output_size, + float conf_thresh, float nms_thresh = 0.5); +// CUDA-related functions void cuda_decode(float* predict, int num_bboxes, float confidence_threshold, float* parray, int max_objects, cudaStream_t stream); - void cuda_nms(float* parray, float nms_threshold, int max_objects, cudaStream_t stream); +void cuda_decode_obb(float* predict, int num_bboxes, float confidence_threshold, float* parray, int max_objects, + cudaStream_t stream); +void cuda_nms_obb(float* parray, float nms_threshold, int max_objects, cudaStream_t stream); +// Drawing functions +void draw_bbox(std::vector& img_batch, std::vector>& res_batch); +void draw_bbox_obb(std::vector& img_batch, std::vector>& res_batch); +void draw_bbox_keypoints_line(std::vector& img_batch, std::vector>& res_batch); void draw_mask_bbox(cv::Mat& img, std::vector& dets, std::vector& masks, std::unordered_map& labels_map); diff --git a/yolov8/include/types.h b/yolov8/include/types.h index c43f589b..bb2c8d20 100644 --- a/yolov8/include/types.h +++ b/yolov8/include/types.h @@ -8,6 +8,7 @@ struct alignas(float) Detection { float class_id; float mask[32]; float keypoints[kNumberOfPoints * 3]; // keypoints array with dynamic size based on kNumberOfPoints + float angle; // obb angle }; struct AffineMatrix { diff --git a/yolov8/plugin/yololayer.cu b/yolov8/plugin/yololayer.cu index c42b841c..bebfeb21 100644 --- a/yolov8/plugin/yololayer.cu +++ b/yolov8/plugin/yololayer.cu @@ -26,8 +26,8 @@ __device__ float sigmoid(float x) { namespace nvinfer1 { YoloLayerPlugin::YoloLayerPlugin(int classCount, int numberofpoints, float confthreshkeypoints, int netWidth, - int netHeight, int maxOut, bool is_segmentation, bool is_pose, const int* strides, - int stridesLength) { + int netHeight, int maxOut, bool is_segmentation, bool is_pose, bool is_obb, + const int* strides, int stridesLength) { mClassCount = classCount; mNumberofpoints = numberofpoints; @@ -40,6 +40,7 @@ YoloLayerPlugin::YoloLayerPlugin(int classCount, int numberofpoints, float conft memcpy(mStrides, strides, stridesLength * sizeof(int)); is_segmentation_ = is_segmentation; is_pose_ = is_pose; + is_obb_ = is_obb; } YoloLayerPlugin::~YoloLayerPlugin() { @@ -66,6 +67,7 @@ YoloLayerPlugin::YoloLayerPlugin(const void* data, size_t length) { } read(d, is_segmentation_); read(d, is_pose_); + read(d, is_obb_); assert(d == a + length); } @@ -87,6 +89,7 @@ void YoloLayerPlugin::serialize(void* buffer) const TRT_NOEXCEPT { } write(d, is_segmentation_); write(d, is_pose_); + write(d, is_obb_); assert(d == a + getSerializationSize()); } @@ -94,7 +97,7 @@ void YoloLayerPlugin::serialize(void* buffer) const TRT_NOEXCEPT { size_t YoloLayerPlugin::getSerializationSize() const TRT_NOEXCEPT { return sizeof(mClassCount) + sizeof(mNumberofpoints) + sizeof(mConfthreshkeypoints) + sizeof(mThreadCount) + sizeof(mYoloV8netHeight) + sizeof(mYoloV8NetWidth) + sizeof(mMaxOutObject) + sizeof(mStridesLength) + - sizeof(int) * mStridesLength + sizeof(is_segmentation_) + sizeof(is_pose_); + sizeof(int) * mStridesLength + sizeof(is_segmentation_) + sizeof(is_pose_) + sizeof(is_obb_); } int YoloLayerPlugin::initialize() TRT_NOEXCEPT { @@ -156,7 +159,7 @@ nvinfer1::IPluginV2IOExt* YoloLayerPlugin::clone() const TRT_NOEXCEPT { YoloLayerPlugin* p = new YoloLayerPlugin(mClassCount, mNumberofpoints, mConfthreshkeypoints, mYoloV8NetWidth, mYoloV8netHeight, - mMaxOutObject, is_segmentation_, is_pose_, mStrides, mStridesLength); + mMaxOutObject, is_segmentation_, is_pose_, is_obb_, mStrides, mStridesLength); p->setPluginNamespace(mPluginNamespace); return p; } @@ -174,14 +177,14 @@ __device__ float Logist(float data) { __global__ void CalDetection(const float* input, float* output, int numElements, int maxoutobject, const int grid_h, int grid_w, const int stride, int classes, int nk, float confkeypoints, int outputElem, - bool is_segmentation, bool is_pose) { + bool is_segmentation, bool is_pose, bool is_obb) { int idx = threadIdx.x + blockDim.x * blockIdx.x; if (idx >= numElements) return; const int N_kpts = nk; int total_grid = grid_h * grid_w; - int info_len = 4 + classes + (is_segmentation ? 32 : 0) + (is_pose ? N_kpts * 3 : 0); + int info_len = 4 + classes + (is_segmentation ? 32 : 0) + (is_pose ? N_kpts * 3 : 0) + (is_obb ? 1 : 0); int batchIdx = idx / total_grid; int elemIdx = idx % total_grid; const float* curInput = input + batchIdx * total_grid * info_len; @@ -218,15 +221,16 @@ __global__ void CalDetection(const float* input, float* output, int numElements, if (is_segmentation) { for (int k = 0; k < 32; ++k) { - det->mask[k] = curInput[elemIdx + (4 + classes + k) * total_grid]; + det->mask[k] = + curInput[elemIdx + (4 + classes + (is_pose ? N_kpts * 3 : 0) + (is_obb ? 1 : 0) + k) * total_grid]; } } if (is_pose) { for (int kpt = 0; kpt < N_kpts; kpt++) { - int kpt_x_idx = (4 + classes + (is_segmentation ? 32 : 0) + kpt * 3) * total_grid; - int kpt_y_idx = (4 + classes + (is_segmentation ? 32 : 0) + kpt * 3 + 1) * total_grid; - int kpt_conf_idx = (4 + classes + (is_segmentation ? 32 : 0) + kpt * 3 + 2) * total_grid; + int kpt_x_idx = (4 + classes + (is_segmentation ? 32 : 0) + (is_obb ? 1 : 0) + kpt * 3) * total_grid; + int kpt_y_idx = (4 + classes + (is_segmentation ? 32 : 0) + (is_obb ? 1 : 0) + kpt * 3 + 1) * total_grid; + int kpt_conf_idx = (4 + classes + (is_segmentation ? 32 : 0) + (is_obb ? 1 : 0) + kpt * 3 + 2) * total_grid; float kpt_confidence = sigmoid(curInput[elemIdx + kpt_conf_idx]); @@ -247,24 +251,43 @@ __global__ void CalDetection(const float* input, float* output, int numElements, } } } + + if (is_obb) { + double pi = M_PI; + auto angle_inx = curInput[elemIdx + (4 + classes + (is_segmentation ? 32 : 0) + (is_pose ? N_kpts * 3 : 0) + + 0) * total_grid]; + auto angle = (sigmoid(angle_inx) - 0.25f) * pi; + + auto cos1 = cos(angle); + auto sin1 = sin(angle); + auto xf = (curInput[elemIdx + 2 * total_grid] - curInput[elemIdx + 0 * total_grid]) / 2; + auto yf = (curInput[elemIdx + 3 * total_grid] - curInput[elemIdx + 1 * total_grid]) / 2; + + auto x = xf * cos1 - yf * sin1; + auto y = xf * sin1 + yf * cos1; + + float cx = (col + 0.5f + x) * stride; + float cy = (row + 0.5f + y) * stride; + + float w1 = (curInput[elemIdx + 0 * total_grid] + curInput[elemIdx + 2 * total_grid]) * stride; + float h1 = (curInput[elemIdx + 1 * total_grid] + curInput[elemIdx + 3 * total_grid]) * stride; + det->bbox[0] = cx; + det->bbox[1] = cy; + det->bbox[2] = w1; + det->bbox[3] = h1; + det->angle = angle; + } } void YoloLayerPlugin::forwardGpu(const float* const* inputs, float* output, cudaStream_t stream, int mYoloV8netHeight, int mYoloV8NetWidth, int batchSize) { + int outputElem = 1 + mMaxOutObject * sizeof(Detection) / sizeof(float); cudaMemsetAsync(output, 0, sizeof(float), stream); for (int idx = 0; idx < batchSize; ++idx) { CUDA_CHECK(cudaMemsetAsync(output + idx * outputElem, 0, sizeof(float), stream)); } int numElem = 0; - - // const int maxGrids = mStridesLength; - // int grids[maxGrids][2]; - // for (int i = 0; i < maxGrids; ++i) { - // grids[i][0] = mYoloV8netHeight / mStrides[i]; - // grids[i][1] = mYoloV8NetWidth / mStrides[i]; - // } - int maxGrids = mStridesLength; int flatGridsLen = 2 * maxGrids; int* flatGrids = new int[flatGridsLen]; @@ -286,7 +309,7 @@ void YoloLayerPlugin::forwardGpu(const float* const* inputs, float* output, cuda // The CUDA kernel call remains unchanged CalDetection<<<(numElem + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream>>>( inputs[i], output, numElem, mMaxOutObject, grid_h, grid_w, stride, mClassCount, mNumberofpoints, - mConfthreshkeypoints, outputElem, is_segmentation_, is_pose_); + mConfthreshkeypoints, outputElem, is_segmentation_, is_pose_, is_obb_); } delete[] flatGrids; @@ -317,7 +340,7 @@ IPluginV2IOExt* YoloPluginCreator::createPlugin(const char* name, const PluginFi assert(fc->nbFields == 1); assert(strcmp(fc->fields[0].name, "combinedInfo") == 0); const int* combinedInfo = static_cast(fc->fields[0].data); - int netinfo_count = 8; + int netinfo_count = 9; int class_count = combinedInfo[0]; int numberofpoints = combinedInfo[1]; float confthreshkeypoints = combinedInfo[2]; @@ -326,11 +349,12 @@ IPluginV2IOExt* YoloPluginCreator::createPlugin(const char* name, const PluginFi int max_output_object_count = combinedInfo[5]; bool is_segmentation = combinedInfo[6]; bool is_pose = combinedInfo[7]; + bool is_obb = combinedInfo[8]; const int* px_arry = combinedInfo + netinfo_count; int px_arry_length = fc->fields[0].length - netinfo_count; YoloLayerPlugin* obj = new YoloLayerPlugin(class_count, numberofpoints, confthreshkeypoints, input_w, input_h, - max_output_object_count, is_segmentation, is_pose, px_arry, px_arry_length); + max_output_object_count, is_segmentation, is_pose, is_obb, px_arry, px_arry_length); obj->setPluginNamespace(mNamespace.c_str()); return obj; } diff --git a/yolov8/plugin/yololayer.h b/yolov8/plugin/yololayer.h index b516ad87..e1dbab5b 100644 --- a/yolov8/plugin/yololayer.h +++ b/yolov8/plugin/yololayer.h @@ -7,7 +7,7 @@ namespace nvinfer1 { class API YoloLayerPlugin : public IPluginV2IOExt { public: YoloLayerPlugin(int classCount, int numberofpoints, float confthreshkeypoints, int netWidth, int netHeight, - int maxOut, bool is_segmentation, bool is_pose, const int* strides, int stridesLength); + int maxOut, bool is_segmentation, bool is_pose, bool is_obb, const int* strides, int stridesLength); YoloLayerPlugin(const void* data, size_t length); ~YoloLayerPlugin(); @@ -75,6 +75,7 @@ class API YoloLayerPlugin : public IPluginV2IOExt { int mMaxOutObject; bool is_segmentation_; bool is_pose_; + bool is_obb_; int* mStrides; int mStridesLength; }; diff --git a/yolov8/src/block.cpp b/yolov8/src/block.cpp index caf395f4..43a694d6 100644 --- a/yolov8/src/block.cpp +++ b/yolov8/src/block.cpp @@ -258,14 +258,15 @@ nvinfer1::IShuffleLayer* DFL(nvinfer1::INetworkDefinition* network, std::map dets, const int* px_arry, - int px_arry_num, bool is_segmentation, bool is_pose) { + int px_arry_num, int num_class, bool is_segmentation, bool is_pose, + bool is_obb) { auto creator = getPluginRegistry()->getPluginCreator("YoloLayer_TRT", "1"); - const int netinfo_count = 8; // Assuming the first 5 elements are for netinfo as per existing code. + const int netinfo_count = 9; // Assuming the first 5 elements are for netinfo as per existing code. const int total_count = netinfo_count + px_arry_num; // Total number of elements for netinfo and px_arry combined. std::vector combinedInfo(total_count); // Fill in the first 5 elements as per existing netinfo. - combinedInfo[0] = kNumClass; + combinedInfo[0] = num_class; combinedInfo[1] = kNumberOfPoints; combinedInfo[2] = kConfThreshKeypoints; combinedInfo[3] = kInputW; @@ -273,6 +274,7 @@ nvinfer1::IPluginV2Layer* addYoLoLayer(nvinfer1::INetworkDefinition* network, combinedInfo[5] = kMaxNumOutputBbox; combinedInfo[6] = is_segmentation; combinedInfo[7] = is_pose; + combinedInfo[8] = is_obb; // Copy the contents of px_arry into the combinedInfo vector after the initial // 5 elements. diff --git a/yolov8/src/model.cpp b/yolov8/src/model.cpp index 8fb524e7..e42951bc 100644 --- a/yolov8/src/model.cpp +++ b/yolov8/src/model.cpp @@ -75,6 +75,10 @@ static nvinfer1::IShuffleLayer* cv4_conv_combined(nvinfer1::INetworkDefinition* std::string bn_weight_key = lname + ".0.bn.weight"; mid_channle = weightMap[bn_weight_key].count; output_channel = kNumberOfPoints * 3; + } else if (algo_type == "obb") { + std::string bn_weight_key = lname + ".0.bn.weight"; + mid_channle = weightMap[bn_weight_key].count; + output_channel = 1; } auto cv0 = convBnSiLU(network, weightMap, input, mid_channle, 3, 1, 1, lname + ".0"); @@ -300,7 +304,7 @@ nvinfer1::IHostMemory* buildEngineYolov8Det(nvinfer1::IBuilder* builder, nvinfer nvinfer1::IPluginV2Layer* yolo = addYoLoLayer(network, std::vector{cat22_dfl_0, cat22_dfl_1, cat22_dfl_2}, - strides, stridesLength, false, false); + strides, stridesLength, kNumClass, false, false, false); yolo->getOutput(0)->setName(kOutputTensorName); network->markOutput(*yolo->getOutput(0)); @@ -616,7 +620,7 @@ nvinfer1::IHostMemory* buildEngineYolov8DetP6(nvinfer1::IBuilder* builder, nvinf nvinfer1::IPluginV2Layer* yolo = addYoLoLayer( network, std::vector{cat30_dfl_0, cat30_dfl_1, cat30_dfl_2, cat30_dfl_3}, - strides, stridesLength, false, false); + strides, stridesLength, kNumClass, false, false, false); yolo->getOutput(0)->setName(kOutputTensorName); network->markOutput(*yolo->getOutput(0)); @@ -932,7 +936,7 @@ nvinfer1::IHostMemory* buildEngineYolov8DetP2(nvinfer1::IBuilder* builder, nvinf nvinfer1::IPluginV2Layer* yolo = addYoLoLayer( network, std::vector{cat28_dfl_0, cat28_dfl_1, cat28_dfl_2, cat28_dfl_3}, - strides, stridesLength, false, false); + strides, stridesLength, kNumClass, false, false, false); yolo->getOutput(0)->setName(kOutputTensorName); network->markOutput(*yolo->getOutput(0)); @@ -1269,7 +1273,7 @@ nvinfer1::IHostMemory* buildEngineYolov8Seg(nvinfer1::IBuilder* builder, nvinfer nvinfer1::IPluginV2Layer* yolo = addYoLoLayer(network, std::vector{cat22_dfl_0, cat22_dfl_1, cat22_dfl_2}, - strides, stridesLength, true, false); + strides, stridesLength, kNumClass, true, false, false); yolo->getOutput(0)->setName(kOutputTensorName); network->markOutput(*yolo->getOutput(0)); @@ -1388,7 +1392,7 @@ nvinfer1::IHostMemory* buildEngineYolov8Pose(nvinfer1::IBuilder* builder, nvinfe ******************************************* *******************************************************************************************************/ int base_in_channel = (gw == 1.25) ? 80 : 64; - int base_out_channel = (gw == 0.25) ? std::max(64, std::min(kNumClass, 100)) : get_width(256, gw, max_channels); + int base_out_channel = (gw == 0.25) ? std::max(64, std::min(kPoseNumClass, 100)) : get_width(256, gw, max_channels); // output0 nvinfer1::IElementWiseLayer* conv22_cv2_0_0 = @@ -1405,7 +1409,7 @@ nvinfer1::IHostMemory* buildEngineYolov8Pose(nvinfer1::IBuilder* builder, nvinfe nvinfer1::IElementWiseLayer* conv22_cv3_0_1 = convBnSiLU(network, weightMap, *conv22_cv3_0_0->getOutput(0), base_out_channel, 3, 1, 1, "model.22.cv3.0.1"); nvinfer1::IConvolutionLayer* conv22_cv3_0_2 = - network->addConvolutionNd(*conv22_cv3_0_1->getOutput(0), kNumClass, nvinfer1::DimsHW{1, 1}, + network->addConvolutionNd(*conv22_cv3_0_1->getOutput(0), kPoseNumClass, nvinfer1::DimsHW{1, 1}, weightMap["model.22.cv3.0.2.weight"], weightMap["model.22.cv3.0.2.bias"]); conv22_cv3_0_2->setStride(nvinfer1::DimsHW{1, 1}); conv22_cv3_0_2->setPadding(nvinfer1::DimsHW{0, 0}); @@ -1427,7 +1431,7 @@ nvinfer1::IHostMemory* buildEngineYolov8Pose(nvinfer1::IBuilder* builder, nvinfe nvinfer1::IElementWiseLayer* conv22_cv3_1_1 = convBnSiLU(network, weightMap, *conv22_cv3_1_0->getOutput(0), base_out_channel, 3, 1, 1, "model.22.cv3.1.1"); nvinfer1::IConvolutionLayer* conv22_cv3_1_2 = - network->addConvolutionNd(*conv22_cv3_1_1->getOutput(0), kNumClass, nvinfer1::DimsHW{1, 1}, + network->addConvolutionNd(*conv22_cv3_1_1->getOutput(0), kPoseNumClass, nvinfer1::DimsHW{1, 1}, weightMap["model.22.cv3.1.2.weight"], weightMap["model.22.cv3.1.2.bias"]); conv22_cv3_1_2->setStrideNd(nvinfer1::DimsHW{1, 1}); conv22_cv3_1_2->setPaddingNd(nvinfer1::DimsHW{0, 0}); @@ -1447,7 +1451,7 @@ nvinfer1::IHostMemory* buildEngineYolov8Pose(nvinfer1::IBuilder* builder, nvinfe nvinfer1::IElementWiseLayer* conv22_cv3_2_1 = convBnSiLU(network, weightMap, *conv22_cv3_2_0->getOutput(0), base_out_channel, 3, 1, 1, "model.22.cv3.2.1"); nvinfer1::IConvolutionLayer* conv22_cv3_2_2 = - network->addConvolution(*conv22_cv3_2_1->getOutput(0), kNumClass, nvinfer1::DimsHW{1, 1}, + network->addConvolution(*conv22_cv3_2_1->getOutput(0), kPoseNumClass, nvinfer1::DimsHW{1, 1}, weightMap["model.22.cv3.2.2.weight"], weightMap["model.22.cv3.2.2.bias"]); nvinfer1::ITensor* inputTensor22_2[] = {conv22_cv2_2_2->getOutput(0), conv22_cv3_2_2->getOutput(0)}; nvinfer1::IConcatenationLayer* cat22_2 = network->addConcatenation(inputTensor22_2, 2); @@ -1463,13 +1467,14 @@ nvinfer1::IHostMemory* buildEngineYolov8Pose(nvinfer1::IBuilder* builder, nvinfe /**************************************************************************************P3****************************************************************************************************************************************/ nvinfer1::IShuffleLayer* shuffle22_0 = network->addShuffle(*cat22_0->getOutput(0)); - shuffle22_0->setReshapeDimensions(nvinfer1::Dims2{64 + kNumClass, (kInputH / strides[0]) * (kInputW / strides[0])}); + shuffle22_0->setReshapeDimensions( + nvinfer1::Dims2{64 + kPoseNumClass, (kInputH / strides[0]) * (kInputW / strides[0])}); nvinfer1::ISliceLayer* split22_0_0 = network->addSlice( *shuffle22_0->getOutput(0), nvinfer1::Dims2{0, 0}, nvinfer1::Dims2{64, (kInputH / strides[0]) * (kInputW / strides[0])}, nvinfer1::Dims2{1, 1}); nvinfer1::ISliceLayer* split22_0_1 = network->addSlice( *shuffle22_0->getOutput(0), nvinfer1::Dims2{64, 0}, - nvinfer1::Dims2{kNumClass, (kInputH / strides[0]) * (kInputW / strides[0])}, nvinfer1::Dims2{1, 1}); + nvinfer1::Dims2{kPoseNumClass, (kInputH / strides[0]) * (kInputW / strides[0])}, nvinfer1::Dims2{1, 1}); nvinfer1::IShuffleLayer* dfl22_0 = DFL(network, weightMap, *split22_0_0->getOutput(0), 4, (kInputH / strides[0]) * (kInputW / strides[0]), 1, 1, 0, "model.22.dfl.conv.weight"); @@ -1484,13 +1489,14 @@ nvinfer1::IHostMemory* buildEngineYolov8Pose(nvinfer1::IBuilder* builder, nvinfe /********************************************************************************************P4**********************************************************************************************************************************/ nvinfer1::IShuffleLayer* shuffle22_1 = network->addShuffle(*cat22_1->getOutput(0)); - shuffle22_1->setReshapeDimensions(nvinfer1::Dims2{64 + kNumClass, (kInputH / strides[1]) * (kInputW / strides[1])}); + shuffle22_1->setReshapeDimensions( + nvinfer1::Dims2{64 + kPoseNumClass, (kInputH / strides[1]) * (kInputW / strides[1])}); nvinfer1::ISliceLayer* split22_1_0 = network->addSlice( *shuffle22_1->getOutput(0), nvinfer1::Dims2{0, 0}, nvinfer1::Dims2{64, (kInputH / strides[1]) * (kInputW / strides[1])}, nvinfer1::Dims2{1, 1}); nvinfer1::ISliceLayer* split22_1_1 = network->addSlice( *shuffle22_1->getOutput(0), nvinfer1::Dims2{64, 0}, - nvinfer1::Dims2{kNumClass, (kInputH / strides[1]) * (kInputW / strides[1])}, nvinfer1::Dims2{1, 1}); + nvinfer1::Dims2{kPoseNumClass, (kInputH / strides[1]) * (kInputW / strides[1])}, nvinfer1::Dims2{1, 1}); nvinfer1::IShuffleLayer* dfl22_1 = DFL(network, weightMap, *split22_1_0->getOutput(0), 4, (kInputH / strides[1]) * (kInputW / strides[1]), 1, 1, 0, "model.22.dfl.conv.weight"); @@ -1505,13 +1511,14 @@ nvinfer1::IHostMemory* buildEngineYolov8Pose(nvinfer1::IBuilder* builder, nvinfe /********************************************************************************************P5**********************************************************************************************************************************/ nvinfer1::IShuffleLayer* shuffle22_2 = network->addShuffle(*cat22_2->getOutput(0)); - shuffle22_2->setReshapeDimensions(nvinfer1::Dims2{64 + kNumClass, (kInputH / strides[2]) * (kInputW / strides[2])}); + shuffle22_2->setReshapeDimensions( + nvinfer1::Dims2{64 + kPoseNumClass, (kInputH / strides[2]) * (kInputW / strides[2])}); nvinfer1::ISliceLayer* split22_2_0 = network->addSlice( *shuffle22_2->getOutput(0), nvinfer1::Dims2{0, 0}, nvinfer1::Dims2{64, (kInputH / strides[2]) * (kInputW / strides[2])}, nvinfer1::Dims2{1, 1}); nvinfer1::ISliceLayer* split22_2_1 = network->addSlice( *shuffle22_2->getOutput(0), nvinfer1::Dims2{64, 0}, - nvinfer1::Dims2{kNumClass, (kInputH / strides[2]) * (kInputW / strides[2])}, nvinfer1::Dims2{1, 1}); + nvinfer1::Dims2{kPoseNumClass, (kInputH / strides[2]) * (kInputW / strides[2])}, nvinfer1::Dims2{1, 1}); nvinfer1::IShuffleLayer* dfl22_2 = DFL(network, weightMap, *split22_2_0->getOutput(0), 4, (kInputH / strides[2]) * (kInputW / strides[2]), 1, 1, 0, "model.22.dfl.conv.weight"); @@ -1525,7 +1532,7 @@ nvinfer1::IHostMemory* buildEngineYolov8Pose(nvinfer1::IBuilder* builder, nvinfe nvinfer1::IPluginV2Layer* yolo = addYoLoLayer(network, std::vector{cat22_dfl_0, cat22_dfl_1, cat22_dfl_2}, - strides, stridesLength, false, true); + strides, stridesLength, kPoseNumClass, false, true, false); yolo->getOutput(0)->setName(kOutputTensorName); network->markOutput(*yolo->getOutput(0)); @@ -1673,7 +1680,7 @@ nvinfer1::IHostMemory* buildEngineYolov8PoseP6(nvinfer1::IBuilder* builder, nvin ******************************************* *******************************************************************************************************/ int base_in_channel = (gw == 1.25) ? 80 : 64; - int base_out_channel = (gw == 0.25) ? std::max(64, std::min(kNumClass, 100)) : get_width(256, gw, max_channels); + int base_out_channel = (gw == 0.25) ? std::max(64, std::min(kPoseNumClass, 100)) : get_width(256, gw, max_channels); // output0 nvinfer1::IElementWiseLayer* conv30_cv2_0_0 = @@ -1693,7 +1700,7 @@ nvinfer1::IHostMemory* buildEngineYolov8PoseP6(nvinfer1::IBuilder* builder, nvin nvinfer1::IElementWiseLayer* conv30_cv3_0_1 = convBnSiLU(network, weightMap, *conv30_cv3_0_0->getOutput(0), base_out_channel, 3, 1, 1, "model.30.cv3.0.1"); nvinfer1::IConvolutionLayer* conv30_cv3_0_2 = - network->addConvolutionNd(*conv30_cv3_0_1->getOutput(0), kNumClass, nvinfer1::DimsHW{1, 1}, + network->addConvolutionNd(*conv30_cv3_0_1->getOutput(0), kPoseNumClass, nvinfer1::DimsHW{1, 1}, weightMap["model.30.cv3.0.2.weight"], weightMap["model.30.cv3.0.2.bias"]); conv30_cv3_0_2->setStride(nvinfer1::DimsHW{1, 1}); conv30_cv3_0_2->setPadding(nvinfer1::DimsHW{0, 0}); @@ -1715,7 +1722,7 @@ nvinfer1::IHostMemory* buildEngineYolov8PoseP6(nvinfer1::IBuilder* builder, nvin nvinfer1::IElementWiseLayer* conv30_cv3_1_1 = convBnSiLU(network, weightMap, *conv30_cv3_1_0->getOutput(0), base_out_channel, 3, 1, 1, "model.30.cv3.1.1"); nvinfer1::IConvolutionLayer* conv30_cv3_1_2 = - network->addConvolutionNd(*conv30_cv3_1_1->getOutput(0), kNumClass, nvinfer1::DimsHW{1, 1}, + network->addConvolutionNd(*conv30_cv3_1_1->getOutput(0), kPoseNumClass, nvinfer1::DimsHW{1, 1}, weightMap["model.30.cv3.1.2.weight"], weightMap["model.30.cv3.1.2.bias"]); conv30_cv3_1_2->setStrideNd(nvinfer1::DimsHW{1, 1}); conv30_cv3_1_2->setPaddingNd(nvinfer1::DimsHW{0, 0}); @@ -1737,7 +1744,7 @@ nvinfer1::IHostMemory* buildEngineYolov8PoseP6(nvinfer1::IBuilder* builder, nvin nvinfer1::IElementWiseLayer* conv30_cv3_2_1 = convBnSiLU(network, weightMap, *conv30_cv3_2_0->getOutput(0), base_out_channel, 3, 1, 1, "model.30.cv3.2.1"); nvinfer1::IConvolutionLayer* conv30_cv3_2_2 = - network->addConvolution(*conv30_cv3_2_1->getOutput(0), kNumClass, nvinfer1::DimsHW{1, 1}, + network->addConvolution(*conv30_cv3_2_1->getOutput(0), kPoseNumClass, nvinfer1::DimsHW{1, 1}, weightMap["model.30.cv3.2.2.weight"], weightMap["model.30.cv3.2.2.bias"]); conv30_cv3_2_2->setStrideNd(nvinfer1::DimsHW{1, 1}); conv30_cv3_2_2->setPaddingNd(nvinfer1::DimsHW{0, 0}); @@ -1759,7 +1766,7 @@ nvinfer1::IHostMemory* buildEngineYolov8PoseP6(nvinfer1::IBuilder* builder, nvin nvinfer1::IElementWiseLayer* conv30_cv3_3_1 = convBnSiLU(network, weightMap, *conv30_cv3_3_0->getOutput(0), base_out_channel, 3, 1, 1, "model.30.cv3.3.1"); nvinfer1::IConvolutionLayer* conv30_cv3_3_2 = - network->addConvolution(*conv30_cv3_3_1->getOutput(0), kNumClass, nvinfer1::DimsHW{1, 1}, + network->addConvolution(*conv30_cv3_3_1->getOutput(0), kPoseNumClass, nvinfer1::DimsHW{1, 1}, weightMap["model.30.cv3.3.2.weight"], weightMap["model.30.cv3.3.2.bias"]); conv30_cv3_3_2->setStrideNd(nvinfer1::DimsHW{1, 1}); conv30_cv3_3_2->setPaddingNd(nvinfer1::DimsHW{0, 0}); @@ -1778,13 +1785,14 @@ nvinfer1::IHostMemory* buildEngineYolov8PoseP6(nvinfer1::IBuilder* builder, nvin // P3 processing steps (remains unchanged) nvinfer1::IShuffleLayer* shuffle30_0 = network->addShuffle(*cat30_0->getOutput(0)); // Reusing the previous cat30_0 as P3 concatenation layer - shuffle30_0->setReshapeDimensions(nvinfer1::Dims2{64 + kNumClass, (kInputH / strides[0]) * (kInputW / strides[0])}); + shuffle30_0->setReshapeDimensions( + nvinfer1::Dims2{64 + kPoseNumClass, (kInputH / strides[0]) * (kInputW / strides[0])}); nvinfer1::ISliceLayer* split30_0_0 = network->addSlice( *shuffle30_0->getOutput(0), nvinfer1::Dims2{0, 0}, nvinfer1::Dims2{64, (kInputH / strides[0]) * (kInputW / strides[0])}, nvinfer1::Dims2{1, 1}); nvinfer1::ISliceLayer* split30_0_1 = network->addSlice( *shuffle30_0->getOutput(0), nvinfer1::Dims2{64, 0}, - nvinfer1::Dims2{kNumClass, (kInputH / strides[0]) * (kInputW / strides[0])}, nvinfer1::Dims2{1, 1}); + nvinfer1::Dims2{kPoseNumClass, (kInputH / strides[0]) * (kInputW / strides[0])}, nvinfer1::Dims2{1, 1}); nvinfer1::IShuffleLayer* dfl30_0 = DFL(network, weightMap, *split30_0_0->getOutput(0), 4, (kInputH / strides[0]) * (kInputW / strides[0]), 1, 1, 0, "model.30.dfl.conv.weight"); @@ -1799,13 +1807,14 @@ nvinfer1::IHostMemory* buildEngineYolov8PoseP6(nvinfer1::IBuilder* builder, nvin // P4 processing steps (remains unchanged) nvinfer1::IShuffleLayer* shuffle30_1 = network->addShuffle(*cat30_1->getOutput(0)); // Reusing the previous cat30_1 as P4 concatenation layer - shuffle30_1->setReshapeDimensions(nvinfer1::Dims2{64 + kNumClass, (kInputH / strides[1]) * (kInputW / strides[1])}); + shuffle30_1->setReshapeDimensions( + nvinfer1::Dims2{64 + kPoseNumClass, (kInputH / strides[1]) * (kInputW / strides[1])}); nvinfer1::ISliceLayer* split30_1_0 = network->addSlice( *shuffle30_1->getOutput(0), nvinfer1::Dims2{0, 0}, nvinfer1::Dims2{64, (kInputH / strides[1]) * (kInputW / strides[1])}, nvinfer1::Dims2{1, 1}); nvinfer1::ISliceLayer* split30_1_1 = network->addSlice( *shuffle30_1->getOutput(0), nvinfer1::Dims2{64, 0}, - nvinfer1::Dims2{kNumClass, (kInputH / strides[1]) * (kInputW / strides[1])}, nvinfer1::Dims2{1, 1}); + nvinfer1::Dims2{kPoseNumClass, (kInputH / strides[1]) * (kInputW / strides[1])}, nvinfer1::Dims2{1, 1}); nvinfer1::IShuffleLayer* dfl30_1 = DFL(network, weightMap, *split30_1_0->getOutput(0), 4, (kInputH / strides[1]) * (kInputW / strides[1]), 1, 1, 0, "model.30.dfl.conv.weight"); @@ -1820,13 +1829,14 @@ nvinfer1::IHostMemory* buildEngineYolov8PoseP6(nvinfer1::IBuilder* builder, nvin // P5 processing steps (remains unchanged) nvinfer1::IShuffleLayer* shuffle30_2 = network->addShuffle(*cat30_2->getOutput(0)); // Reusing the previous cat30_2 as P5 concatenation layer - shuffle30_2->setReshapeDimensions(nvinfer1::Dims2{64 + kNumClass, (kInputH / strides[2]) * (kInputW / strides[2])}); + shuffle30_2->setReshapeDimensions( + nvinfer1::Dims2{64 + kPoseNumClass, (kInputH / strides[2]) * (kInputW / strides[2])}); nvinfer1::ISliceLayer* split30_2_0 = network->addSlice( *shuffle30_2->getOutput(0), nvinfer1::Dims2{0, 0}, nvinfer1::Dims2{64, (kInputH / strides[2]) * (kInputW / strides[2])}, nvinfer1::Dims2{1, 1}); nvinfer1::ISliceLayer* split30_2_1 = network->addSlice( *shuffle30_2->getOutput(0), nvinfer1::Dims2{64, 0}, - nvinfer1::Dims2{kNumClass, (kInputH / strides[2]) * (kInputW / strides[2])}, nvinfer1::Dims2{1, 1}); + nvinfer1::Dims2{kPoseNumClass, (kInputH / strides[2]) * (kInputW / strides[2])}, nvinfer1::Dims2{1, 1}); nvinfer1::IShuffleLayer* dfl30_2 = DFL(network, weightMap, *split30_2_0->getOutput(0), 4, (kInputH / strides[2]) * (kInputW / strides[2]), 1, 1, 0, "model.30.dfl.conv.weight"); @@ -1840,13 +1850,14 @@ nvinfer1::IHostMemory* buildEngineYolov8PoseP6(nvinfer1::IBuilder* builder, nvin // P6 processing steps nvinfer1::IShuffleLayer* shuffle30_3 = network->addShuffle(*cat30_3->getOutput(0)); - shuffle30_3->setReshapeDimensions(nvinfer1::Dims2{64 + kNumClass, (kInputH / strides[3]) * (kInputW / strides[3])}); + shuffle30_3->setReshapeDimensions( + nvinfer1::Dims2{64 + kPoseNumClass, (kInputH / strides[3]) * (kInputW / strides[3])}); nvinfer1::ISliceLayer* split30_3_0 = network->addSlice( *shuffle30_3->getOutput(0), nvinfer1::Dims2{0, 0}, nvinfer1::Dims2{64, (kInputH / strides[3]) * (kInputW / strides[3])}, nvinfer1::Dims2{1, 1}); nvinfer1::ISliceLayer* split30_3_1 = network->addSlice( *shuffle30_3->getOutput(0), nvinfer1::Dims2{64, 0}, - nvinfer1::Dims2{kNumClass, (kInputH / strides[3]) * (kInputW / strides[3])}, nvinfer1::Dims2{1, 1}); + nvinfer1::Dims2{kPoseNumClass, (kInputH / strides[3]) * (kInputW / strides[3])}, nvinfer1::Dims2{1, 1}); nvinfer1::IShuffleLayer* dfl30_3 = DFL(network, weightMap, *split30_3_0->getOutput(0), 4, (kInputH / strides[3]) * (kInputW / strides[3]), 1, 1, 0, "model.30.dfl.conv.weight"); @@ -1860,7 +1871,7 @@ nvinfer1::IHostMemory* buildEngineYolov8PoseP6(nvinfer1::IBuilder* builder, nvin nvinfer1::IPluginV2Layer* yolo = addYoLoLayer( network, std::vector{cat30_dfl_0, cat30_dfl_1, cat30_dfl_2, cat30_dfl_3}, - strides, stridesLength, false, true); + strides, stridesLength, kPoseNumClass, false, true, false); yolo->getOutput(0)->setName(kOutputTensorName); network->markOutput(*yolo->getOutput(0)); @@ -2123,7 +2134,7 @@ nvinfer1::IHostMemory* buildEngineYolov8_5uDet(nvinfer1::IBuilder* builder, nvin nvinfer1::IPluginV2Layer* yolo = addYoLoLayer(network, std::vector{cat24_dfl_0, cat24_dfl_1, cat24_dfl_2}, - strides, stridesLength, false, false); + strides, stridesLength, kNumClass, false, false, false); yolo->getOutput(0)->setName(kOutputTensorName); network->markOutput(*yolo->getOutput(0)); @@ -2454,7 +2465,260 @@ nvinfer1::IHostMemory* buildEngineYolov8_5uDetP6(nvinfer1::IBuilder* builder, nv nvinfer1::IPluginV2Layer* yolo = addYoLoLayer( network, std::vector{cat33_dfl_0, cat33_dfl_1, cat33_dfl_2, cat33_dfl_3}, - strides, stridesLength, false, false); + strides, stridesLength, kNumClass, false, false, false); + + yolo->getOutput(0)->setName(kOutputTensorName); + network->markOutput(*yolo->getOutput(0)); + + builder->setMaxBatchSize(kBatchSize); + config->setMaxWorkspaceSize(16 * (1 << 20)); + +#if defined(USE_FP16) + config->setFlag(nvinfer1::BuilderFlag::kFP16); +#elif defined(USE_INT8) + std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl; + assert(builder->platformHasFastInt8()); + config->setFlag(nvinfer1::BuilderFlag::kINT8); + auto* calibrator = new Int8EntropyCalibrator2(1, kInputW, kInputH, kInputQuantizationFolder, "int8calib.table", + kInputTensorName); + config->setInt8Calibrator(calibrator); +#endif + + std::cout << "Building engine, please wait for a while..." << std::endl; + nvinfer1::IHostMemory* serialized_model = builder->buildSerializedNetwork(*network, *config); + std::cout << "Build engine successfully!" << std::endl; + + delete network; + + for (auto& mem : weightMap) { + free((void*)(mem.second.values)); + } + return serialized_model; +} + +nvinfer1::IHostMemory* buildEngineYolov8Obb(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config, + nvinfer1::DataType dt, const std::string& wts_path, float& gd, float& gw, + int& max_channels) { + std::map weightMap = loadWeights(wts_path); + nvinfer1::INetworkDefinition* network = builder->createNetworkV2(0U); + + /******************************************************************************************************* + ****************************************** YOLOV8 INPUT ********************************************** + *******************************************************************************************************/ + nvinfer1::ITensor* data = network->addInput(kInputTensorName, dt, nvinfer1::Dims3{3, kInputH, kInputW}); + assert(data); + + /******************************************************************************************************* + ***************************************** YOLOV8 BACKBONE ******************************************** + *******************************************************************************************************/ + nvinfer1::IElementWiseLayer* conv0 = + convBnSiLU(network, weightMap, *data, get_width(64, gw, max_channels), 3, 2, 1, "model.0"); + nvinfer1::IElementWiseLayer* conv1 = + convBnSiLU(network, weightMap, *conv0->getOutput(0), get_width(128, gw, max_channels), 3, 2, 1, "model.1"); + nvinfer1::IElementWiseLayer* conv2 = C2F(network, weightMap, *conv1->getOutput(0), get_width(128, gw, max_channels), + get_width(128, gw, max_channels), get_depth(3, gd), true, 0.5, "model.2"); + nvinfer1::IElementWiseLayer* conv3 = + convBnSiLU(network, weightMap, *conv2->getOutput(0), get_width(256, gw, max_channels), 3, 2, 1, "model.3"); + nvinfer1::IElementWiseLayer* conv4 = C2F(network, weightMap, *conv3->getOutput(0), get_width(256, gw, max_channels), + get_width(256, gw, max_channels), get_depth(6, gd), true, 0.5, "model.4"); + nvinfer1::IElementWiseLayer* conv5 = + convBnSiLU(network, weightMap, *conv4->getOutput(0), get_width(512, gw, max_channels), 3, 2, 1, "model.5"); + nvinfer1::IElementWiseLayer* conv6 = C2F(network, weightMap, *conv5->getOutput(0), get_width(512, gw, max_channels), + get_width(512, gw, max_channels), get_depth(6, gd), true, 0.5, "model.6"); + nvinfer1::IElementWiseLayer* conv7 = + convBnSiLU(network, weightMap, *conv6->getOutput(0), get_width(1024, gw, max_channels), 3, 2, 1, "model.7"); + nvinfer1::IElementWiseLayer* conv8 = + C2F(network, weightMap, *conv7->getOutput(0), get_width(1024, gw, max_channels), + get_width(1024, gw, max_channels), get_depth(3, gd), true, 0.5, "model.8"); + nvinfer1::IElementWiseLayer* conv9 = + SPPF(network, weightMap, *conv8->getOutput(0), get_width(1024, gw, max_channels), + get_width(1024, gw, max_channels), 5, "model.9"); + + /******************************************************************************************************* + ********************************************* YOLOV8 HEAD ******************************************** + *******************************************************************************************************/ + float scale[] = {1.0, 2.0, 2.0}; + nvinfer1::IResizeLayer* upsample10 = network->addResize(*conv9->getOutput(0)); + assert(upsample10); + upsample10->setResizeMode(nvinfer1::ResizeMode::kNEAREST); + upsample10->setScales(scale, 3); + + nvinfer1::ITensor* inputTensor11[] = {upsample10->getOutput(0), conv6->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat11 = network->addConcatenation(inputTensor11, 2); + nvinfer1::IElementWiseLayer* conv12 = + C2F(network, weightMap, *cat11->getOutput(0), get_width(512, gw, max_channels), + get_width(512, gw, max_channels), get_depth(3, gd), false, 0.5, "model.12"); + + nvinfer1::IResizeLayer* upsample13 = network->addResize(*conv12->getOutput(0)); + assert(upsample13); + upsample13->setResizeMode(nvinfer1::ResizeMode::kNEAREST); + upsample13->setScales(scale, 3); + + nvinfer1::ITensor* inputTensor14[] = {upsample13->getOutput(0), conv4->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat14 = network->addConcatenation(inputTensor14, 2); + nvinfer1::IElementWiseLayer* conv15 = + C2F(network, weightMap, *cat14->getOutput(0), get_width(256, gw, max_channels), + get_width(256, gw, max_channels), get_depth(3, gd), false, 0.5, "model.15"); + nvinfer1::IElementWiseLayer* conv16 = convBnSiLU(network, weightMap, *conv15->getOutput(0), + get_width(256, gw, max_channels), 3, 2, 1, "model.16"); + nvinfer1::ITensor* inputTensor17[] = {conv16->getOutput(0), conv12->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat17 = network->addConcatenation(inputTensor17, 2); + nvinfer1::IElementWiseLayer* conv18 = + C2F(network, weightMap, *cat17->getOutput(0), get_width(512, gw, max_channels), + get_width(512, gw, max_channels), get_depth(3, gd), false, 0.5, "model.18"); + nvinfer1::IElementWiseLayer* conv19 = convBnSiLU(network, weightMap, *conv18->getOutput(0), + get_width(512, gw, max_channels), 3, 2, 1, "model.19"); + nvinfer1::ITensor* inputTensor20[] = {conv19->getOutput(0), conv9->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat20 = network->addConcatenation(inputTensor20, 2); + nvinfer1::IElementWiseLayer* conv21 = + C2F(network, weightMap, *cat20->getOutput(0), get_width(1024, gw, max_channels), + get_width(1024, gw, max_channels), get_depth(3, gd), false, 0.5, "model.21"); + + /******************************************************************************************************* + ********************************************* YOLOV8 OUTPUT ****************************************** + *******************************************************************************************************/ + int base_in_channel = (gw == 1.25) ? 80 : 64; + int base_out_channel = (gw == 0.25) ? std::max(64, std::min(kObbNumClass, 100)) : get_width(256, gw, max_channels); + + // output0 + nvinfer1::IElementWiseLayer* conv22_cv2_0_0 = + convBnSiLU(network, weightMap, *conv15->getOutput(0), base_in_channel, 3, 1, 1, "model.22.cv2.0.0"); + + nvinfer1::IElementWiseLayer* conv22_cv2_0_1 = + convBnSiLU(network, weightMap, *conv22_cv2_0_0->getOutput(0), base_in_channel, 3, 1, 1, "model.22.cv2.0.1"); + + nvinfer1::IConvolutionLayer* conv22_cv2_0_2 = + network->addConvolutionNd(*conv22_cv2_0_1->getOutput(0), 64, nvinfer1::DimsHW{1, 1}, + weightMap["model.22.cv2.0.2.weight"], weightMap["model.22.cv2.0.2.bias"]); + conv22_cv2_0_2->setStrideNd(nvinfer1::DimsHW{1, 1}); + conv22_cv2_0_2->setPaddingNd(nvinfer1::DimsHW{0, 0}); + + nvinfer1::IElementWiseLayer* conv22_cv3_0_0 = + convBnSiLU(network, weightMap, *conv15->getOutput(0), base_out_channel, 3, 1, 1, "model.22.cv3.0.0"); + nvinfer1::IElementWiseLayer* conv22_cv3_0_1 = convBnSiLU(network, weightMap, *conv22_cv3_0_0->getOutput(0), + base_out_channel, 3, 1, 1, "model.22.cv3.0.1"); + + nvinfer1::IConvolutionLayer* conv22_cv3_0_2 = + network->addConvolutionNd(*conv22_cv3_0_1->getOutput(0), kObbNumClass, nvinfer1::DimsHW{1, 1}, + weightMap["model.22.cv3.0.2.weight"], weightMap["model.22.cv3.0.2.bias"]); + conv22_cv3_0_2->setStride(nvinfer1::DimsHW{1, 1}); + conv22_cv3_0_2->setPadding(nvinfer1::DimsHW{0, 0}); + nvinfer1::ITensor* inputTensor22_0[] = {conv22_cv2_0_2->getOutput(0), conv22_cv3_0_2->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat22_0 = network->addConcatenation(inputTensor22_0, 2); + + // output1 + nvinfer1::IElementWiseLayer* conv22_cv2_1_0 = + convBnSiLU(network, weightMap, *conv18->getOutput(0), base_in_channel, 3, 1, 1, "model.22.cv2.1.0"); + nvinfer1::IElementWiseLayer* conv22_cv2_1_1 = + convBnSiLU(network, weightMap, *conv22_cv2_1_0->getOutput(0), base_in_channel, 3, 1, 1, "model.22.cv2.1.1"); + nvinfer1::IConvolutionLayer* conv22_cv2_1_2 = + network->addConvolutionNd(*conv22_cv2_1_1->getOutput(0), 64, nvinfer1::DimsHW{1, 1}, + weightMap["model.22.cv2.1.2.weight"], weightMap["model.22.cv2.1.2.bias"]); + conv22_cv2_1_2->setStrideNd(nvinfer1::DimsHW{1, 1}); + conv22_cv2_1_2->setPaddingNd(nvinfer1::DimsHW{0, 0}); + nvinfer1::IElementWiseLayer* conv22_cv3_1_0 = + convBnSiLU(network, weightMap, *conv18->getOutput(0), base_out_channel, 3, 1, 1, "model.22.cv3.1.0"); + nvinfer1::IElementWiseLayer* conv22_cv3_1_1 = convBnSiLU(network, weightMap, *conv22_cv3_1_0->getOutput(0), + base_out_channel, 3, 1, 1, "model.22.cv3.1.1"); + nvinfer1::IConvolutionLayer* conv22_cv3_1_2 = + network->addConvolutionNd(*conv22_cv3_1_1->getOutput(0), kObbNumClass, nvinfer1::DimsHW{1, 1}, + weightMap["model.22.cv3.1.2.weight"], weightMap["model.22.cv3.1.2.bias"]); + conv22_cv3_1_2->setStrideNd(nvinfer1::DimsHW{1, 1}); + conv22_cv3_1_2->setPaddingNd(nvinfer1::DimsHW{0, 0}); + nvinfer1::ITensor* inputTensor22_1[] = {conv22_cv2_1_2->getOutput(0), conv22_cv3_1_2->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat22_1 = network->addConcatenation(inputTensor22_1, 2); + + // output2 + nvinfer1::IElementWiseLayer* conv22_cv2_2_0 = + convBnSiLU(network, weightMap, *conv21->getOutput(0), base_in_channel, 3, 1, 1, "model.22.cv2.2.0"); + nvinfer1::IElementWiseLayer* conv22_cv2_2_1 = + convBnSiLU(network, weightMap, *conv22_cv2_2_0->getOutput(0), base_in_channel, 3, 1, 1, "model.22.cv2.2.1"); + nvinfer1::IConvolutionLayer* conv22_cv2_2_2 = + network->addConvolution(*conv22_cv2_2_1->getOutput(0), 64, nvinfer1::DimsHW{1, 1}, + weightMap["model.22.cv2.2.2.weight"], weightMap["model.22.cv2.2.2.bias"]); + nvinfer1::IElementWiseLayer* conv22_cv3_2_0 = + convBnSiLU(network, weightMap, *conv21->getOutput(0), base_out_channel, 3, 1, 1, "model.22.cv3.2.0"); + nvinfer1::IElementWiseLayer* conv22_cv3_2_1 = convBnSiLU(network, weightMap, *conv22_cv3_2_0->getOutput(0), + base_out_channel, 3, 1, 1, "model.22.cv3.2.1"); + nvinfer1::IConvolutionLayer* conv22_cv3_2_2 = + network->addConvolution(*conv22_cv3_2_1->getOutput(0), kObbNumClass, nvinfer1::DimsHW{1, 1}, + weightMap["model.22.cv3.2.2.weight"], weightMap["model.22.cv3.2.2.bias"]); + nvinfer1::ITensor* inputTensor22_2[] = {conv22_cv2_2_2->getOutput(0), conv22_cv3_2_2->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat22_2 = network->addConcatenation(inputTensor22_2, 2); + + /******************************************************************************************************* + ********************************************* YOLOV8 DETECT ****************************************** + *******************************************************************************************************/ + + nvinfer1::IElementWiseLayer* conv_layers[] = {conv3, conv5, conv7}; + int strides[sizeof(conv_layers) / sizeof(conv_layers[0])]; + calculateStrides(conv_layers, sizeof(conv_layers) / sizeof(conv_layers[0]), kInputH, strides); + int stridesLength = sizeof(strides) / sizeof(int); + + nvinfer1::IShuffleLayer* shuffle22_0 = network->addShuffle(*cat22_0->getOutput(0)); + shuffle22_0->setReshapeDimensions( + nvinfer1::Dims2{64 + kObbNumClass, (kInputH / strides[0]) * (kInputW / strides[0])}); + + nvinfer1::ISliceLayer* split22_0_0 = network->addSlice( + *shuffle22_0->getOutput(0), nvinfer1::Dims2{0, 0}, + nvinfer1::Dims2{64, (kInputH / strides[0]) * (kInputW / strides[0])}, nvinfer1::Dims2{1, 1}); + nvinfer1::ISliceLayer* split22_0_1 = network->addSlice( + *shuffle22_0->getOutput(0), nvinfer1::Dims2{64, 0}, + nvinfer1::Dims2{kObbNumClass, (kInputH / strides[0]) * (kInputW / strides[0])}, nvinfer1::Dims2{1, 1}); + nvinfer1::IShuffleLayer* dfl22_0 = + DFL(network, weightMap, *split22_0_0->getOutput(0), 4, (kInputH / strides[0]) * (kInputW / strides[0]), 1, + 1, 0, "model.22.dfl.conv.weight"); + + nvinfer1::IShuffleLayer* shuffle22_1 = network->addShuffle(*cat22_1->getOutput(0)); + shuffle22_1->setReshapeDimensions( + nvinfer1::Dims2{64 + kObbNumClass, (kInputH / strides[1]) * (kInputW / strides[1])}); + nvinfer1::ISliceLayer* split22_1_0 = network->addSlice( + *shuffle22_1->getOutput(0), nvinfer1::Dims2{0, 0}, + nvinfer1::Dims2{64, (kInputH / strides[1]) * (kInputW / strides[1])}, nvinfer1::Dims2{1, 1}); + nvinfer1::ISliceLayer* split22_1_1 = network->addSlice( + *shuffle22_1->getOutput(0), nvinfer1::Dims2{64, 0}, + nvinfer1::Dims2{kObbNumClass, (kInputH / strides[1]) * (kInputW / strides[1])}, nvinfer1::Dims2{1, 1}); + nvinfer1::IShuffleLayer* dfl22_1 = + DFL(network, weightMap, *split22_1_0->getOutput(0), 4, (kInputH / strides[1]) * (kInputW / strides[1]), 1, + 1, 0, "model.22.dfl.conv.weight"); + + nvinfer1::IShuffleLayer* shuffle22_2 = network->addShuffle(*cat22_2->getOutput(0)); + shuffle22_2->setReshapeDimensions( + nvinfer1::Dims2{64 + kObbNumClass, (kInputH / strides[2]) * (kInputW / strides[2])}); + nvinfer1::ISliceLayer* split22_2_0 = network->addSlice( + *shuffle22_2->getOutput(0), nvinfer1::Dims2{0, 0}, + nvinfer1::Dims2{64, (kInputH / strides[2]) * (kInputW / strides[2])}, nvinfer1::Dims2{1, 1}); + nvinfer1::ISliceLayer* split22_2_1 = network->addSlice( + *shuffle22_2->getOutput(0), nvinfer1::Dims2{64, 0}, + nvinfer1::Dims2{kObbNumClass, (kInputH / strides[2]) * (kInputW / strides[2])}, nvinfer1::Dims2{1, 1}); + nvinfer1::IShuffleLayer* dfl22_2 = + DFL(network, weightMap, *split22_2_0->getOutput(0), 4, (kInputH / strides[2]) * (kInputW / strides[2]), 1, + 1, 0, "model.22.dfl.conv.weight"); + + // det0 + auto shuffle_conv15 = cv4_conv_combined(network, weightMap, *conv15->getOutput(0), "model.22.cv4.0", + (kInputH / strides[0]) * (kInputW / strides[0]), gw, "obb"); + nvinfer1::ITensor* inputTensor22_dfl_0[] = {dfl22_0->getOutput(0), split22_0_1->getOutput(0), + shuffle_conv15->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat22_dfl_0 = network->addConcatenation(inputTensor22_dfl_0, 3); + + // det1 + auto shuffle_conv18 = cv4_conv_combined(network, weightMap, *conv18->getOutput(0), "model.22.cv4.1", + (kInputH / strides[1]) * (kInputW / strides[1]), gw, "obb"); + nvinfer1::ITensor* inputTensor22_dfl_1[] = {dfl22_1->getOutput(0), split22_1_1->getOutput(0), + shuffle_conv18->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat22_dfl_1 = network->addConcatenation(inputTensor22_dfl_1, 3); + + // det2 + auto shuffle_conv21 = cv4_conv_combined(network, weightMap, *conv21->getOutput(0), "model.22.cv4.2", + (kInputH / strides[2]) * (kInputW / strides[2]), gw, "obb"); + nvinfer1::ITensor* inputTensor22_dfl_2[] = {dfl22_2->getOutput(0), split22_2_1->getOutput(0), + shuffle_conv21->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat22_dfl_2 = network->addConcatenation(inputTensor22_dfl_2, 3); + + nvinfer1::IPluginV2Layer* yolo = + addYoLoLayer(network, std::vector{cat22_dfl_0, cat22_dfl_1, cat22_dfl_2}, + strides, stridesLength, kObbNumClass, false, false, true); yolo->getOutput(0)->setName(kOutputTensorName); network->markOutput(*yolo->getOutput(0)); diff --git a/yolov8/src/postprocess.cpp b/yolov8/src/postprocess.cpp index f19acc0a..309e0d0f 100644 --- a/yolov8/src/postprocess.cpp +++ b/yolov8/src/postprocess.cpp @@ -1,4 +1,6 @@ #include "postprocess.h" +#include +#include // Include this header for printing #include "utils.h" cv::Rect get_rect(cv::Mat& img, float bbox[4]) { @@ -94,7 +96,7 @@ void nms(std::vector& res, float* output, float conf_thresh, float nm std::map> m; for (int i = 0; i < output[0]; i++) { - if (output[1 + det_size * i + 4] <= conf_thresh) + if (output[1 + det_size * i + 4] <= conf_thresh || isnan(output[1 + det_size * i + 4])) continue; Detection det; memcpy(&det, &output[1 + det_size * i], det_size * sizeof(float)); @@ -267,3 +269,239 @@ void draw_mask_bbox(cv::Mat& img, std::vector& dets, std::vector& res, const float* decode_ptr_host, int bbox_element, + cv::Mat& img, int count) { + Detection det; + for (int i = 0; i < count; i++) { + int basic_pos = 1 + i * bbox_element; + int keep_flag = decode_ptr_host[basic_pos + 6]; + if (keep_flag == 1) { + det.bbox[0] = decode_ptr_host[basic_pos + 0]; + det.bbox[1] = decode_ptr_host[basic_pos + 1]; + det.bbox[2] = decode_ptr_host[basic_pos + 2]; + det.bbox[3] = decode_ptr_host[basic_pos + 3]; + det.conf = decode_ptr_host[basic_pos + 4]; + det.class_id = decode_ptr_host[basic_pos + 5]; + det.angle = decode_ptr_host[basic_pos + 7]; + res.push_back(det); + } + } +} + +void batch_process_obb(std::vector>& res_batch, const float* decode_ptr_host, int batch_size, + int bbox_element, const std::vector& img_batch) { + res_batch.resize(batch_size); + int count = static_cast(*decode_ptr_host); + count = std::min(count, kMaxNumOutputBbox); + for (int i = 0; i < batch_size; i++) { + auto& img = const_cast(img_batch[i]); + process_decode_ptr_host_obb(res_batch[i], &decode_ptr_host[i * count], bbox_element, img, count); + } +} + +std::tuple convariance_matrix(Detection res) { + float w = res.bbox[2]; + float h = res.bbox[3]; + + float a = w * w / 12.0; + float b = h * h / 12.0; + float c = res.angle; + + float cos_r = std::cos(c); + float sin_r = std::sin(c); + + float cos_r2 = cos_r * cos_r; + float sin_r2 = sin_r * sin_r; + + float a_val = a * cos_r2 + b * sin_r2; + float b_val = a * sin_r2 + b * cos_r2; + float c_val = (a - b) * cos_r * sin_r; + + return std::make_tuple(a_val, b_val, c_val); +} + +static float probiou(const Detection& res1, const Detection& res2, float eps = 1e-7) { + // Calculate the prob iou between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf. + float a1, b1, c1, a2, b2, c2; + std::tuple matrix1 = {a1, b1, c1}; + std::tuple matrix2 = {a2, b2, c2}; + matrix1 = convariance_matrix(res1); + matrix2 = convariance_matrix(res2); + a1 = std::get<0>(matrix1); + b1 = std::get<1>(matrix1); + c1 = std::get<2>(matrix1); + a2 = std::get<0>(matrix2); + b2 = std::get<1>(matrix2); + c2 = std::get<2>(matrix2); + + float x1 = res1.bbox[0], y1 = res1.bbox[1]; + float x2 = res2.bbox[0], y2 = res2.bbox[1]; + + float t1 = ((a1 + a2) * std::pow(y1 - y2, 2) + (b1 + b2) * std::pow(x1 - x2, 2)) / + ((a1 + a2) * (b1 + b2) - std::pow(c1 + c2, 2) + eps); + float t2 = ((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - std::pow(c1 + c2, 2) + eps); + float t3 = std::log( + ((a1 + a2) * (b1 + b2) - std::pow(c1 + c2, 2)) / + (4 * std::sqrt(std::max(a1 * b1 - c1 * c1, 0.0f)) * std::sqrt(std::max(a2 * b2 - c2 * c2, 0.0f)) + + eps) + + eps); + + float bd = 0.25f * t1 + 0.5f * t2 + 0.5f * t3; + bd = std::max(std::min(bd, 100.0f), eps); + float hd = std::sqrt(1.0 - std::exp(-bd) + eps); + + return 1 - hd; +} + +void nms_obb(std::vector& res, float* output, float conf_thresh, float nms_thresh) { + int det_size = sizeof(Detection) / sizeof(float); + std::map> m; + + for (int i = 0; i < output[0]; i++) { + + if (output[1 + det_size * i + 4] <= conf_thresh) + continue; + Detection det; + memcpy(&det, &output[1 + det_size * i], det_size * sizeof(float)); + if (m.count(det.class_id) == 0) + m.emplace(det.class_id, std::vector()); + m[det.class_id].push_back(det); + } + for (auto it = m.begin(); it != m.end(); it++) { + auto& dets = it->second; + std::sort(dets.begin(), dets.end(), cmp); + for (size_t m = 0; m < dets.size(); ++m) { + auto& item = dets[m]; + res.push_back(item); + for (size_t n = m + 1; n < dets.size(); ++n) { + if (probiou(item, dets[n]) >= nms_thresh) { + dets.erase(dets.begin() + n); + --n; + } + } + } + } +} + +void batch_nms_obb(std::vector>& res_batch, float* output, int batch_size, int output_size, + float conf_thresh, float nms_thresh) { + res_batch.resize(batch_size); + for (int i = 0; i < batch_size; i++) { + nms_obb(res_batch[i], &output[i * output_size], conf_thresh, nms_thresh); + } +} + +static std::vector get_corner(cv::Mat& img, const Detection& box) { + float cos_value, sin_value; + + // Calculate center point and width/height + float x1 = box.bbox[0]; + float y1 = box.bbox[1]; + float w = box.bbox[2]; + float h = box.bbox[3]; + float angle = box.angle * 180.0f / CV_PI; // Convert radians to degrees + + // Print original angle + std::cout << "Original angle: " << angle << std::endl; + + // Swap width and height if height is greater than or equal to width + if (h >= w) { + std::swap(w, h); + angle = fmod(angle + 90.0f, 180.0f); // Adjust angle to be within [0, 180) + } + + // Ensure the angle is between 0 and 180 degrees + if (angle < 0) { + angle += 360.0f; // Convert to positive value + } + if (angle > 180.0f) { + angle -= 180.0f; // Subtract 180 from angles greater than 180 + } + + // Print adjusted angle + std::cout << "Adjusted angle: " << angle << std::endl; + + // Convert to normal angle value + float normal_angle = fmod(angle, 180.0f); + if (normal_angle < 0) { + normal_angle += 180.0f; // Ensure it's a positive value + } + + // Print normal angle value + std::cout << "Normal angle: " << normal_angle << std::endl; + + cos_value = std::cos(angle * CV_PI / 180.0f); // Convert to radians + sin_value = std::sin(angle * CV_PI / 180.0f); + + // Calculate each corner point + float l = x1 - w / 2; // Left boundary + float r = x1 + w / 2; // Right boundary + float t = y1 - h / 2; // Top boundary + float b = y1 + h / 2; // Bottom boundary + + // Use get_rect function to scale the coordinates + float bbox[4] = {l, t, r, b}; + cv::Rect rect = get_rect(img, bbox); + + float x_ = (rect.x + rect.x + rect.width) / 2; // Center x + float y_ = (rect.y + rect.y + rect.height) / 2; // Center y + float width = rect.width; // Width + float height = rect.height; // Height + + // Calculate each corner point + std::vector corner_points(4); + float vec1x = width / 2 * cos_value; + float vec1y = width / 2 * sin_value; + float vec2x = -height / 2 * sin_value; + float vec2y = height / 2 * cos_value; + + corner_points[0] = cv::Point(int(round(x_ + vec1x + vec2x)), int(round(y_ + vec1y + vec2y))); // Top-left corner + corner_points[1] = cv::Point(int(round(x_ + vec1x - vec2x)), int(round(y_ + vec1y - vec2y))); // Top-right corner + corner_points[2] = + cv::Point(int(round(x_ - vec1x - vec2x)), int(round(y_ - vec1y - vec2y))); // Bottom-right corner + corner_points[3] = cv::Point(int(round(x_ - vec1x + vec2x)), int(round(y_ - vec1y + vec2y))); // Bottom-left corner + + // Check and adjust corner points to ensure the rectangle is parallel to image boundaries + for (auto& point : corner_points) { + point.x = std::max(0, std::min(point.x, img.cols - 1)); + point.y = std::max(0, std::min(point.y, img.rows - 1)); + } + + return corner_points; +} + +void draw_bbox_obb(std::vector& img_batch, std::vector>& res_batch) { + static std::vector colors = {0xFF3838, 0xFF9D97, 0xFF701F, 0xFFB21D, 0xCFD231, 0x48F90A, 0x92CC17, + 0x3DDB86, 0x1A9334, 0x00D4BB, 0x2C99A8, 0x00C2FF, 0x344593, 0x6473FF, + 0x0018EC, 0x8438FF, 0x520085, 0xCB38FF, 0xFF95C8, 0xFF37C7}; + for (size_t i = 0; i < img_batch.size(); i++) { + auto& res = res_batch[i]; + auto& img = img_batch[i]; + for (auto& obj : res) { + auto color = colors[(int)obj.class_id % colors.size()]; + auto bgr = cv::Scalar(color & 0xFF, color >> 8 & 0xFF, color >> 16 & 0xFF); + auto corner_points = get_corner(img, obj); + cv::polylines(img, std::vector>{corner_points}, true, bgr, 1); + + auto text = (std::to_string((int)(obj.class_id)) + ":" + to_string_with_precision(obj.conf)); + cv::Size textsize = cv::getTextSize(text, 0, 0.3, 1, nullptr); + + int width = textsize.width; + int height = textsize.height; + bool outside = (corner_points[0].y - height >= 3) ? true : false; + cv::Point p1(corner_points[0].x, corner_points[0].y), p2; + p2.x = corner_points[0].x + width; + if (outside) { + p2.y = corner_points[0].y - height - 3; + } else { + p2.y = corner_points[0].y + height + 3; + } + cv::rectangle(img, p1, p2, bgr, -1, cv::LINE_AA); + cv::putText( + img, text, + cv::Point(corner_points[0].x, (outside ? corner_points[0].y - 2 : corner_points[0].y + height + 2)), + 0, 0.3, cv::Scalar::all(255), 1, cv::LINE_AA); + } + } +} diff --git a/yolov8/src/postprocess.cu b/yolov8/src/postprocess.cu index 3cae0427..2a58fd96 100644 --- a/yolov8/src/postprocess.cu +++ b/yolov8/src/postprocess.cu @@ -1,21 +1,59 @@ // // Created by lindsay on 23-7-17. // -#include "types.h" #include "postprocess.h" +#include "types.h" + +static __global__ void decode_kernel_obb(float* predict, int num_bboxes, float confidence_threshold, float* parray, + int max_objects) { + float count = predict[0]; + int position = (blockDim.x * blockIdx.x + threadIdx.x); + if (position >= count) + return; + + float* pitem = predict + 1 + position * (sizeof(Detection) / sizeof(float)); + int index = atomicAdd(parray, 1); + if (index >= max_objects) + return; + + float confidence = pitem[4]; + + if (confidence < confidence_threshold) + return; + //[center_x center_y w h conf class_id mask[32] keypoints[51] angle] + float cx = pitem[0]; + float cy = pitem[1]; + float width = pitem[2]; + float height = pitem[3]; + float label = pitem[5]; + float angle = pitem[89]; -static __global__ void -decode_kernel(float *predict, int num_bboxes, float confidence_threshold, float *parray, int max_objects) { + float* pout_item = parray + 1 + index * bbox_element; + *pout_item++ = cx; + *pout_item++ = cy; + *pout_item++ = width; + *pout_item++ = height; + *pout_item++ = confidence; + *pout_item++ = label; + *pout_item++ = 1; // 1 = keep, 0 = ignore + *pout_item++ = angle; +} + +static __global__ void decode_kernel(float* predict, int num_bboxes, float confidence_threshold, float* parray, + int max_objects) { float count = predict[0]; int position = (blockDim.x * blockIdx.x + threadIdx.x); - if (position >= count) return; + if (position >= count) + return; - float *pitem = predict + 1 + position * (sizeof(Detection) / sizeof(float)); + float* pitem = predict + 1 + position * (sizeof(Detection) / sizeof(float)); int index = atomicAdd(parray, 1); - if (index >= max_objects) return; + if (index >= max_objects) + return; float confidence = pitem[4]; - if (confidence < confidence_threshold) return; + if (confidence < confidence_threshold) + return; float left = pitem[0]; float top = pitem[1]; @@ -23,7 +61,7 @@ decode_kernel(float *predict, int num_bboxes, float confidence_threshold, float float bottom = pitem[3]; float label = pitem[5]; - float *pout_item = parray + 1 + index * bbox_element; + float* pout_item = parray + 1 + index * bbox_element; *pout_item++ = left; *pout_item++ = top; *pout_item++ = right; @@ -33,35 +71,92 @@ decode_kernel(float *predict, int num_bboxes, float confidence_threshold, float *pout_item++ = 1; // 1 = keep, 0 = ignore } -static __device__ float -box_iou(float aleft, float atop, float aright, float abottom, float bleft, float btop, float bright, float bbottom) { +static __device__ float box_iou(float aleft, float atop, float aright, float abottom, float bleft, float btop, + float bright, float bbottom) { float cleft = max(aleft, bleft); float ctop = max(atop, btop); float cright = min(aright, bright); float cbottom = min(abottom, bbottom); float c_area = max(cright - cleft, 0.0f) * max(cbottom - ctop, 0.0f); - if (c_area == 0.0f) return 0.0f; + if (c_area == 0.0f) + return 0.0f; float a_area = max(0.0f, aright - aleft) * max(0.0f, abottom - atop); float b_area = max(0.0f, bright - bleft) * max(0.0f, bbottom - btop); return c_area / (a_area + b_area - c_area); } -static __global__ void nms_kernel(float *bboxes, int max_objects, float threshold) { +static __global__ void nms_kernel(float* bboxes, int max_objects, float threshold) { + int position = (blockDim.x * blockIdx.x + threadIdx.x); + int count = bboxes[0]; + if (position >= count) + return; + + float* pcurrent = bboxes + 1 + position * bbox_element; + for (int i = 0; i < count; ++i) { + float* pitem = bboxes + 1 + i * bbox_element; + if (i == position || pcurrent[5] != pitem[5]) + continue; + if (pitem[4] >= pcurrent[4]) { + if (pitem[4] == pcurrent[4] && i < position) + continue; + float iou = + box_iou(pcurrent[0], pcurrent[1], pcurrent[2], pcurrent[3], pitem[0], pitem[1], pitem[2], pitem[3]); + if (iou > threshold) { + pcurrent[6] = 0; + return; + } + } + } +} + +static __device__ void convariance_matrix(float w, float h, float r, float& a, float& b, float& c) { + float a_val = w * w / 12.0f; + float b_val = h * h / 12.0f; + float cos_r = cosf(r); + float sin_r = sinf(r); + + a = a_val * cos_r * cos_r + b_val * sin_r * sin_r; + b = a_val * sin_r * sin_r + b_val * cos_r * cos_r; + c = (a_val - b_val) * sin_r * cos_r; +} + +static __device__ float box_probiou(float cx1, float cy1, float w1, float h1, float r1, float cx2, float cy2, float w2, + float h2, float r2, float eps = 1e-7) { + + // Calculate the prob iou between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf. + float a1, b1, c1, a2, b2, c2; + convariance_matrix(w1, h1, r1, a1, b1, c1); + convariance_matrix(w2, h2, r2, a2, b2, c2); + + float t1 = ((a1 + a2) * powf(cy1 - cy2, 2) + (b1 + b2) * powf(cx1 - cx2, 2)) / + ((a1 + a2) * (b1 + b2) - powf(c1 + c2, 2) + eps); + float t2 = ((c1 + c2) * (cx2 - cx1) * (cy1 - cy2)) / ((a1 + a2) * (b1 + b2) - powf(c1 + c2, 2) + eps); + float t3 = logf(((a1 + a2) * (b1 + b2) - powf(c1 + c2, 2)) / + (4 * sqrtf(fmaxf(a1 * b1 - c1 * c1, 0.0f)) * sqrtf(fmaxf(a2 * b2 - c2 * c2, 0.0f)) + eps) + + eps); + float bd = 0.25f * t1 + 0.5f * t2 + 0.5f * t3; + bd = fmaxf(fminf(bd, 100.0f), eps); + float hd = sqrtf(1.0f - expf(-bd) + eps); + return 1 - hd; +} + +static __global__ void nms_kernel_obb(float* bboxes, int max_objects, float threshold) { int position = (blockDim.x * blockIdx.x + threadIdx.x); int count = bboxes[0]; - if (position >= count) return; + if (position >= count) + return; - float *pcurrent = bboxes + 1 + position * bbox_element; + float* pcurrent = bboxes + 1 + position * bbox_element; for (int i = 0; i < count; ++i) { - float *pitem = bboxes + 1 + i * bbox_element; - if (i == position || pcurrent[5] != pitem[5]) continue; + float* pitem = bboxes + 1 + i * bbox_element; + if (i == position || pcurrent[5] != pitem[5]) + continue; if (pitem[4] >= pcurrent[4]) { - if (pitem[4] == pcurrent[4] && i < position) continue; - float iou = box_iou( - pcurrent[0], pcurrent[1], pcurrent[2], pcurrent[3], - pitem[0], pitem[1], pitem[2], pitem[3] - ); + if (pitem[4] == pcurrent[4] && i < position) + continue; + float iou = box_probiou(pcurrent[0], pcurrent[1], pcurrent[2], pcurrent[3], pcurrent[7], pitem[0], pitem[1], + pitem[2], pitem[3], pitem[7]); if (iou > threshold) { pcurrent[6] = 0; return; @@ -70,15 +165,29 @@ static __global__ void nms_kernel(float *bboxes, int max_objects, float threshol } } -void cuda_decode(float *predict, int num_bboxes, float confidence_threshold, float *parray, int max_objects, +void cuda_decode(float* predict, int num_bboxes, float confidence_threshold, float* parray, int max_objects, cudaStream_t stream) { int block = 256; int grid = ceil(num_bboxes / (float)block); decode_kernel<<>>((float*)predict, num_bboxes, confidence_threshold, parray, max_objects); } -void cuda_nms(float *parray, float nms_threshold, int max_objects, cudaStream_t stream) { +void cuda_nms(float* parray, float nms_threshold, int max_objects, cudaStream_t stream) { int block = max_objects < 256 ? max_objects : 256; int grid = ceil(max_objects / (float)block); nms_kernel<<>>(parray, max_objects, nms_threshold); } + +void cuda_decode_obb(float* predict, int num_bboxes, float confidence_threshold, float* parray, int max_objects, + cudaStream_t stream) { + int block = 256; + int grid = ceil(num_bboxes / (float)block); + decode_kernel_obb<<>>((float*)predict, num_bboxes, confidence_threshold, parray, + max_objects); +} + +void cuda_nms_obb(float* parray, float nms_threshold, int max_objects, cudaStream_t stream) { + int block = max_objects < 256 ? max_objects : 256; + int grid = ceil(max_objects / (float)block); + nms_kernel_obb<<>>(parray, max_objects, nms_threshold); +} diff --git a/yolov8/src/preprocess.cu b/yolov8/src/preprocess.cu index 14d9e778..d3d6f879 100644 --- a/yolov8/src/preprocess.cu +++ b/yolov8/src/preprocess.cu @@ -1,15 +1,14 @@ -#include "preprocess.h" #include "cuda_utils.h" +#include "preprocess.h" -static uint8_t *img_buffer_host = nullptr; -static uint8_t *img_buffer_device = nullptr; - +static uint8_t* img_buffer_host = nullptr; +static uint8_t* img_buffer_device = nullptr; -__global__ void -warpaffine_kernel(uint8_t *src, int src_line_size, int src_width, int src_height, float *dst, int dst_width, - int dst_height, uint8_t const_value_st, AffineMatrix d2s, int edge) { +__global__ void warpaffine_kernel(uint8_t* src, int src_line_size, int src_width, int src_height, float* dst, + int dst_width, int dst_height, uint8_t const_value_st, AffineMatrix d2s, int edge) { int position = blockDim.x * blockIdx.x + threadIdx.x; - if (position >= edge) return; + if (position >= edge) + return; float m_x1 = d2s.value[0]; float m_y1 = d2s.value[1]; @@ -41,10 +40,10 @@ warpaffine_kernel(uint8_t *src, int src_line_size, int src_width, int src_height float hy = 1 - ly; float hx = 1 - lx; float w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; - uint8_t *v1 = const_value; - uint8_t *v2 = const_value; - uint8_t *v3 = const_value; - uint8_t *v4 = const_value; + uint8_t* v1 = const_value; + uint8_t* v2 = const_value; + uint8_t* v3 = const_value; + uint8_t* v4 = const_value; if (y_low >= 0) { if (x_low >= 0) @@ -79,18 +78,15 @@ warpaffine_kernel(uint8_t *src, int src_line_size, int src_width, int src_height // rgbrgbrgb to rrrgggbbb int area = dst_width * dst_height; - float *pdst_c0 = dst + dy * dst_width + dx; - float *pdst_c1 = pdst_c0 + area; - float *pdst_c2 = pdst_c1 + area; + float* pdst_c0 = dst + dy * dst_width + dx; + float* pdst_c1 = pdst_c0 + area; + float* pdst_c2 = pdst_c1 + area; *pdst_c0 = c0; *pdst_c1 = c1; *pdst_c2 = c2; } - - - -void cuda_preprocess(uint8_t *src, int src_width, int src_height, float *dst, int dst_width, int dst_height, +void cuda_preprocess(uint8_t* src, int src_width, int src_height, float* dst, int dst_width, int dst_height, cudaStream_t stream) { int img_size = src_width * src_height * 3; // copy data to pinned memory @@ -99,7 +95,7 @@ void cuda_preprocess(uint8_t *src, int src_width, int src_height, float *dst, in CUDA_CHECK(cudaMemcpyAsync(img_buffer_device, img_buffer_host, img_size, cudaMemcpyHostToDevice, stream)); AffineMatrix s2d, d2s; - float scale = std::min(dst_height / (float) src_height, dst_width / (float) src_width); + float scale = std::min(dst_height / (float)src_height, dst_width / (float)src_width); s2d.value[0] = scale; s2d.value[1] = 0; @@ -115,16 +111,12 @@ void cuda_preprocess(uint8_t *src, int src_width, int src_height, float *dst, in int jobs = dst_height * dst_width; int threads = 256; - int blocks = ceil(jobs / (float) threads); - warpaffine_kernel<<>>( - img_buffer_device, src_width * 3, src_width, - src_height, dst, dst_width, - dst_height, 128, d2s, jobs); + int blocks = ceil(jobs / (float)threads); + warpaffine_kernel<<>>(img_buffer_device, src_width * 3, src_width, src_height, dst, + dst_width, dst_height, 128, d2s, jobs); } - -void cuda_batch_preprocess(std::vector &img_batch, - float *dst, int dst_width, int dst_height, +void cuda_batch_preprocess(std::vector& img_batch, float* dst, int dst_width, int dst_height, cudaStream_t stream) { int dst_size = dst_width * dst_height * 3; for (size_t i = 0; i < img_batch.size(); i++) { @@ -134,22 +126,14 @@ void cuda_batch_preprocess(std::vector &img_batch, } } - - - - void cuda_preprocess_init(int max_image_size) { // prepare input data in pinned memory - CUDA_CHECK(cudaMallocHost((void **) &img_buffer_host, max_image_size * 3)); + CUDA_CHECK(cudaMallocHost((void**)&img_buffer_host, max_image_size * 3)); // prepare input data in device memory - CUDA_CHECK(cudaMalloc((void **) &img_buffer_device, max_image_size * 3)); + CUDA_CHECK(cudaMalloc((void**)&img_buffer_device, max_image_size * 3)); } void cuda_preprocess_destroy() { CUDA_CHECK(cudaFree(img_buffer_device)); CUDA_CHECK(cudaFreeHost(img_buffer_host)); } - - - - diff --git a/yolov8/yolov8_5u_det_trt.py b/yolov8/yolov8_5u_det_trt.py index 252fe767..5b0b15b3 100644 --- a/yolov8/yolov8_5u_det_trt.py +++ b/yolov8/yolov8_5u_det_trt.py @@ -19,6 +19,7 @@ POSE_NUM = 17 * 3 DET_NUM = 6 SEG_NUM = 32 +OBB_NUM = 1 def get_img_path_batches(batch_size, img_dir): @@ -69,7 +70,7 @@ def plot_one_box(x, img, color=None, label=None, line_thickness=None): [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA, - ) + ) class YoLov8TRT(object): @@ -291,7 +292,7 @@ def post_process(self, output, origin_h, origin_w): result_scores: finally scores, a numpy, each element is the score correspoing to box result_classid: finally classid, a numpy, each element is the classid correspoing to box """ - num_values_per_detection = DET_NUM + SEG_NUM + POSE_NUM + num_values_per_detection = DET_NUM + SEG_NUM + POSE_NUM + OBB_NUM # Get the num of boxes detected num = int(output[0]) # Reshape to a two dimentional ndarray @@ -408,7 +409,7 @@ def run(self): if __name__ == "__main__": # load custom plugin and engine - PLUGIN_LIBRARY = "build/libmyplugins.so" + PLUGIN_LIBRARY = "./build/libmyplugins.so" engine_file_path = "yolov5xu.engine" if len(sys.argv) > 1: @@ -443,7 +444,7 @@ def run(self): try: print('batch size is', yolov8_wrapper.batch_size) - image_dir = "samples/" + image_dir = "images/" image_path_batches = get_img_path_batches(yolov8_wrapper.batch_size, image_dir) for i in range(10): diff --git a/yolov8/yolov8_det_trt.py b/yolov8/yolov8_det_trt.py index 64136390..6f387160 100644 --- a/yolov8/yolov8_det_trt.py +++ b/yolov8/yolov8_det_trt.py @@ -19,6 +19,7 @@ POSE_NUM = 17 * 3 DET_NUM = 6 SEG_NUM = 32 +OBB_NUM = 1 def get_img_path_batches(batch_size, img_dir): @@ -69,7 +70,7 @@ def plot_one_box(x, img, color=None, label=None, line_thickness=None): [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA, - ) + ) class YoLov8TRT(object): @@ -291,7 +292,7 @@ def post_process(self, output, origin_h, origin_w): result_scores: finally scores, a numpy, each element is the score correspoing to box result_classid: finally classid, a numpy, each element is the classid correspoing to box """ - num_values_per_detection = DET_NUM + SEG_NUM + POSE_NUM + num_values_per_detection = DET_NUM + SEG_NUM + POSE_NUM + OBB_NUM # Get the num of boxes detected num = int(output[0]) # Reshape to a two dimentional ndarray @@ -408,8 +409,8 @@ def run(self): if __name__ == "__main__": # load custom plugin and engine - PLUGIN_LIBRARY = "build/libmyplugins.so" - engine_file_path = "yolov8s.engine" + PLUGIN_LIBRARY = "./build/libmyplugins.so" + engine_file_path = "yolov8n.engine" if len(sys.argv) > 1: engine_file_path = sys.argv[1] diff --git a/yolov8/yolov8_obb.cpp b/yolov8/yolov8_obb.cpp new file mode 100644 index 00000000..d7bcf9c7 --- /dev/null +++ b/yolov8/yolov8_obb.cpp @@ -0,0 +1,276 @@ + +#include +#include +#include +#include "cuda_utils.h" +#include "logging.h" +#include "model.h" +#include "postprocess.h" +#include "preprocess.h" +#include "utils.h" + +Logger gLogger; +using namespace nvinfer1; +const int kOutputSize = kMaxNumOutputBbox * sizeof(Detection) / sizeof(float) + 1; + +void serialize_engine(std::string& wts_name, std::string& engine_name, int& is_p, std::string& sub_type, float& gd, + float& gw, int& max_channels) { + IBuilder* builder = createInferBuilder(gLogger); + IBuilderConfig* config = builder->createBuilderConfig(); + IHostMemory* serialized_engine = nullptr; + + if (is_p == 6) { + std::cout << "p6 is not supported right now" << std::endl; + } else if (is_p == 2) { + std::cout << "p2 is not supported right now" << std::endl; + } else { + serialized_engine = buildEngineYolov8Obb(builder, config, DataType::kFLOAT, wts_name, gd, gw, max_channels); + } + + assert(serialized_engine); + std::ofstream p(engine_name, std::ios::binary); + if (!p) { + std::cout << "could not open plan output file" << std::endl; + assert(false); + } + p.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); + + delete serialized_engine; + delete config; + delete builder; +} + +void deserialize_engine(std::string& engine_name, IRuntime** runtime, ICudaEngine** engine, + IExecutionContext** context) { + std::ifstream file(engine_name, std::ios::binary); + if (!file.good()) { + std::cerr << "read " << engine_name << " error!" << std::endl; + assert(false); + } + size_t size = 0; + file.seekg(0, file.end); + size = file.tellg(); + file.seekg(0, file.beg); + char* serialized_engine = new char[size]; + assert(serialized_engine); + file.read(serialized_engine, size); + file.close(); + + *runtime = createInferRuntime(gLogger); + assert(*runtime); + *engine = (*runtime)->deserializeCudaEngine(serialized_engine, size); + assert(*engine); + *context = (*engine)->createExecutionContext(); + assert(*context); + delete[] serialized_engine; +} + +void prepare_buffer(ICudaEngine* engine, float** input_buffer_device, float** output_buffer_device, + float** output_buffer_host, float** decode_ptr_host, float** decode_ptr_device, + std::string cuda_post_process) { + assert(engine->getNbBindings() == 2); + // In order to bind the buffers, we need to know the names of the input and output tensors. + // Note that indices are guaranteed to be less than IEngine::getNbBindings() + const int inputIndex = engine->getBindingIndex(kInputTensorName); + const int outputIndex = engine->getBindingIndex(kOutputTensorName); + assert(inputIndex == 0); + assert(outputIndex == 1); + // Create GPU buffers on device + CUDA_CHECK(cudaMalloc((void**)input_buffer_device, kBatchSize * 3 * kInputH * kInputW * sizeof(float))); + CUDA_CHECK(cudaMalloc((void**)output_buffer_device, kBatchSize * kOutputSize * sizeof(float))); + if (cuda_post_process == "c") { + *output_buffer_host = new float[kBatchSize * kOutputSize]; + } else if (cuda_post_process == "g") { + if (kBatchSize > 1) { + std::cerr << "Do not yet support GPU post processing for multiple batches" << std::endl; + exit(0); + } + // Allocate memory for decode_ptr_host and copy to device + *decode_ptr_host = new float[1 + kMaxNumOutputBbox * bbox_element]; + CUDA_CHECK(cudaMalloc((void**)decode_ptr_device, sizeof(float) * (1 + kMaxNumOutputBbox * bbox_element))); + } +} + +void infer(IExecutionContext& context, cudaStream_t& stream, void** buffers, float* output, int batchsize, + float* decode_ptr_host, float* decode_ptr_device, int model_bboxes, std::string cuda_post_process) { + // infer on the batch asynchronously, and DMA output back to host + auto start = std::chrono::system_clock::now(); + context.enqueue(batchsize, buffers, stream, nullptr); + if (cuda_post_process == "c") { + CUDA_CHECK(cudaMemcpyAsync(output, buffers[1], batchsize * kOutputSize * sizeof(float), cudaMemcpyDeviceToHost, + stream)); + auto end = std::chrono::system_clock::now(); + std::cout << "inference time: " << std::chrono::duration_cast(end - start).count() + << "ms" << std::endl; + } else if (cuda_post_process == "g") { + CUDA_CHECK( + cudaMemsetAsync(decode_ptr_device, 0, sizeof(float) * (1 + kMaxNumOutputBbox * bbox_element), stream)); + cuda_decode_obb((float*)buffers[1], model_bboxes, kConfThresh, decode_ptr_device, kMaxNumOutputBbox, stream); + cuda_nms_obb(decode_ptr_device, kNmsThresh, kMaxNumOutputBbox, stream); //cuda nms + CUDA_CHECK(cudaMemcpyAsync(decode_ptr_host, decode_ptr_device, + sizeof(float) * (1 + kMaxNumOutputBbox * bbox_element), cudaMemcpyDeviceToHost, + stream)); + auto end = std::chrono::system_clock::now(); + std::cout << "inference and gpu postprocess time: " + << std::chrono::duration_cast(end - start).count() << "ms" << std::endl; + } + + CUDA_CHECK(cudaStreamSynchronize(stream)); +} + +bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, int& is_p, std::string& img_dir, + std::string& sub_type, std::string& cuda_post_process, float& gd, float& gw, int& max_channels) { + if (argc < 4) + return false; + if (std::string(argv[1]) == "-s" && (argc == 5 || argc == 7)) { + wts = std::string(argv[2]); + engine = std::string(argv[3]); + auto sub_type = std::string(argv[4]); + + if (sub_type[0] == 'n') { + gd = 0.33; + gw = 0.25; + max_channels = 1024; + } else if (sub_type[0] == 's') { + gd = 0.33; + gw = 0.50; + max_channels = 1024; + } else if (sub_type[0] == 'm') { + gd = 0.67; + gw = 0.75; + max_channels = 576; + } else if (sub_type[0] == 'l') { + gd = 1.0; + gw = 1.0; + max_channels = 512; + } else if (sub_type[0] == 'x') { + gd = 1.0; + gw = 1.25; + max_channels = 640; + } else { + return false; + } + if (sub_type.size() == 2 && sub_type[1] == '6') { + is_p = 6; + } else if (sub_type.size() == 2 && sub_type[1] == '2') { + is_p = 2; + } + } else if (std::string(argv[1]) == "-d" && argc == 5) { + engine = std::string(argv[2]); + img_dir = std::string(argv[3]); + cuda_post_process = std::string(argv[4]); + } else { + return false; + } + return true; +} + +int main(int argc, char** argv) { + cudaSetDevice(kGpuId); + std::string wts_name = ""; + std::string engine_name = ""; + std::string img_dir; + std::string sub_type = ""; + std::string cuda_post_process = ""; + int model_bboxes; + int is_p = 0; + float gd = 0.0f, gw = 0.0f; + int max_channels = 0; + + if (!parse_args(argc, argv, wts_name, engine_name, is_p, img_dir, sub_type, cuda_post_process, gd, gw, + max_channels)) { + std::cerr << "Arguments not right!" << std::endl; + std::cerr << "./yolov8 -s [.wts] [.engine] [n/s/m/l/x/n2/s2/m2/l2/x2/n6/s6/m6/l6/x6] // serialize model to " + "plan file" + << std::endl; + std::cerr << "./yolov8 -d [.engine] ../samples [c/g]// deserialize plan file and run inference" << std::endl; + return -1; + } + + // Create a model using the API directly and serialize it to a file + if (!wts_name.empty()) { + serialize_engine(wts_name, engine_name, is_p, sub_type, gd, gw, max_channels); + return 0; + } + + // Deserialize the engine from file + IRuntime* runtime = nullptr; + ICudaEngine* engine = nullptr; + IExecutionContext* context = nullptr; + deserialize_engine(engine_name, &runtime, &engine, &context); + cudaStream_t stream; + CUDA_CHECK(cudaStreamCreate(&stream)); + cuda_preprocess_init(kMaxInputImageSize); + auto out_dims = engine->getBindingDimensions(1); + model_bboxes = out_dims.d[0]; + // Prepare cpu and gpu buffers + float* device_buffers[2]; + float* output_buffer_host = nullptr; + float* decode_ptr_host = nullptr; + float* decode_ptr_device = nullptr; + + // Read images from directory + std::vector file_names; + if (read_files_in_dir(img_dir.c_str(), file_names) < 0) { + std::cerr << "read_files_in_dir failed." << std::endl; + return -1; + } + + prepare_buffer(engine, &device_buffers[0], &device_buffers[1], &output_buffer_host, &decode_ptr_host, + &decode_ptr_device, cuda_post_process); + + // batch predict + for (size_t i = 0; i < file_names.size(); i += kBatchSize) { + // Get a batch of images + std::vector img_batch; + std::vector img_name_batch; + for (size_t j = i; j < i + kBatchSize && j < file_names.size(); j++) { + cv::Mat img = cv::imread(img_dir + "/" + file_names[j]); + img_batch.push_back(img); + img_name_batch.push_back(file_names[j]); + } + // Preprocess + cuda_batch_preprocess(img_batch, device_buffers[0], kInputW, kInputH, stream); + // Run inference + infer(*context, stream, (void**)device_buffers, output_buffer_host, kBatchSize, decode_ptr_host, + decode_ptr_device, model_bboxes, cuda_post_process); + std::vector> res_batch; + if (cuda_post_process == "c") { + // NMS + batch_nms_obb(res_batch, output_buffer_host, img_batch.size(), kOutputSize, kConfThresh, kNmsThresh); + } else if (cuda_post_process == "g") { + //Process gpu decode and nms results + batch_process_obb(res_batch, decode_ptr_host, img_batch.size(), bbox_element, img_batch); + } + // Draw bounding boxes + draw_bbox_obb(img_batch, res_batch); + // Save images + for (size_t j = 0; j < img_batch.size(); j++) { + cv::imwrite("_" + img_name_batch[j], img_batch[j]); + } + } + + // Release stream and buffers + cudaStreamDestroy(stream); + CUDA_CHECK(cudaFree(device_buffers[0])); + CUDA_CHECK(cudaFree(device_buffers[1])); + CUDA_CHECK(cudaFree(decode_ptr_device)); + delete[] decode_ptr_host; + delete[] output_buffer_host; + cuda_preprocess_destroy(); + // Destroy the engine + delete context; + delete engine; + delete runtime; + + // Print histogram of the output distribution + //std::cout << "\nOutput:\n\n"; + //for (unsigned int i = 0; i < kOutputSize; i++) + //{ + // std::cout << prob[i] << ", "; + // if (i % 10 == 0) std::cout << std::endl; + //} + //std::cout << std::endl; + + return 0; +} diff --git a/yolov8/yolov8_obb_trt.py b/yolov8/yolov8_obb_trt.py new file mode 100644 index 00000000..291a6b04 --- /dev/null +++ b/yolov8/yolov8_obb_trt.py @@ -0,0 +1,571 @@ +""" +An example that uses TensorRT's Python api to make inferences. +""" +import ctypes +import os +import shutil +import sys +import threading +import time +import cv2 +import math +import numpy as np +import pycuda.autoinit # noqa: F401 +import pycuda.driver as cuda +import tensorrt as trt + +CONF_THRESH = 0.5 +IOU_THRESHOLD = 0.4 +POSE_NUM = 17 * 3 +DET_NUM = 6 +SEG_NUM = 32 +OBB_NUM = 1 + + +def get_img_path_batches(batch_size, img_dir): + ret = [] + batch = [] + for root, dirs, files in os.walk(img_dir): + for name in files: + if len(batch) == batch_size: + ret.append(batch) + batch = [] + batch.append(os.path.join(root, name)) + if len(batch) > 0: + ret.append(batch) + return ret + + +def regularize_rboxes(rboxes): + """ + Regularize rotated boxes in range [0, pi/2]. + + Args: + rboxes (numpy.ndarray): Input boxes of shape(N, 5) in xywhr format. + + Returns: + (numpy.ndarray): The regularized boxes. + """ + x, y, w, h, t = np.split(rboxes, 5, axis=-1) + w_ = np.where(w > h, w, h) + h_ = np.where(w > h, h, w) + t = np.where(w > h, t, t + math.pi / 2) % math.pi + return np.concatenate([x, y, w_, h_, t], axis=-1) # regularized boxes + + +def xywhr2xyxyxyxy(x): + """ + Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. + + Args: + x (numpy.ndarray): Boxes in [cx, cy, w, h, rotation] format of shape (n, 5) or (b, n, 5). + + Returns: + (numpy.ndarray): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2). + """ + # Regularize the input boxes first + rboxes = regularize_rboxes(x) + + ctr = rboxes[..., :2] + w, h, angle = (rboxes[..., i: i + 1] for i in range(2, 5)) + + cos_value = np.cos(angle) + sin_value = np.sin(angle) + + vec1 = np.concatenate([w / 2 * cos_value, w / 2 * sin_value], axis=-1) + vec2 = np.concatenate([-h / 2 * sin_value, h / 2 * cos_value], axis=-1) + + pt1 = ctr + vec1 + vec2 + pt2 = ctr + vec1 - vec2 + pt3 = ctr - vec1 - vec2 + pt4 = ctr - vec1 + vec2 + + return np.stack([pt1, pt2, pt3, pt4], axis=-2) + + +def plot_one_box(x, img, color=None, label=None, line_thickness=None): + """ + description: Plots one bounding box on image img, + this function comes from YoLov8 project. + param: + x: a box likes [x1,y1,x2,y2,angle] + img: a opencv image object + color: color to draw rectangle, such as (0,255,0) + label: str + line_thickness: int + return: + no return + + """ + tl = ( + line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 + ) # line/font thickness + box = xywhr2xyxyxyxy(x).reshape(-1, 4, 2).squeeze() + p1 = [int(b) for b in box[0]] + # NOTE: cv2-version polylines needs np.asarray type. + cv2.polylines(img, [np.asarray(box, dtype=int)], True, color, thickness=tl, lineType=cv2.LINE_AA) + if label: + tf = max(tl - 1, 1) # font thickness + w, h = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] # text width, height + outside = p1[1] - h >= 3 + p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 + cv2.rectangle(img, p1, p2, color, -1, cv2.LINE_AA) # filled + cv2.putText( + img, + label, + (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), + 0, + tl / 3, + [225, 255, 255], + thickness=tf, + lineType=cv2.LINE_AA, + ) + + +class YoLov8TRT(object): + """ + description: A YOLOv8 class that warps TensorRT ops, preprocess and postprocess ops. + """ + + def __init__(self, engine_file_path): + # Create a Context on this device, + self.ctx = cuda.Device(0).make_context() + stream = cuda.Stream() + TRT_LOGGER = trt.Logger(trt.Logger.INFO) + runtime = trt.Runtime(TRT_LOGGER) + + # Deserialize the engine from file + with open(engine_file_path, "rb") as f: + engine = runtime.deserialize_cuda_engine(f.read()) + context = engine.create_execution_context() + + host_inputs = [] + cuda_inputs = [] + host_outputs = [] + cuda_outputs = [] + bindings = [] + + for binding in engine: + print('bingding:', binding, engine.get_binding_shape(binding)) + size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size + dtype = trt.nptype(engine.get_binding_dtype(binding)) + # Allocate host and device buffers + host_mem = cuda.pagelocked_empty(size, dtype) + cuda_mem = cuda.mem_alloc(host_mem.nbytes) + # Append the device buffer to device bindings. + bindings.append(int(cuda_mem)) + # Append to the appropriate list. + if engine.binding_is_input(binding): + self.input_w = engine.get_binding_shape(binding)[-1] + self.input_h = engine.get_binding_shape(binding)[-2] + host_inputs.append(host_mem) + cuda_inputs.append(cuda_mem) + else: + host_outputs.append(host_mem) + cuda_outputs.append(cuda_mem) + + # Store + self.stream = stream + self.context = context + self.engine = engine + self.host_inputs = host_inputs + self.cuda_inputs = cuda_inputs + self.host_outputs = host_outputs + self.cuda_outputs = cuda_outputs + self.bindings = bindings + self.batch_size = engine.max_batch_size + self.det_output_length = host_outputs[0].shape[0] + + def infer(self, raw_image_generator): + threading.Thread.__init__(self) + # Make self the active context, pushing it on top of the context stack. + self.ctx.push() + # Restore + stream = self.stream + context = self.context + host_inputs = self.host_inputs + cuda_inputs = self.cuda_inputs + host_outputs = self.host_outputs + cuda_outputs = self.cuda_outputs + bindings = self.bindings + # Do image preprocess + batch_image_raw = [] + batch_origin_h = [] + batch_origin_w = [] + batch_input_image = np.empty(shape=[self.batch_size, 3, self.input_h, self.input_w]) + for i, image_raw in enumerate(raw_image_generator): + input_image, image_raw, origin_h, origin_w = self.preprocess_image(image_raw) + batch_image_raw.append(image_raw) + batch_origin_h.append(origin_h) + batch_origin_w.append(origin_w) + np.copyto(batch_input_image[i], input_image) + batch_input_image = np.ascontiguousarray(batch_input_image) + + # Copy input image to host buffer + np.copyto(host_inputs[0], batch_input_image.ravel()) + start = time.time() + # Transfer input data to the GPU. + cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream) + # Run inference. + context.execute_async(batch_size=self.batch_size, bindings=bindings, stream_handle=stream.handle) + # Transfer predictions back from the GPU. + cuda.memcpy_dtoh_async(host_outputs[0], cuda_outputs[0], stream) + # Synchronize the stream + stream.synchronize() + end = time.time() + # Remove any context from the top of the context stack, deactivating it. + self.ctx.pop() + # Here we use the first row of output in that batch_size = 1 + output = host_outputs[0] + # Do postprocess + for i in range(self.batch_size): + result_boxes, result_scores, result_classid = self.post_process( + output[i * self.det_output_length: (i + 1) * self.det_output_length], batch_origin_h[i], + batch_origin_w[i] + ) + # Draw rectangles and labels on the original image + for j in range(len(result_boxes)): + box = result_boxes[j] + np.random.seed(int(result_classid[j])) + color = [np.random.randint(0, 255) for _ in range(3)] + plot_one_box( + box, + batch_image_raw[i], + label="{}:{:.2f}".format( + categories[int(result_classid[j])], result_scores[j] + ), + color=color, + line_thickness=1 + ) + return batch_image_raw, end - start + + def destroy(self): + # Remove any context from the top of the context stack, deactivating it. + self.ctx.pop() + + def get_raw_image(self, image_path_batch): + """ + description: Read an image from image path + """ + for img_path in image_path_batch: + yield cv2.imread(img_path) + + def get_raw_image_zeros(self, image_path_batch=None): + """ + description: Ready data for warmup + """ + for _ in range(self.batch_size): + yield np.zeros([self.input_h, self.input_w, 3], dtype=np.uint8) + + def preprocess_image(self, raw_bgr_image): + """ + description: Convert BGR image to RGB, + resize and pad it to target size, normalize to [0,1], + transform to NCHW format. + param: + input_image_path: str, image path + return: + image: the processed image + image_raw: the original image + h: original height + w: original width + """ + image_raw = raw_bgr_image + h, w, c = image_raw.shape + image = cv2.cvtColor(image_raw, cv2.COLOR_BGR2RGB) + # Calculate widht and height and paddings + r_w = self.input_w / w + r_h = self.input_h / h + if r_h > r_w: + tw = self.input_w + th = int(r_w * h) + tx1 = tx2 = 0 + ty1 = int((self.input_h - th) / 2) + ty2 = self.input_h - th - ty1 + else: + tw = int(r_h * w) + th = self.input_h + tx1 = int((self.input_w - tw) / 2) + tx2 = self.input_w - tw - tx1 + ty1 = ty2 = 0 + # Resize the image with long side while maintaining ratio + image = cv2.resize(image, (tw, th)) + # Pad the short side with (128,128,128) + image = cv2.copyMakeBorder( + image, ty1, ty2, tx1, tx2, cv2.BORDER_CONSTANT, None, (128, 128, 128) + ) + image = image.astype(np.float32) + # Normalize to [0,1] + image /= 255.0 + # HWC to CHW format: + image = np.transpose(image, [2, 0, 1]) + # CHW to NCHW format + image = np.expand_dims(image, axis=0) + # Convert the image to row-major order, also known as "C order": + image = np.ascontiguousarray(image) + return image, image_raw, h, w + + def xywh2xyxy(self, origin_h, origin_w, x): + """ + description: Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + param: + origin_h: height of original image + origin_w: width of original image + x: A boxes numpy, each row is a box [center_x, center_y, w, h] + return: + y: A boxes numpy, each row is a box [x1, y1, x2, y2] + """ + y = np.zeros_like(x) + r_w = self.input_w / origin_w + r_h = self.input_h / origin_h + if r_h > r_w: + y[:, 0] = x[:, 0] + y[:, 2] = x[:, 2] + y[:, 1] = x[:, 1] - (self.input_h - r_w * origin_h) / 2 + y[:, 3] = x[:, 3] - (self.input_h - r_w * origin_h) / 2 + y /= r_w + else: + y[:, 0] = x[:, 0] - (self.input_w - r_h * origin_w) / 2 + y[:, 2] = x[:, 2] - (self.input_w - r_h * origin_w) / 2 + y[:, 1] = x[:, 1] + y[:, 3] = x[:, 3] + y /= r_h + + return y + + def post_process(self, output, origin_h, origin_w): + """ + description: postprocess the prediction + param: + output: A numpy likes [num_boxes,cx,cy,w,h,conf,cls_id,angle cx,cy,w,h,conf,cls_id,angle ...] + origin_h: height of original image + origin_w: width of original image + return: + result_boxes: finally boxes, a boxes numpy, each row is a box [x1, y1, x2, y2, angle] + result_scores: finally scores, a numpy, each element is the score correspoing to box + result_classid: finally classid, a numpy, each element is the classid correspoing to box + """ + num_values_per_detection = DET_NUM + SEG_NUM + POSE_NUM + OBB_NUM + # Get the num of boxes detected + num = int(output[0]) + # Reshape to a two dimentional ndarray + # pred = np.reshape(output[1:], (-1, 38))[:num, :] + pred = np.reshape(output[1:], (-1, num_values_per_detection))[:num, :] + # Do nms + boxes = self.non_max_suppression(pred, origin_h, origin_w, + conf_thres=CONF_THRESH, nms_thres=IOU_THRESHOLD) + + columns_to_keep = [0, 1, 2, 3, 89] + result_boxes = boxes[:, columns_to_keep] if len(boxes) else np.array([]) + result_scores = boxes[:, 4] if len(boxes) else np.array([]) + result_classid = boxes[:, 5] if len(boxes) else np.array([]) + return result_boxes, result_scores, result_classid + + def covariance_matrix(self, boxes): + """ + description: Generating covariance matrix from obbs. + param: + boxes (np.ndarray): A numpy of shape (N, 5) representing rotated bounding boxes, with xywhr format. + + return: + (np.ndarray): Covariance metrixs corresponding to original rotated bounding boxes. + """ + # Gaussian bounding boxes, ignore the center points (the first two columns) because they are not needed here. + widths = boxes[:, 2:3].reshape(-1) + heights = boxes[:, 3:4].reshape(-1) + angles = boxes[:, 4].reshape(-1) + + a, b, c = (widths ** 2) / 12, (heights ** 2) / 12, angles + + cos_angles = np.cos(c) + sin_angles = np.sin(c) + + cos2 = cos_angles ** 2 + sin2 = sin_angles ** 2 + + return a * cos2 + b * sin2, a * sin2 + b * cos2, (a - b) * cos_angles * sin_angles + + def bbox_iou(self, box1, box2, x1y1x2y2=True): + """ + description: compute the IoU of two bounding boxes + param: + box1: A box coordinate (can be (x1, y1, x2, y2) or (x, y, w, h)) + box2: A box coordinate (can be (x1, y1, x2, y2) or (x, y, w, h)) + x1y1x2y2: select the coordinate format + return: + iou: computed iou + """ + if not x1y1x2y2: + # Transform from center and width to exact coordinates + b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2 + b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2 + b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2 + b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2 + else: + # Get the coordinates of bounding boxes + b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3] + + # Get the coordinates of the intersection rectangle + inter_rect_x1 = np.maximum(b1_x1, b2_x1) + inter_rect_y1 = np.maximum(b1_y1, b2_y1) + inter_rect_x2 = np.minimum(b1_x2, b2_x2) + inter_rect_y2 = np.minimum(b1_y2, b2_y2) + # Intersection area + inter_area = (np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, None) + * np.clip(inter_rect_y2 - inter_rect_y1 + 1, 0, None)) + # Union Area + b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1) + b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1) + + iou = inter_area / (b1_area + b2_area - inter_area + 1e-16) + + return iou + + def batch_probiou(self, obb1, obb2, eps=1e-7): + """ + description: Calculate the prob IoU between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf. + param: + obb1 (np.ndarray): A numpy of shape (N, 5) representing ground truth obbs, with xywhr format. + obb2 (np.ndarray): A numpy of shape (M, 5) representing predicted obbs, with xywhr format. + eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7. + return: + iou: computed iou + """ + x1, y1 = obb1[:, 0], obb1[:, 1] + x2, y2 = obb2[:, 0], obb2[:, 1] + + a1, b1, c1 = self.covariance_matrix(obb1) + a2, b2, c2 = self.covariance_matrix(obb2) + + t1 = ( + ((a1 + a2) * (y1 - y2) ** 2 + (b1 + b2) * (x1 - x2) ** 2) / + ((a1 + a2) * (b1 + b2) - (c1 + c2) ** 2 + eps) + ) * 0.25 + + t2 = ( + ((c1 + c2) * (x2 - x1) * (y1 - y2)) / + ((a1 + a2) * (b1 + b2) - (c1 + c2) ** 2 + eps) + ) * 0.5 + + t3 = ( + ((a1 + a2) * (b1 + b2) - (c1 + c2) ** 2) / + (4 * (np.clip(a1 * b1 - c1 ** 2, 0, None) * np.clip(a2 * b2 - c2 ** 2, 0, None)) ** 0.5 + eps) + + eps + ) + t3 = np.log(t3) * 0.5 + + bd = np.clip(t1 + t2 + t3, eps, 100.0) + hd = np.sqrt(1.0 - np.exp(-bd) + eps) + return 1 - hd + + def non_max_suppression(self, prediction, origin_h, origin_w, conf_thres=0.5, nms_thres=0.4): + """ + description: Removes detections with lower object confidence score than 'conf_thres' and performs + Non-Maximum Suppression to further filter detections. + param: + prediction: detections, (x1, y1, x2, y2, conf, cls_id, angle) + origin_h: original image height + origin_w: original image width + conf_thres: a confidence threshold to filter detections + nms_thres: a iou threshold to filter detections + return: + boxes: output after nms with the shape (x1, y1, x2, y2, conf, cls_id, angle) + """ + # Get the boxes that score > CONF_THRESH + boxes = prediction[prediction[:, 4] >= conf_thres] + col_idx = [0, 1, 2, 3, 89] + # Trandform bbox from [center_x, center_y, w, h] to [x1, y1, x2, y2] + boxes[:, :4] = self.xywh2xyxy(origin_h, origin_w, boxes[:, :4]) + # clip the coordinates + boxes[:, 0] = np.clip(boxes[:, 0], 0, origin_w - 1) + boxes[:, 2] = np.clip(boxes[:, 2], 0, origin_w - 1) + boxes[:, 1] = np.clip(boxes[:, 1], 0, origin_h - 1) + boxes[:, 3] = np.clip(boxes[:, 3], 0, origin_h - 1) + # Object confidence + confs = boxes[:, 4] + # Sort by the confs + boxes = boxes[np.argsort(-confs)] + # Perform non-maximum suppression + keep_boxes = [] + while boxes.shape[0]: + large_overlap = self.batch_probiou(np.expand_dims(boxes[0, col_idx], 0), boxes[:, col_idx]) > nms_thres + label_match = boxes[0, 5] == boxes[:, 5] + # Indices of boxes with lower confidence scores, large IOUs and matching labels + invalid = large_overlap & label_match + keep_boxes += [boxes[0]] + boxes = boxes[~invalid] + boxes = np.stack(keep_boxes, 0) if len(keep_boxes) else np.array([]) + + return boxes + + +class inferThread(threading.Thread): + def __init__(self, yolov8_wrapper, image_path_batch): + threading.Thread.__init__(self) + self.yolov8_wrapper = yolov8_wrapper + self.image_path_batch = image_path_batch + + def run(self): + batch_image_raw, use_time = self.yolov8_wrapper.infer(self.yolov8_wrapper.get_raw_image(self.image_path_batch)) + for i, img_path in enumerate(self.image_path_batch): + parent, filename = os.path.split(img_path) + save_name = os.path.join('output', filename) + # Save image + cv2.imwrite(save_name, batch_image_raw[i]) + print('input->{}, time->{:.2f}ms, saving into output/'.format(self.image_path_batch, use_time * 1000)) + + +class warmUpThread(threading.Thread): + def __init__(self, yolov8_wrapper): + threading.Thread.__init__(self) + self.yolov8_wrapper = yolov8_wrapper + + def run(self): + batch_image_raw, use_time = self.yolov8_wrapper.infer(self.yolov8_wrapper.get_raw_image_zeros()) + print('warm_up->{}, time->{:.2f}ms'.format(batch_image_raw[0].shape, use_time * 1000)) + + +if __name__ == "__main__": + # load custom plugin and engine + PLUGIN_LIBRARY = "./build/libmyplugins.so" + engine_file_path = "yolov8n-obb.engine" + + if len(sys.argv) > 1: + engine_file_path = sys.argv[1] + if len(sys.argv) > 2: + PLUGIN_LIBRARY = sys.argv[2] + + ctypes.CDLL(PLUGIN_LIBRARY) + + # load DOTAV 1.5 labels + + categories = ["plane", "ship", "storage tank", "baseball diamond", "tennis court", + "basketball court", "ground track field", "harbor", + "bridge", "large vehicle", "small vehicle", "helicopter", + "roundabout", "soccer ball field", "swimming pool", "container crane"] + + if os.path.exists('output/'): + shutil.rmtree('output/') + os.makedirs('output/') + # a YoLov8TRT instance + yolov8_wrapper = YoLov8TRT(engine_file_path) + try: + print('batch size is', yolov8_wrapper.batch_size) + + image_dir = "images/" + image_path_batches = get_img_path_batches(yolov8_wrapper.batch_size, image_dir) + + for i in range(10): + # create a new thread to do warm_up + thread1 = warmUpThread(yolov8_wrapper) + thread1.start() + thread1.join() + for batch in image_path_batches: + # create a new thread to do inference + thread1 = inferThread(yolov8_wrapper, batch) + thread1.start() + thread1.join() + finally: + # destroy the instance + yolov8_wrapper.destroy() diff --git a/yolov8/yolov8_pose.cpp b/yolov8/yolov8_pose.cpp index 84d35aea..6ee4a0cd 100644 --- a/yolov8/yolov8_pose.cpp +++ b/yolov8/yolov8_pose.cpp @@ -11,7 +11,7 @@ Logger gLogger; using namespace nvinfer1; -const int kOutputSize = kMaxNumOutputBbox * (sizeof(Detection) - sizeof(float) * 32) / sizeof(float) + 1; +const int kOutputSize = kMaxNumOutputBbox * sizeof(Detection) / sizeof(float) + 1; void serialize_engine(std::string& wts_name, std::string& engine_name, int& is_p, std::string& sub_type, float& gd, float& gw, int& max_channels) { diff --git a/yolov8/yolov8_pose_trt.py b/yolov8/yolov8_pose_trt.py index 80f11e7b..3cf36e87 100644 --- a/yolov8/yolov8_pose_trt.py +++ b/yolov8/yolov8_pose_trt.py @@ -19,6 +19,7 @@ POSE_NUM = 17 * 3 DET_NUM = 6 SEG_NUM = 32 +OBB_NUM = 1 keypoint_pairs = [ (0, 1), (0, 2), (0, 5), (0, 6), (1, 2), (1, 3), (2, 4), (5, 6), (5, 7), (5, 11), @@ -75,7 +76,7 @@ def plot_one_box(x, img, color=None, label=None, line_thickness=None): [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA, - ) + ) class YoLov8TRT(object): @@ -329,8 +330,8 @@ def post_process(self, output, origin_h, origin_w): result_keypoints: Final keypoints, a list of numpy arrays, each element represents keypoints for a box, shaped as (#keypoints, 3) """ - # Number of values per detection: 38 base values + 17 keypoints * 3 values each - num_values_per_detection = DET_NUM + SEG_NUM + POSE_NUM + # Number of values per detection: 38 base values + 17 keypoints * 3 values each + angle + num_values_per_detection = DET_NUM + SEG_NUM + POSE_NUM + OBB_NUM # Get the number of boxes detected num = int(output[0]) # Reshape to a two-dimensional ndarray with the full detection shape @@ -345,7 +346,7 @@ def post_process(self, output, origin_h, origin_w): result_boxes = boxes[:, :4] if len(boxes) else np.array([]) result_scores = boxes[:, 4] if len(boxes) else np.array([]) result_classid = boxes[:, 5] if len(boxes) else np.array([]) - result_keypoints = boxes[:, -POSE_NUM:] if len(boxes) else np.array([]) + result_keypoints = boxes[:, -POSE_NUM-1:-1] if len(boxes) else np.array([]) # Return the post-processed results including keypoints return result_boxes, result_scores, result_classid, result_keypoints @@ -405,11 +406,11 @@ def non_max_suppression(self, prediction, origin_h, origin_w, conf_thres=0.5, nm # Trandform bbox from [center_x, center_y, w, h] to [x1, y1, x2, y2] res_array = np.copy(boxes) box_pred_deep_copy = np.copy(boxes[:, :4]) - keypoints_pred_deep_copy = np.copy(boxes[:, -POSE_NUM:]) + keypoints_pred_deep_copy = np.copy(boxes[:, -POSE_NUM-1:-1]) res_box, res_keypoints = self.xywh2xyxy_with_keypoints( origin_h, origin_w, box_pred_deep_copy, keypoints_pred_deep_copy) res_array[:, :4] = res_box - res_array[:, -POSE_NUM:] = res_keypoints + res_array[:, -POSE_NUM-1:-1] = res_keypoints # clip the coordinates res_array[:, 0] = np.clip(res_array[:, 0], 0, origin_w - 1) res_array[:, 2] = np.clip(res_array[:, 2], 0, origin_w - 1) diff --git a/yolov8/yolov8_seg.cpp b/yolov8/yolov8_seg.cpp index 415d4e74..b1541c76 100644 --- a/yolov8/yolov8_seg.cpp +++ b/yolov8/yolov8_seg.cpp @@ -11,7 +11,7 @@ Logger gLogger; using namespace nvinfer1; -const int kOutputSize = kMaxNumOutputBbox * (sizeof(Detection) - sizeof(float) * 51) / sizeof(float) + 1; +const int kOutputSize = kMaxNumOutputBbox * sizeof(Detection) / sizeof(float) + 1; const static int kOutputSegSize = 32 * (kInputH / 4) * (kInputW / 4); static cv::Rect get_downscale_rect(float bbox[4], float scale) { diff --git a/yolov8/yolov8_seg_trt.py b/yolov8/yolov8_seg_trt.py index b31dd780..3dfd65e3 100644 --- a/yolov8/yolov8_seg_trt.py +++ b/yolov8/yolov8_seg_trt.py @@ -19,6 +19,7 @@ POSE_NUM = 17 * 3 DET_NUM = 6 SEG_NUM = 32 +OBB_NUM = 1 def get_img_path_batches(batch_size, img_dir): @@ -69,7 +70,7 @@ def plot_one_box(x, img, color=None, label=None, line_thickness=None): [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA, - ) + ) class YoLov8TRT(object): @@ -131,7 +132,7 @@ def __init__(self, engine_file_path): self.seg_w = int(self.input_w / 4) self.seg_h = int(self.input_h / 4) self.seg_c = int(self.seg_output_length / (self.seg_w * self.seg_w)) - self.det_row_output_length = self.seg_c + DET_NUM + POSE_NUM + self.det_row_output_length = self.seg_c + DET_NUM + POSE_NUM + OBB_NUM # Draw mask self.colors_obj = Colors() @@ -526,7 +527,7 @@ def hex2rgb(h): # rgb order (PIL) if __name__ == "__main__": # load custom plugin and engine - PLUGIN_LIBRARY = "build/libmyplugins.so" + PLUGIN_LIBRARY = "./build/libmyplugins.so" engine_file_path = "yolov8n-seg.engine" if len(sys.argv) > 1: