diff --git a/yolov8/src/block.cpp b/yolov8/src/block.cpp index 43a694d6..6c6fb020 100644 --- a/yolov8/src/block.cpp +++ b/yolov8/src/block.cpp @@ -80,17 +80,31 @@ nvinfer1::IElementWiseLayer* convBnSiLU(nvinfer1::INetworkDefinition* network, std::map weightMap, nvinfer1::ITensor& input, int ch, int k, int s, int p, std::string lname) { nvinfer1::Weights bias_empty{nvinfer1::DataType::kFLOAT, nullptr, 0}; - nvinfer1::IConvolutionLayer* conv = - network->addConvolutionNd(input, ch, nvinfer1::DimsHW{k, k}, weightMap[lname + ".conv.weight"], bias_empty); + nvinfer1::IConvolutionLayer* conv{nullptr}; + std::string bias_name{lname + ".conv.bias"}; + // Compatibility whether there is a bias or not + // fuse conv+nb into conv, the new conv layer will have bias, and bn will disappear + if(weightMap.find(bias_name) == weightMap.end()){ + conv = network->addConvolutionNd(input, ch, nvinfer1::DimsHW{k, k}, weightMap[lname + ".conv.weight"], bias_empty); + }else{ + conv = network->addConvolutionNd(input, ch, nvinfer1::DimsHW{k, k}, weightMap[lname + ".conv.weight"], weightMap[bias_name]); + } assert(conv); conv->setStrideNd(nvinfer1::DimsHW{s, s}); conv->setPaddingNd(nvinfer1::DimsHW{p, p}); - nvinfer1::IScaleLayer* bn = addBatchNorm2d(network, weightMap, *conv->getOutput(0), lname + ".bn", 1e-3); + // Compatibility whether there is a bn or not + nvinfer1::ILayer *layer{nullptr}; + if(weightMap.find(std::string{lname + ".bn.weight"}) != weightMap.end()){ + layer = addBatchNorm2d(network, weightMap, *conv->getOutput(0), lname + ".bn", 1e-3); + }else{ + layer = conv; + } + nvinfer1::IActivationLayer* sigmoid = network->addActivation(*layer->getOutput(0), nvinfer1::ActivationType::kSIGMOID); - nvinfer1::IActivationLayer* sigmoid = network->addActivation(*bn->getOutput(0), nvinfer1::ActivationType::kSIGMOID); nvinfer1::IElementWiseLayer* ew = - network->addElementWise(*bn->getOutput(0), *sigmoid->getOutput(0), nvinfer1::ElementWiseOperation::kPROD); + network->addElementWise(*layer->getOutput(0), *sigmoid->getOutput(0), nvinfer1::ElementWiseOperation::kPROD); + assert(ew); return ew; }