Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support torch2trt() conversion for TensorRt 8.0; #581

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion torch2trt/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_batch(self, *args, **kwargs):
buffer[i].copy_(tensor)

self.count += 1

return [int(buf.data_ptr()) for buf in self.buffers]
else:
return []
Expand Down
3 changes: 2 additions & 1 deletion torch2trt/converters/AdaptiveAvgPool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def convert_AdaptiveAvgPool2d(ctx):

stride = (input_trt.shape[-2] // output_size[-2], input_trt.shape[-1] // output_size[-1])

kernel_size = stride
kernel_size = (input_trt.shape[-2] - (output_size[-2] - 1) * stride[-2], input_trt.shape[-1] - (output_size[-1] - 1) * stride[-1])

layer = ctx.network.add_pooling(
input=input_trt, type=trt.PoolingType.AVERAGE, window_size=kernel_size)
layer.stride = stride
Expand Down
3 changes: 2 additions & 1 deletion torch2trt/converters/AdaptiveAvgPool3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def convert_AdaptiveAvgPool3d(ctx):
input_trt.shape[-1] // output_size[-1],
)

kernel_size = stride
kernel_size = (input_trt.shape[-3] - (output_size[-3] - 1) * stride[-3], input_trt.shape[-2] - (output_size[-2] - 1) * stride[-2], input_trt.shape[-1] - (output_size[-1] - 1) * stride[-1])

layer = ctx.network.add_pooling_nd(
input=input_trt,
type=trt.PoolingType.AVERAGE,
Expand Down
2 changes: 1 addition & 1 deletion torch2trt/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

# supported converters will override dummy converters

from .AdaptiveAvgPool2d import *
from .BatchNorm1d import *
from .BatchNorm2d import *
from .clone import *
Expand Down Expand Up @@ -38,6 +37,7 @@
from .instance_norm import *
from .interpolate import *
from .layer_norm import *
from .matmul import *
from .max import *
from .max_pool1d import *
from .max_pool2d import *
Expand Down
41 changes: 41 additions & 0 deletions torch2trt/converters/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,47 @@ def convert_interpolate_trt7(ctx):

output._trt = layer.get_output(0)

@tensorrt_converter('torch.nn.functional.interpolate', enabled=trt_version() >= '8.0')
@tensorrt_converter('torch.nn.functional.upsample', enabled=trt_version() >= '8.0')
def convert_interpolate_trt8(ctx):
#parse args
input = get_arg(ctx, 'input', pos=0, default=None)
size = get_arg(ctx, 'size', pos=1, default=None)
scale_factor=get_arg(ctx, 'scale_factor', pos=2, default=None)
mode = get_arg(ctx, 'mode', pos=3, default='nearest')
align_corners = get_arg(ctx, 'align_corners', pos=4, default=None)

input_dim = input.dim() - 2

input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
output = ctx.method_return
layer = ctx.network.add_resize(input=input_trt)

shape = size
if shape != None:
if isinstance(shape, collections.Sequence):
shape = [input.size(1)] + list(shape)
else:
shape = [input.size(1)] + [shape] * input_dim

layer.shape = shape

scales = scale_factor
if scales != None:
if not isinstance(scales, collections.Sequence):
scales = [scales] * input_dim
layer.scales = [1] + list(scales)

resize_mode = mode
if resize_mode.lower() in ["linear","bilinear","trilinear"]:
layer.resize_mode = trt.ResizeMode.LINEAR
else:
layer.resize_mode=trt.ResizeMode.NEAREST

if align_corners:
layer.coordinate_transformation = trt.ResizeCoordinateTransformation.ALIGN_CORNERS

output._trt = layer.get_output(0)

class Interpolate(torch.nn.Module):
def __init__(self, size=None,scale_factor=None, mode=None, align_corners=None):
Expand Down
38 changes: 38 additions & 0 deletions torch2trt/converters/matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test


@tensorrt_converter('torch.matmul')
@tensorrt_converter('torch.mm')
@tensorrt_converter('torch.bmm')
def convert_matmul(ctx):
input_a = ctx.method_args[0]
input_b = ctx.method_args[1]
input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b)
output = ctx.method_return
layer = ctx.network.add_matrix_multiply(input_a_trt, trt.MatrixOperation.NONE, input_b_trt, trt.MatrixOperation.NONE)
output._trt = layer.get_output(0)


class MatMul(torch.nn.Module):
def __init__(self):
super(MatMul, self).__init__()

def forward(self, x, y):
return torch.matmul(x, y)

@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4), (1, 4, 5)])
def test_matmul_basic():
return MatMul()


class BatchMatMul(torch.nn.Module):
def __init__(self):
super(BatchMatMul, self).__init__()

def forward(self, x, y):
return torch.bmm(x, y)

@add_module_test(torch.float32, torch.device('cuda'), [(10, 3, 4), (10, 4, 5)], max_batch_size=10)
def test_batchmatmul_basic():
return BatchMatMul()
2 changes: 1 addition & 1 deletion torch2trt/converters/mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self):
super(ModAssign, self).__init__()

def forward(self, x, y):
x %= y
x = x % y
return x


Expand Down
50 changes: 25 additions & 25 deletions torch2trt/plugins/group_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,19 +112,19 @@ class GroupNormPlugin : public IPluginV2 {
return data_str.str();
}

const char* getPluginType() const override {
const char* getPluginType() const noexcept override {
return "group_norm";
};

const char* getPluginVersion() const override {
const char* getPluginVersion() const noexcept override {
return "1";
}

int getNbOutputs() const override {
int getNbOutputs() const noexcept override {
return 1;
}

Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override {
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) noexcept override {
Dims dims;
dims.nbDims = inputs->nbDims;

Expand All @@ -135,8 +135,8 @@ class GroupNormPlugin : public IPluginV2 {
return dims;
}

bool supportsFormat(DataType type, PluginFormat format) const override {
if (format != PluginFormat::kNCHW) {
bool supportsFormat(DataType type, PluginFormat format) const noexcept override {
if (format != PluginFormat::kLINEAR) {
return false;
}
if (type == DataType::kINT32 || type == DataType::kINT8) {
Expand All @@ -146,7 +146,7 @@ class GroupNormPlugin : public IPluginV2 {
}

void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims,
int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) override {
int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) noexcept override {

// set data type
if (type == DataType::kFLOAT) {
Expand All @@ -170,7 +170,7 @@ class GroupNormPlugin : public IPluginV2 {
}
}

int initialize() override {
int initialize() noexcept override {
// set device
tensor_options = tensor_options.device(c10::kCUDA);

Expand All @@ -188,11 +188,11 @@ class GroupNormPlugin : public IPluginV2 {
return 0;
}

void terminate() override {}
void terminate() noexcept override {}

size_t getWorkspaceSize(int maxBatchSize) const override { return 0; }

int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override {
size_t getWorkspaceSize(int maxBatchSize) const noexcept override { return 0; }
int32_t enqueue(int32_t batchSize, void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) noexcept override {
// get input / output dimensions
std::vector<long> batch_input_sizes = input_sizes;
std::vector<long> batch_output_sizes = output_sizes;
Expand Down Expand Up @@ -235,25 +235,25 @@ class GroupNormPlugin : public IPluginV2 {
}


size_t getSerializationSize() const override {
size_t getSerializationSize() const noexcept override {
return serializeToString().size();
}

void serialize(void* buffer) const override {
void serialize(void* buffer) const noexcept override {
std::string data = serializeToString();
size_t size = getSerializationSize();
data.copy((char *) buffer, size);
}

void destroy() override {}
void destroy() noexcept override {}

IPluginV2* clone() const override {
IPluginV2* clone() const noexcept override {
return new GroupNormPlugin(num_groups, weight, bias, eps);
}

void setPluginNamespace(const char* pluginNamespace) override {}
void setPluginNamespace(const char* pluginNamespace) noexcept override {}

const char *getPluginNamespace() const override {
const char *getPluginNamespace() const noexcept override {
return "torch2trt";
}

Expand All @@ -263,26 +263,26 @@ class GroupNormPluginCreator : public IPluginCreator {
public:
GroupNormPluginCreator() {}

const char *getPluginNamespace() const override {
const char *getPluginNamespace() const noexcept override {
return "torch2trt";
}

const char *getPluginName() const override {
const char *getPluginName() const noexcept override {
return "group_norm";
}

const char *getPluginVersion() const override {
const char *getPluginVersion() const noexcept override {
return "1";
}

IPluginV2 *deserializePlugin(const char *name, const void *data, size_t length) override {
IPluginV2 *deserializePlugin(const char *name, const void *data, size_t length) noexcept override {
return new GroupNormPlugin((const char*) data, length);
}

void setPluginNamespace(const char *N) override {}
const PluginFieldCollection *getFieldNames() override { return nullptr; }
void setPluginNamespace(const char *N) noexcept override {}
const PluginFieldCollection *getFieldNames() noexcept override { return nullptr; }

IPluginV2 *createPlugin(const char *name, const PluginFieldCollection *fc) override { return nullptr; }
IPluginV2 *createPlugin(const char *name, const PluginFieldCollection *fc) noexcept override { return nullptr; }

};

Expand Down
Loading