Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
CaoWGG committed Nov 2, 2019
1 parent 9c0c59d commit 66dc0b2
Show file tree
Hide file tree
Showing 14 changed files with 147 additions and 40 deletions.
Binary file removed 000138.jpg
Binary file not shown.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# TensorRT-CenterNet
### demo (GT 1070)
![image](img/show.gif)
* ![image](img/show.gif)
* ![image](img/show2.png)

### Performance
| backbone | input_size | GPU | mode | inference Time |
Expand Down
18 changes: 9 additions & 9 deletions example/runDet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@

int main(int argc, const char** argv){
optparse::OptionParser parser;
parser.add_option("-i", "--input-engine-file").dest("engineFile").set_default("model/ctdet_helmet.engine")
parser.add_option("-i", "--input-engine-file").dest("engineFile").set_default("model/centerface.engine")
.help("the path of onnx file");
parser.add_option("-img").dest("imgFile").set_default("000138.jpg");
parser.add_option("-img").dest("imgFile").set_default("test.jpg");
parser.add_option("-cap").dest("capFile").set_default("test.h264");
optparse::Values options = parser.parse_args(argc, argv);
if(options["engineFile"].size() == 0){
std::cout << "no file input" << std::endl;
exit(-1);
}

cv::RNG rng(57);
cv::RNG rng(244);
std::vector<cv::Scalar> color;
for(int i=0; i<ctdet::classNum;++i)color.push_back(randomColor(rng));

Expand All @@ -36,7 +36,7 @@ int main(int argc, const char** argv){
if(options["imgFile"].size()>0)
{
img = cv::imread(options["imgFile"]);
auto inputData = prepareImage(img);
auto inputData = prepareImage(img,net.forwardFace);

net.doInference(inputData.data(), outputData.get());
net.printTime();
Expand All @@ -46,9 +46,9 @@ int main(int argc, const char** argv){
result.resize(num_det);
memcpy(result.data(), &outputData[1], num_det * sizeof(Detection));

postProcess(result,img);
postProcess(result,img,net.forwardFace);

drawImg(result,img,color);
drawImg(result,img,color,net.forwardFace);

cv::imshow("result",img);
cv::waitKey(0);
Expand All @@ -58,7 +58,7 @@ int main(int argc, const char** argv){
cv::VideoCapture cap(options["capFile"]);
while (cap.read(img))
{
auto inputData = prepareImage(img);
auto inputData = prepareImage(img,net.forwardFace);

net.doInference(inputData.data(), outputData.get());
net.printTime();
Expand All @@ -68,9 +68,9 @@ int main(int argc, const char** argv){
result.resize(num_det);
memcpy(result.data(), &outputData[1], num_det * sizeof(Detection));

postProcess(result,img);
postProcess(result,img,net.forwardFace);

drawImg(result,img,color);
drawImg(result,img,color,net.forwardFace);

cv::imshow("result",img);
if((cv::waitKey(1)& 0xff) == 27){
Expand Down
Binary file added img/show2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
20 changes: 15 additions & 5 deletions include/ctdetConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,26 @@

namespace ctdet{

constexpr static int classNum = 2 ;


constexpr static float visThresh = 0.3;

constexpr static int inputSize = 512 ;
constexpr static int channel = 3 ;
constexpr static int ouputSize = 128 ;
constexpr static int ouputSize = inputSize/4 ;
constexpr static int kernelSize = 4 ;

constexpr static float mean[]= {0.485,0.456,0.406};
constexpr static float std[] = {0.229,0.224,0.225};
constexpr static char *className[]= {(char*)"person",(char*)"helmet"};

//cthelmet
// constexpr static int classNum = 2 ;
// constexpr static float mean[]= {0.485,0.456,0.406};
// constexpr static float std[] = {0.229,0.224,0.225};
// constexpr static char *className[]= {(char*)"person",(char*)"helmet"};
//ctface
constexpr static int classNum = 1 ;
constexpr static float mean[]= {0,0,0};
constexpr static float std[] = {1,1,1};
constexpr static char *className[]= {(char*)"face"};

}
#endif //CTDET_TRT_CTDETCONFIG_H
3 changes: 2 additions & 1 deletion include/ctdetLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
#include <utils.h>
extern "C" void CTdetforward_gpu(const float *hm, const float *reg,const float *wh ,float *output,
const int w,const int h,const int classes,const int kernerl_size,const float visthresh );

extern "C" void CTfaceforward_gpu(const float *hm, const float *wh,const float *reg,const float* landmarks,float *output,
const int w,const int h,const int classes,const int kernerl_size, const float visthresh );
#endif //CTDET_TRT_CTDETLAYER_H
3 changes: 1 addition & 2 deletions include/ctdetNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ namespace ctdet
};

int64_t outputBufferSize;

bool forwardFace;
private:

void InitEngine();
Expand All @@ -65,7 +65,6 @@ namespace ctdet

int runIters;
Profiler mProfiler;

};

}
Expand Down
13 changes: 10 additions & 3 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,23 @@ struct Box{
float x2;
float y2;
};
struct landmarks{
float x;
float y;
};
struct Detection{
//x1 y1 x2 y2
Box bbox;
//float objectness;
landmarks marks[5];
int classId;
float prob;
};


extern dim3 cudaGridSize(uint n);
extern std::vector<float> prepareImage(cv::Mat& img);
extern void postProcess(std::vector<Detection> & result,const cv::Mat& img);
extern void drawImg(const std::vector<Detection> & result,cv::Mat& img,const std::vector<cv::Scalar>& color );
extern std::vector<float> prepareImage(cv::Mat& img, const bool& forwardFace);
extern void postProcess(std::vector<Detection> & result,const cv::Mat& img, const bool& forwardFace);
extern void drawImg(const std::vector<Detection> & result,cv::Mat& img,const std::vector<cv::Scalar>& color, const bool& forwardFace);
extern cv::Scalar randomColor(cv::RNG& rng);
#endif //CTDET_TRT_UTILS_H
Binary file added model/centerface.engine
Binary file not shown.
Binary file added model/centerface.onnx
Binary file not shown.
55 changes: 55 additions & 0 deletions src/ctdetLayer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,64 @@ __global__ void CTdetforward_kernel(const float *hm, const float *reg,const floa
}
}

__global__ void CTfaceforward_kernel(const float *hm, const float *wh,const float *reg,const float* landmarks,
float *output,const int w,const int h,const int classes,const int kernerl_size,const float visthresh ) {
int idx = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
if (idx >= w*h) return;
int padding = kernerl_size/2;
int offset = - padding /2;
int stride = w*h;
int grid_x = idx % w ;
int grid_y = idx / w ;
int cls,l,m,mark_id;
float c_x,c_y,scale_w,scale_h;
for (cls = 0; cls < classes; ++cls )
{
int objIndex = stride * cls + idx;
float objProb = hm[objIndex];
float max=-1;
int max_index =0;
for(l=0 ;l < kernerl_size ; ++l)
for(m=0 ; m < kernerl_size ; ++m){
int cur_x = offset + l + grid_x;
int cur_y = offset + m + grid_y;
int cur_index = cur_y * w + cur_x + stride*cls;
int valid = (cur_x>=0 && cur_x < w && cur_y >=0 && cur_y <h );
float val = (valid !=0 ) ? hm[cur_index]: -1;
max_index = (val > max) ? cur_index : max_index;
max = (val > max ) ? val: max ;
}
//printf("%f\n",objProb);
if((max_index == objIndex) && (objProb > visthresh)){

int resCount = (int)atomicAdd(output,1);
//printf("%d",resCount);
char* data = (char * )output + sizeof(float) + resCount*sizeof(Detection);
Detection* det = (Detection*)(data);
c_x = (grid_x + reg[idx+stride] + 0.5)*4 ; c_y = (grid_y + reg[idx] + 0.5) * 4;
scale_w = expf(wh[idx+stride]) * 4 ; scale_h = expf(wh[idx]) * 4;
det->bbox.x1 = c_x - scale_w/2;
det->bbox.y1 = c_y - scale_h/2 ;
det->bbox.x2 = c_x + scale_w/2;
det->bbox.y2 = c_y + scale_h/2;
det->prob = objProb;
det->classId = cls;
for(mark_id=0 ; mark_id < 5 ; mark_id ++){
det->marks[mark_id].x = det->bbox.x1 + landmarks[idx + (mark_id+1)*stride]*scale_w;
det->marks[mark_id].y = det->bbox.y1 + landmarks[idx + (mark_id)*stride]*scale_h;
}
}
}
}

void CTdetforward_gpu(const float *hm, const float *reg,const float *wh ,float *output,
const int w,const int h,const int classes,const int kernerl_size, const float visthresh ){
uint num = w * h;
CTdetforward_kernel<<<cudaGridSize(num),BLOCK>>>(hm,reg,wh,output,w,h,classes,kernerl_size,visthresh);
}

void CTfaceforward_gpu(const float *hm, const float *wh,const float *reg,const float* landmarks,float *output,
const int w,const int h,const int classes,const int kernerl_size, const float visthresh ){
uint num = w * h;
CTfaceforward_kernel<<<cudaGridSize(num),BLOCK>>>(hm,wh,reg,landmarks,output,w,h,classes,kernerl_size,visthresh);
}
16 changes: 12 additions & 4 deletions src/ctdetNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace ctdet
{

ctdetNet::ctdetNet(const std::string &onnxFile, const std::string &calibFile,
ctdet::RUN_MODE mode):mContext(nullptr),mEngine(nullptr),mRunTime(nullptr),
ctdet::RUN_MODE mode):forwardFace(false),mContext(nullptr),mEngine(nullptr),mRunTime(nullptr),
runMode(mode),runIters(0)
{

Expand Down Expand Up @@ -83,7 +83,7 @@ namespace ctdet
}

ctdetNet::ctdetNet(const std::string &engineFile)
:mContext(nullptr),mEngine(nullptr),mRunTime(nullptr),runMode(RUN_MODE::FLOAT32),runIters(0)
:forwardFace(false),mContext(nullptr),mEngine(nullptr),mRunTime(nullptr),runMode(RUN_MODE::FLOAT32),runIters(0)
{
using namespace std;
fstream file;
Expand Down Expand Up @@ -116,7 +116,9 @@ namespace ctdet
assert(mContext != nullptr);
mContext->setProfiler(&mProfiler);
int nbBindings = mEngine->getNbBindings();
assert(nbBindings == 4);

if (nbBindings > 4) forwardFace= true;

mCudaBuffers.resize(nbBindings);
mBindBufferSizes.resize(nbBindings);
int64_t totalSize = 0;
Expand All @@ -140,9 +142,15 @@ namespace ctdet
CUDA_CHECK(cudaMemcpyAsync(mCudaBuffers[inputIndex], inputData, mBindBufferSizes[inputIndex], cudaMemcpyHostToDevice, mCudaStream));
mContext->execute(batchSize, &mCudaBuffers[inputIndex]);
CUDA_CHECK(cudaMemset(cudaOutputBuffer, 0, sizeof(float)));
CTdetforward_gpu(static_cast<const float *>(mCudaBuffers[1]),static_cast<const float *>(mCudaBuffers[2]),
if (forwardFace){
CTfaceforward_gpu(static_cast<const float *>(mCudaBuffers[1]),static_cast<const float *>(mCudaBuffers[2]),
static_cast<const float *>(mCudaBuffers[3]),static_cast<const float *>(mCudaBuffers[4]),static_cast<float *>(cudaOutputBuffer),
ouputSize,ouputSize,classNum,kernelSize,visThresh);
} else{
CTdetforward_gpu(static_cast<const float *>(mCudaBuffers[1]),static_cast<const float *>(mCudaBuffers[2]),
static_cast<const float *>(mCudaBuffers[3]),static_cast<float *>(cudaOutputBuffer),
ouputSize,ouputSize,classNum,kernelSize,visThresh);
}

CUDA_CHECK(cudaMemcpyAsync(outputData, cudaOutputBuffer, outputBufferSize, cudaMemcpyDeviceToHost, mCudaStream));

Expand Down
56 changes: 41 additions & 15 deletions src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dim3 cudaGridSize(uint n)
return d;
}

std::vector<float> prepareImage(cv::Mat& img)
std::vector<float> prepareImage(cv::Mat& img, const bool& forwardFace)
{
using namespace cv;

Expand All @@ -37,7 +37,10 @@ std::vector<float> prepareImage(cv::Mat& img)
resized.copyTo(cropped(rect));

cv::Mat img_float;
cropped.convertTo(img_float, CV_32FC3, 1/255.0);
if(forwardFace)
cropped.convertTo(img_float, CV_32FC3, 1.);
else
cropped.convertTo(img_float, CV_32FC3,1./255.);

//HWC TO CHW
vector<Mat> input_channels(channel);
Expand All @@ -55,9 +58,10 @@ std::vector<float> prepareImage(cv::Mat& img)
return result;
}

void postProcess(std::vector<Detection> & result,const cv::Mat& img)
void postProcess(std::vector<Detection> & result,const cv::Mat& img, const bool& forwardFace)
{
using namespace cv;
int mark;
int inputSize = ctdet::inputSize;
float scale = min(float(inputSize)/img.cols,float(inputSize)/img.rows);
float dx = (inputSize - scale * img.cols) / 2;
Expand All @@ -76,6 +80,19 @@ void postProcess(std::vector<Detection> & result,const cv::Mat& img)
item.bbox.y1 = y1 ;
item.bbox.x2 = x2 ;
item.bbox.y2 = y2 ;
if(forwardFace){
float x,y;
for(mark=0;mark<5; ++mark ){
x = (item.marks[mark].x - dx) / scale ;
y = (item.marks[mark].y - dy) / scale ;
x = (x > 0 ) ? x : 0 ;
y = (y > 0 ) ? y : 0 ;
x = (x < img.cols ) ? x : img.cols - 1 ;
y = (y < img.rows ) ? y : img.rows - 1 ;
item.marks[mark].x = x ;
item.marks[mark].y = y ;
}
}
}
}

Expand All @@ -84,24 +101,33 @@ cv::Scalar randomColor(cv::RNG& rng) {
return cv::Scalar(icolor & 255, (icolor >> 8) & 255, (icolor >> 16) & 255);
}

void drawImg(const std::vector<Detection> & result,cv::Mat& img,const std::vector<cv::Scalar>& color )
void drawImg(const std::vector<Detection> & result,cv::Mat& img,const std::vector<cv::Scalar>& color, const bool& forwardFace)
{
int mark;
int box_think = (img.rows+img.cols) * .001 ;
float label_scale = img.rows * 0.0009;
int base_line ;
for (const auto &item : result) {
std::string label;
std::stringstream stream;
stream << ctdet::className[item.classId] << " " << item.prob << std::endl;
std::getline(stream,label);
std::string label;
std::stringstream stream;
stream << ctdet::className[item.classId] << " " << item.prob << std::endl;
std::getline(stream,label);

auto size = cv::getTextSize(label,cv::FONT_HERSHEY_COMPLEX,label_scale,1,&base_line);
auto size = cv::getTextSize(label,cv::FONT_HERSHEY_COMPLEX,label_scale,1,&base_line);

cv::rectangle(img, cv::Point(item.bbox.x1,item.bbox.y1),
cv::Point(item.bbox.x2 ,item.bbox.y2),
color[item.classId], box_think, 8, 0);
if(!forwardFace){
cv::putText(img,label,
cv::Point(item.bbox.x2,item.bbox.y2 - size.height),
cv::FONT_HERSHEY_COMPLEX, label_scale , color[item.classId], box_think/3, 8, 0);
}
if(forwardFace)
{
for(mark=0;mark<5; ++mark )
cv::circle(img, cv::Point(item.marks[mark].x, item.marks[mark].y), 2, cv::Scalar(255, 255, 0), 2);
}

cv::rectangle(img, cv::Point(item.bbox.x1,item.bbox.y1),
cv::Point(item.bbox.x2 ,item.bbox.y2),
color[item.classId], box_think, 8, 0);
cv::putText(img,label,
cv::Point(item.bbox.x2,item.bbox.y2 - size.height),
cv::FONT_HERSHEY_COMPLEX, label_scale , color[item.classId], box_think/3, 8, 0);
}
}
Binary file added test.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 66dc0b2

Please sign in to comment.