Skip to content

Commit

Permalink
Distinguish the numclass of pose and obb in config. h and add the Ori…
Browse files Browse the repository at this point in the history
…ented Bounding Boxes (OBB) Estimation algorithm (#1593)

* Add the generation of multi-class pose engines

* Change grids in forwardGpu to one-dimensional arrays

* Update README.md

* Update types.h

keypoints array with dynamic size based on kNumberOfPoints

* yolov8_5u_det(YOLOv5u with the anchor-free, objectness-free split head structure based on YOLOv8 features) model

* update

* fix code style

* yolov8_5u_det model download link

* yolov8_5u_det model download link

* Distinguish the numclass of pose and obb in config. h and add the Oriented Bounding Boxes (OBB) Estimation algorithm

* fix code style

---------

Co-authored-by: lindsayshuo <[email protected]>
  • Loading branch information
lindsayshuo and lindsayshuo authored Oct 22, 2024
1 parent 0f75fd7 commit 21f23ff
Show file tree
Hide file tree
Showing 23 changed files with 1,677 additions and 156 deletions.
7 changes: 5 additions & 2 deletions yolov8/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ else()
link_directories(/usr/local/cuda/lib64)

# tensorrt
include_directories(/home/lindsay/TensorRT-8.4.1.5/include)
link_directories(/home/lindsay/TensorRT-8.4.1.5/lib)
include_directories(/home/lindsay/TensorRT-8.6.1.6/include)
link_directories(/home/lindsay/TensorRT-8.6.1.6/lib)
# include_directories(/home/lindsay/TensorRT-7.2.3.4/include)
# link_directories(/home/lindsay/TensorRT-7.2.3.4/lib)

Expand Down Expand Up @@ -60,3 +60,6 @@ target_link_libraries(yolov8_cls nvinfer cudart myplugins ${OpenCV_LIBS})

add_executable(yolov8_5u_det ${PROJECT_SOURCE_DIR}/yolov8_5u_det.cpp ${SRCS})
target_link_libraries(yolov8_5u_det nvinfer cudart myplugins ${OpenCV_LIBS})

add_executable(yolov8_obb ${PROJECT_SOURCE_DIR}/yolov8_obb.cpp ${SRCS})
target_link_libraries(yolov8_obb nvinfer cudart myplugins ${OpenCV_LIBS})
27 changes: 25 additions & 2 deletions yolov8/README.md
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ wget -O coco.txt https://raw.githubusercontent.com/amikelive/coco-labels/master/
```
cd {tensorrtx}/yolov8/
// Download inference images
wget https://github.com/lindsayshuo/infer_pic/blob/main/1709970363.6990473rescls.jpg
wget https://github.com/lindsayshuo/infer_pic/releases/download/pics/1709970363.6990473rescls.jpg
mkdir samples
cp -r 1709970363.6990473rescls.jpg samples
// Download ImageNet labels
Expand All @@ -130,7 +130,7 @@ sudo ./yolov8_cls -d yolov8n-cls.engine ../samples
### Pose Estimation
```
cd {tensorrtx}/yolov8/
// update "kNumClass = 1" in config.h
// update "kPoseNumClass = 1" in config.h
mkdir build
cd build
cp {ultralytics}/ultralytics/yolov8-pose.wts {tensorrtx}/yolov8/build
Expand All @@ -146,6 +146,28 @@ sudo ./yolov8_pose -d yolov8n-pose.engine ../images g //gpu postprocess
```


### Oriented Bounding Boxes (OBB) Estimation
```
cd {tensorrtx}/yolov8/
// update "kObbNumClass = 15" "kInputH = 1024" "kInputW = 1024" in config.h
wget https://github.com/lindsayshuo/infer_pic/releases/download/pics/obb.png
mkdir images
mv obb.png ./images
mkdir build
cd build
cp {ultralytics}/ultralytics/yolov8-obb.wts {tensorrtx}/yolov8/build
cmake ..
make
sudo ./yolov8_obb -s [.wts] [.engine] [n/s/m/l/x/n2/s2/m2/l2/x2/n6/s6/m6/l6/x6] // serialize model to plan file
sudo ./yolov8_obb -d [.engine] [image folder] [c/g] // deserialize and run inference, the images in [image folder] will be processed.
// For example yolov8-obb
sudo ./yolov8_obb -s yolov8n-obb.wts yolov8n-obb.engine n
sudo ./yolov8_obb -d yolov8n-obb.engine ../images c //cpu postprocess
sudo ./yolov8_obb -d yolov8n-obb.engine ../images g //gpu postprocess
```


4. optional, load and run the tensorrt model in python

```
Expand All @@ -156,6 +178,7 @@ python yolov8_seg_trt.py # Segmentation
python yolov8_cls_trt.py # Classification
python yolov8_pose_trt.py # Pose Estimation
python yolov8_5u_det_trt.py # yolov8_5u_det(YOLOv5u with the anchor-free, objectness-free split head structure based on YOLOv8 features) model
python yolov8_obb_trt.py # Oriented Bounding Boxes (OBB) Estimation
```

# INT8 Quantization
Expand Down
4 changes: 2 additions & 2 deletions yolov8/gen_wts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def parse_args():
parser.add_argument(
'-o', '--output', help='Output (.wts) file path (optional)')
parser.add_argument(
'-t', '--type', type=str, default='detect', choices=['detect', 'cls', 'seg', 'pose'],
'-t', '--type', type=str, default='detect', choices=['detect', 'cls', 'seg', 'pose', 'obb'],
help='determines the model is detection/classification')
args = parser.parse_args()
if not os.path.isfile(args.weights):
Expand All @@ -39,7 +39,7 @@ def parse_args():
# Load model
model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32

if m_type in ['detect', 'seg', 'pose']:
if m_type in ['detect', 'seg', 'pose', 'obb']:
anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None]

delattr(model.model[-1], 'anchors')
Expand Down
2 changes: 1 addition & 1 deletion yolov8/include/block.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ nvinfer1::IShuffleLayer* DFL(nvinfer1::INetworkDefinition* network, std::map<std

nvinfer1::IPluginV2Layer* addYoLoLayer(nvinfer1::INetworkDefinition* network,
std::vector<nvinfer1::IConcatenationLayer*> dets, const int* px_arry,
int px_arry_num, bool is_segmentation, bool is_pose);
int px_arry_num, int num_class, bool is_segmentation, bool is_pose, bool is_obb);
8 changes: 7 additions & 1 deletion yolov8/include/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
const static char* kInputTensorName = "images";
const static char* kOutputTensorName = "output";
const static int kNumClass = 80;
const static int kNumberOfPoints = 17; // number of keypoints total
const static int kBatchSize = 1;
const static int kGpuId = 0;
const static int kInputH = 640;
Expand All @@ -23,3 +22,10 @@ constexpr static int kClsNumClass = 1000;
// Classfication model's input shape
constexpr static int kClsInputH = 224;
constexpr static int kClsInputW = 224;

// pose model's number of classes
constexpr static int kPoseNumClass = 1;
const static int kNumberOfPoints = 17; // number of keypoints total

// obb model's number of classes
constexpr static int kObbNumClass = 15;
4 changes: 4 additions & 0 deletions yolov8/include/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,7 @@ nvinfer1::IHostMemory* buildEngineYolov8_5uDet(nvinfer1::IBuilder* builder, nvin
nvinfer1::IHostMemory* buildEngineYolov8_5uDetP6(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config,
nvinfer1::DataType dt, const std::string& wts_path, float& gd,
float& gw, int& max_channels);

nvinfer1::IHostMemory* buildEngineYolov8Obb(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config,
nvinfer1::DataType dt, const std::string& wts_path, float& gd, float& gw,
int& max_channels);
33 changes: 22 additions & 11 deletions yolov8/include/postprocess.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,38 @@
#include "NvInfer.h"
#include "types.h"

// Preprocessing functions
cv::Rect get_rect(cv::Mat& img, float bbox[4]);

void nms(std::vector<Detection>& res, float* output, float conf_thresh, float nms_thresh = 0.5);

void batch_nms(std::vector<std::vector<Detection>>& batch_res, float* output, int batch_size, int output_size,
float conf_thresh, float nms_thresh = 0.5);

void draw_bbox(std::vector<cv::Mat>& img_batch, std::vector<std::vector<Detection>>& res_batch);

void draw_bbox_keypoints_line(std::vector<cv::Mat>& img_batch, std::vector<std::vector<Detection>>& res_batch);

// Processing functions
void batch_process(std::vector<std::vector<Detection>>& res_batch, const float* decode_ptr_host, int batch_size,
int bbox_element, const std::vector<cv::Mat>& img_batch);

void batch_process_obb(std::vector<std::vector<Detection>>& res_batch, const float* decode_ptr_host, int batch_size,
int bbox_element, const std::vector<cv::Mat>& img_batch);
void process_decode_ptr_host(std::vector<Detection>& res, const float* decode_ptr_host, int bbox_element, cv::Mat& img,
int count);
void process_decode_ptr_host_obb(std::vector<Detection>& res, const float* decode_ptr_host, int bbox_element,
cv::Mat& img, int count);

// NMS functions
void nms(std::vector<Detection>& res, float* output, float conf_thresh, float nms_thresh = 0.5);
void batch_nms(std::vector<std::vector<Detection>>& batch_res, float* output, int batch_size, int output_size,
float conf_thresh, float nms_thresh = 0.5);
void nms_obb(std::vector<Detection>& res, float* output, float conf_thresh, float nms_thresh = 0.5);
void batch_nms_obb(std::vector<std::vector<Detection>>& batch_res, float* output, int batch_size, int output_size,
float conf_thresh, float nms_thresh = 0.5);

// CUDA-related functions
void cuda_decode(float* predict, int num_bboxes, float confidence_threshold, float* parray, int max_objects,
cudaStream_t stream);

void cuda_nms(float* parray, float nms_threshold, int max_objects, cudaStream_t stream);
void cuda_decode_obb(float* predict, int num_bboxes, float confidence_threshold, float* parray, int max_objects,
cudaStream_t stream);
void cuda_nms_obb(float* parray, float nms_threshold, int max_objects, cudaStream_t stream);

// Drawing functions
void draw_bbox(std::vector<cv::Mat>& img_batch, std::vector<std::vector<Detection>>& res_batch);
void draw_bbox_obb(std::vector<cv::Mat>& img_batch, std::vector<std::vector<Detection>>& res_batch);
void draw_bbox_keypoints_line(std::vector<cv::Mat>& img_batch, std::vector<std::vector<Detection>>& res_batch);
void draw_mask_bbox(cv::Mat& img, std::vector<Detection>& dets, std::vector<cv::Mat>& masks,
std::unordered_map<int, std::string>& labels_map);
1 change: 1 addition & 0 deletions yolov8/include/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ struct alignas(float) Detection {
float class_id;
float mask[32];
float keypoints[kNumberOfPoints * 3]; // keypoints array with dynamic size based on kNumberOfPoints
float angle; // obb angle
};

struct AffineMatrix {
Expand Down
66 changes: 45 additions & 21 deletions yolov8/plugin/yololayer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ __device__ float sigmoid(float x) {

namespace nvinfer1 {
YoloLayerPlugin::YoloLayerPlugin(int classCount, int numberofpoints, float confthreshkeypoints, int netWidth,
int netHeight, int maxOut, bool is_segmentation, bool is_pose, const int* strides,
int stridesLength) {
int netHeight, int maxOut, bool is_segmentation, bool is_pose, bool is_obb,
const int* strides, int stridesLength) {

mClassCount = classCount;
mNumberofpoints = numberofpoints;
Expand All @@ -40,6 +40,7 @@ YoloLayerPlugin::YoloLayerPlugin(int classCount, int numberofpoints, float conft
memcpy(mStrides, strides, stridesLength * sizeof(int));
is_segmentation_ = is_segmentation;
is_pose_ = is_pose;
is_obb_ = is_obb;
}

YoloLayerPlugin::~YoloLayerPlugin() {
Expand All @@ -66,6 +67,7 @@ YoloLayerPlugin::YoloLayerPlugin(const void* data, size_t length) {
}
read(d, is_segmentation_);
read(d, is_pose_);
read(d, is_obb_);

assert(d == a + length);
}
Expand All @@ -87,14 +89,15 @@ void YoloLayerPlugin::serialize(void* buffer) const TRT_NOEXCEPT {
}
write(d, is_segmentation_);
write(d, is_pose_);
write(d, is_obb_);

assert(d == a + getSerializationSize());
}

size_t YoloLayerPlugin::getSerializationSize() const TRT_NOEXCEPT {
return sizeof(mClassCount) + sizeof(mNumberofpoints) + sizeof(mConfthreshkeypoints) + sizeof(mThreadCount) +
sizeof(mYoloV8netHeight) + sizeof(mYoloV8NetWidth) + sizeof(mMaxOutObject) + sizeof(mStridesLength) +
sizeof(int) * mStridesLength + sizeof(is_segmentation_) + sizeof(is_pose_);
sizeof(int) * mStridesLength + sizeof(is_segmentation_) + sizeof(is_pose_) + sizeof(is_obb_);
}

int YoloLayerPlugin::initialize() TRT_NOEXCEPT {
Expand Down Expand Up @@ -156,7 +159,7 @@ nvinfer1::IPluginV2IOExt* YoloLayerPlugin::clone() const TRT_NOEXCEPT {

YoloLayerPlugin* p =
new YoloLayerPlugin(mClassCount, mNumberofpoints, mConfthreshkeypoints, mYoloV8NetWidth, mYoloV8netHeight,
mMaxOutObject, is_segmentation_, is_pose_, mStrides, mStridesLength);
mMaxOutObject, is_segmentation_, is_pose_, is_obb_, mStrides, mStridesLength);
p->setPluginNamespace(mPluginNamespace);
return p;
}
Expand All @@ -174,14 +177,14 @@ __device__ float Logist(float data) {

__global__ void CalDetection(const float* input, float* output, int numElements, int maxoutobject, const int grid_h,
int grid_w, const int stride, int classes, int nk, float confkeypoints, int outputElem,
bool is_segmentation, bool is_pose) {
bool is_segmentation, bool is_pose, bool is_obb) {
int idx = threadIdx.x + blockDim.x * blockIdx.x;
if (idx >= numElements)
return;

const int N_kpts = nk;
int total_grid = grid_h * grid_w;
int info_len = 4 + classes + (is_segmentation ? 32 : 0) + (is_pose ? N_kpts * 3 : 0);
int info_len = 4 + classes + (is_segmentation ? 32 : 0) + (is_pose ? N_kpts * 3 : 0) + (is_obb ? 1 : 0);
int batchIdx = idx / total_grid;
int elemIdx = idx % total_grid;
const float* curInput = input + batchIdx * total_grid * info_len;
Expand Down Expand Up @@ -218,15 +221,16 @@ __global__ void CalDetection(const float* input, float* output, int numElements,

if (is_segmentation) {
for (int k = 0; k < 32; ++k) {
det->mask[k] = curInput[elemIdx + (4 + classes + k) * total_grid];
det->mask[k] =
curInput[elemIdx + (4 + classes + (is_pose ? N_kpts * 3 : 0) + (is_obb ? 1 : 0) + k) * total_grid];
}
}

if (is_pose) {
for (int kpt = 0; kpt < N_kpts; kpt++) {
int kpt_x_idx = (4 + classes + (is_segmentation ? 32 : 0) + kpt * 3) * total_grid;
int kpt_y_idx = (4 + classes + (is_segmentation ? 32 : 0) + kpt * 3 + 1) * total_grid;
int kpt_conf_idx = (4 + classes + (is_segmentation ? 32 : 0) + kpt * 3 + 2) * total_grid;
int kpt_x_idx = (4 + classes + (is_segmentation ? 32 : 0) + (is_obb ? 1 : 0) + kpt * 3) * total_grid;
int kpt_y_idx = (4 + classes + (is_segmentation ? 32 : 0) + (is_obb ? 1 : 0) + kpt * 3 + 1) * total_grid;
int kpt_conf_idx = (4 + classes + (is_segmentation ? 32 : 0) + (is_obb ? 1 : 0) + kpt * 3 + 2) * total_grid;

float kpt_confidence = sigmoid(curInput[elemIdx + kpt_conf_idx]);

Expand All @@ -247,24 +251,43 @@ __global__ void CalDetection(const float* input, float* output, int numElements,
}
}
}

if (is_obb) {
double pi = M_PI;
auto angle_inx = curInput[elemIdx + (4 + classes + (is_segmentation ? 32 : 0) + (is_pose ? N_kpts * 3 : 0) +
0) * total_grid];
auto angle = (sigmoid(angle_inx) - 0.25f) * pi;

auto cos1 = cos(angle);
auto sin1 = sin(angle);
auto xf = (curInput[elemIdx + 2 * total_grid] - curInput[elemIdx + 0 * total_grid]) / 2;
auto yf = (curInput[elemIdx + 3 * total_grid] - curInput[elemIdx + 1 * total_grid]) / 2;

auto x = xf * cos1 - yf * sin1;
auto y = xf * sin1 + yf * cos1;

float cx = (col + 0.5f + x) * stride;
float cy = (row + 0.5f + y) * stride;

float w1 = (curInput[elemIdx + 0 * total_grid] + curInput[elemIdx + 2 * total_grid]) * stride;
float h1 = (curInput[elemIdx + 1 * total_grid] + curInput[elemIdx + 3 * total_grid]) * stride;
det->bbox[0] = cx;
det->bbox[1] = cy;
det->bbox[2] = w1;
det->bbox[3] = h1;
det->angle = angle;
}
}

void YoloLayerPlugin::forwardGpu(const float* const* inputs, float* output, cudaStream_t stream, int mYoloV8netHeight,
int mYoloV8NetWidth, int batchSize) {

int outputElem = 1 + mMaxOutObject * sizeof(Detection) / sizeof(float);
cudaMemsetAsync(output, 0, sizeof(float), stream);
for (int idx = 0; idx < batchSize; ++idx) {
CUDA_CHECK(cudaMemsetAsync(output + idx * outputElem, 0, sizeof(float), stream));
}
int numElem = 0;

// const int maxGrids = mStridesLength;
// int grids[maxGrids][2];
// for (int i = 0; i < maxGrids; ++i) {
// grids[i][0] = mYoloV8netHeight / mStrides[i];
// grids[i][1] = mYoloV8NetWidth / mStrides[i];
// }

int maxGrids = mStridesLength;
int flatGridsLen = 2 * maxGrids;
int* flatGrids = new int[flatGridsLen];
Expand All @@ -286,7 +309,7 @@ void YoloLayerPlugin::forwardGpu(const float* const* inputs, float* output, cuda
// The CUDA kernel call remains unchanged
CalDetection<<<(numElem + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream>>>(
inputs[i], output, numElem, mMaxOutObject, grid_h, grid_w, stride, mClassCount, mNumberofpoints,
mConfthreshkeypoints, outputElem, is_segmentation_, is_pose_);
mConfthreshkeypoints, outputElem, is_segmentation_, is_pose_, is_obb_);
}

delete[] flatGrids;
Expand Down Expand Up @@ -317,7 +340,7 @@ IPluginV2IOExt* YoloPluginCreator::createPlugin(const char* name, const PluginFi
assert(fc->nbFields == 1);
assert(strcmp(fc->fields[0].name, "combinedInfo") == 0);
const int* combinedInfo = static_cast<const int*>(fc->fields[0].data);
int netinfo_count = 8;
int netinfo_count = 9;
int class_count = combinedInfo[0];
int numberofpoints = combinedInfo[1];
float confthreshkeypoints = combinedInfo[2];
Expand All @@ -326,11 +349,12 @@ IPluginV2IOExt* YoloPluginCreator::createPlugin(const char* name, const PluginFi
int max_output_object_count = combinedInfo[5];
bool is_segmentation = combinedInfo[6];
bool is_pose = combinedInfo[7];
bool is_obb = combinedInfo[8];
const int* px_arry = combinedInfo + netinfo_count;
int px_arry_length = fc->fields[0].length - netinfo_count;
YoloLayerPlugin* obj =
new YoloLayerPlugin(class_count, numberofpoints, confthreshkeypoints, input_w, input_h,
max_output_object_count, is_segmentation, is_pose, px_arry, px_arry_length);
max_output_object_count, is_segmentation, is_pose, is_obb, px_arry, px_arry_length);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
Expand Down
3 changes: 2 additions & 1 deletion yolov8/plugin/yololayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace nvinfer1 {
class API YoloLayerPlugin : public IPluginV2IOExt {
public:
YoloLayerPlugin(int classCount, int numberofpoints, float confthreshkeypoints, int netWidth, int netHeight,
int maxOut, bool is_segmentation, bool is_pose, const int* strides, int stridesLength);
int maxOut, bool is_segmentation, bool is_pose, bool is_obb, const int* strides, int stridesLength);

YoloLayerPlugin(const void* data, size_t length);
~YoloLayerPlugin();
Expand Down Expand Up @@ -75,6 +75,7 @@ class API YoloLayerPlugin : public IPluginV2IOExt {
int mMaxOutObject;
bool is_segmentation_;
bool is_pose_;
bool is_obb_;
int* mStrides;
int mStridesLength;
};
Expand Down
Loading

0 comments on commit 21f23ff

Please sign in to comment.