Skip to content

Commit

Permalink
Merge pull request #19 from occ-ai/roy.generic_onnx_rt_model_support
Browse files Browse the repository at this point in the history
feat: Add YuNet model support for object detection
  • Loading branch information
royshil committed May 31, 2024
2 parents 1d80702 + a8e33d0 commit 6ccfecb
Show file tree
Hide file tree
Showing 17 changed files with 712 additions and 414 deletions.
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ target_sources(
src/detect-filter-info.c
src/detect-filter-utils.cpp
src/obs-utils/obs-utils.cpp
src/ort-model/ONNXRuntimeModel.cpp
src/edgeyolo/edgeyolo_onnxruntime.cpp
src/sort/Sort.cpp)
src/sort/Sort.cpp
src/yunet/YuNet.cpp)

set_target_properties_plugin(${CMAKE_PROJECT_NAME} PROPERTIES OUTPUT_NAME ${_name})
Binary file added data/models/face_detection_yunet_2023mar.onnx
Binary file not shown.
4 changes: 2 additions & 2 deletions src/FilterData.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#define FILTERDATA_H

#include <obs-module.h>
#include "edgeyolo/edgeyolo_onnxruntime.hpp"
#include "ort-model/ONNXRuntimeModel.h"
#include "sort/Sort.h"

/**
Expand Down Expand Up @@ -58,7 +58,7 @@ struct filter_data {
std::mutex outputLock;
std::mutex modelMutex;

std::unique_ptr<edgeyolo_cpp::EdgeYOLOONNXRuntime> edgeyolo;
std::unique_ptr<ONNXRuntimeModel> onnxruntimemodel;
std::vector<std::string> classNames;

#if _WIN32
Expand Down
77 changes: 47 additions & 30 deletions src/detect-filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@
#include "FilterData.h"
#include "consts.h"
#include "obs-utils/obs-utils.h"
#include "edgeyolo/utils.hpp"
#include "ort-model/utils.hpp"
#include "detect-filter-utils.h"
#include "edgeyolo/edgeyolo_onnxruntime.hpp"
#include "yunet/YuNet.h"

#define EXTERNAL_MODEL_SIZE "!!!EXTERNAL_MODEL!!!"
#define FACE_DETECT_MODEL_SIZE "!!!FACE_DETECT!!!"

struct detect_filter : public filter_data {};

Expand Down Expand Up @@ -325,6 +328,8 @@ obs_properties_t *detect_filter_properties(void *data)
obs_property_list_add_string(model_size, obs_module_text("SmallFast"), "small");
obs_property_list_add_string(model_size, obs_module_text("Medium"), "medium");
obs_property_list_add_string(model_size, obs_module_text("LargeSlow"), "large");
obs_property_list_add_string(model_size, obs_module_text("FaceDetect"),
FACE_DETECT_MODEL_SIZE);
obs_property_list_add_string(model_size, obs_module_text("ExternalModel"),
EXTERNAL_MODEL_SIZE);

Expand Down Expand Up @@ -513,6 +518,9 @@ void detect_filter_update(void *data, obs_data_t *settings)
} else if (newModelSize == "large") {
modelFilepath_rawPtr =
obs_module_file("models/edgeyolo_tiny_lrelu_coco_736x1280.onnx");
} else if (newModelSize == FACE_DETECT_MODEL_SIZE) {
modelFilepath_rawPtr =
obs_module_file("models/face_detection_yunet_2023mar.onnx");
} else if (newModelSize == EXTERNAL_MODEL_SIZE) {
const char *external_model_file =
obs_data_get_string(settings, "external_model_file");
Expand Down Expand Up @@ -580,41 +588,53 @@ void detect_filter_update(void *data, obs_data_t *settings)
obs_log(LOG_ERROR,
"JSON file does not contain 'labels' field");
tf->isDisabled = true;
tf->edgeyolo.reset();
tf->onnxruntimemodel.reset();
return;
}
} else {
obs_log(LOG_ERROR, "Failed to open JSON file: %s",
labelsFilepath.c_str());
tf->isDisabled = true;
tf->edgeyolo.reset();
tf->onnxruntimemodel.reset();
return;
}
} else if (tf->modelSize == FACE_DETECT_MODEL_SIZE) {
num_classes_ = 1;
tf->classNames = yunet::FACE_CLASSES;
}

// Load model
try {
if (tf->edgeyolo) {
tf->edgeyolo.reset();
if (tf->onnxruntimemodel) {
tf->onnxruntimemodel.reset();
}
if (tf->modelSize == FACE_DETECT_MODEL_SIZE) {
tf->onnxruntimemodel = std::make_unique<yunet::YuNetONNX>(
tf->modelFilepath, tf->numThreads, 50, tf->numThreads,
tf->useGPU, onnxruntime_device_id_,
onnxruntime_use_parallel_, nms_th_, tf->conf_threshold);
} else {
tf->onnxruntimemodel =
std::make_unique<edgeyolo_cpp::EdgeYOLOONNXRuntime>(
tf->modelFilepath, tf->numThreads, num_classes_,
tf->numThreads, tf->useGPU, onnxruntime_device_id_,
onnxruntime_use_parallel_, nms_th_,
tf->conf_threshold);
}
tf->edgeyolo = std::make_unique<edgeyolo_cpp::EdgeYOLOONNXRuntime>(
tf->modelFilepath, tf->numThreads, tf->numThreads, tf->useGPU,
onnxruntime_device_id_, onnxruntime_use_parallel_, nms_th_,
tf->conf_threshold, num_classes_);
// clear error message
obs_data_set_string(settings, "error", "");
} catch (const std::exception &e) {
obs_log(LOG_ERROR, "Failed to load model: %s", e.what());
// disable filter
tf->isDisabled = true;
tf->edgeyolo.reset();
tf->onnxruntimemodel.reset();
return;
}
}

// update threshold on edgeyolo
if (tf->edgeyolo) {
tf->edgeyolo->setBBoxConfThresh(tf->conf_threshold);
if (tf->onnxruntimemodel) {
tf->onnxruntimemodel->setBBoxConfThresh(tf->conf_threshold);
}

if (reinitialize) {
Expand Down Expand Up @@ -746,7 +766,7 @@ void detect_filter_video_tick(void *data, float seconds)

struct detect_filter *tf = reinterpret_cast<detect_filter *>(data);

if (tf->isDisabled || !tf->edgeyolo) {
if (tf->isDisabled || !tf->onnxruntimemodel) {
return;
}

Expand Down Expand Up @@ -775,18 +795,16 @@ void detect_filter_video_tick(void *data, float seconds)
cropRect = cv::Rect(tf->crop_left, tf->crop_top,
imageBGRA.cols - tf->crop_left - tf->crop_right,
imageBGRA.rows - tf->crop_top - tf->crop_bottom);
obs_log(LOG_INFO, "Crop: %d %d %d %d", cropRect.x, cropRect.y, cropRect.width,
cropRect.height);
cv::cvtColor(imageBGRA(cropRect), inferenceFrame, cv::COLOR_BGRA2BGR);
} else {
cv::cvtColor(imageBGRA, inferenceFrame, cv::COLOR_BGRA2BGR);
}

std::vector<edgeyolo_cpp::Object> objects;
std::vector<Object> objects;

try {
std::unique_lock<std::mutex> lock(tf->modelMutex);
objects = tf->edgeyolo->inference(inferenceFrame);
objects = tf->onnxruntimemodel->inference(inferenceFrame);
} catch (const Ort::Exception &e) {
obs_log(LOG_ERROR, "ONNXRuntime Exception: %s", e.what());
} catch (const std::exception &e) {
Expand All @@ -795,7 +813,7 @@ void detect_filter_video_tick(void *data, float seconds)

if (tf->crop_enabled) {
// translate the detected objects to the original frame
for (edgeyolo_cpp::Object &obj : objects) {
for (Object &obj : objects) {
obj.rect.x += (float)cropRect.x;
obj.rect.y += (float)cropRect.y;
}
Expand Down Expand Up @@ -824,8 +842,8 @@ void detect_filter_video_tick(void *data, float seconds)
}

if (tf->objectCategory != -1) {
std::vector<edgeyolo_cpp::Object> filtered_objects;
for (const edgeyolo_cpp::Object &obj : objects) {
std::vector<Object> filtered_objects;
for (const Object &obj : objects) {
if (obj.label == tf->objectCategory) {
filtered_objects.push_back(obj);
}
Expand All @@ -838,18 +856,17 @@ void detect_filter_video_tick(void *data, float seconds)
}

if (!tf->showUnseenObjects) {
objects.erase(std::remove_if(objects.begin(), objects.end(),
[](const edgeyolo_cpp::Object &obj) {
return obj.unseenFrames > 0;
}),
objects.end());
objects.erase(
std::remove_if(objects.begin(), objects.end(),
[](const Object &obj) { return obj.unseenFrames > 0; }),
objects.end());
}

if (!tf->saveDetectionsPath.empty()) {
std::ofstream detectionsFile(tf->saveDetectionsPath);
if (detectionsFile.is_open()) {
nlohmann::json j;
for (const edgeyolo_cpp::Object &obj : objects) {
for (const Object &obj : objects) {
nlohmann::json obj_json;
obj_json["label"] = obj.label;
obj_json["confidence"] = obj.prob;
Expand Down Expand Up @@ -877,11 +894,11 @@ void detect_filter_video_tick(void *data, float seconds)
drawDashedRectangle(frame, cropRect, cv::Scalar(0, 255, 0), 5, 8, 15);
}
if (tf->preview && objects.size() > 0) {
edgeyolo_cpp::utils::draw_objects(frame, objects, tf->classNames);
draw_objects(frame, objects, tf->classNames);
}
if (tf->maskingEnabled) {
cv::Mat mask = cv::Mat::zeros(frame.size(), CV_8UC1);
for (const edgeyolo_cpp::Object &obj : objects) {
for (const Object &obj : objects) {
cv::rectangle(mask, obj.rect, cv::Scalar(255), -1);
}
std::lock_guard<std::mutex> lock(tf->outputLock);
Expand All @@ -906,7 +923,7 @@ void detect_filter_video_tick(void *data, float seconds)
// get the bounding box of all objects
if (objects.size() > 0) {
boundingBox = objects[0].rect;
for (const edgeyolo_cpp::Object &obj : objects) {
for (const Object &obj : objects) {
boundingBox |= obj.rect;
}
}
Expand Down Expand Up @@ -967,7 +984,7 @@ void detect_filter_video_render(void *data, gs_effect_t *_effect)

struct detect_filter *tf = reinterpret_cast<detect_filter *>(data);

if (tf->isDisabled || !tf->edgeyolo) {
if (tf->isDisabled || !tf->onnxruntimemodel) {
if (tf->source) {
obs_source_skip_video_filter(tf->source);
}
Expand Down
28 changes: 0 additions & 28 deletions src/edgeyolo/coco_names.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,33 +30,5 @@ static const std::vector<std::string> COCO_CLASSES = {
"refrigerator", "book", "clock",
"vase", "scissors", "teddy bear",
"hair drier", "toothbrush"};
const float color_list[80][3] = {
{0.000f, 0.447f, 0.741f}, {0.850f, 0.325f, 0.098f}, {0.929f, 0.694f, 0.125f},
{0.494f, 0.184f, 0.556f}, {0.466f, 0.674f, 0.188f}, {0.301f, 0.745f, 0.933f},
{0.635f, 0.078f, 0.184f}, {0.300f, 0.300f, 0.300f}, {0.600f, 0.600f, 0.600f},
{1.000f, 0.000f, 0.000f}, {1.000f, 0.500f, 0.000f}, {0.749f, 0.749f, 0.000f},
{0.000f, 1.000f, 0.000f}, {0.000f, 0.000f, 1.000f}, {0.667f, 0.000f, 1.000f},
{0.333f, 0.333f, 0.000f}, {0.333f, 0.667f, 0.000f}, {0.333f, 1.000f, 0.000f},
{0.667f, 0.333f, 0.000f}, {0.667f, 0.667f, 0.000f}, {0.667f, 1.000f, 0.000f},
{1.000f, 0.333f, 0.000f}, {1.000f, 0.667f, 0.000f}, {1.000f, 1.000f, 0.000f},
{0.000f, 0.333f, 0.500f}, {0.000f, 0.667f, 0.500f}, {0.000f, 1.000f, 0.500f},
{0.333f, 0.000f, 0.500f}, {0.333f, 0.333f, 0.500f}, {0.333f, 0.667f, 0.500f},
{0.333f, 1.000f, 0.500f}, {0.667f, 0.000f, 0.500f}, {0.667f, 0.333f, 0.500f},
{0.667f, 0.667f, 0.500f}, {0.667f, 1.000f, 0.500f}, {1.000f, 0.000f, 0.500f},
{1.000f, 0.333f, 0.500f}, {1.000f, 0.667f, 0.500f}, {1.000f, 1.000f, 0.500f},
{0.000f, 0.333f, 1.000f}, {0.000f, 0.667f, 1.000f}, {0.000f, 1.000f, 1.000f},
{0.333f, 0.000f, 1.000f}, {0.333f, 0.333f, 1.000f}, {0.333f, 0.667f, 1.000f},
{0.333f, 1.000f, 1.000f}, {0.667f, 0.000f, 1.000f}, {0.667f, 0.333f, 1.000f},
{0.667f, 0.667f, 1.000f}, {0.667f, 1.000f, 1.000f}, {1.000f, 0.000f, 1.000f},
{1.000f, 0.333f, 1.000f}, {1.000f, 0.667f, 1.000f}, {0.333f, 0.000f, 0.000f},
{0.500f, 0.000f, 0.000f}, {0.667f, 0.000f, 0.000f}, {0.833f, 0.000f, 0.000f},
{1.000f, 0.000f, 0.000f}, {0.000f, 0.167f, 0.000f}, {0.000f, 0.333f, 0.000f},
{0.000f, 0.500f, 0.000f}, {0.000f, 0.667f, 0.000f}, {0.000f, 0.833f, 0.000f},
{0.000f, 1.000f, 0.000f}, {0.000f, 0.000f, 0.167f}, {0.000f, 0.000f, 0.333f},
{0.000f, 0.000f, 0.500f}, {0.000f, 0.000f, 0.667f}, {0.000f, 0.000f, 0.833f},
{0.000f, 0.000f, 1.000f}, {0.000f, 0.000f, 0.000f}, {0.143f, 0.143f, 0.143f},
{0.286f, 0.286f, 0.286f}, {0.429f, 0.429f, 0.429f}, {0.571f, 0.571f, 0.571f},
{0.714f, 0.714f, 0.714f}, {0.857f, 0.857f, 0.857f}, {0.000f, 0.447f, 0.741f},
{0.314f, 0.717f, 0.741f}, {0.50f, 0.5f, 0.0f}};
} // namespace edgeyolo_cpp
#endif
Loading

0 comments on commit 6ccfecb

Please sign in to comment.