Skip to content

Commit

Permalink
Compatibility whether yolov8 fuse or not
Browse files Browse the repository at this point in the history
  • Loading branch information
saladjay committed Sep 11, 2024
1 parent e254513 commit cc4d8ac
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions yolov8/src/block.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,31 @@ nvinfer1::IElementWiseLayer* convBnSiLU(nvinfer1::INetworkDefinition* network,
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor& input,
int ch, int k, int s, int p, std::string lname) {
nvinfer1::Weights bias_empty{nvinfer1::DataType::kFLOAT, nullptr, 0};
nvinfer1::IConvolutionLayer* conv =
network->addConvolutionNd(input, ch, nvinfer1::DimsHW{k, k}, weightMap[lname + ".conv.weight"], bias_empty);
nvinfer1::IConvolutionLayer* conv{nullptr};
std::string bias_name{lname + ".conv.bias"};
// Compatibility whether there is a bias or not
// fuse conv+nb into conv, the new conv layer will have bias, and bn will disappear
if(weightMap.find(bias_name) == weightMap.end()){
conv = network->addConvolutionNd(input, ch, nvinfer1::DimsHW{k, k}, weightMap[lname + ".conv.weight"], bias_empty);
}else{
conv = network->addConvolutionNd(input, ch, nvinfer1::DimsHW{k, k}, weightMap[lname + ".conv.weight"], weightMap[bias_name]);
}
assert(conv);
conv->setStrideNd(nvinfer1::DimsHW{s, s});
conv->setPaddingNd(nvinfer1::DimsHW{p, p});

nvinfer1::IScaleLayer* bn = addBatchNorm2d(network, weightMap, *conv->getOutput(0), lname + ".bn", 1e-3);
// Compatibility whether there is a bn or not
nvinfer1::ILayer *layer{nullptr};
if(weightMap.find(std::string{lname + ".bn.weight"}) != weightMap.end()){
layer = addBatchNorm2d(network, weightMap, *conv->getOutput(0), lname + ".bn", 1e-3);
}else{
layer = conv;
}
nvinfer1::IActivationLayer* sigmoid = network->addActivation(*layer->getOutput(0), nvinfer1::ActivationType::kSIGMOID);

nvinfer1::IActivationLayer* sigmoid = network->addActivation(*bn->getOutput(0), nvinfer1::ActivationType::kSIGMOID);
nvinfer1::IElementWiseLayer* ew =
network->addElementWise(*bn->getOutput(0), *sigmoid->getOutput(0), nvinfer1::ElementWiseOperation::kPROD);
network->addElementWise(*layer->getOutput(0), *sigmoid->getOutput(0), nvinfer1::ElementWiseOperation::kPROD);

assert(ew);
return ew;
}
Expand Down

1 comment on commit cc4d8ac

@saladjay
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

兼容fuse后的yolov8的模型,fuse后bn层合并进去conv,convBnSiLU种少了一个bn,conv多了一个bias

Please sign in to comment.