diff --git a/ghostnet/README.md b/ghostnet/README.md new file mode 100644 index 00000000..40e9d54e --- /dev/null +++ b/ghostnet/README.md @@ -0,0 +1,82 @@ +# GhostNet + +GhostNetv1 architecture is from the paper "GhostNet: More Features from Cheap Operations" [(https://arxiv.org/abs/1911.11907)](https://arxiv.org/abs/1911.11907). +GhostNetv2 architecture is from the paper "GhostNetV2: Enhance Cheap Operation with Long-Range Attention" [(https://arxiv.org/abs/2211.12905)](https://arxiv.org/abs/2211.12905). + +For the PyTorch implementations, you can refer to [huawei-noah/ghostnet](https://github.com/huawei-noah/ghostnet). + +Both versions use the following techniques in their TensorRT implementations: + +- **BatchNorm** layer is implemented by TensorRT's **Scale** layer. +- **Ghost Modules** are used to generate more features from cheap operations, as described in the paper. +- Replacing `IPoolingLayer` with `IReduceLayer` in TensorRT for Global Average Pooling. The `IReduceLayer` allows you to perform reduction operations (such as sum, average, max) over specified dimensions without being constrained by the kernel size limitations of pooling layers. + +## Project Structure + +```plaintext +ghostnet +│ +├── ghostnetv1 +│ ├── CMakeLists.txt +│ ├── gen_wts.py +│ ├── ghostnetv1.cpp +│ └── logging.h +│ +├── ghostnetv2 +│ ├── CMakeLists.txt +│ ├── gen_wts.py +│ ├── ghostnetv2.cpp +│ └── logging.h +│ +└── README.md +``` + +## Steps to use GhostNet in TensorRT + +### 1. Generate `.wts` files for both GhostNetv1 and GhostNetv2 + +```bash +# For ghostnetv1 +python ghostnetv1/gen_wts.py + +# For ghostnetv2 +python ghostnetv2/gen_wts.py +``` + +### 2. Build the project + +```bash +cd tensorrtx/ghostnet +mkdir build +cd build +cmake .. +make +``` + +### 3. Serialize the models to engine files + +Use the following commands to serialize the PyTorch models into TensorRT engine files (`ghostnetv1.engine` and `ghostnetv2.engine`): + +```bash +# For ghostnetv1 +sudo ./ghostnetv1 -s + +# For ghostnetv2 +sudo ./ghostnetv2 -s +``` + +### 4. Run inference using the engine files + +Once the engine files are generated, you can run inference with the following commands: + +```bash +# For ghostnetv1 +sudo ./ghostnetv1 -d + +# For ghostnetv2 +sudo ./ghostnetv2 -d +``` + +### 5. Verify output + +Compare the output with the PyTorch implementation from [huawei-noah/ghostnet](https://github.com/huawei-noah/ghostnet) to ensure that the TensorRT results are consistent with the PyTorch model. diff --git a/ghostnet/ghostnetv1/CMakeLists.txt b/ghostnet/ghostnetv1/CMakeLists.txt new file mode 100644 index 00000000..ee62ad2c --- /dev/null +++ b/ghostnet/ghostnetv1/CMakeLists.txt @@ -0,0 +1,24 @@ +cmake_minimum_required(VERSION 2.6) + +project(ghostnetv1) + +add_definitions(-std=c++11) + +option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_BUILD_TYPE Debug) + +include_directories(${PROJECT_SOURCE_DIR}/include) +# include and link dirs of cuda and tensorrt, you need adapt them if yours are different +# cuda +include_directories(/usr/local/cuda/include) +link_directories(/usr/local/cuda/lib64) +# tensorrt +include_directories(/usr/include/x86_64-linux-gnu/) +link_directories(/usr/lib/x86_64-linux-gnu/) + +add_executable(ghostnetv1 ${PROJECT_SOURCE_DIR}/ghostnetv1.cpp) +target_link_libraries(ghostnetv1 nvinfer) +target_link_libraries(ghostnetv1 cudart) + +add_definitions(-O2 -pthread) diff --git a/ghostnet/ghostnetv1/gen_wts.py b/ghostnet/ghostnetv1/gen_wts.py new file mode 100644 index 00000000..b3029329 --- /dev/null +++ b/ghostnet/ghostnetv1/gen_wts.py @@ -0,0 +1,292 @@ +""" +Creates a GhostNet Model as defined in: +GhostNet: More Features from Cheap Operations By Kai Han, Yunhe Wang, Qi Tian, Jianyuan Guo, Chunjing Xu, Chang Xu. +https://arxiv.org/abs/1911.11907 +Modified from https://github.com/d-li14/mobilenetv3.pytorch and https://github.com/rwightman/pytorch-image-models +""" +import torch +import torch.nn as nn +import torch.onnx +import struct +import torch +import torch.nn.functional as F +import math + + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def hard_sigmoid(x, inplace: bool = False): + if inplace: + return x.add_(3.).clamp_(0., 6.).div_(6.) + else: + return F.relu6(x + 3.) / 6. + + +class SqueezeExcite(nn.Module): + def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, + act_layer=nn.ReLU, gate_fn=hard_sigmoid, divisor=4, **_): + super(SqueezeExcite, self).__init__() + self.gate_fn = gate_fn + reduced_chs = _make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) + self.act1 = act_layer(inplace=True) + self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) + + def forward(self, x): + x_se = self.avg_pool(x) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + x = x * self.gate_fn(x_se) + return x + + +class ConvBnAct(nn.Module): + def __init__(self, in_chs, out_chs, kernel_size, + stride=1, act_layer=nn.ReLU): + super(ConvBnAct, self).__init__() + self.conv = nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size//2, bias=False) + self.bn1 = nn.BatchNorm2d(out_chs) + self.act1 = act_layer(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn1(x) + x = self.act1(x) + return x + + +class GhostModule(nn.Module): + def __init__(self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True): + super(GhostModule, self).__init__() + self.oup = oup + init_channels = math.ceil(oup / ratio) + new_channels = init_channels*(ratio-1) + + self.primary_conv = nn.Sequential( + nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False), + nn.BatchNorm2d(init_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + + self.cheap_operation = nn.Sequential( + nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False), + nn.BatchNorm2d(new_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + + def forward(self, x): + x1 = self.primary_conv(x) + x2 = self.cheap_operation(x1) + out = torch.cat([x1, x2], dim=1) + return out[:, :self.oup, :, :] + + +class GhostBottleneck(nn.Module): + """ Ghost bottleneck w/ optional SE""" + + def __init__(self, in_chs, mid_chs, out_chs, dw_kernel_size=3, + stride=1, act_layer=nn.ReLU, se_ratio=0.): + super(GhostBottleneck, self).__init__() + has_se = se_ratio is not None and se_ratio > 0. + self.stride = stride + + # Point-wise expansion + self.ghost1 = GhostModule(in_chs, mid_chs, relu=True) + + # Depth-wise convolution + if self.stride > 1: + self.conv_dw = nn.Conv2d(mid_chs, mid_chs, dw_kernel_size, stride=stride, + padding=(dw_kernel_size-1)//2, groups=mid_chs, bias=False) + self.bn_dw = nn.BatchNorm2d(mid_chs) + + # Squeeze-and-excitation + if has_se: + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio) + else: + self.se = None + + # Point-wise linear projection + self.ghost2 = GhostModule(mid_chs, out_chs, relu=False) + + # shortcut + if (in_chs == out_chs and self.stride == 1): + self.shortcut = nn.Sequential() + else: + self.shortcut = nn.Sequential( + nn.Conv2d(in_chs, in_chs, dw_kernel_size, stride=stride, + padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False), + nn.BatchNorm2d(in_chs), + nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(out_chs), + ) + + def forward(self, x): + residual = x + + # 1st ghost bottleneck + x = self.ghost1(x) + + # Depth-wise convolution + if self.stride > 1: + x = self.conv_dw(x) + x = self.bn_dw(x) + + # Squeeze-and-excitation + if self.se is not None: + x = self.se(x) + + # 2nd ghost bottleneck + x = self.ghost2(x) + + x += self.shortcut(residual) + return x + + +class GhostNet(nn.Module): + def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2): + super(GhostNet, self).__init__() + # setting of inverted residual blocks + self.cfgs = cfgs + self.dropout = dropout + + # building first layer + output_channel = _make_divisible(16 * width, 4) + self.conv_stem = nn.Conv2d(3, output_channel, 3, 2, 1, bias=False) + self.bn1 = nn.BatchNorm2d(output_channel) + self.act1 = nn.ReLU(inplace=True) + input_channel = output_channel + + # building inverted residual blocks + stages = [] + block = GhostBottleneck + for cfg in self.cfgs: + layers = [] + for k, exp_size, c, se_ratio, s in cfg: + output_channel = _make_divisible(c * width, 4) + hidden_channel = _make_divisible(exp_size * width, 4) + layers.append(block(input_channel, hidden_channel, output_channel, k, s, + se_ratio=se_ratio)) + input_channel = output_channel + stages.append(nn.Sequential(*layers)) + + output_channel = _make_divisible(exp_size * width, 4) + stages.append(nn.Sequential(ConvBnAct(input_channel, output_channel, 1))) + input_channel = output_channel + + self.blocks = nn.Sequential(*stages) + + # building last several layers + output_channel = 1280 + self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) + self.conv_head = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=True) + self.act2 = nn.ReLU(inplace=True) + self.classifier = nn.Linear(output_channel, num_classes) + + def forward(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + x = x.view(x.size(0), -1) + if self.dropout > 0.: + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.classifier(x) + return x + + +def ghostnet(**kwargs): + """ + Constructs a GhostNet model + """ + cfgs = [ + # k, t, c, SE, s + # stage1 + [[3, 16, 16, 0, 1]], + # stage2 + [[3, 48, 24, 0, 2]], + [[3, 72, 24, 0, 1]], + # stage3 + [[5, 72, 40, 0.25, 2]], + [[5, 120, 40, 0.25, 1]], + # stage4 + [[3, 240, 80, 0, 2]], + [[3, 200, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 480, 112, 0.25, 1], + [3, 672, 112, 0.25, 1]], + # stage5 + [[5, 672, 160, 0.25, 2]], + [[5, 960, 160, 0, 1], + [5, 960, 160, 0.25, 1], + [5, 960, 160, 0, 1], + [5, 960, 160, 0.25, 1]] + ] + return GhostNet(cfgs, **kwargs) + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + + +# Function to export weights in the specified format +def export_weight(model): + f = open("ghostnetv1.weights", 'w') + f.write("{}\n".format(len(model.state_dict().keys()))) + + # Convert weights to hexadecimal format + for k, v in model.state_dict().items(): + print('exporting ... {}: {}'.format(k, v.shape)) + + # Reshape the weights to 1D + vr = v.reshape(-1).cpu().numpy() + f.write("{} {}".format(k, len(vr))) + for vv in vr: + f.write(" ") + f.write(struct.pack(">f", float(vv)).hex()) + f.write("\n") + + f.close() + + +# Function to evaluate the model (optional) +def eval_model(input, model): + output = model(input) + print("------from inference------") + print(input) + print(output) + + +if __name__ == "__main__": + setup_seed(1) + + model = ghostnet(num_classes=1000, width=1.0, dropout=0.2) + + model.eval() + + input = torch.full((32, 3, 320, 256), 10.0) + + export_weight(model) + + eval_model(input, model) diff --git a/ghostnet/ghostnetv1/ghostnetv1.cpp b/ghostnet/ghostnetv1/ghostnetv1.cpp new file mode 100644 index 00000000..fc508aa7 --- /dev/null +++ b/ghostnet/ghostnetv1/ghostnetv1.cpp @@ -0,0 +1,516 @@ +#include +#include +#include +#include +#include +#include +#include +#include "NvInfer.h" +#include "cuda_runtime_api.h" +#include "logging.h" + +using namespace std; + +#define CHECK(status) \ + do { \ + auto ret = (status); \ + if (ret != 0) { \ + std::cerr << "Cuda failure: " << ret << std::endl; \ + abort(); \ + } \ + } while (0) + +// stuff we know about the network and the input/output blobs +static const int INPUT_H = 256; +static const int INPUT_W = 320; +static const int OUTPUT_SIZE = 1000; +static const int batchSize = 32; + +const char* INPUT_BLOB_NAME = "data"; +const char* OUTPUT_BLOB_NAME = "prob"; +using namespace nvinfer1; + +static Logger gLogger; + +// Load weights from files shared with TensorRT samples. +// TensorRT weight files have a simple space delimited format: +// [type] [size] +std::map loadWeights(const std::string file) { + std::cout << "Loading weights: " << file << std::endl; + std::map weightMap; + + // Open weights file + std::ifstream input(file); + if (!input.is_open()) { + std::cerr << "Unable to load weight file." << std::endl; + exit(EXIT_FAILURE); + } + + // Read number of weight blobs + int32_t count; + input >> count; + if (count <= 0) { + std::cerr << "Invalid weight map file." << std::endl; + exit(EXIT_FAILURE); + } + + while (count--) { + Weights wt{DataType::kFLOAT, nullptr, 0}; + uint32_t size; + + // Read name and type of blob + std::string name; + input >> name >> std::dec >> size; + wt.type = DataType::kFLOAT; + + // Load blob + uint32_t* val = reinterpret_cast(malloc(sizeof(uint32_t) * size)); + for (uint32_t x = 0, y = size; x < y; ++x) { + input >> std::hex >> val[x]; + } + wt.values = val; + + wt.count = size; + weightMap[name] = wt; + } + + return weightMap; +} + +int _make_divisible(int v, int divisor, int min_value = -1) { + if (min_value == -1) { + min_value = divisor; + } + + int new_v = std::max(min_value, (v + divisor / 2) / divisor * divisor); + + if (new_v < static_cast(0.9 * v)) { + new_v += divisor; + } + + return new_v; +} + +ILayer* hardSigmoid(INetworkDefinition* network, ITensor& input) { + + IActivationLayer* scale_layer = network->addActivation(input, ActivationType::kHARD_SIGMOID); + + return scale_layer; +} + +IScaleLayer* addBatchNorm2d(INetworkDefinition* network, std::map& weightMap, ITensor& input, + std::string lname, float eps) { + float* gamma = (float*)weightMap[lname + ".weight"].values; + float* beta = (float*)weightMap[lname + ".bias"].values; + float* mean = (float*)weightMap[lname + ".running_mean"].values; + float* var = (float*)weightMap[lname + ".running_var"].values; + int len = weightMap[lname + ".running_var"].count; + + float* scval = reinterpret_cast(malloc(sizeof(float) * len)); + for (int i = 0; i < len; i++) { + scval[i] = gamma[i] / sqrt(var[i] + eps); + } + Weights scale{DataType::kFLOAT, scval, len}; + + float* shval = reinterpret_cast(malloc(sizeof(float) * len)); + for (int i = 0; i < len; i++) { + shval[i] = beta[i] - mean[i] * gamma[i] / sqrt(var[i] + eps); + } + Weights shift{DataType::kFLOAT, shval, len}; + + float* pval = reinterpret_cast(malloc(sizeof(float) * len)); + for (int i = 0; i < len; i++) { + pval[i] = 1.0; + } + Weights power{DataType::kFLOAT, pval, len}; + + weightMap[lname + ".scale"] = scale; + weightMap[lname + ".shift"] = shift; + weightMap[lname + ".power"] = power; + IScaleLayer* scale_1 = network->addScale(input, ScaleMode::kCHANNEL, shift, scale, power); + assert(scale_1); + return scale_1; +} + +IActivationLayer* convBnReluStem(INetworkDefinition* network, std::map& weightMap, ITensor& input, + int outch, std::string lname) { + Weights emptywts{DataType::kFLOAT, nullptr, 0}; + + IConvolutionLayer* conv1 = + network->addConvolutionNd(input, outch, DimsHW{3, 3}, weightMap[lname + ".weight"], emptywts); + assert(conv1); + conv1->setStrideNd(DimsHW{2, 2}); // Stride = 2 + conv1->setPaddingNd(DimsHW{1, 1}); // Padding = 1 + + IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), "bn1", 1e-5); + + IActivationLayer* relu1 = network->addActivation(*bn1->getOutput(0), ActivationType::kRELU); + assert(relu1); + + return relu1; +} + +ILayer* convBnAct(INetworkDefinition* network, std::map& weightMap, ITensor& input, + int out_channels, std::string lname, ActivationType actType = ActivationType::kRELU) { + Weights emptywts{DataType::kFLOAT, nullptr, 0}; + + IConvolutionLayer* conv = + network->addConvolutionNd(input, out_channels, DimsHW{1, 1}, weightMap[lname + ".conv.weight"], emptywts); + assert(conv); + conv->setStrideNd(DimsHW{1, 1}); + + IScaleLayer* bn = addBatchNorm2d(network, weightMap, *conv->getOutput(0), lname + ".bn1", 1e-5); + + IActivationLayer* act = network->addActivation(*bn->getOutput(0), actType); + assert(act); + + return act; +} + +ILayer* squeezeExcite(INetworkDefinition* network, ITensor& input, std::map& weightMap, + int in_chs, float se_ratio = 0.25, std::string lname = "", float eps = 1e-5) { + + IReduceLayer* avg_pool = network->addReduce(input, ReduceOperation::kAVG, 1 << 2 | 1 << 3, true); + assert(avg_pool); + + // Reduce channels with 1x1 convolution + int reduced_chs = _make_divisible(static_cast(in_chs * se_ratio), 4); + IConvolutionLayer* conv_reduce = + network->addConvolutionNd(*avg_pool->getOutput(0), reduced_chs, DimsHW{1, 1}, + weightMap[lname + ".conv_reduce.weight"], weightMap[lname + ".conv_reduce.bias"]); + assert(conv_reduce); + + IActivationLayer* relu1 = network->addActivation(*conv_reduce->getOutput(0), ActivationType::kRELU); + assert(relu1); + + // Expand channels back with another 1x1 convolution + IConvolutionLayer* conv_expand = + network->addConvolutionNd(*relu1->getOutput(0), in_chs, DimsHW{1, 1}, + weightMap[lname + ".conv_expand.weight"], weightMap[lname + ".conv_expand.bias"]); + assert(conv_expand); + cout << "SE conv_expand -> " << printTensorShape(conv_expand->getOutput(0)) << endl; + + // Apply hardSigmoid function + ILayer* hard_sigmoid = hardSigmoid(network, *conv_expand->getOutput(0)); + cout << "hard_sigmoid conv_expand -> " << printTensorShape(hard_sigmoid->getOutput(0)) << endl; + + // Elementwise multiplication of input and gated SE output + IElementWiseLayer* scale = network->addElementWise(input, *hard_sigmoid->getOutput(0), ElementWiseOperation::kPROD); + assert(scale); + + return scale; +} + +ILayer* ghostModule(INetworkDefinition* network, ITensor& input, std::map& weightMap, int inp, + int oup, int kernel_size = 1, int ratio = 2, int dw_size = 3, int stride = 1, bool relu = true, + std::string lname = "") { + int init_channels = std::ceil(oup / ratio); + int new_channels = init_channels * (ratio - 1); + + // Primary convolution + IConvolutionLayer* primary_conv = network->addConvolutionNd(input, init_channels, DimsHW{kernel_size, kernel_size}, + weightMap[lname + ".primary_conv.0.weight"], Weights{}); + primary_conv->setStrideNd(DimsHW{stride, stride}); + primary_conv->setPaddingNd(DimsHW{kernel_size / 2, kernel_size / 2}); + IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *primary_conv->getOutput(0), lname + ".primary_conv.1", 1e-5); + + // Cheap operation (Depthwise Convolution) + IConvolutionLayer* cheap_conv = + network->addConvolutionNd(*bn1->getOutput(0), new_channels, DimsHW{dw_size, dw_size}, + weightMap[lname + ".cheap_operation.0.weight"], Weights{}); + cheap_conv->setStrideNd(DimsHW{1, 1}); + cheap_conv->setPaddingNd(DimsHW{dw_size / 2, dw_size / 2}); + cheap_conv->setNbGroups(init_channels); + IScaleLayer* bn2 = + addBatchNorm2d(network, weightMap, *cheap_conv->getOutput(0), lname + ".cheap_operation.1", 1e-5); + + // Define relu1 and relu2 + IActivationLayer* relu1 = network->addActivation(*bn1->getOutput(0), ActivationType::kRELU); + IActivationLayer* relu2 = network->addActivation(*bn2->getOutput(0), ActivationType::kRELU); + + // Initialize inputs array based on the `relu` flag + std::vector inputs_vec; + if (relu) { + inputs_vec = {relu1->getOutput(0), relu2->getOutput(0)}; + } else { + inputs_vec = {bn1->getOutput(0), bn2->getOutput(0)}; + } + + ITensor* inputs[] = {inputs_vec[0], inputs_vec[1]}; + IConcatenationLayer* concat = network->addConcatenation(inputs, 2); + std::cout << printTensorShape(concat->getOutput(0)) << std::endl; + + // Slice the output to keep only the first `oup` channels + Dims start{4, {0, 0, 0, 0}}; // Starting from batch=0, channel=0, height=0, width=0 + Dims size{4, + {concat->getOutput(0)->getDimensions().d[0], oup, concat->getOutput(0)->getDimensions().d[2], + concat->getOutput(0) + ->getDimensions() + .d[3]}}; // Keep all batches, first `oup` channels, all heights and widths + Dims stride_{4, {1, 1, 1, 1}}; // Stride is 1 for all dimensions + + ISliceLayer* slice = network->addSlice(*concat->getOutput(0), start, size, stride_); + cout << "slice" << printTensorShape(slice->getOutput(0)) << endl; + + return slice; +} + +ILayer* ghostBottleneck(INetworkDefinition* network, ITensor& input, std::map& weightMap, + int in_chs, int mid_chs, int out_chs, int dw_kernel_size = 3, int stride = 1, + float se_ratio = 0.0f, std::string lname = "") { + ILayer* ghost1 = ghostModule(network, input, weightMap, in_chs, mid_chs, 1, 2, 3, 1, true, lname + ".ghost1"); + + ILayer* depthwise_conv = ghost1; + if (stride > 1) { + IConvolutionLayer* conv_dw = + network->addConvolutionNd(*ghost1->getOutput(0), mid_chs, DimsHW{dw_kernel_size, dw_kernel_size}, + weightMap[lname + ".conv_dw.weight"], Weights{}); + conv_dw->setStrideNd(DimsHW{stride, stride}); + conv_dw->setPaddingNd(DimsHW{(dw_kernel_size - 1) / 2, (dw_kernel_size - 1) / 2}); + conv_dw->setNbGroups(mid_chs); // Depth-wise convolution + IScaleLayer* bn_dw = addBatchNorm2d(network, weightMap, *conv_dw->getOutput(0), lname + ".bn_dw", 1e-5); + depthwise_conv = bn_dw; + } + + ILayer* se_layer = depthwise_conv; + if (se_ratio > 0.0f) { + se_layer = squeezeExcite(network, *depthwise_conv->getOutput(0), weightMap, mid_chs, se_ratio, lname + ".se"); + } + + ILayer* ghost2 = ghostModule(network, *se_layer->getOutput(0), weightMap, mid_chs, out_chs, 1, 2, 3, 1, false, + lname + ".ghost2"); + + ILayer* shortcut_layer = nullptr; + if (in_chs == out_chs && stride == 1) { + shortcut_layer = network->addIdentity(input); + } else { + IConvolutionLayer* conv_shortcut_dw = + network->addConvolutionNd(input, in_chs, DimsHW{dw_kernel_size, dw_kernel_size}, + weightMap[lname + ".shortcut.0.weight"], Weights{}); + + conv_shortcut_dw->setStrideNd(DimsHW{stride, stride}); + conv_shortcut_dw->setPaddingNd(DimsHW{(dw_kernel_size - 1) / 2, (dw_kernel_size - 1) / 2}); + conv_shortcut_dw->setNbGroups(in_chs); // Depth-wise convolution + IScaleLayer* bn_shortcut_dw = + addBatchNorm2d(network, weightMap, *conv_shortcut_dw->getOutput(0), lname + ".shortcut.1", 1e-5); + + IConvolutionLayer* conv_shortcut_pw = + network->addConvolutionNd(*bn_shortcut_dw->getOutput(0), out_chs, DimsHW{1, 1}, + weightMap[lname + ".shortcut.2.weight"], Weights{}); + IScaleLayer* bn_shortcut_pw = + addBatchNorm2d(network, weightMap, *conv_shortcut_pw->getOutput(0), lname + ".shortcut.3", 1e-5); + shortcut_layer = bn_shortcut_pw; + } + + IElementWiseLayer* ew_sum = + network->addElementWise(*ghost2->getOutput(0), *shortcut_layer->getOutput(0), ElementWiseOperation::kSUM); + + return ew_sum; +} + +ICudaEngine* createEngine(IBuilder* builder, IBuilderConfig* config, DataType dt) { + + INetworkDefinition* network = + builder->createNetworkV2(1U << static_cast(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)); + + // Create input tensor of shape {batchSize, 3, INPUT_H, INPUT_W} with name INPUT_BLOB_NAME + ITensor* data = network->addInput(INPUT_BLOB_NAME, dt, Dims4{batchSize, 3, INPUT_H, INPUT_W}); + assert(data); + + std::map weightMap = loadWeights("../ghostnetv1.weights"); + Weights emptywts{DataType::kFLOAT, nullptr, 0}; + + // Conv Stem + IActivationLayer* conv_stem = convBnReluStem(network, weightMap, *data, 16, "conv_stem"); + + ILayer* current_layer = conv_stem; + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 16, 16, 16, 3, 1, 0, "blocks.0.0"); + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 16, 48, 24, 3, 2, 0, "blocks.1.0"); + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 24, 72, 24, 3, 1, 0, "blocks.2.0"); + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 24, 72, 40, 5, 2, 0.25, "blocks.3.0"); + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 40, 120, 40, 5, 1, 0.25, "blocks.4.0"); + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 40, 240, 80, 3, 2, 0, "blocks.5.0"); + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 80, 200, 80, 3, 1, 0, "blocks.6.0"); + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 80, 184, 80, 3, 1, 0, "blocks.6.1"); + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 80, 184, 80, 3, 1, 0, "blocks.6.2"); + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 80, 480, 112, 3, 1, 0.25, "blocks.6.3"); + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 112, 672, 112, 3, 1, 0.25, "blocks.6.4"); + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 112, 672, 160, 5, 2, 0.25, "blocks.7.0"); + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 160, 960, 160, 5, 1, 0, "blocks.8.0"); + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 160, 960, 160, 5, 1, 0.25, "blocks.8.1"); + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 160, 960, 160, 5, 1, 0, "blocks.8.2"); + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 160, 960, 160, 5, 1, 0.25, "blocks.8.3"); + + // Apply ConvBnAct + current_layer = convBnAct(network, weightMap, *current_layer->getOutput(0), 960, "blocks.9.0"); + // Global Average Pooling + IReduceLayer* global_pool = + network->addReduce(*current_layer->getOutput(0), ReduceOperation::kAVG, 1 << 2 | 1 << 3, true); + assert(global_pool); + + // Conv Head + IConvolutionLayer* conv_head = network->addConvolutionNd( + *global_pool->getOutput(0), 1280, DimsHW{1, 1}, weightMap["conv_head.weight"], weightMap["conv_head.bias"]); + IActivationLayer* act2 = network->addActivation(*conv_head->getOutput(0), ActivationType::kRELU); + + // Fully Connected Layer (Classifier) + IFullyConnectedLayer* classifier = network->addFullyConnected( + *act2->getOutput(0), 1000, weightMap["classifier.weight"], weightMap["classifier.bias"]); + classifier->getOutput(0)->setName(OUTPUT_BLOB_NAME); + network->markOutput(*classifier->getOutput(0)); + + // Build engine + config->setMaxWorkspaceSize(1 << 24); + ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config); + + // Don't need the network any more + network->destroy(); + + // Release host memory + for (auto& mem : weightMap) { + free((void*)(mem.second.values)); + } + + return engine; +} + +void APIToModel(IHostMemory** modelStream) { + // Create builder + IBuilder* builder = createInferBuilder(gLogger); + IBuilderConfig* config = builder->createBuilderConfig(); + + // Create model to populate the network, then set the outputs and create an engine + ICudaEngine* engine = createEngine(builder, config, DataType::kFLOAT); + assert(engine != nullptr); + + // Serialize the engine + (*modelStream) = engine->serialize(); + + // Close everything down + engine->destroy(); + config->destroy(); + builder->destroy(); +} + +void doInference(IExecutionContext& context, float* input, float* output, int batchSize) { + const ICudaEngine& engine = context.getEngine(); + + const int inputIndex = engine.getBindingIndex(INPUT_BLOB_NAME); + const int outputIndex = engine.getBindingIndex(OUTPUT_BLOB_NAME); + + // Pointers to input and output device buffers to pass to engine. + void* buffers[2]; + + // Create GPU buffers on device + CHECK(cudaMalloc(&buffers[inputIndex], batchSize * 3 * INPUT_H * INPUT_W * sizeof(float))); + CHECK(cudaMalloc(&buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float))); + + // Create stream + cudaStream_t stream; + CHECK(cudaStreamCreate(&stream)); + + // DMA input batch data to device, infer on the batch asynchronously, and DMA output back to host + CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize * 3 * INPUT_H * INPUT_W * sizeof(float), + cudaMemcpyHostToDevice, stream)); + context.enqueueV2(buffers, stream, nullptr); + CHECK(cudaMemcpyAsync(output, buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, + stream)); + cudaStreamSynchronize(stream); + + // Release stream and buffers + cudaStreamDestroy(stream); + CHECK(cudaFree(buffers[inputIndex])); + CHECK(cudaFree(buffers[outputIndex])); +} + +int main(int argc, char** argv) { + if (argc != 2) { + std::cerr << "arguments not right!" << std::endl; + std::cerr << "./ghostnetv1 -s // serialize model to plan file" << std::endl; + std::cerr << "./ghostnetv1 -d // deserialize plan file and run inference" << std::endl; + return -1; + } + + // create a model using the API directly and serialize it to a stream + char* trtModelStream{nullptr}; + size_t size{0}; + + if (std::string(argv[1]) == "-s") { + IHostMemory* modelStream{nullptr}; + APIToModel(&modelStream); + assert(modelStream != nullptr); + + std::ofstream p("ghostnetv1.engine", std::ios::binary); + if (!p) { + std::cerr << "could not open plan output file" << std::endl; + return -1; + } + p.write(reinterpret_cast(modelStream->data()), modelStream->size()); + modelStream->destroy(); + return 0; + } else if (std::string(argv[1]) == "-d") { + std::ifstream file("ghostnetv1.engine", std::ios::binary); + if (file.good()) { + file.seekg(0, file.end); + size = file.tellg(); + file.seekg(0, file.beg); + trtModelStream = new char[size]; + assert(trtModelStream); + file.read(trtModelStream, size); + file.close(); + } + } else { + return -1; + } + + float* data = new float[batchSize * 3 * INPUT_H * INPUT_W]; + for (int i = 0; i < batchSize * 3 * INPUT_H * INPUT_W; i++) + data[i] = 10.0; + + float* prob = new float[batchSize * OUTPUT_SIZE]; + + IRuntime* runtime = createInferRuntime(gLogger); + assert(runtime != nullptr); + ICudaEngine* engine = runtime->deserializeCudaEngine(trtModelStream, size, nullptr); + assert(engine != nullptr); + IExecutionContext* context = engine->createExecutionContext(); + assert(context != nullptr); + delete[] trtModelStream; + + doInference(*context, data, prob, batchSize); + + std::cout << "\nOutput:\n\n"; + for (int i = 0; i < batchSize; i++) { + std::cout << "Batch " << i << ":\n"; + for (unsigned int j = 0; j < OUTPUT_SIZE; j++) { + std::cout << prob[i * OUTPUT_SIZE + j] << ", "; + if (j % 10 == 0) + std::cout << j / 10 << std::endl; + } + std::cout << "\n"; + } + + context->destroy(); + engine->destroy(); + runtime->destroy(); + delete[] data; + delete[] prob; + + return 0; +} diff --git a/ghostnet/ghostnetv1/logging.h b/ghostnet/ghostnetv1/logging.h new file mode 100644 index 00000000..f57438c8 --- /dev/null +++ b/ghostnet/ghostnetv1/logging.h @@ -0,0 +1,455 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TENSORRT_LOGGING_H +#define TENSORRT_LOGGING_H + +#include +#include +#include +#include +#include +#include +#include +#include "NvInferRuntimeCommon.h" + +using Severity = nvinfer1::ILogger::Severity; + +class LogStreamConsumerBuffer : public std::stringbuf { + public: + LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog) + : mOutput(stream), mPrefix(prefix), mShouldLog(shouldLog) {} + + LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) : mOutput(other.mOutput) {} + + ~LogStreamConsumerBuffer() { + // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence + // std::streambuf::pptr() gives a pointer to the current position of the output sequence + // if the pointer to the beginning is not equal to the pointer to the current position, + // call putOutput() to log the output to the stream + if (pbase() != pptr()) { + putOutput(); + } + } + + // synchronizes the stream buffer and returns 0 on success + // synchronizing the stream buffer consists of inserting the buffer contents into the stream, + // resetting the buffer and flushing the stream + virtual int sync() { + putOutput(); + return 0; + } + + void putOutput() { + if (mShouldLog) { + // prepend timestamp + std::time_t timestamp = std::time(nullptr); + tm* tm_local = std::localtime(×tamp); + std::cout << "["; + std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/"; + std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] "; + // std::stringbuf::str() gets the string contents of the buffer + // insert the buffer contents pre-appended by the appropriate prefix into the stream + mOutput << mPrefix << str(); + // set the buffer to empty + str(""); + // flush the stream + mOutput.flush(); + } + } + + void setShouldLog(bool shouldLog) { mShouldLog = shouldLog; } + + private: + std::ostream& mOutput; + std::string mPrefix; + bool mShouldLog; +}; + +//! +//! \class LogStreamConsumerBase +//! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer +//! +class LogStreamConsumerBase { + public: + LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog) + : mBuffer(stream, prefix, shouldLog) {} + + protected: + LogStreamConsumerBuffer mBuffer; +}; + +//! +//! \class LogStreamConsumer +//! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages. +//! Order of base classes is LogStreamConsumerBase and then std::ostream. +//! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field +//! in LogStreamConsumer and then the address of the buffer is passed to std::ostream. +//! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream. +//! Please do not change the order of the parent classes. +//! +class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream { + public: + //! \brief Creates a LogStreamConsumer which logs messages with level severity. + //! Reportable severity determines if the messages are severe enough to be logged. + LogStreamConsumer(Severity reportableSeverity, Severity severity) + : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity), + std::ostream(&mBuffer) // links the stream buffer with the stream + , + mShouldLog(severity <= reportableSeverity), + mSeverity(severity) {} + + LogStreamConsumer(LogStreamConsumer&& other) + : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog), + std::ostream(&mBuffer) // links the stream buffer with the stream + , + mShouldLog(other.mShouldLog), + mSeverity(other.mSeverity) {} + + void setReportableSeverity(Severity reportableSeverity) { + mShouldLog = mSeverity <= reportableSeverity; + mBuffer.setShouldLog(mShouldLog); + } + + private: + static std::ostream& severityOstream(Severity severity) { + return severity >= Severity::kINFO ? std::cout : std::cerr; + } + + static std::string severityPrefix(Severity severity) { + switch (severity) { + case Severity::kINTERNAL_ERROR: + return "[F] "; + case Severity::kERROR: + return "[E] "; + case Severity::kWARNING: + return "[W] "; + case Severity::kINFO: + return "[I] "; + case Severity::kVERBOSE: + return "[V] "; + default: + assert(0); + return ""; + } + } + + bool mShouldLog; + Severity mSeverity; +}; + +//! \class Logger +//! +//! \brief Class which manages logging of TensorRT tools and samples +//! +//! \details This class provides a common interface for TensorRT tools and samples to log information to the console, +//! and supports logging two types of messages: +//! +//! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal) +//! - Test pass/fail messages +//! +//! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is +//! that the logic for controlling the verbosity and formatting of sample output is centralized in one location. +//! +//! In the future, this class could be extended to support dumping test results to a file in some standard format +//! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run). +//! +//! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger +//! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT +//! library and messages coming from the sample. +//! +//! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the +//! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger +//! object. + +class Logger : public nvinfer1::ILogger { + public: + Logger(Severity severity = Severity::kWARNING) : mReportableSeverity(severity) {} + + //! + //! \enum TestResult + //! \brief Represents the state of a given test + //! + enum class TestResult { + kRUNNING, //!< The test is running + kPASSED, //!< The test passed + kFAILED, //!< The test failed + kWAIVED //!< The test was waived + }; + + //! + //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger + //! \return The nvinfer1::ILogger associated with this Logger + //! + //! TODO Once all samples are updated to use this method to register the logger with TensorRT, + //! we can eliminate the inheritance of Logger from ILogger + //! + nvinfer1::ILogger& getTRTLogger() { return *this; } + + //! + //! \brief Implementation of the nvinfer1::ILogger::log() virtual method + //! + //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the + //! inheritance from nvinfer1::ILogger + //! + void log(Severity severity, const char* msg) noexcept override { + LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl; + } + + //! + //! \brief Method for controlling the verbosity of logging output + //! + //! \param severity The logger will only emit messages that have severity of this level or higher. + //! + void setReportableSeverity(Severity severity) { mReportableSeverity = severity; } + + //! + //! \brief Opaque handle that holds logging information for a particular test + //! + //! This object is an opaque handle to information used by the Logger to print test results. + //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used + //! with Logger::reportTest{Start,End}(). + //! + class TestAtom { + public: + TestAtom(TestAtom&&) = default; + + private: + friend class Logger; + + TestAtom(bool started, const std::string& name, const std::string& cmdline) + : mStarted(started), mName(name), mCmdline(cmdline) {} + + bool mStarted; + std::string mName; + std::string mCmdline; + }; + + //! + //! \brief Define a test for logging + //! + //! \param[in] name The name of the test. This should be a string starting with + //! "TensorRT" and containing dot-separated strings containing + //! the characters [A-Za-z0-9_]. + //! For example, "TensorRT.sample_googlenet" + //! \param[in] cmdline The command line used to reproduce the test + // + //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). + //! + static TestAtom defineTest(const std::string& name, const std::string& cmdline) { + return TestAtom(false, name, cmdline); + } + + //! + //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments + //! as input + //! + //! \param[in] name The name of the test + //! \param[in] argc The number of command-line arguments + //! \param[in] argv The array of command-line arguments (given as C strings) + //! + //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). + static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) { + auto cmdline = genCmdlineString(argc, argv); + return defineTest(name, cmdline); + } + + //! + //! \brief Report that a test has started. + //! + //! \pre reportTestStart() has not been called yet for the given testAtom + //! + //! \param[in] testAtom The handle to the test that has started + //! + static void reportTestStart(TestAtom& testAtom) { + reportTestResult(testAtom, TestResult::kRUNNING); + assert(!testAtom.mStarted); + testAtom.mStarted = true; + } + + //! + //! \brief Report that a test has ended. + //! + //! \pre reportTestStart() has been called for the given testAtom + //! + //! \param[in] testAtom The handle to the test that has ended + //! \param[in] result The result of the test. Should be one of TestResult::kPASSED, + //! TestResult::kFAILED, TestResult::kWAIVED + //! + static void reportTestEnd(const TestAtom& testAtom, TestResult result) { + assert(result != TestResult::kRUNNING); + assert(testAtom.mStarted); + reportTestResult(testAtom, result); + } + + static int reportPass(const TestAtom& testAtom) { + reportTestEnd(testAtom, TestResult::kPASSED); + return EXIT_SUCCESS; + } + + static int reportFail(const TestAtom& testAtom) { + reportTestEnd(testAtom, TestResult::kFAILED); + return EXIT_FAILURE; + } + + static int reportWaive(const TestAtom& testAtom) { + reportTestEnd(testAtom, TestResult::kWAIVED); + return EXIT_SUCCESS; + } + + static int reportTest(const TestAtom& testAtom, bool pass) { + return pass ? reportPass(testAtom) : reportFail(testAtom); + } + + Severity getReportableSeverity() const { return mReportableSeverity; } + + private: + //! + //! \brief returns an appropriate string for prefixing a log message with the given severity + //! + static const char* severityPrefix(Severity severity) { + switch (severity) { + case Severity::kINTERNAL_ERROR: + return "[F] "; + case Severity::kERROR: + return "[E] "; + case Severity::kWARNING: + return "[W] "; + case Severity::kINFO: + return "[I] "; + case Severity::kVERBOSE: + return "[V] "; + default: + assert(0); + return ""; + } + } + + //! + //! \brief returns an appropriate string for prefixing a test result message with the given result + //! + static const char* testResultString(TestResult result) { + switch (result) { + case TestResult::kRUNNING: + return "RUNNING"; + case TestResult::kPASSED: + return "PASSED"; + case TestResult::kFAILED: + return "FAILED"; + case TestResult::kWAIVED: + return "WAIVED"; + default: + assert(0); + return ""; + } + } + + //! + //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity + //! + static std::ostream& severityOstream(Severity severity) { + return severity >= Severity::kINFO ? std::cout : std::cerr; + } + + //! + //! \brief method that implements logging test results + //! + static void reportTestResult(const TestAtom& testAtom, TestResult result) { + severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # " + << testAtom.mCmdline << std::endl; + } + + //! + //! \brief generate a command line string from the given (argc, argv) values + //! + static std::string genCmdlineString(int argc, char const* const* argv) { + std::stringstream ss; + for (int i = 0; i < argc; i++) { + if (i > 0) + ss << " "; + ss << argv[i]; + } + return ss.str(); + } + + Severity mReportableSeverity; +}; + +namespace { + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE +//! +//! Example usage: +//! +//! LOG_VERBOSE(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) { + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO +//! +//! Example usage: +//! +//! LOG_INFO(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_INFO(const Logger& logger) { + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING +//! +//! Example usage: +//! +//! LOG_WARN(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_WARN(const Logger& logger) { + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR +//! +//! Example usage: +//! +//! LOG_ERROR(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_ERROR(const Logger& logger) { + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR +// ("fatal" severity) +//! +//! Example usage: +//! +//! LOG_FATAL(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_FATAL(const Logger& logger) { + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR); +} + +} // anonymous namespace + +#endif // TENSORRT_LOGGING_H diff --git a/ghostnet/ghostnetv2/CMakeLists.txt b/ghostnet/ghostnetv2/CMakeLists.txt new file mode 100644 index 00000000..0796ef91 --- /dev/null +++ b/ghostnet/ghostnetv2/CMakeLists.txt @@ -0,0 +1,24 @@ +cmake_minimum_required(VERSION 2.6) + +project(ghostnetv2) + +add_definitions(-std=c++11) + +option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_BUILD_TYPE Debug) + +include_directories(${PROJECT_SOURCE_DIR}/include) +# include and link dirs of cuda and tensorrt, you need adapt them if yours are different +# cuda +include_directories(/usr/local/cuda/include) +link_directories(/usr/local/cuda/lib64) +# tensorrt +include_directories(/usr/include/x86_64-linux-gnu/) +link_directories(/usr/lib/x86_64-linux-gnu/) + +add_executable(ghostnetv2 ${PROJECT_SOURCE_DIR}/ghostnetv2.cpp) +target_link_libraries(ghostnetv2 nvinfer) +target_link_libraries(ghostnetv2 cudart) + +add_definitions(-O2 -pthread) diff --git a/ghostnet/ghostnetv2/gen_wts.py b/ghostnet/ghostnetv2/gen_wts.py new file mode 100644 index 00000000..9e2bdd19 --- /dev/null +++ b/ghostnet/ghostnetv2/gen_wts.py @@ -0,0 +1,312 @@ +import torch +import torch.nn as nn +import torch.onnx +import struct + +import torch +import torch.nn.functional as F +import math + +from timm.models.registry import register_model + + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def hard_sigmoid(x, inplace: bool = False): + if inplace: + return x.add_(3.).clamp_(0., 6.).div_(6.) + else: + return F.relu6(x + 3.) / 6. + + +class SqueezeExcite(nn.Module): + def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, + act_layer=nn.ReLU, gate_fn=hard_sigmoid, divisor=4, **_): + super(SqueezeExcite, self).__init__() + self.gate_fn = gate_fn + reduced_chs = _make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) + self.act1 = act_layer(inplace=True) + self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) + + def forward(self, x): + x_se = self.avg_pool(x) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + x = x * self.gate_fn(x_se) + return x + + +class ConvBnAct(nn.Module): + def __init__(self, in_chs, out_chs, kernel_size, + stride=1, act_layer=nn.ReLU): + super(ConvBnAct, self).__init__() + self.conv = nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size//2, bias=False) + self.bn1 = nn.BatchNorm2d(out_chs) + self.act1 = act_layer(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn1(x) + x = self.act1(x) + return x + + +class GhostModuleV2(nn.Module): + def __init__(self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True, mode=None, args=None): + super(GhostModuleV2, self).__init__() + self.mode = mode + self.gate_fn = nn.Sigmoid() + + if self.mode in ['original']: + self.oup = oup + init_channels = math.ceil(oup / ratio) + new_channels = init_channels*(ratio-1) + self.primary_conv = nn.Sequential( + nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False), + nn.BatchNorm2d(init_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + self.cheap_operation = nn.Sequential( + nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False), + nn.BatchNorm2d(new_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + elif self.mode in ['attn']: + self.oup = oup + init_channels = math.ceil(oup / ratio) + new_channels = init_channels*(ratio-1) + self.primary_conv = nn.Sequential( + nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False), + nn.BatchNorm2d(init_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + self.cheap_operation = nn.Sequential( + nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False), + nn.BatchNorm2d(new_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + self.short_conv = nn.Sequential( + nn.Conv2d(inp, oup, kernel_size, stride, kernel_size//2, bias=False), + nn.BatchNorm2d(oup), + nn.Conv2d(oup, oup, kernel_size=(1, 5), stride=1, padding=(0, 2), groups=oup, bias=False), + nn.BatchNorm2d(oup), + nn.Conv2d(oup, oup, kernel_size=(5, 1), stride=1, padding=(2, 0), groups=oup, bias=False), + nn.BatchNorm2d(oup), + ) + + def forward(self, x): + if self.mode in ['original']: + x1 = self.primary_conv(x) + x2 = self.cheap_operation(x1) + out = torch.cat([x1, x2], dim=1) + return out[:, :self.oup, :, :] + elif self.mode in ['attn']: + res = self.short_conv(F.avg_pool2d(x, kernel_size=2, stride=2)) + x1 = self.primary_conv(x) + x2 = self.cheap_operation(x1) + out = torch.cat([x1, x2], dim=1) + return out[:, :self.oup, :, :]*F.interpolate(self.gate_fn(res), + size=(out.shape[-2], out.shape[-1]), mode='nearest') + + +class GhostBottleneckV2(nn.Module): + + def __init__(self, in_chs, mid_chs, out_chs, dw_kernel_size=3, + stride=1, act_layer=nn.ReLU, se_ratio=0., layer_id=None, args=None): + super(GhostBottleneckV2, self).__init__() + has_se = se_ratio is not None and se_ratio > 0. + self.stride = stride + + # Point-wise expansion + if layer_id <= 1: + self.ghost1 = GhostModuleV2(in_chs, mid_chs, relu=True, mode='original', args=args) + else: + self.ghost1 = GhostModuleV2(in_chs, mid_chs, relu=True, mode='attn', args=args) + + # Depth-wise convolution + if self.stride > 1: + self.conv_dw = nn.Conv2d(mid_chs, mid_chs, dw_kernel_size, stride=stride, + padding=(dw_kernel_size-1)//2, groups=mid_chs, bias=False) + self.bn_dw = nn.BatchNorm2d(mid_chs) + + # Squeeze-and-excitation + if has_se: + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio) + else: + self.se = None + + self.ghost2 = GhostModuleV2(mid_chs, out_chs, relu=False, mode='original', args=args) + + # shortcut + if (in_chs == out_chs and self.stride == 1): + self.shortcut = nn.Sequential() + else: + self.shortcut = nn.Sequential( + nn.Conv2d(in_chs, in_chs, dw_kernel_size, stride=stride, + padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False), + nn.BatchNorm2d(in_chs), + nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(out_chs), + ) + + def forward(self, x): + residual = x + x = self.ghost1(x) + if self.stride > 1: + x = self.conv_dw(x) + x = self.bn_dw(x) + if self.se is not None: + x = self.se(x) + x = self.ghost2(x) + x += self.shortcut(residual) + return x + + +class GhostNetV2(nn.Module): + def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, block=GhostBottleneckV2, args=None): + super(GhostNetV2, self).__init__() + self.cfgs = cfgs + self.dropout = dropout + + # building first layer + output_channel = _make_divisible(16 * width, 4) + self.conv_stem = nn.Conv2d(3, output_channel, 3, 2, 1, bias=False) + self.bn1 = nn.BatchNorm2d(output_channel) + self.act1 = nn.ReLU(inplace=True) + input_channel = output_channel + + # building inverted residual blocks + stages = [] + layer_id = 0 + for cfg in self.cfgs: + layers = [] + for k, exp_size, c, se_ratio, s in cfg: + output_channel = _make_divisible(c * width, 4) + hidden_channel = _make_divisible(exp_size * width, 4) + layers.append(block(input_channel, hidden_channel, output_channel, k, s, + se_ratio=se_ratio, layer_id=layer_id, args=args)) + input_channel = output_channel + layer_id += 1 + stages.append(nn.Sequential(*layers)) + + output_channel = _make_divisible(exp_size * width, 4) + stages.append(nn.Sequential(ConvBnAct(input_channel, output_channel, 1))) + input_channel = output_channel + + self.blocks = nn.Sequential(*stages) + + # building last several layers + output_channel = 1280 + self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) + self.conv_head = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=True) + self.act2 = nn.ReLU(inplace=True) + self.classifier = nn.Linear(output_channel, num_classes) + + def forward(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + x = x.view(x.size(0), -1) + if self.dropout > 0.: + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.classifier(x) + return x + + +@register_model +def ghostnetv2(**kwargs): + cfgs = [ + # k, t, c, SE, s + [[3, 16, 16, 0, 1]], + [[3, 48, 24, 0, 2]], + [[3, 72, 24, 0, 1]], + [[5, 72, 40, 0.25, 2]], + [[5, 120, 40, 0.25, 1]], + [[3, 240, 80, 0, 2]], + [[3, 200, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 480, 112, 0.25, 1], + [3, 672, 112, 0.25, 1]], + [[5, 672, 160, 0.25, 2]], + [[5, 960, 160, 0, 1], + [5, 960, 160, 0.25, 1], + [5, 960, 160, 0, 1], + [5, 960, 160, 0.25, 1]] + ] + return GhostNetV2(cfgs, num_classes=kwargs['num_classes'], + width=kwargs['width'], + dropout=kwargs['dropout'], + args=kwargs['args']) + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + + +# Function to export weights in the specified format +def export_weight(model): + f = open("ghostnetv2.weights", 'w') + f.write("{}\n".format(len(model.state_dict().keys()))) + + # Convert weights to hexadecimal format + for k, v in model.state_dict().items(): + print('exporting ... {}: {}'.format(k, v.shape)) + + # Reshape the weights to 1D + vr = v.reshape(-1).cpu().numpy() + f.write("{} {}".format(k, len(vr))) + for vv in vr: + f.write(" ") + f.write(struct.pack(">f", float(vv)).hex()) + f.write("\n") + + f.close() + + +# Function to evaluate the model (optional) +def eval_model(input, model): + output = model(input) + print("------from inference------") + print(input) + print(output) + + +if __name__ == "__main__": + setup_seed(1) + + # Create an instance of GhostNetV2 + model = ghostnetv2(width=1.0, num_classes=1000, dropout=0.2, args=None) + model.eval() + + # Dummy input tensor (adjust the shape as per your requirement) + input = torch.full((32, 3, 320, 256), 10.0) + + # Export the model weights + export_weight(model) + + # Evaluate the model + eval_model(input, model) diff --git a/ghostnet/ghostnetv2/ghostnetv2.cpp b/ghostnet/ghostnetv2/ghostnetv2.cpp new file mode 100644 index 00000000..9c0a19a9 --- /dev/null +++ b/ghostnet/ghostnetv2/ghostnetv2.cpp @@ -0,0 +1,591 @@ +#include +#include +#include +#include +#include +#include +#include +#include "NvInfer.h" +#include "cuda_runtime_api.h" +#include "logging.h" + +using namespace std; + +#define CHECK(status) \ + do { \ + auto ret = (status); \ + if (ret != 0) { \ + std::cerr << "Cuda failure: " << ret << std::endl; \ + abort(); \ + } \ + } while (0) + +// Define input/output parameters +static const int INPUT_H = 256; +static const int INPUT_W = 320; +static const int OUTPUT_SIZE = 1000; +static const int batchSize = 32; + +const char* INPUT_BLOB_NAME = "data"; +const char* OUTPUT_BLOB_NAME = "prob"; +using namespace nvinfer1; + +static Logger gLogger; + +// Load weight file +std::map loadWeights(const std::string file) { + std::cout << "Loading weights: " << file << std::endl; + std::map weightMap; + + // Open the weight file + std::ifstream input(file); + if (!input.is_open()) { + std::cerr << "Unable to load weight file." << std::endl; + exit(EXIT_FAILURE); + } + + // Read the number of weights + int32_t count; + input >> count; + if (count <= 0) { + std::cerr << "Invalid weight map file." << std::endl; + exit(EXIT_FAILURE); + } + + while (count--) { + Weights wt{DataType::kFLOAT, nullptr, 0}; + uint32_t size; + + // Read the name and size + std::string name; + input >> name >> std::dec >> size; + wt.type = DataType::kFLOAT; + + // Load weight data + uint32_t* val = reinterpret_cast(malloc(sizeof(uint32_t) * size)); + for (uint32_t x = 0, y = size; x < y; ++x) { + input >> std::hex >> val[x]; + } + wt.values = val; + + wt.count = size; + weightMap[name] = wt; + } + + return weightMap; +} + +int _make_divisible(int v, int divisor, int min_value = -1) { + // If min_value is not specified, set it to divisor + if (min_value == -1) { + min_value = divisor; + } + + // Calculate new channel size to be divisible by divisor + int new_v = std::max(min_value, (v + divisor / 2) / divisor * divisor); + + // Ensure rounding down does not reduce by more than 10% + if (new_v < static_cast(0.9 * v)) { + new_v += divisor; + } + + return new_v; +} + +ILayer* hardSigmoid(INetworkDefinition* network, ITensor& input) { + // Apply Hard Sigmoid activation function + IActivationLayer* scale_layer = network->addActivation(input, ActivationType::kHARD_SIGMOID); + + // Return the output after activation + return scale_layer; +} + +IScaleLayer* addBatchNorm2d(INetworkDefinition* network, std::map& weightMap, ITensor& input, + std::string lname, float eps) { + float* gamma = (float*)weightMap[lname + ".weight"].values; + float* beta = (float*)weightMap[lname + ".bias"].values; + float* mean = (float*)weightMap[lname + ".running_mean"].values; + float* var = (float*)weightMap[lname + ".running_var"].values; + int len = weightMap[lname + ".running_var"].count; + + float* scval = reinterpret_cast(malloc(sizeof(float) * len)); + for (int i = 0; i < len; i++) { + scval[i] = gamma[i] / sqrt(var[i] + eps); + } + Weights scale{DataType::kFLOAT, scval, len}; + + float* shval = reinterpret_cast(malloc(sizeof(float) * len)); + for (int i = 0; i < len; i++) { + shval[i] = beta[i] - mean[i] * gamma[i] / sqrt(var[i] + eps); + } + Weights shift{DataType::kFLOAT, shval, len}; + + float* pval = reinterpret_cast(malloc(sizeof(float) * len)); + for (int i = 0; i < len; i++) { + pval[i] = 1.0; + } + Weights power{DataType::kFLOAT, pval, len}; + + weightMap[lname + ".scale"] = scale; + weightMap[lname + ".shift"] = shift; + weightMap[lname + ".power"] = power; + IScaleLayer* scale_1 = network->addScale(input, ScaleMode::kCHANNEL, shift, scale, power); + assert(scale_1); + return scale_1; +} + +IActivationLayer* convBnReluStem(INetworkDefinition* network, std::map& weightMap, ITensor& input, + int outch, std::string lname) { + Weights emptywts{DataType::kFLOAT, nullptr, 0}; + + // Step 1: Convolution layer + IConvolutionLayer* conv1 = + network->addConvolutionNd(input, outch, DimsHW{3, 3}, weightMap[lname + ".weight"], emptywts); + assert(conv1); + conv1->setStrideNd(DimsHW{2, 2}); // Stride of 2 + conv1->setPaddingNd(DimsHW{1, 1}); // Padding of 1 + + // Step 2: Batch normalization layer + IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), "bn1", 1e-5); + + // Step 3: ReLU activation + IActivationLayer* relu1 = network->addActivation(*bn1->getOutput(0), ActivationType::kRELU); + assert(relu1); + + return relu1; // Return the result after activation +} + +ILayer* convBnAct(INetworkDefinition* network, std::map& weightMap, ITensor& input, + int out_channels, std::string lname, ActivationType actType = ActivationType::kRELU) { + Weights emptywts{DataType::kFLOAT, nullptr, 0}; + + // Add convolution layer + IConvolutionLayer* conv = + network->addConvolutionNd(input, out_channels, DimsHW{1, 1}, weightMap[lname + ".conv.weight"], emptywts); + assert(conv); + conv->setStrideNd(DimsHW{1, 1}); + + // Add batch normalization layer + IScaleLayer* bn = addBatchNorm2d(network, weightMap, *conv->getOutput(0), lname + ".bn1", 1e-5); + + // Add activation layer (default is ReLU) + IActivationLayer* act = network->addActivation(*bn->getOutput(0), actType); + assert(act); + + return act; +} + +ILayer* squeezeExcite(INetworkDefinition* network, ITensor& input, std::map& weightMap, + int in_chs, float se_ratio = 0.25, std::string lname = "", float eps = 1e-5) { + // Step 1: Global average pooling + IReduceLayer* avg_pool = network->addReduce(input, ReduceOperation::kAVG, 1 << 2 | 1 << 3, true); + assert(avg_pool); + + // Step 2: 1x1 convolution for dimension reduction + int reduced_chs = _make_divisible(static_cast(in_chs * se_ratio), 4); + IConvolutionLayer* conv_reduce = + network->addConvolutionNd(*avg_pool->getOutput(0), reduced_chs, DimsHW{1, 1}, + weightMap[lname + ".conv_reduce.weight"], weightMap[lname + ".conv_reduce.bias"]); + assert(conv_reduce); + + // Step 3: ReLU activation + IActivationLayer* relu1 = network->addActivation(*conv_reduce->getOutput(0), ActivationType::kRELU); + assert(relu1); + + // Step 4: 1x1 convolution for dimension expansion + IConvolutionLayer* conv_expand = + network->addConvolutionNd(*relu1->getOutput(0), in_chs, DimsHW{1, 1}, + weightMap[lname + ".conv_expand.weight"], weightMap[lname + ".conv_expand.bias"]); + assert(conv_expand); + + // Step 5: Hard Sigmoid activation + ILayer* hard_sigmoid = hardSigmoid(network, *conv_expand->getOutput(0)); + + // Step 6: Multiply input by the output of SE module + IElementWiseLayer* scale = network->addElementWise(input, *hard_sigmoid->getOutput(0), ElementWiseOperation::kPROD); + assert(scale); + + return scale; +} + +ILayer* ghostModuleV2(INetworkDefinition* network, ITensor& input, std::map& weightMap, int inp, + int oup, int kernel_size = 1, int ratio = 2, int dw_size = 3, int stride = 1, bool relu = true, + std::string lname = "", std::string mode = "original") { + int init_channels = std::ceil(oup / ratio); + int new_channels = init_channels * (ratio - 1); + + // Primary convolution + IConvolutionLayer* primary_conv = network->addConvolutionNd(input, init_channels, DimsHW{kernel_size, kernel_size}, + weightMap[lname + ".primary_conv.0.weight"], Weights{}); + primary_conv->setStrideNd(DimsHW{stride, stride}); + primary_conv->setPaddingNd(DimsHW{kernel_size / 2, kernel_size / 2}); + + IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *primary_conv->getOutput(0), lname + ".primary_conv.1", 1e-5); + + ITensor* act1_output = bn1->getOutput(0); + if (relu) { + IActivationLayer* relu1 = network->addActivation(*bn1->getOutput(0), ActivationType::kRELU); + act1_output = relu1->getOutput(0); + } + + // Cheap operation + IConvolutionLayer* cheap_conv = + network->addConvolutionNd(*act1_output, new_channels, DimsHW{dw_size, dw_size}, + weightMap[lname + ".cheap_operation.0.weight"], Weights{}); + cheap_conv->setStrideNd(DimsHW{1, 1}); + cheap_conv->setPaddingNd(DimsHW{dw_size / 2, dw_size / 2}); + cheap_conv->setNbGroups(init_channels); + + IScaleLayer* bn2 = + addBatchNorm2d(network, weightMap, *cheap_conv->getOutput(0), lname + ".cheap_operation.1", 1e-5); + + ITensor* act2_output = bn2->getOutput(0); + if (relu) { + IActivationLayer* relu2 = network->addActivation(*bn2->getOutput(0), ActivationType::kRELU); + act2_output = relu2->getOutput(0); + } + + // Concatenate + ITensor* concat_inputs[] = {act1_output, act2_output}; + IConcatenationLayer* concat = network->addConcatenation(concat_inputs, 2); + + // Slice to oup channels + Dims start{4, {0, 0, 0, 0}}; + Dims size = concat->getOutput(0)->getDimensions(); + size.d[1] = oup; + Dims stride_{4, {1, 1, 1, 1}}; + + ISliceLayer* slice = network->addSlice(*concat->getOutput(0), start, size, stride_); + + ITensor* out = slice->getOutput(0); + + if (mode == "original") { + return slice; + } else if (mode == "attn") { + // Attention mechanism + // Average pooling + IPoolingLayer* avg_pool = network->addPoolingNd(input, PoolingType::kAVERAGE, DimsHW{2, 2}); + avg_pool->setStrideNd(DimsHW{2, 2}); + + ITensor* avg_pooled = avg_pool->getOutput(0); + + // Short convolution branch + IConvolutionLayer* short_conv1 = + network->addConvolutionNd(*avg_pooled, oup, DimsHW{kernel_size, kernel_size}, + weightMap[lname + ".short_conv.0.weight"], Weights{}); + short_conv1->setStrideNd(DimsHW{1, 1}); + short_conv1->setPaddingNd(DimsHW{kernel_size / 2, kernel_size / 2}); + IScaleLayer* short_bn1 = + addBatchNorm2d(network, weightMap, *short_conv1->getOutput(0), lname + ".short_conv.1", 1e-5); + + // Conv with kernel size (1,5) + IConvolutionLayer* short_conv2 = network->addConvolutionNd( + *short_bn1->getOutput(0), oup, DimsHW{1, 5}, weightMap[lname + ".short_conv.2.weight"], Weights{}); + short_conv2->setStrideNd(DimsHW{1, 1}); + short_conv2->setPaddingNd(DimsHW{0, 2}); + short_conv2->setNbGroups(oup); + IScaleLayer* short_bn2 = + addBatchNorm2d(network, weightMap, *short_conv2->getOutput(0), lname + ".short_conv.3", 1e-5); + + // Conv with kernel size (5,1) + IConvolutionLayer* short_conv3 = network->addConvolutionNd( + *short_bn2->getOutput(0), oup, DimsHW{5, 1}, weightMap[lname + ".short_conv.4.weight"], Weights{}); + short_conv3->setStrideNd(DimsHW{1, 1}); + short_conv3->setPaddingNd(DimsHW{2, 0}); + short_conv3->setNbGroups(oup); + IScaleLayer* short_bn3 = + addBatchNorm2d(network, weightMap, *short_conv3->getOutput(0), lname + ".short_conv.5", 1e-5); + + ITensor* res = short_bn3->getOutput(0); + + // Sigmoid activation + IActivationLayer* gate = network->addActivation(*res, ActivationType::kSIGMOID); + + // Upsample to the same size as out + IResizeLayer* gate_upsampled = network->addResize(*gate->getOutput(0)); + gate_upsampled->setResizeMode(ResizeMode::kNEAREST); + Dims out_dims = out->getDimensions(); + gate_upsampled->setOutputDimensions(out_dims); + + // Element-wise multiplication + IElementWiseLayer* scaled_out = + network->addElementWise(*out, *gate_upsampled->getOutput(0), ElementWiseOperation::kPROD); + + return scaled_out; + } else { + std::cerr << "Invalid mode: " << mode << " in ghostModuleV2" << std::endl; + return nullptr; + } +} + +ILayer* ghostBottleneck(INetworkDefinition* network, ITensor& input, std::map& weightMap, + int in_chs, int mid_chs, int out_chs, int dw_kernel_size = 3, int stride = 1, + float se_ratio = 0.0f, std::string lname = "", int layer_id = 0) { + // Determine mode based on layer_id + std::string mode = (layer_id <= 1) ? "original" : "attn"; + + // ghost1 + ILayer* ghost1 = + ghostModuleV2(network, input, weightMap, in_chs, mid_chs, 1, 2, 3, 1, true, lname + ".ghost1", mode); + + ILayer* depthwise_conv = ghost1; + if (stride > 1) { + IConvolutionLayer* conv_dw = + network->addConvolutionNd(*ghost1->getOutput(0), mid_chs, DimsHW{dw_kernel_size, dw_kernel_size}, + weightMap[lname + ".conv_dw.weight"], Weights{}); + conv_dw->setStrideNd(DimsHW{stride, stride}); + conv_dw->setPaddingNd(DimsHW{(dw_kernel_size - 1) / 2, (dw_kernel_size - 1) / 2}); + conv_dw->setNbGroups(mid_chs); + IScaleLayer* bn_dw = addBatchNorm2d(network, weightMap, *conv_dw->getOutput(0), lname + ".bn_dw", 1e-5); + depthwise_conv = bn_dw; + } + + ILayer* se_layer = depthwise_conv; + if (se_ratio > 0.0f) { + se_layer = squeezeExcite(network, *depthwise_conv->getOutput(0), weightMap, mid_chs, se_ratio, lname + ".se"); + } + + // ghost2 uses original mode + ILayer* ghost2 = ghostModuleV2(network, *se_layer->getOutput(0), weightMap, mid_chs, out_chs, 1, 2, 3, 1, false, + lname + ".ghost2", "original"); + + ILayer* shortcut_layer = nullptr; + if (in_chs == out_chs && stride == 1) { + shortcut_layer = network->addIdentity(input); + } else { + IConvolutionLayer* conv_shortcut_dw = + network->addConvolutionNd(input, in_chs, DimsHW{dw_kernel_size, dw_kernel_size}, + weightMap[lname + ".shortcut.0.weight"], Weights{}); + conv_shortcut_dw->setStrideNd(DimsHW{stride, stride}); + conv_shortcut_dw->setPaddingNd(DimsHW{(dw_kernel_size - 1) / 2, (dw_kernel_size - 1) / 2}); + conv_shortcut_dw->setNbGroups(in_chs); + IScaleLayer* bn_shortcut_dw = + addBatchNorm2d(network, weightMap, *conv_shortcut_dw->getOutput(0), lname + ".shortcut.1", 1e-5); + + IConvolutionLayer* conv_shortcut_pw = + network->addConvolutionNd(*bn_shortcut_dw->getOutput(0), out_chs, DimsHW{1, 1}, + weightMap[lname + ".shortcut.2.weight"], Weights{}); + IScaleLayer* bn_shortcut_pw = + addBatchNorm2d(network, weightMap, *conv_shortcut_pw->getOutput(0), lname + ".shortcut.3", 1e-5); + shortcut_layer = bn_shortcut_pw; + } + + IElementWiseLayer* ew_sum = + network->addElementWise(*ghost2->getOutput(0), *shortcut_layer->getOutput(0), ElementWiseOperation::kSUM); + + return ew_sum; +} + +ICudaEngine* createEngine(IBuilder* builder, IBuilderConfig* config, DataType dt) { + // Use explicit batch mode + INetworkDefinition* network = + builder->createNetworkV2(1U << static_cast(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)); + + // Create input tensor + ITensor* data = network->addInput(INPUT_BLOB_NAME, dt, Dims4{batchSize, 3, INPUT_H, INPUT_W}); + assert(data); + + // Load weights + std::map weightMap = loadWeights("../ghostnetv2.weights"); + Weights emptywts{DataType::kFLOAT, nullptr, 0}; + + // Step 1: Conv Stem + IActivationLayer* conv_stem = convBnReluStem(network, weightMap, *data, 16, "conv_stem"); + + ILayer* current_layer = conv_stem; + + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 16, 16, 16, 3, 1, 0.0f, "blocks.0.0", 0); + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 16, 48, 24, 3, 2, 0.0f, "blocks.1.0", 1); + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 24, 72, 24, 3, 1, 0.0f, "blocks.2.0", 2); + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 24, 72, 40, 5, 2, 0.25f, "blocks.3.0", 3); + current_layer = ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 40, 120, 40, 5, 1, 0.25f, + "blocks.4.0", 4); + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 40, 240, 80, 3, 2, 0.0f, "blocks.5.0", 5); + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 80, 200, 80, 3, 1, 0.0f, "blocks.6.0", 6); + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 80, 184, 80, 3, 1, 0.0f, "blocks.6.1", 7); + current_layer = + ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 80, 184, 80, 3, 1, 0.0f, "blocks.6.2", 8); + current_layer = ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 80, 480, 112, 3, 1, 0.25f, + "blocks.6.3", 9); + current_layer = ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 112, 672, 112, 3, 1, 0.25f, + "blocks.6.4", 10); + current_layer = ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 112, 672, 160, 5, 2, 0.25f, + "blocks.7.0", 11); + current_layer = ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 160, 960, 160, 5, 1, 0.0f, + "blocks.8.0", 12); + current_layer = ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 160, 960, 160, 5, 1, 0.25f, + "blocks.8.1", 13); + current_layer = ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 160, 960, 160, 5, 1, 0.0f, + "blocks.8.2", 14); + current_layer = ghostBottleneck(network, *current_layer->getOutput(0), weightMap, 160, 960, 160, 5, 1, 0.25f, + "blocks.8.3", 15); + + // Apply ConvBnAct + current_layer = convBnAct(network, weightMap, *current_layer->getOutput(0), 960, "blocks.9.0"); + + // Global average pooling + IReduceLayer* global_pool = + network->addReduce(*current_layer->getOutput(0), ReduceOperation::kAVG, 1 << 2 | 1 << 3, true); + assert(global_pool); + + // Conv Head + IConvolutionLayer* conv_head = network->addConvolutionNd( + *global_pool->getOutput(0), 1280, DimsHW{1, 1}, weightMap["conv_head.weight"], weightMap["conv_head.bias"]); + IActivationLayer* act2 = network->addActivation(*conv_head->getOutput(0), ActivationType::kRELU); + + // Fully connected layer (classifier) + IFullyConnectedLayer* classifier = network->addFullyConnected( + *act2->getOutput(0), 1000, weightMap["classifier.weight"], weightMap["classifier.bias"]); + classifier->getOutput(0)->setName(OUTPUT_BLOB_NAME); + network->markOutput(*classifier->getOutput(0)); + + // Build the engine + config->setMaxWorkspaceSize(1 << 24); + ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config); + + // Destroy the network + network->destroy(); + + // Free memory + for (auto& mem : weightMap) { + free((void*)(mem.second.values)); + } + + return engine; +} + +void APIToModel(IHostMemory** modelStream) { + // Create builder + IBuilder* builder = createInferBuilder(gLogger); + IBuilderConfig* config = builder->createBuilderConfig(); + + // Create model and serialize + ICudaEngine* engine = createEngine(builder, config, DataType::kFLOAT); + assert(engine != nullptr); + + // Serialize the engine + (*modelStream) = engine->serialize(); + + // Release resources + engine->destroy(); + config->destroy(); + builder->destroy(); +} + +void doInference(IExecutionContext& context, float* input, float* output, int batchSize) { + const ICudaEngine& engine = context.getEngine(); + + const int inputIndex = engine.getBindingIndex(INPUT_BLOB_NAME); + const int outputIndex = engine.getBindingIndex(OUTPUT_BLOB_NAME); + + // Input and output buffers + void* buffers[2]; + + // Create buffers on device + CHECK(cudaMalloc(&buffers[inputIndex], batchSize * 3 * INPUT_H * INPUT_W * sizeof(float))); + CHECK(cudaMalloc(&buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float))); + + // Create stream + cudaStream_t stream; + CHECK(cudaStreamCreate(&stream)); + + // Copy input data to device, execute inference, and copy output back to host + CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize * 3 * INPUT_H * INPUT_W * sizeof(float), + cudaMemcpyHostToDevice, stream)); + context.enqueueV2(buffers, stream, nullptr); + CHECK(cudaMemcpyAsync(output, buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, + stream)); + cudaStreamSynchronize(stream); + + // Release stream and buffers + cudaStreamDestroy(stream); + CHECK(cudaFree(buffers[inputIndex])); + CHECK(cudaFree(buffers[outputIndex])); +} + +int main(int argc, char** argv) { + if (argc != 2) { + std::cerr << "arguments not right!" << std::endl; + std::cerr << "./ghostnetv2 -s // serialize model to plan file" << std::endl; + std::cerr << "./ghostnetv2 -d // deserialize plan file and run inference" << std::endl; + return -1; + } + + // Create model and serialize + char* trtModelStream{nullptr}; + size_t size{0}; + + if (std::string(argv[1]) == "-s") { + IHostMemory* modelStream{nullptr}; + APIToModel(&modelStream); + assert(modelStream != nullptr); + + std::ofstream p("ghostnetv2.engine", std::ios::binary); + if (!p) { + std::cerr << "could not open plan output file" << std::endl; + return -1; + } + p.write(reinterpret_cast(modelStream->data()), modelStream->size()); + modelStream->destroy(); + return 0; + } else if (std::string(argv[1]) == "-d") { + std::ifstream file("ghostnetv2.engine", std::ios::binary); + if (file.good()) { + file.seekg(0, file.end); + size = file.tellg(); + file.seekg(0, file.beg); + trtModelStream = new char[size]; + assert(trtModelStream); + file.read(trtModelStream, size); + file.close(); + } + } else { + return -1; + } + + // Allocate input and output data + float* data = new float[batchSize * 3 * INPUT_H * INPUT_W]; + for (int i = 0; i < batchSize * 3 * INPUT_H * INPUT_W; i++) + data[i] = 10.0; + + float* prob = new float[batchSize * OUTPUT_SIZE]; + + IRuntime* runtime = createInferRuntime(gLogger); + assert(runtime != nullptr); + ICudaEngine* engine = runtime->deserializeCudaEngine(trtModelStream, size, nullptr); + assert(engine != nullptr); + IExecutionContext* context = engine->createExecutionContext(); + assert(context != nullptr); + delete[] trtModelStream; + + // Execute inference + doInference(*context, data, prob, batchSize); + + // Print output results + std::cout << "\nOutput:\n\n"; + for (int i = 0; i < batchSize; i++) { + std::cout << "Batch " << i << ":\n"; + for (unsigned int j = 0; j < OUTPUT_SIZE; j++) { + std::cout << prob[i * OUTPUT_SIZE + j] << ", "; + if (j % 10 == 0) + std::cout << j / 10 << std::endl; + } + std::cout << "\n"; + } + + // Release resources + context->destroy(); + engine->destroy(); + runtime->destroy(); + delete[] data; + delete[] prob; + + return 0; +} diff --git a/ghostnet/ghostnetv2/logging.h b/ghostnet/ghostnetv2/logging.h new file mode 100644 index 00000000..f57438c8 --- /dev/null +++ b/ghostnet/ghostnetv2/logging.h @@ -0,0 +1,455 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TENSORRT_LOGGING_H +#define TENSORRT_LOGGING_H + +#include +#include +#include +#include +#include +#include +#include +#include "NvInferRuntimeCommon.h" + +using Severity = nvinfer1::ILogger::Severity; + +class LogStreamConsumerBuffer : public std::stringbuf { + public: + LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog) + : mOutput(stream), mPrefix(prefix), mShouldLog(shouldLog) {} + + LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) : mOutput(other.mOutput) {} + + ~LogStreamConsumerBuffer() { + // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence + // std::streambuf::pptr() gives a pointer to the current position of the output sequence + // if the pointer to the beginning is not equal to the pointer to the current position, + // call putOutput() to log the output to the stream + if (pbase() != pptr()) { + putOutput(); + } + } + + // synchronizes the stream buffer and returns 0 on success + // synchronizing the stream buffer consists of inserting the buffer contents into the stream, + // resetting the buffer and flushing the stream + virtual int sync() { + putOutput(); + return 0; + } + + void putOutput() { + if (mShouldLog) { + // prepend timestamp + std::time_t timestamp = std::time(nullptr); + tm* tm_local = std::localtime(×tamp); + std::cout << "["; + std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/"; + std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] "; + // std::stringbuf::str() gets the string contents of the buffer + // insert the buffer contents pre-appended by the appropriate prefix into the stream + mOutput << mPrefix << str(); + // set the buffer to empty + str(""); + // flush the stream + mOutput.flush(); + } + } + + void setShouldLog(bool shouldLog) { mShouldLog = shouldLog; } + + private: + std::ostream& mOutput; + std::string mPrefix; + bool mShouldLog; +}; + +//! +//! \class LogStreamConsumerBase +//! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer +//! +class LogStreamConsumerBase { + public: + LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog) + : mBuffer(stream, prefix, shouldLog) {} + + protected: + LogStreamConsumerBuffer mBuffer; +}; + +//! +//! \class LogStreamConsumer +//! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages. +//! Order of base classes is LogStreamConsumerBase and then std::ostream. +//! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field +//! in LogStreamConsumer and then the address of the buffer is passed to std::ostream. +//! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream. +//! Please do not change the order of the parent classes. +//! +class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream { + public: + //! \brief Creates a LogStreamConsumer which logs messages with level severity. + //! Reportable severity determines if the messages are severe enough to be logged. + LogStreamConsumer(Severity reportableSeverity, Severity severity) + : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity), + std::ostream(&mBuffer) // links the stream buffer with the stream + , + mShouldLog(severity <= reportableSeverity), + mSeverity(severity) {} + + LogStreamConsumer(LogStreamConsumer&& other) + : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog), + std::ostream(&mBuffer) // links the stream buffer with the stream + , + mShouldLog(other.mShouldLog), + mSeverity(other.mSeverity) {} + + void setReportableSeverity(Severity reportableSeverity) { + mShouldLog = mSeverity <= reportableSeverity; + mBuffer.setShouldLog(mShouldLog); + } + + private: + static std::ostream& severityOstream(Severity severity) { + return severity >= Severity::kINFO ? std::cout : std::cerr; + } + + static std::string severityPrefix(Severity severity) { + switch (severity) { + case Severity::kINTERNAL_ERROR: + return "[F] "; + case Severity::kERROR: + return "[E] "; + case Severity::kWARNING: + return "[W] "; + case Severity::kINFO: + return "[I] "; + case Severity::kVERBOSE: + return "[V] "; + default: + assert(0); + return ""; + } + } + + bool mShouldLog; + Severity mSeverity; +}; + +//! \class Logger +//! +//! \brief Class which manages logging of TensorRT tools and samples +//! +//! \details This class provides a common interface for TensorRT tools and samples to log information to the console, +//! and supports logging two types of messages: +//! +//! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal) +//! - Test pass/fail messages +//! +//! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is +//! that the logic for controlling the verbosity and formatting of sample output is centralized in one location. +//! +//! In the future, this class could be extended to support dumping test results to a file in some standard format +//! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run). +//! +//! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger +//! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT +//! library and messages coming from the sample. +//! +//! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the +//! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger +//! object. + +class Logger : public nvinfer1::ILogger { + public: + Logger(Severity severity = Severity::kWARNING) : mReportableSeverity(severity) {} + + //! + //! \enum TestResult + //! \brief Represents the state of a given test + //! + enum class TestResult { + kRUNNING, //!< The test is running + kPASSED, //!< The test passed + kFAILED, //!< The test failed + kWAIVED //!< The test was waived + }; + + //! + //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger + //! \return The nvinfer1::ILogger associated with this Logger + //! + //! TODO Once all samples are updated to use this method to register the logger with TensorRT, + //! we can eliminate the inheritance of Logger from ILogger + //! + nvinfer1::ILogger& getTRTLogger() { return *this; } + + //! + //! \brief Implementation of the nvinfer1::ILogger::log() virtual method + //! + //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the + //! inheritance from nvinfer1::ILogger + //! + void log(Severity severity, const char* msg) noexcept override { + LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl; + } + + //! + //! \brief Method for controlling the verbosity of logging output + //! + //! \param severity The logger will only emit messages that have severity of this level or higher. + //! + void setReportableSeverity(Severity severity) { mReportableSeverity = severity; } + + //! + //! \brief Opaque handle that holds logging information for a particular test + //! + //! This object is an opaque handle to information used by the Logger to print test results. + //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used + //! with Logger::reportTest{Start,End}(). + //! + class TestAtom { + public: + TestAtom(TestAtom&&) = default; + + private: + friend class Logger; + + TestAtom(bool started, const std::string& name, const std::string& cmdline) + : mStarted(started), mName(name), mCmdline(cmdline) {} + + bool mStarted; + std::string mName; + std::string mCmdline; + }; + + //! + //! \brief Define a test for logging + //! + //! \param[in] name The name of the test. This should be a string starting with + //! "TensorRT" and containing dot-separated strings containing + //! the characters [A-Za-z0-9_]. + //! For example, "TensorRT.sample_googlenet" + //! \param[in] cmdline The command line used to reproduce the test + // + //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). + //! + static TestAtom defineTest(const std::string& name, const std::string& cmdline) { + return TestAtom(false, name, cmdline); + } + + //! + //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments + //! as input + //! + //! \param[in] name The name of the test + //! \param[in] argc The number of command-line arguments + //! \param[in] argv The array of command-line arguments (given as C strings) + //! + //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). + static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) { + auto cmdline = genCmdlineString(argc, argv); + return defineTest(name, cmdline); + } + + //! + //! \brief Report that a test has started. + //! + //! \pre reportTestStart() has not been called yet for the given testAtom + //! + //! \param[in] testAtom The handle to the test that has started + //! + static void reportTestStart(TestAtom& testAtom) { + reportTestResult(testAtom, TestResult::kRUNNING); + assert(!testAtom.mStarted); + testAtom.mStarted = true; + } + + //! + //! \brief Report that a test has ended. + //! + //! \pre reportTestStart() has been called for the given testAtom + //! + //! \param[in] testAtom The handle to the test that has ended + //! \param[in] result The result of the test. Should be one of TestResult::kPASSED, + //! TestResult::kFAILED, TestResult::kWAIVED + //! + static void reportTestEnd(const TestAtom& testAtom, TestResult result) { + assert(result != TestResult::kRUNNING); + assert(testAtom.mStarted); + reportTestResult(testAtom, result); + } + + static int reportPass(const TestAtom& testAtom) { + reportTestEnd(testAtom, TestResult::kPASSED); + return EXIT_SUCCESS; + } + + static int reportFail(const TestAtom& testAtom) { + reportTestEnd(testAtom, TestResult::kFAILED); + return EXIT_FAILURE; + } + + static int reportWaive(const TestAtom& testAtom) { + reportTestEnd(testAtom, TestResult::kWAIVED); + return EXIT_SUCCESS; + } + + static int reportTest(const TestAtom& testAtom, bool pass) { + return pass ? reportPass(testAtom) : reportFail(testAtom); + } + + Severity getReportableSeverity() const { return mReportableSeverity; } + + private: + //! + //! \brief returns an appropriate string for prefixing a log message with the given severity + //! + static const char* severityPrefix(Severity severity) { + switch (severity) { + case Severity::kINTERNAL_ERROR: + return "[F] "; + case Severity::kERROR: + return "[E] "; + case Severity::kWARNING: + return "[W] "; + case Severity::kINFO: + return "[I] "; + case Severity::kVERBOSE: + return "[V] "; + default: + assert(0); + return ""; + } + } + + //! + //! \brief returns an appropriate string for prefixing a test result message with the given result + //! + static const char* testResultString(TestResult result) { + switch (result) { + case TestResult::kRUNNING: + return "RUNNING"; + case TestResult::kPASSED: + return "PASSED"; + case TestResult::kFAILED: + return "FAILED"; + case TestResult::kWAIVED: + return "WAIVED"; + default: + assert(0); + return ""; + } + } + + //! + //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity + //! + static std::ostream& severityOstream(Severity severity) { + return severity >= Severity::kINFO ? std::cout : std::cerr; + } + + //! + //! \brief method that implements logging test results + //! + static void reportTestResult(const TestAtom& testAtom, TestResult result) { + severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # " + << testAtom.mCmdline << std::endl; + } + + //! + //! \brief generate a command line string from the given (argc, argv) values + //! + static std::string genCmdlineString(int argc, char const* const* argv) { + std::stringstream ss; + for (int i = 0; i < argc; i++) { + if (i > 0) + ss << " "; + ss << argv[i]; + } + return ss.str(); + } + + Severity mReportableSeverity; +}; + +namespace { + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE +//! +//! Example usage: +//! +//! LOG_VERBOSE(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) { + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO +//! +//! Example usage: +//! +//! LOG_INFO(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_INFO(const Logger& logger) { + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING +//! +//! Example usage: +//! +//! LOG_WARN(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_WARN(const Logger& logger) { + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR +//! +//! Example usage: +//! +//! LOG_ERROR(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_ERROR(const Logger& logger) { + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR +// ("fatal" severity) +//! +//! Example usage: +//! +//! LOG_FATAL(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_FATAL(const Logger& logger) { + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR); +} + +} // anonymous namespace + +#endif // TENSORRT_LOGGING_H diff --git a/mobilenet/mobilenetv2/CMakeLists.txt b/mobilenet/mobilenetv2/CMakeLists.txt index 570561b4..f0e23e09 100644 --- a/mobilenet/mobilenetv2/CMakeLists.txt +++ b/mobilenet/mobilenetv2/CMakeLists.txt @@ -22,4 +22,3 @@ target_link_libraries(mobilenet nvinfer) target_link_libraries(mobilenet cudart) add_definitions(-O2 -pthread) - diff --git a/mobilenet/mobilenetv2/logging.h b/mobilenet/mobilenetv2/logging.h index 602b69fb..960135e4 100644 --- a/mobilenet/mobilenetv2/logging.h +++ b/mobilenet/mobilenetv2/logging.h @@ -17,7 +17,6 @@ #ifndef TENSORRT_LOGGING_H #define TENSORRT_LOGGING_H -#include "NvInferRuntimeCommon.h" #include #include #include @@ -25,32 +24,23 @@ #include #include #include +#include "NvInferRuntimeCommon.h" using Severity = nvinfer1::ILogger::Severity; -class LogStreamConsumerBuffer : public std::stringbuf -{ -public: +class LogStreamConsumerBuffer : public std::stringbuf { + public: LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog) - : mOutput(stream) - , mPrefix(prefix) - , mShouldLog(shouldLog) - { - } + : mOutput(stream), mPrefix(prefix), mShouldLog(shouldLog) {} - LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) - : mOutput(other.mOutput) - { - } + LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) : mOutput(other.mOutput) {} - ~LogStreamConsumerBuffer() - { + ~LogStreamConsumerBuffer() { // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence // std::streambuf::pptr() gives a pointer to the current position of the output sequence // if the pointer to the beginning is not equal to the pointer to the current position, // call putOutput() to log the output to the stream - if (pbase() != pptr()) - { + if (pbase() != pptr()) { putOutput(); } } @@ -58,16 +48,13 @@ class LogStreamConsumerBuffer : public std::stringbuf // synchronizes the stream buffer and returns 0 on success // synchronizing the stream buffer consists of inserting the buffer contents into the stream, // resetting the buffer and flushing the stream - virtual int sync() - { + virtual int sync() { putOutput(); return 0; } - void putOutput() - { - if (mShouldLog) - { + void putOutput() { + if (mShouldLog) { // prepend timestamp std::time_t timestamp = std::time(nullptr); tm* tm_local = std::localtime(×tamp); @@ -88,12 +75,9 @@ class LogStreamConsumerBuffer : public std::stringbuf } } - void setShouldLog(bool shouldLog) - { - mShouldLog = shouldLog; - } + void setShouldLog(bool shouldLog) { mShouldLog = shouldLog; } -private: + private: std::ostream& mOutput; std::string mPrefix; bool mShouldLog; @@ -103,15 +87,12 @@ class LogStreamConsumerBuffer : public std::stringbuf //! \class LogStreamConsumerBase //! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer //! -class LogStreamConsumerBase -{ -public: +class LogStreamConsumerBase { + public: LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog) - : mBuffer(stream, prefix, shouldLog) - { - } + : mBuffer(stream, prefix, shouldLog) {} -protected: + protected: LogStreamConsumerBuffer mBuffer; }; @@ -124,49 +105,49 @@ class LogStreamConsumerBase //! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream. //! Please do not change the order of the parent classes. //! -class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream -{ -public: +class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream { + public: //! \brief Creates a LogStreamConsumer which logs messages with level severity. //! Reportable severity determines if the messages are severe enough to be logged. LogStreamConsumer(Severity reportableSeverity, Severity severity) - : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity) - , std::ostream(&mBuffer) // links the stream buffer with the stream - , mShouldLog(severity <= reportableSeverity) - , mSeverity(severity) - { - } + : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity), + std::ostream(&mBuffer) // links the stream buffer with the stream + , + mShouldLog(severity <= reportableSeverity), + mSeverity(severity) {} LogStreamConsumer(LogStreamConsumer&& other) - : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog) - , std::ostream(&mBuffer) // links the stream buffer with the stream - , mShouldLog(other.mShouldLog) - , mSeverity(other.mSeverity) - { - } + : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog), + std::ostream(&mBuffer) // links the stream buffer with the stream + , + mShouldLog(other.mShouldLog), + mSeverity(other.mSeverity) {} - void setReportableSeverity(Severity reportableSeverity) - { + void setReportableSeverity(Severity reportableSeverity) { mShouldLog = mSeverity <= reportableSeverity; mBuffer.setShouldLog(mShouldLog); } -private: - static std::ostream& severityOstream(Severity severity) - { + private: + static std::ostream& severityOstream(Severity severity) { return severity >= Severity::kINFO ? std::cout : std::cerr; } - static std::string severityPrefix(Severity severity) - { - switch (severity) - { - case Severity::kINTERNAL_ERROR: return "[F] "; - case Severity::kERROR: return "[E] "; - case Severity::kWARNING: return "[W] "; - case Severity::kINFO: return "[I] "; - case Severity::kVERBOSE: return "[V] "; - default: assert(0); return ""; + static std::string severityPrefix(Severity severity) { + switch (severity) { + case Severity::kINTERNAL_ERROR: + return "[F] "; + case Severity::kERROR: + return "[E] "; + case Severity::kWARNING: + return "[W] "; + case Severity::kINFO: + return "[I] "; + case Severity::kVERBOSE: + return "[V] "; + default: + assert(0); + return ""; } } @@ -198,24 +179,19 @@ class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream //! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger //! object. -class Logger : public nvinfer1::ILogger -{ -public: - Logger(Severity severity = Severity::kWARNING) - : mReportableSeverity(severity) - { - } +class Logger : public nvinfer1::ILogger { + public: + Logger(Severity severity = Severity::kWARNING) : mReportableSeverity(severity) {} //! //! \enum TestResult //! \brief Represents the state of a given test //! - enum class TestResult - { - kRUNNING, //!< The test is running - kPASSED, //!< The test passed - kFAILED, //!< The test failed - kWAIVED //!< The test was waived + enum class TestResult { + kRUNNING, //!< The test is running + kPASSED, //!< The test passed + kFAILED, //!< The test failed + kWAIVED //!< The test was waived }; //! @@ -225,10 +201,7 @@ class Logger : public nvinfer1::ILogger //! TODO Once all samples are updated to use this method to register the logger with TensorRT, //! we can eliminate the inheritance of Logger from ILogger //! - nvinfer1::ILogger& getTRTLogger() - { - return *this; - } + nvinfer1::ILogger& getTRTLogger() { return *this; } //! //! \brief Implementation of the nvinfer1::ILogger::log() virtual method @@ -236,8 +209,7 @@ class Logger : public nvinfer1::ILogger //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the //! inheritance from nvinfer1::ILogger //! - void log(Severity severity, const char* msg) override - { + void log(Severity severity, const char* msg) override { LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl; } @@ -246,10 +218,7 @@ class Logger : public nvinfer1::ILogger //! //! \param severity The logger will only emit messages that have severity of this level or higher. //! - void setReportableSeverity(Severity severity) - { - mReportableSeverity = severity; - } + void setReportableSeverity(Severity severity) { mReportableSeverity = severity; } //! //! \brief Opaque handle that holds logging information for a particular test @@ -258,20 +227,15 @@ class Logger : public nvinfer1::ILogger //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used //! with Logger::reportTest{Start,End}(). //! - class TestAtom - { - public: + class TestAtom { + public: TestAtom(TestAtom&&) = default; - private: + private: friend class Logger; TestAtom(bool started, const std::string& name, const std::string& cmdline) - : mStarted(started) - , mName(name) - , mCmdline(cmdline) - { - } + : mStarted(started), mName(name), mCmdline(cmdline) {} bool mStarted; std::string mName; @@ -289,8 +253,7 @@ class Logger : public nvinfer1::ILogger // //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). //! - static TestAtom defineTest(const std::string& name, const std::string& cmdline) - { + static TestAtom defineTest(const std::string& name, const std::string& cmdline) { return TestAtom(false, name, cmdline); } @@ -303,8 +266,7 @@ class Logger : public nvinfer1::ILogger //! \param[in] argv The array of command-line arguments (given as C strings) //! //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). - static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) - { + static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) { auto cmdline = genCmdlineString(argc, argv); return defineTest(name, cmdline); } @@ -316,8 +278,7 @@ class Logger : public nvinfer1::ILogger //! //! \param[in] testAtom The handle to the test that has started //! - static void reportTestStart(TestAtom& testAtom) - { + static void reportTestStart(TestAtom& testAtom) { reportTestResult(testAtom, TestResult::kRUNNING); assert(!testAtom.mStarted); testAtom.mStarted = true; @@ -332,86 +293,85 @@ class Logger : public nvinfer1::ILogger //! \param[in] result The result of the test. Should be one of TestResult::kPASSED, //! TestResult::kFAILED, TestResult::kWAIVED //! - static void reportTestEnd(const TestAtom& testAtom, TestResult result) - { + static void reportTestEnd(const TestAtom& testAtom, TestResult result) { assert(result != TestResult::kRUNNING); assert(testAtom.mStarted); reportTestResult(testAtom, result); } - static int reportPass(const TestAtom& testAtom) - { + static int reportPass(const TestAtom& testAtom) { reportTestEnd(testAtom, TestResult::kPASSED); return EXIT_SUCCESS; } - static int reportFail(const TestAtom& testAtom) - { + static int reportFail(const TestAtom& testAtom) { reportTestEnd(testAtom, TestResult::kFAILED); return EXIT_FAILURE; } - static int reportWaive(const TestAtom& testAtom) - { + static int reportWaive(const TestAtom& testAtom) { reportTestEnd(testAtom, TestResult::kWAIVED); return EXIT_SUCCESS; } - static int reportTest(const TestAtom& testAtom, bool pass) - { + static int reportTest(const TestAtom& testAtom, bool pass) { return pass ? reportPass(testAtom) : reportFail(testAtom); } - Severity getReportableSeverity() const - { - return mReportableSeverity; - } + Severity getReportableSeverity() const { return mReportableSeverity; } -private: + private: //! //! \brief returns an appropriate string for prefixing a log message with the given severity //! - static const char* severityPrefix(Severity severity) - { - switch (severity) - { - case Severity::kINTERNAL_ERROR: return "[F] "; - case Severity::kERROR: return "[E] "; - case Severity::kWARNING: return "[W] "; - case Severity::kINFO: return "[I] "; - case Severity::kVERBOSE: return "[V] "; - default: assert(0); return ""; + static const char* severityPrefix(Severity severity) { + switch (severity) { + case Severity::kINTERNAL_ERROR: + return "[F] "; + case Severity::kERROR: + return "[E] "; + case Severity::kWARNING: + return "[W] "; + case Severity::kINFO: + return "[I] "; + case Severity::kVERBOSE: + return "[V] "; + default: + assert(0); + return ""; } } //! //! \brief returns an appropriate string for prefixing a test result message with the given result //! - static const char* testResultString(TestResult result) - { - switch (result) - { - case TestResult::kRUNNING: return "RUNNING"; - case TestResult::kPASSED: return "PASSED"; - case TestResult::kFAILED: return "FAILED"; - case TestResult::kWAIVED: return "WAIVED"; - default: assert(0); return ""; + static const char* testResultString(TestResult result) { + switch (result) { + case TestResult::kRUNNING: + return "RUNNING"; + case TestResult::kPASSED: + return "PASSED"; + case TestResult::kFAILED: + return "FAILED"; + case TestResult::kWAIVED: + return "WAIVED"; + default: + assert(0); + return ""; } } //! //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity //! - static std::ostream& severityOstream(Severity severity) - { + static std::ostream& severityOstream(Severity severity) { return severity >= Severity::kINFO ? std::cout : std::cerr; } //! //! \brief method that implements logging test results //! - static void reportTestResult(const TestAtom& testAtom, TestResult result) - { + static void reportTestResult(const TestAtom& testAtom, TestResult result) { severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # " << testAtom.mCmdline << std::endl; } @@ -419,11 +379,9 @@ class Logger : public nvinfer1::ILogger //! //! \brief generate a command line string from the given (argc, argv) values //! - static std::string genCmdlineString(int argc, char const* const* argv) - { + static std::string genCmdlineString(int argc, char const* const* argv) { std::stringstream ss; - for (int i = 0; i < argc; i++) - { + for (int i = 0; i < argc; i++) { if (i > 0) ss << " "; ss << argv[i]; @@ -434,8 +392,7 @@ class Logger : public nvinfer1::ILogger Severity mReportableSeverity; }; -namespace -{ +namespace { //! //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE @@ -444,8 +401,7 @@ namespace //! //! LOG_VERBOSE(logger) << "hello world" << std::endl; //! -inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) -{ +inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) { return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE); } @@ -456,8 +412,7 @@ inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) //! //! LOG_INFO(logger) << "hello world" << std::endl; //! -inline LogStreamConsumer LOG_INFO(const Logger& logger) -{ +inline LogStreamConsumer LOG_INFO(const Logger& logger) { return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO); } @@ -468,8 +423,7 @@ inline LogStreamConsumer LOG_INFO(const Logger& logger) //! //! LOG_WARN(logger) << "hello world" << std::endl; //! -inline LogStreamConsumer LOG_WARN(const Logger& logger) -{ +inline LogStreamConsumer LOG_WARN(const Logger& logger) { return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING); } @@ -480,8 +434,7 @@ inline LogStreamConsumer LOG_WARN(const Logger& logger) //! //! LOG_ERROR(logger) << "hello world" << std::endl; //! -inline LogStreamConsumer LOG_ERROR(const Logger& logger) -{ +inline LogStreamConsumer LOG_ERROR(const Logger& logger) { return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR); } @@ -493,11 +446,10 @@ inline LogStreamConsumer LOG_ERROR(const Logger& logger) //! //! LOG_FATAL(logger) << "hello world" << std::endl; //! -inline LogStreamConsumer LOG_FATAL(const Logger& logger) -{ +inline LogStreamConsumer LOG_FATAL(const Logger& logger) { return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR); } -} // anonymous namespace +} // anonymous namespace -#endif // TENSORRT_LOGGING_H +#endif // TENSORRT_LOGGING_H diff --git a/mobilenet/mobilenetv2/mobilenet_v2.cpp b/mobilenet/mobilenetv2/mobilenet_v2.cpp index 9c1627ca..f9efe2b6 100644 --- a/mobilenet/mobilenetv2/mobilenet_v2.cpp +++ b/mobilenet/mobilenetv2/mobilenet_v2.cpp @@ -1,23 +1,21 @@ -#include "NvInfer.h" -#include "cuda_runtime_api.h" -#include "logging.h" +#include +#include #include #include #include #include #include -#include -#include +#include "NvInfer.h" +#include "cuda_runtime_api.h" +#include "logging.h" -#define CHECK(status) \ - do\ - {\ - auto ret = (status);\ - if (ret != 0)\ - {\ - std::cerr << "Cuda failure: " << ret << std::endl;\ - abort();\ - }\ +#define CHECK(status) \ + do { \ + auto ret = (status); \ + if (ret != 0) { \ + std::cerr << "Cuda failure: " << ret << std::endl; \ + abort(); \ + } \ } while (0) // stuff we know about the network and the input/output blobs @@ -35,8 +33,7 @@ static Logger gLogger; // Load weights from files shared with TensorRT samples. // TensorRT weight files have a simple space delimited format: // [type] [size] -std::map loadWeights(const std::string file) -{ +std::map loadWeights(const std::string file) { std::cout << "Loading weights: " << file << std::endl; std::map weightMap; @@ -49,8 +46,7 @@ std::map loadWeights(const std::string file) input >> count; assert(count > 0 && "Invalid weight map file."); - while (count--) - { + while (count--) { Weights wt{DataType::kFLOAT, nullptr, 0}; uint32_t size; @@ -61,12 +57,11 @@ std::map loadWeights(const std::string file) // Load blob uint32_t* val = reinterpret_cast(malloc(sizeof(val) * size)); - for (uint32_t x = 0, y = size; x < y; ++x) - { + for (uint32_t x = 0, y = size; x < y; ++x) { input >> std::hex >> val[x]; } wt.values = val; - + wt.count = size; weightMap[name] = wt; } @@ -74,27 +69,28 @@ std::map loadWeights(const std::string file) return weightMap; } -IScaleLayer* addBatchNorm2d(INetworkDefinition *network, std::map& weightMap, ITensor& input, std::string lname, float eps) { - float *gamma = (float*)weightMap[lname + ".weight"].values; - float *beta = (float*)weightMap[lname + ".bias"].values; - float *mean = (float*)weightMap[lname + ".running_mean"].values; - float *var = (float*)weightMap[lname + ".running_var"].values; +IScaleLayer* addBatchNorm2d(INetworkDefinition* network, std::map& weightMap, ITensor& input, + std::string lname, float eps) { + float* gamma = (float*)weightMap[lname + ".weight"].values; + float* beta = (float*)weightMap[lname + ".bias"].values; + float* mean = (float*)weightMap[lname + ".running_mean"].values; + float* var = (float*)weightMap[lname + ".running_var"].values; int len = weightMap[lname + ".running_var"].count; std::cout << "len " << len << std::endl; - float *scval = reinterpret_cast(malloc(sizeof(float) * len)); + float* scval = reinterpret_cast(malloc(sizeof(float) * len)); for (int i = 0; i < len; i++) { scval[i] = gamma[i] / sqrt(var[i] + eps); } Weights scale{DataType::kFLOAT, scval, len}; - - float *shval = reinterpret_cast(malloc(sizeof(float) * len)); + + float* shval = reinterpret_cast(malloc(sizeof(float) * len)); for (int i = 0; i < len; i++) { shval[i] = beta[i] - mean[i] * gamma[i] / sqrt(var[i] + eps); } Weights shift{DataType::kFLOAT, shval, len}; - float *pval = reinterpret_cast(malloc(sizeof(float) * len)); + float* pval = reinterpret_cast(malloc(sizeof(float) * len)); for (int i = 0; i < len; i++) { pval[i] = 1.0; } @@ -108,10 +104,12 @@ IScaleLayer* addBatchNorm2d(INetworkDefinition *network, std::map& weightMap, ITensor& input, int outch, int ksize, int s, int g, std::string lname) { +IElementWiseLayer* convBnRelu(INetworkDefinition* network, std::map& weightMap, ITensor& input, + int outch, int ksize, int s, int g, std::string lname) { Weights emptywts{DataType::kFLOAT, nullptr, 0}; int p = (ksize - 1) / 2; - IConvolutionLayer* conv1 = network->addConvolutionNd(input, outch, DimsHW{ksize, ksize}, weightMap[lname + "0.weight"], emptywts); + IConvolutionLayer* conv1 = + network->addConvolutionNd(input, outch, DimsHW{ksize, ksize}, weightMap[lname + "0.weight"], emptywts); assert(conv1); conv1->setStrideNd(DimsHW{s, s}); conv1->setPaddingNd(DimsHW{p, p}); @@ -122,9 +120,9 @@ IElementWiseLayer* convBnRelu(INetworkDefinition *network, std::mapaddActivation(*bn1->getOutput(0), ActivationType::kRELU); assert(relu1); - float *shval = reinterpret_cast(malloc(sizeof(float) * 1)); - float *scval = reinterpret_cast(malloc(sizeof(float) * 1)); - float *pval = reinterpret_cast(malloc(sizeof(float) * 1)); + float* shval = reinterpret_cast(malloc(sizeof(float) * 1)); + float* scval = reinterpret_cast(malloc(sizeof(float) * 1)); + float* pval = reinterpret_cast(malloc(sizeof(float) * 1)); shval[0] = -6.0; scval[0] = 1.0; pval[0] = 1.0; @@ -140,39 +138,43 @@ IElementWiseLayer* convBnRelu(INetworkDefinition *network, std::mapaddActivation(*scale1->getOutput(0), ActivationType::kRELU); assert(relu2); - IElementWiseLayer* ew1 = network->addElementWise(*relu1->getOutput(0), *relu2->getOutput(0), ElementWiseOperation::kSUB); + IElementWiseLayer* ew1 = + network->addElementWise(*relu1->getOutput(0), *relu2->getOutput(0), ElementWiseOperation::kSUB); assert(ew1); return ew1; } -ILayer* invertedRes(INetworkDefinition *network, std::map& weightMap, ITensor& input, std::string lname, - int inch, int outch, int s, int exp) { +ILayer* invertedRes(INetworkDefinition* network, std::map& weightMap, ITensor& input, + std::string lname, int inch, int outch, int s, int exp) { Weights emptywts{DataType::kFLOAT, nullptr, 0}; int hidden = inch * exp; bool use_res_connect = (s == 1 && inch == outch); - IScaleLayer *bn1 = nullptr; + IScaleLayer* bn1 = nullptr; if (exp != 1) { IElementWiseLayer* ew1 = convBnRelu(network, weightMap, input, hidden, 1, 1, 1, lname + "conv.0."); - IElementWiseLayer* ew2 = convBnRelu(network, weightMap, *ew1->getOutput(0), hidden, 3, s, hidden, lname + "conv.1."); - IConvolutionLayer* conv1 = network->addConvolutionNd(*ew2->getOutput(0), outch, DimsHW{1, 1}, weightMap[lname + "conv.2.weight"], emptywts); + IElementWiseLayer* ew2 = + convBnRelu(network, weightMap, *ew1->getOutput(0), hidden, 3, s, hidden, lname + "conv.1."); + IConvolutionLayer* conv1 = network->addConvolutionNd(*ew2->getOutput(0), outch, DimsHW{1, 1}, + weightMap[lname + "conv.2.weight"], emptywts); assert(conv1); bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), lname + "conv.3", 1e-5); } else { IElementWiseLayer* ew1 = convBnRelu(network, weightMap, input, hidden, 3, s, hidden, lname + "conv.0."); - IConvolutionLayer* conv1 = network->addConvolutionNd(*ew1->getOutput(0), outch, DimsHW{1, 1}, weightMap[lname + "conv.1.weight"], emptywts); + IConvolutionLayer* conv1 = network->addConvolutionNd(*ew1->getOutput(0), outch, DimsHW{1, 1}, + weightMap[lname + "conv.1.weight"], emptywts); assert(conv1); bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), lname + "conv.2", 1e-5); } - if (!use_res_connect) return bn1; + if (!use_res_connect) + return bn1; IElementWiseLayer* ew3 = network->addElementWise(input, *bn1->getOutput(0), ElementWiseOperation::kSUM); assert(ew3); return ew3; } // Creat the engine using only the API and not any parser. -ICudaEngine* createEngine(unsigned int maxBatchSize, IBuilder* builder, IBuilderConfig* config, DataType dt) -{ +ICudaEngine* createEngine(unsigned int maxBatchSize, IBuilder* builder, IBuilderConfig* config, DataType dt) { INetworkDefinition* network = builder->createNetworkV2(0U); // Create input tensor of shape { 3, INPUT_H, INPUT_W } with name INPUT_BLOB_NAME @@ -205,7 +207,8 @@ ICudaEngine* createEngine(unsigned int maxBatchSize, IBuilder* builder, IBuilder IPoolingLayer* pool1 = network->addPoolingNd(*ew2->getOutput(0), PoolingType::kAVERAGE, DimsHW{7, 7}); assert(pool1); - IFullyConnectedLayer* fc1 = network->addFullyConnected(*pool1->getOutput(0), 1000, weightMap["classifier.1.weight"], weightMap["classifier.1.bias"]); + IFullyConnectedLayer* fc1 = network->addFullyConnected(*pool1->getOutput(0), 1000, weightMap["classifier.1.weight"], + weightMap["classifier.1.bias"]); assert(fc1); fc1->getOutput(0)->setName(OUTPUT_BLOB_NAME); @@ -222,16 +225,14 @@ ICudaEngine* createEngine(unsigned int maxBatchSize, IBuilder* builder, IBuilder network->destroy(); // Release host memory - for (auto& mem : weightMap) - { - free((void*) (mem.second.values)); + for (auto& mem : weightMap) { + free((void*)(mem.second.values)); } return engine; } -void APIToModel(unsigned int maxBatchSize, IHostMemory** modelStream) -{ +void APIToModel(unsigned int maxBatchSize, IHostMemory** modelStream) { // Create builder IBuilder* builder = createInferBuilder(gLogger); IBuilderConfig* config = builder->createBuilderConfig(); @@ -245,12 +246,11 @@ void APIToModel(unsigned int maxBatchSize, IHostMemory** modelStream) // Close everything down engine->destroy(); - builder->destroy(); config->destroy(); + builder->destroy(); } -void doInference(IExecutionContext& context, float* input, float* output, int batchSize) -{ +void doInference(IExecutionContext& context, float* input, float* output, int batchSize) { const ICudaEngine& engine = context.getEngine(); // Pointers to input and output device buffers to pass to engine. @@ -272,9 +272,11 @@ void doInference(IExecutionContext& context, float* input, float* output, int ba CHECK(cudaStreamCreate(&stream)); // DMA input batch data to device, infer on the batch asynchronously, and DMA output back to host - CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize * 3 * INPUT_H * INPUT_W * sizeof(float), cudaMemcpyHostToDevice, stream)); + CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize * 3 * INPUT_H * INPUT_W * sizeof(float), + cudaMemcpyHostToDevice, stream)); context.enqueue(batchSize, buffers, stream, nullptr); - CHECK(cudaMemcpyAsync(output, buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream)); + CHECK(cudaMemcpyAsync(output, buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, + stream)); cudaStreamSynchronize(stream); // Release stream and buffers @@ -283,8 +285,7 @@ void doInference(IExecutionContext& context, float* input, float* output, int ba CHECK(cudaFree(buffers[outputIndex])); } -int main(int argc, char** argv) -{ +int main(int argc, char** argv) { if (argc != 2) { std::cerr << "arguments not right!" << std::endl; std::cerr << "./mobilenet -s // serialize model to plan file" << std::endl; @@ -293,7 +294,7 @@ int main(int argc, char** argv) } // create a model using the API directly and serialize it to a stream - char *trtModelStream{nullptr}; + char* trtModelStream{nullptr}; size_t size{0}; if (std::string(argv[1]) == "-s") { @@ -302,8 +303,7 @@ int main(int argc, char** argv) assert(modelStream != nullptr); std::ofstream p("mobilenet.engine", std::ios::binary); - if (!p) - { + if (!p) { std::cerr << "could not open plan output file" << std::endl; return -1; } @@ -325,7 +325,6 @@ int main(int argc, char** argv) return -1; } - // Subtract mean from image static float data[3 * INPUT_H * INPUT_W]; for (int i = 0; i < 3 * INPUT_H * INPUT_W; i++) @@ -355,10 +354,10 @@ int main(int argc, char** argv) // Print histogram of the output distribution std::cout << "\nOutput:\n\n"; - for (unsigned int i = 0; i < OUTPUT_SIZE; i++) - { + for (unsigned int i = 0; i < OUTPUT_SIZE; i++) { std::cout << prob[i] << ", "; - if (i % 10 == 0) std::cout << i / 10 << std::endl; + if (i % 10 == 0) + std::cout << i / 10 << std::endl; } std::cout << std::endl; diff --git a/mobilenet/mobilenetv2/mobilenet_v2.py b/mobilenet/mobilenetv2/mobilenet_v2.py index 3cabdc21..8f0e5c98 100644 --- a/mobilenet/mobilenetv2/mobilenet_v2.py +++ b/mobilenet/mobilenetv2/mobilenet_v2.py @@ -4,8 +4,8 @@ import argparse import numpy as np -import pycuda.autoinit import pycuda.driver as cuda +import pycuda.autoinit # noqa: F401 import tensorrt as trt BATCH_SIZE = 1 diff --git a/mobilenet/mobilenetv3/CMakeLists.txt b/mobilenet/mobilenetv3/CMakeLists.txt index 03dde4f0..d44f07b3 100644 --- a/mobilenet/mobilenetv3/CMakeLists.txt +++ b/mobilenet/mobilenetv3/CMakeLists.txt @@ -24,4 +24,3 @@ target_link_libraries(mobilenetv3 nvinfer) target_link_libraries(mobilenetv3 cudart) add_definitions(-O2 -pthread) - diff --git a/mobilenet/mobilenetv3/logging.h b/mobilenet/mobilenetv3/logging.h index 602b69fb..960135e4 100644 --- a/mobilenet/mobilenetv3/logging.h +++ b/mobilenet/mobilenetv3/logging.h @@ -17,7 +17,6 @@ #ifndef TENSORRT_LOGGING_H #define TENSORRT_LOGGING_H -#include "NvInferRuntimeCommon.h" #include #include #include @@ -25,32 +24,23 @@ #include #include #include +#include "NvInferRuntimeCommon.h" using Severity = nvinfer1::ILogger::Severity; -class LogStreamConsumerBuffer : public std::stringbuf -{ -public: +class LogStreamConsumerBuffer : public std::stringbuf { + public: LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog) - : mOutput(stream) - , mPrefix(prefix) - , mShouldLog(shouldLog) - { - } + : mOutput(stream), mPrefix(prefix), mShouldLog(shouldLog) {} - LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) - : mOutput(other.mOutput) - { - } + LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) : mOutput(other.mOutput) {} - ~LogStreamConsumerBuffer() - { + ~LogStreamConsumerBuffer() { // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence // std::streambuf::pptr() gives a pointer to the current position of the output sequence // if the pointer to the beginning is not equal to the pointer to the current position, // call putOutput() to log the output to the stream - if (pbase() != pptr()) - { + if (pbase() != pptr()) { putOutput(); } } @@ -58,16 +48,13 @@ class LogStreamConsumerBuffer : public std::stringbuf // synchronizes the stream buffer and returns 0 on success // synchronizing the stream buffer consists of inserting the buffer contents into the stream, // resetting the buffer and flushing the stream - virtual int sync() - { + virtual int sync() { putOutput(); return 0; } - void putOutput() - { - if (mShouldLog) - { + void putOutput() { + if (mShouldLog) { // prepend timestamp std::time_t timestamp = std::time(nullptr); tm* tm_local = std::localtime(×tamp); @@ -88,12 +75,9 @@ class LogStreamConsumerBuffer : public std::stringbuf } } - void setShouldLog(bool shouldLog) - { - mShouldLog = shouldLog; - } + void setShouldLog(bool shouldLog) { mShouldLog = shouldLog; } -private: + private: std::ostream& mOutput; std::string mPrefix; bool mShouldLog; @@ -103,15 +87,12 @@ class LogStreamConsumerBuffer : public std::stringbuf //! \class LogStreamConsumerBase //! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer //! -class LogStreamConsumerBase -{ -public: +class LogStreamConsumerBase { + public: LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog) - : mBuffer(stream, prefix, shouldLog) - { - } + : mBuffer(stream, prefix, shouldLog) {} -protected: + protected: LogStreamConsumerBuffer mBuffer; }; @@ -124,49 +105,49 @@ class LogStreamConsumerBase //! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream. //! Please do not change the order of the parent classes. //! -class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream -{ -public: +class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream { + public: //! \brief Creates a LogStreamConsumer which logs messages with level severity. //! Reportable severity determines if the messages are severe enough to be logged. LogStreamConsumer(Severity reportableSeverity, Severity severity) - : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity) - , std::ostream(&mBuffer) // links the stream buffer with the stream - , mShouldLog(severity <= reportableSeverity) - , mSeverity(severity) - { - } + : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity), + std::ostream(&mBuffer) // links the stream buffer with the stream + , + mShouldLog(severity <= reportableSeverity), + mSeverity(severity) {} LogStreamConsumer(LogStreamConsumer&& other) - : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog) - , std::ostream(&mBuffer) // links the stream buffer with the stream - , mShouldLog(other.mShouldLog) - , mSeverity(other.mSeverity) - { - } + : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog), + std::ostream(&mBuffer) // links the stream buffer with the stream + , + mShouldLog(other.mShouldLog), + mSeverity(other.mSeverity) {} - void setReportableSeverity(Severity reportableSeverity) - { + void setReportableSeverity(Severity reportableSeverity) { mShouldLog = mSeverity <= reportableSeverity; mBuffer.setShouldLog(mShouldLog); } -private: - static std::ostream& severityOstream(Severity severity) - { + private: + static std::ostream& severityOstream(Severity severity) { return severity >= Severity::kINFO ? std::cout : std::cerr; } - static std::string severityPrefix(Severity severity) - { - switch (severity) - { - case Severity::kINTERNAL_ERROR: return "[F] "; - case Severity::kERROR: return "[E] "; - case Severity::kWARNING: return "[W] "; - case Severity::kINFO: return "[I] "; - case Severity::kVERBOSE: return "[V] "; - default: assert(0); return ""; + static std::string severityPrefix(Severity severity) { + switch (severity) { + case Severity::kINTERNAL_ERROR: + return "[F] "; + case Severity::kERROR: + return "[E] "; + case Severity::kWARNING: + return "[W] "; + case Severity::kINFO: + return "[I] "; + case Severity::kVERBOSE: + return "[V] "; + default: + assert(0); + return ""; } } @@ -198,24 +179,19 @@ class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream //! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger //! object. -class Logger : public nvinfer1::ILogger -{ -public: - Logger(Severity severity = Severity::kWARNING) - : mReportableSeverity(severity) - { - } +class Logger : public nvinfer1::ILogger { + public: + Logger(Severity severity = Severity::kWARNING) : mReportableSeverity(severity) {} //! //! \enum TestResult //! \brief Represents the state of a given test //! - enum class TestResult - { - kRUNNING, //!< The test is running - kPASSED, //!< The test passed - kFAILED, //!< The test failed - kWAIVED //!< The test was waived + enum class TestResult { + kRUNNING, //!< The test is running + kPASSED, //!< The test passed + kFAILED, //!< The test failed + kWAIVED //!< The test was waived }; //! @@ -225,10 +201,7 @@ class Logger : public nvinfer1::ILogger //! TODO Once all samples are updated to use this method to register the logger with TensorRT, //! we can eliminate the inheritance of Logger from ILogger //! - nvinfer1::ILogger& getTRTLogger() - { - return *this; - } + nvinfer1::ILogger& getTRTLogger() { return *this; } //! //! \brief Implementation of the nvinfer1::ILogger::log() virtual method @@ -236,8 +209,7 @@ class Logger : public nvinfer1::ILogger //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the //! inheritance from nvinfer1::ILogger //! - void log(Severity severity, const char* msg) override - { + void log(Severity severity, const char* msg) override { LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl; } @@ -246,10 +218,7 @@ class Logger : public nvinfer1::ILogger //! //! \param severity The logger will only emit messages that have severity of this level or higher. //! - void setReportableSeverity(Severity severity) - { - mReportableSeverity = severity; - } + void setReportableSeverity(Severity severity) { mReportableSeverity = severity; } //! //! \brief Opaque handle that holds logging information for a particular test @@ -258,20 +227,15 @@ class Logger : public nvinfer1::ILogger //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used //! with Logger::reportTest{Start,End}(). //! - class TestAtom - { - public: + class TestAtom { + public: TestAtom(TestAtom&&) = default; - private: + private: friend class Logger; TestAtom(bool started, const std::string& name, const std::string& cmdline) - : mStarted(started) - , mName(name) - , mCmdline(cmdline) - { - } + : mStarted(started), mName(name), mCmdline(cmdline) {} bool mStarted; std::string mName; @@ -289,8 +253,7 @@ class Logger : public nvinfer1::ILogger // //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). //! - static TestAtom defineTest(const std::string& name, const std::string& cmdline) - { + static TestAtom defineTest(const std::string& name, const std::string& cmdline) { return TestAtom(false, name, cmdline); } @@ -303,8 +266,7 @@ class Logger : public nvinfer1::ILogger //! \param[in] argv The array of command-line arguments (given as C strings) //! //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). - static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) - { + static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) { auto cmdline = genCmdlineString(argc, argv); return defineTest(name, cmdline); } @@ -316,8 +278,7 @@ class Logger : public nvinfer1::ILogger //! //! \param[in] testAtom The handle to the test that has started //! - static void reportTestStart(TestAtom& testAtom) - { + static void reportTestStart(TestAtom& testAtom) { reportTestResult(testAtom, TestResult::kRUNNING); assert(!testAtom.mStarted); testAtom.mStarted = true; @@ -332,86 +293,85 @@ class Logger : public nvinfer1::ILogger //! \param[in] result The result of the test. Should be one of TestResult::kPASSED, //! TestResult::kFAILED, TestResult::kWAIVED //! - static void reportTestEnd(const TestAtom& testAtom, TestResult result) - { + static void reportTestEnd(const TestAtom& testAtom, TestResult result) { assert(result != TestResult::kRUNNING); assert(testAtom.mStarted); reportTestResult(testAtom, result); } - static int reportPass(const TestAtom& testAtom) - { + static int reportPass(const TestAtom& testAtom) { reportTestEnd(testAtom, TestResult::kPASSED); return EXIT_SUCCESS; } - static int reportFail(const TestAtom& testAtom) - { + static int reportFail(const TestAtom& testAtom) { reportTestEnd(testAtom, TestResult::kFAILED); return EXIT_FAILURE; } - static int reportWaive(const TestAtom& testAtom) - { + static int reportWaive(const TestAtom& testAtom) { reportTestEnd(testAtom, TestResult::kWAIVED); return EXIT_SUCCESS; } - static int reportTest(const TestAtom& testAtom, bool pass) - { + static int reportTest(const TestAtom& testAtom, bool pass) { return pass ? reportPass(testAtom) : reportFail(testAtom); } - Severity getReportableSeverity() const - { - return mReportableSeverity; - } + Severity getReportableSeverity() const { return mReportableSeverity; } -private: + private: //! //! \brief returns an appropriate string for prefixing a log message with the given severity //! - static const char* severityPrefix(Severity severity) - { - switch (severity) - { - case Severity::kINTERNAL_ERROR: return "[F] "; - case Severity::kERROR: return "[E] "; - case Severity::kWARNING: return "[W] "; - case Severity::kINFO: return "[I] "; - case Severity::kVERBOSE: return "[V] "; - default: assert(0); return ""; + static const char* severityPrefix(Severity severity) { + switch (severity) { + case Severity::kINTERNAL_ERROR: + return "[F] "; + case Severity::kERROR: + return "[E] "; + case Severity::kWARNING: + return "[W] "; + case Severity::kINFO: + return "[I] "; + case Severity::kVERBOSE: + return "[V] "; + default: + assert(0); + return ""; } } //! //! \brief returns an appropriate string for prefixing a test result message with the given result //! - static const char* testResultString(TestResult result) - { - switch (result) - { - case TestResult::kRUNNING: return "RUNNING"; - case TestResult::kPASSED: return "PASSED"; - case TestResult::kFAILED: return "FAILED"; - case TestResult::kWAIVED: return "WAIVED"; - default: assert(0); return ""; + static const char* testResultString(TestResult result) { + switch (result) { + case TestResult::kRUNNING: + return "RUNNING"; + case TestResult::kPASSED: + return "PASSED"; + case TestResult::kFAILED: + return "FAILED"; + case TestResult::kWAIVED: + return "WAIVED"; + default: + assert(0); + return ""; } } //! //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity //! - static std::ostream& severityOstream(Severity severity) - { + static std::ostream& severityOstream(Severity severity) { return severity >= Severity::kINFO ? std::cout : std::cerr; } //! //! \brief method that implements logging test results //! - static void reportTestResult(const TestAtom& testAtom, TestResult result) - { + static void reportTestResult(const TestAtom& testAtom, TestResult result) { severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # " << testAtom.mCmdline << std::endl; } @@ -419,11 +379,9 @@ class Logger : public nvinfer1::ILogger //! //! \brief generate a command line string from the given (argc, argv) values //! - static std::string genCmdlineString(int argc, char const* const* argv) - { + static std::string genCmdlineString(int argc, char const* const* argv) { std::stringstream ss; - for (int i = 0; i < argc; i++) - { + for (int i = 0; i < argc; i++) { if (i > 0) ss << " "; ss << argv[i]; @@ -434,8 +392,7 @@ class Logger : public nvinfer1::ILogger Severity mReportableSeverity; }; -namespace -{ +namespace { //! //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE @@ -444,8 +401,7 @@ namespace //! //! LOG_VERBOSE(logger) << "hello world" << std::endl; //! -inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) -{ +inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) { return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE); } @@ -456,8 +412,7 @@ inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) //! //! LOG_INFO(logger) << "hello world" << std::endl; //! -inline LogStreamConsumer LOG_INFO(const Logger& logger) -{ +inline LogStreamConsumer LOG_INFO(const Logger& logger) { return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO); } @@ -468,8 +423,7 @@ inline LogStreamConsumer LOG_INFO(const Logger& logger) //! //! LOG_WARN(logger) << "hello world" << std::endl; //! -inline LogStreamConsumer LOG_WARN(const Logger& logger) -{ +inline LogStreamConsumer LOG_WARN(const Logger& logger) { return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING); } @@ -480,8 +434,7 @@ inline LogStreamConsumer LOG_WARN(const Logger& logger) //! //! LOG_ERROR(logger) << "hello world" << std::endl; //! -inline LogStreamConsumer LOG_ERROR(const Logger& logger) -{ +inline LogStreamConsumer LOG_ERROR(const Logger& logger) { return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR); } @@ -493,11 +446,10 @@ inline LogStreamConsumer LOG_ERROR(const Logger& logger) //! //! LOG_FATAL(logger) << "hello world" << std::endl; //! -inline LogStreamConsumer LOG_FATAL(const Logger& logger) -{ +inline LogStreamConsumer LOG_FATAL(const Logger& logger) { return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR); } -} // anonymous namespace +} // anonymous namespace -#endif // TENSORRT_LOGGING_H +#endif // TENSORRT_LOGGING_H diff --git a/mobilenet/mobilenetv3/mobilenet_v3.cpp b/mobilenet/mobilenetv3/mobilenet_v3.cpp index 19eeeb9e..fb00108d 100644 --- a/mobilenet/mobilenetv3/mobilenet_v3.cpp +++ b/mobilenet/mobilenetv3/mobilenet_v3.cpp @@ -1,23 +1,21 @@ -#include "NvInfer.h" -#include "cuda_runtime_api.h" -#include "logging.h" +#include +#include #include #include #include #include #include -#include -#include +#include "NvInfer.h" +#include "cuda_runtime_api.h" +#include "logging.h" -#define CHECK(status) \ - do\ - {\ - auto ret = (status);\ - if (ret != 0)\ - {\ - std::cerr << "Cuda failure: " << ret << std::endl;\ - abort();\ - }\ +#define CHECK(status) \ + do { \ + auto ret = (status); \ + if (ret != 0) { \ + std::cerr << "Cuda failure: " << ret << std::endl; \ + abort(); \ + } \ } while (0) // stuff we know about the network and the input/output blobs @@ -36,8 +34,7 @@ static Logger gLogger; // Load weights from files shared with TensorRT samples. // TensorRT weight files have a simple space delimited format: // [type] [size] -std::map loadWeights(const std::string file) -{ +std::map loadWeights(const std::string file) { std::cout << "Loading weights: " << file << std::endl; std::map weightMap; @@ -50,8 +47,7 @@ std::map loadWeights(const std::string file) input >> count; assert(count > 0 && "Invalid weight map file."); - while (count--) - { + while (count--) { Weights wt{DataType::kFLOAT, nullptr, 0}; uint32_t size; @@ -62,12 +58,11 @@ std::map loadWeights(const std::string file) // Load blob uint32_t* val = reinterpret_cast(malloc(sizeof(val) * size)); - for (uint32_t x = 0, y = size; x < y; ++x) - { + for (uint32_t x = 0, y = size; x < y; ++x) { input >> std::hex >> val[x]; } wt.values = val; - + wt.count = size; weightMap[name] = wt; } @@ -75,27 +70,28 @@ std::map loadWeights(const std::string file) return weightMap; } -IScaleLayer* addBatchNorm(INetworkDefinition *network, std::map& weightMap, ITensor& input, std::string lname, float eps) { - float *gamma = (float*)weightMap[lname + ".weight"].values; - float *beta = (float*)weightMap[lname + ".bias"].values; - float *mean = (float*)weightMap[lname + ".running_mean"].values; - float *var = (float*)weightMap[lname + ".running_var"].values; +IScaleLayer* addBatchNorm(INetworkDefinition* network, std::map& weightMap, ITensor& input, + std::string lname, float eps) { + float* gamma = (float*)weightMap[lname + ".weight"].values; + float* beta = (float*)weightMap[lname + ".bias"].values; + float* mean = (float*)weightMap[lname + ".running_mean"].values; + float* var = (float*)weightMap[lname + ".running_var"].values; int len = weightMap[lname + ".running_var"].count; std::cout << "len " << len << std::endl; - float *scval = reinterpret_cast(malloc(sizeof(float) * len)); + float* scval = reinterpret_cast(malloc(sizeof(float) * len)); for (int i = 0; i < len; i++) { scval[i] = gamma[i] / sqrt(var[i] + eps); } Weights scale{DataType::kFLOAT, scval, len}; - - float *shval = reinterpret_cast(malloc(sizeof(float) * len)); + + float* shval = reinterpret_cast(malloc(sizeof(float) * len)); for (int i = 0; i < len; i++) { shval[i] = beta[i] - mean[i] * gamma[i] / sqrt(var[i] + eps); } Weights shift{DataType::kFLOAT, shval, len}; - float *pval = reinterpret_cast(malloc(sizeof(float) * len)); + float* pval = reinterpret_cast(malloc(sizeof(float) * len)); for (int i = 0; i < len; i++) { pval[i] = 1.0; } @@ -109,39 +105,44 @@ IScaleLayer* addBatchNorm(INetworkDefinition *network, std::mapaddActivation(input, ActivationType::kHARD_SIGMOID); assert(hsig); hsig->setAlpha(1.0 / 6.0); hsig->setBeta(0.5); - ILayer* hsw = network->addElementWise(input, *hsig->getOutput(0),ElementWiseOperation::kPROD); + ILayer* hsw = network->addElementWise(input, *hsig->getOutput(0), ElementWiseOperation::kPROD); assert(hsw); return hsw; } -ILayer* convBnHswish(INetworkDefinition *network, std::map& weightMap, ITensor& input, int outch, int ksize, int s, int g, std::string lname) { +ILayer* convBnHswish(INetworkDefinition* network, std::map& weightMap, ITensor& input, int outch, + int ksize, int s, int g, std::string lname) { Weights emptywts{DataType::kFLOAT, nullptr, 0}; int p = (ksize - 1) / 2; - IConvolutionLayer* conv1 = network->addConvolutionNd(input, outch, DimsHW{ksize, ksize}, weightMap[lname + "0.weight"], emptywts); + IConvolutionLayer* conv1 = + network->addConvolutionNd(input, outch, DimsHW{ksize, ksize}, weightMap[lname + "0.weight"], emptywts); assert(conv1); conv1->setStrideNd(DimsHW{s, s}); conv1->setPaddingNd(DimsHW{p, p}); conv1->setNbGroups(g); IScaleLayer* bn1 = addBatchNorm(network, weightMap, *conv1->getOutput(0), lname + "1", 1e-5); - ILayer* hsw = hSwish(network, *bn1->getOutput(0), lname+"2"); + ILayer* hsw = hSwish(network, *bn1->getOutput(0), lname + "2"); assert(hsw); return hsw; } -ILayer* seLayer(INetworkDefinition *network, std::map& weightMap, ITensor& input, int c, int w, std::string lname) { +ILayer* seLayer(INetworkDefinition* network, std::map& weightMap, ITensor& input, int c, int w, + std::string lname) { int h = w; IPoolingLayer* l1 = network->addPoolingNd(input, PoolingType::kAVERAGE, DimsHW(w, h)); assert(l1); l1->setStrideNd(DimsHW{w, h}); - IFullyConnectedLayer* l2 = network->addFullyConnected(*l1->getOutput(0), BS*c/4, weightMap[lname+"fc.0.weight"], weightMap[lname+"fc.0.bias"]); + IFullyConnectedLayer* l2 = network->addFullyConnected( + *l1->getOutput(0), BS * c / 4, weightMap[lname + "fc.0.weight"], weightMap[lname + "fc.0.bias"]); IActivationLayer* relu1 = network->addActivation(*l2->getOutput(0), ActivationType::kRELU); - IFullyConnectedLayer* l4 = network->addFullyConnected(*relu1->getOutput(0), BS*c, weightMap[lname+"fc.2.weight"], weightMap[lname+"fc.2.bias"]); + IFullyConnectedLayer* l4 = network->addFullyConnected( + *relu1->getOutput(0), BS * c, weightMap[lname + "fc.2.weight"], weightMap[lname + "fc.2.bias"]); auto hsig = network->addActivation(*l4->getOutput(0), ActivationType::kHARD_SIGMOID); assert(hsig); @@ -153,10 +154,12 @@ ILayer* seLayer(INetworkDefinition *network, std::map& wei return se; } -ILayer* convSeq1(INetworkDefinition *network, std::map& weightMap, ITensor& input, int output, int hdim, int k, int s, bool use_se, bool use_hs, int w, std::string lname) { +ILayer* convSeq1(INetworkDefinition* network, std::map& weightMap, ITensor& input, int output, + int hdim, int k, int s, bool use_se, bool use_hs, int w, std::string lname) { Weights emptywts{DataType::kFLOAT, nullptr, 0}; int p = (k - 1) / 2; - IConvolutionLayer* conv1 = network->addConvolutionNd(input, hdim, DimsHW{k, k}, weightMap[lname + "0.weight"], emptywts); + IConvolutionLayer* conv1 = + network->addConvolutionNd(input, hdim, DimsHW{k, k}, weightMap[lname + "0.weight"], emptywts); conv1->setStrideNd(DimsHW{s, s}); conv1->setPaddingNd(DimsHW{p, p}); conv1->setNbGroups(hdim); @@ -166,28 +169,31 @@ ILayer* convSeq1(INetworkDefinition *network, std::map& we tensor3 = nullptr; tensor4 = nullptr; if (use_hs) { - ILayer* hsw = hSwish(network, *bn1->getOutput(0), lname+"2"); + ILayer* hsw = hSwish(network, *bn1->getOutput(0), lname + "2"); tensor3 = hsw->getOutput(0); } else { IActivationLayer* relu1 = network->addActivation(*bn1->getOutput(0), ActivationType::kRELU); tensor3 = relu1->getOutput(0); } if (use_se) { - ILayer* se1 = seLayer(network, weightMap, *tensor3, hdim, w, lname + "3."); - tensor4 = se1->getOutput(0); + ILayer* se1 = seLayer(network, weightMap, *tensor3, hdim, w, lname + "3."); + tensor4 = se1->getOutput(0); } else { - tensor4 = tensor3; + tensor4 = tensor3; } - IConvolutionLayer* conv2 = network->addConvolutionNd(*tensor4, output, DimsHW{1, 1}, weightMap[lname + "4.weight"], emptywts); + IConvolutionLayer* conv2 = + network->addConvolutionNd(*tensor4, output, DimsHW{1, 1}, weightMap[lname + "4.weight"], emptywts); IScaleLayer* bn2 = addBatchNorm(network, weightMap, *conv2->getOutput(0), lname + "5", 1e-5); assert(bn2); return bn2; } -ILayer* convSeq2(INetworkDefinition *network, std::map& weightMap, ITensor& input, int output, int hdim, int k, int s, bool use_se, bool use_hs, int w, std::string lname) { +ILayer* convSeq2(INetworkDefinition* network, std::map& weightMap, ITensor& input, int output, + int hdim, int k, int s, bool use_se, bool use_hs, int w, std::string lname) { Weights emptywts{DataType::kFLOAT, nullptr, 0}; int p = (k - 1) / 2; - IConvolutionLayer* conv1 = network->addConvolutionNd(input, hdim, DimsHW{1, 1}, weightMap[lname + "0.weight"], emptywts); + IConvolutionLayer* conv1 = + network->addConvolutionNd(input, hdim, DimsHW{1, 1}, weightMap[lname + "0.weight"], emptywts); IScaleLayer* bn1 = addBatchNorm(network, weightMap, *conv1->getOutput(0), lname + "1", 1e-5); ITensor *tensor3, *tensor6, *tensor7; tensor3 = nullptr; @@ -200,16 +206,17 @@ ILayer* convSeq2(INetworkDefinition *network, std::map& we IActivationLayer* relu1 = network->addActivation(*bn1->getOutput(0), ActivationType::kRELU); tensor3 = relu1->getOutput(0); } - IConvolutionLayer* conv2 = network->addConvolutionNd(*tensor3, hdim, DimsHW{k, k}, weightMap[lname + "3.weight"], emptywts); + IConvolutionLayer* conv2 = + network->addConvolutionNd(*tensor3, hdim, DimsHW{k, k}, weightMap[lname + "3.weight"], emptywts); conv2->setStrideNd(DimsHW{s, s}); conv2->setPaddingNd(DimsHW{p, p}); conv2->setNbGroups(hdim); IScaleLayer* bn2 = addBatchNorm(network, weightMap, *conv2->getOutput(0), lname + "4", 1e-5); if (use_se) { - ILayer* se1 = seLayer(network, weightMap, *bn2->getOutput(0), hdim, w, lname + "5."); - tensor6 = se1->getOutput(0); + ILayer* se1 = seLayer(network, weightMap, *bn2->getOutput(0), hdim, w, lname + "5."); + tensor6 = se1->getOutput(0); } else { - tensor6 = bn2->getOutput(0); + tensor6 = bn2->getOutput(0); } if (use_hs) { ILayer* hsw2 = hSwish(network, *tensor6, lname + "6"); @@ -218,30 +225,32 @@ ILayer* convSeq2(INetworkDefinition *network, std::map& we IActivationLayer* relu2 = network->addActivation(*tensor6, ActivationType::kRELU); tensor7 = relu2->getOutput(0); } - IConvolutionLayer* conv3 = network->addConvolutionNd(*tensor7, output, DimsHW{1, 1}, weightMap[lname + "7.weight"], emptywts); + IConvolutionLayer* conv3 = + network->addConvolutionNd(*tensor7, output, DimsHW{1, 1}, weightMap[lname + "7.weight"], emptywts); IScaleLayer* bn3 = addBatchNorm(network, weightMap, *conv3->getOutput(0), lname + "8", 1e-5); assert(bn3); return bn3; } -ILayer* invertedRes(INetworkDefinition *network, std::map& weightMap, ITensor& input, std::string lname, int inch, int outch, int s, int hidden, int k, bool use_se, bool use_hs, int w) { +ILayer* invertedRes(INetworkDefinition* network, std::map& weightMap, ITensor& input, + std::string lname, int inch, int outch, int s, int hidden, int k, bool use_se, bool use_hs, int w) { bool use_res_connect = (s == 1 && inch == outch); - ILayer *conv = nullptr; + ILayer* conv = nullptr; if (inch == hidden) { conv = convSeq1(network, weightMap, input, outch, hidden, k, s, use_se, use_hs, w, lname + "conv."); } else { conv = convSeq2(network, weightMap, input, outch, hidden, k, s, use_se, use_hs, w, lname + "conv."); } - if (!use_res_connect) return conv; + if (!use_res_connect) + return conv; IElementWiseLayer* ew3 = network->addElementWise(input, *conv->getOutput(0), ElementWiseOperation::kSUM); assert(ew3); return ew3; } // Creat the engine using only the API and not any parser. -ICudaEngine* createEngineSmall(unsigned int maxBatchSize, IBuilder* builder, IBuilderConfig* config, DataType dt) -{ +ICudaEngine* createEngineSmall(unsigned int maxBatchSize, IBuilder* builder, IBuilderConfig* config, DataType dt) { INetworkDefinition* network = builder->createNetworkV2(0U); ITensor* data = network->addInput(INPUT_BLOB_NAME, dt, Dims3{3, INPUT_H, INPUT_W}); @@ -271,11 +280,13 @@ ICudaEngine* createEngineSmall(unsigned int maxBatchSize, IBuilder* builder, IBu pool1->setStrideNd(DimsHW{7, 7}); ILayer* sw1 = hSwish(network, *pool1->getOutput(0), "hSwish.0"); - IFullyConnectedLayer* fc1 = network->addFullyConnected(*sw1->getOutput(0), 1280, weightMap["classifier.0.weight"], weightMap["classifier.0.bias"]); + IFullyConnectedLayer* fc1 = network->addFullyConnected(*sw1->getOutput(0), 1280, weightMap["classifier.0.weight"], + weightMap["classifier.0.bias"]); assert(fc1); ILayer* bn1 = addBatchNorm(network, weightMap, *fc1->getOutput(0), "classifier.1", 1e-5); ILayer* sw2 = hSwish(network, *bn1->getOutput(0), "hSwish.1"); - IFullyConnectedLayer* fc2 = network->addFullyConnected(*sw2->getOutput(0), 1000, weightMap["classifier.3.weight"], weightMap["classifier.3.bias"]); + IFullyConnectedLayer* fc2 = network->addFullyConnected(*sw2->getOutput(0), 1000, weightMap["classifier.3.weight"], + weightMap["classifier.3.bias"]); ILayer* bn2 = addBatchNorm(network, weightMap, *fc2->getOutput(0), "classifier.4", 1e-5); ILayer* sw3 = hSwish(network, *bn2->getOutput(0), "hSwish.2"); @@ -293,16 +304,14 @@ ICudaEngine* createEngineSmall(unsigned int maxBatchSize, IBuilder* builder, IBu network->destroy(); // Release host memory - for (auto& mem : weightMap) - { - free((void*) (mem.second.values)); + for (auto& mem : weightMap) { + free((void*)(mem.second.values)); } return engine; } -ICudaEngine* createEngineLarge(unsigned int maxBatchSize, IBuilder* builder, IBuilderConfig* config, DataType dt) -{ +ICudaEngine* createEngineLarge(unsigned int maxBatchSize, IBuilder* builder, IBuilderConfig* config, DataType dt) { INetworkDefinition* network = builder->createNetworkV2(0U); ITensor* data = network->addInput(INPUT_BLOB_NAME, dt, Dims3{3, INPUT_H, INPUT_W}); @@ -335,10 +344,12 @@ ICudaEngine* createEngineLarge(unsigned int maxBatchSize, IBuilder* builder, IBu pool1->setStrideNd(DimsHW{7, 7}); ILayer* sw1 = hSwish(network, *pool1->getOutput(0), "hSwish.0"); - IFullyConnectedLayer* fc1 = network->addFullyConnected(*sw1->getOutput(0), 1280, weightMap["classifier.0.weight"], weightMap["classifier.0.bias"]); + IFullyConnectedLayer* fc1 = network->addFullyConnected(*sw1->getOutput(0), 1280, weightMap["classifier.0.weight"], + weightMap["classifier.0.bias"]); assert(fc1); ILayer* sw2 = hSwish(network, *fc1->getOutput(0), "hSwish.1"); - IFullyConnectedLayer* fc2 = network->addFullyConnected(*sw2->getOutput(0), 1000, weightMap["classifier.3.weight"], weightMap["classifier.3.bias"]); + IFullyConnectedLayer* fc2 = network->addFullyConnected(*sw2->getOutput(0), 1000, weightMap["classifier.3.weight"], + weightMap["classifier.3.bias"]); fc2->getOutput(0)->setName(OUTPUT_BLOB_NAME); std::cout << "set name out" << std::endl; @@ -354,16 +365,14 @@ ICudaEngine* createEngineLarge(unsigned int maxBatchSize, IBuilder* builder, IBu network->destroy(); // Release host memory - for (auto& mem : weightMap) - { - free((void*) (mem.second.values)); + for (auto& mem : weightMap) { + free((void*)(mem.second.values)); } return engine; } -void APIToModel(unsigned int maxBatchSize, IHostMemory** modelStream, std::string mode) -{ +void APIToModel(unsigned int maxBatchSize, IHostMemory** modelStream, std::string mode) { // Create builder IBuilder* builder = createInferBuilder(gLogger); IBuilderConfig* config = builder->createBuilderConfig(); @@ -384,12 +393,11 @@ void APIToModel(unsigned int maxBatchSize, IHostMemory** modelStream, std::strin // Close everything down engine->destroy(); - builder->destroy(); config->destroy(); + builder->destroy(); } -void doInference(IExecutionContext& context, float* input, float* output, int batchSize) -{ +void doInference(IExecutionContext& context, float* input, float* output, int batchSize) { const ICudaEngine& engine = context.getEngine(); // Pointers to input and output device buffers to pass to engine. @@ -411,9 +419,11 @@ void doInference(IExecutionContext& context, float* input, float* output, int ba CHECK(cudaStreamCreate(&stream)); // DMA input batch data to device, infer on the batch asynchronously, and DMA output back to host - CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize * 3 * INPUT_H * INPUT_W * sizeof(float), cudaMemcpyHostToDevice, stream)); + CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize * 3 * INPUT_H * INPUT_W * sizeof(float), + cudaMemcpyHostToDevice, stream)); context.enqueue(batchSize, buffers, stream, nullptr); - CHECK(cudaMemcpyAsync(output, buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream)); + CHECK(cudaMemcpyAsync(output, buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, + stream)); cudaStreamSynchronize(stream); // Release stream and buffers @@ -422,8 +432,7 @@ void doInference(IExecutionContext& context, float* input, float* output, int ba CHECK(cudaFree(buffers[outputIndex])); } -int main(int argc, char** argv) -{ +int main(int argc, char** argv) { if (argc != 3) { std::cerr << "arguments not right!" << std::endl; std::cerr << "./mobilenet -s small // serialize small model to plan file" << std::endl; @@ -434,7 +443,7 @@ int main(int argc, char** argv) } // create a model using the API directly and serialize it to a stream - char *trtModelStream{nullptr}; + char* trtModelStream{nullptr}; size_t size{0}; std::string mode = std::string(argv[2]); std::cout << mode << std::endl; @@ -495,8 +504,7 @@ int main(int argc, char** argv) // Print histogram of the output distribution std::cout << "\nOutput:\n\n"; - for (unsigned int i = 0; i < OUTPUT_SIZE; i++) - { + for (unsigned int i = 0; i < OUTPUT_SIZE; i++) { std::cout << prob[i] << ", "; //if (i % 10 == 0) std::cout << i / 10 << std::endl; } diff --git a/mobilenet/mobilenetv3/mobilenet_v3.py b/mobilenet/mobilenetv3/mobilenet_v3.py index cc966c3f..07d45b88 100644 --- a/mobilenet/mobilenetv3/mobilenet_v3.py +++ b/mobilenet/mobilenetv3/mobilenet_v3.py @@ -4,8 +4,8 @@ import argparse import numpy as np -import pycuda.autoinit import pycuda.driver as cuda +import pycuda.autoinit # noqa: F401 import tensorrt as trt BATCH_SIZE = 1