From 751f130f608d16531980789cc68cccc87c55332b Mon Sep 17 00:00:00 2001 From: Ibrahim Abedrabbo Date: Wed, 30 Jun 2021 18:56:03 +0300 Subject: [PATCH 1/8] support torch2trt() conversion for TensorRt 8.0; fix unit tests for trt 8.0. --- torch2trt/converters/interpolate.py | 41 +++++++++++++++++++++++++++++ torch2trt/converters/mod.py | 2 +- torch2trt/test.py | 1 - torch2trt/torch2trt.py | 8 +++--- 4 files changed, 46 insertions(+), 6 deletions(-) diff --git a/torch2trt/converters/interpolate.py b/torch2trt/converters/interpolate.py index dfa20d19..17320b30 100644 --- a/torch2trt/converters/interpolate.py +++ b/torch2trt/converters/interpolate.py @@ -90,6 +90,47 @@ def convert_interpolate_trt7(ctx): output._trt = layer.get_output(0) +@tensorrt_converter('torch.nn.functional.interpolate', enabled=trt_version() >= '8.0') +@tensorrt_converter('torch.nn.functional.upsample', enabled=trt_version() >= '8.0') +def convert_interpolate_trt8(ctx): + #parse args + input = get_arg(ctx, 'input', pos=0, default=None) + size = get_arg(ctx, 'size', pos=1, default=None) + scale_factor=get_arg(ctx, 'scale_factor', pos=2, default=None) + mode = get_arg(ctx, 'mode', pos=3, default='nearest') + align_corners = get_arg(ctx, 'align_corners', pos=4, default=None) + + input_dim = input.dim() - 2 + + input_trt = add_missing_trt_tensors(ctx.network, [input])[0] + output = ctx.method_return + layer = ctx.network.add_resize(input=input_trt) + + shape = size + if shape != None: + if isinstance(shape, collections.Sequence): + shape = [input.size(1)] + list(shape) + else: + shape = [input.size(1)] + [shape] * input_dim + + layer.shape = shape + + scales = scale_factor + if scales != None: + if not isinstance(scales, collections.Sequence): + scales = [scales] * input_dim + layer.scales = [1] + list(scales) + + resize_mode = mode + if resize_mode.lower() in ["linear","bilinear","trilinear"]: + layer.resize_mode = trt.ResizeMode.LINEAR + else: + layer.resize_mode=trt.ResizeMode.NEAREST + + if align_corners: + layer.coordinate_transformation = trt.ResizeCoordinateTransformation.ALIGN_CORNERS + + output._trt = layer.get_output(0) class Interpolate(torch.nn.Module): def __init__(self, size, mode, align_corners): diff --git a/torch2trt/converters/mod.py b/torch2trt/converters/mod.py index 6cf69435..2bead5fa 100644 --- a/torch2trt/converters/mod.py +++ b/torch2trt/converters/mod.py @@ -64,7 +64,7 @@ def __init__(self): super(ModAssign, self).__init__() def forward(self, x, y): - x %= y + x = x % y return x diff --git a/torch2trt/test.py b/torch2trt/test.py index dec9bb88..f7bd3518 100644 --- a/torch2trt/test.py +++ b/torch2trt/test.py @@ -108,7 +108,6 @@ def run(self): num_tests, num_success, num_tolerance, num_error = 0, 0, 0, 0 for test in MODULE_TESTS: - # filter by module name name = test.module_name() if not re.search(args.name, name): diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index 6b153a02..61776b1a 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -556,10 +556,10 @@ def torch2trt(module, outputs = (outputs,) ctx.mark_outputs(outputs, output_names) - builder.max_workspace_size = max_workspace_size - builder.fp16_mode = fp16_mode + config = builder.create_builder_config() + config.max_workspace_size = max_workspace_size + config.flags = fp16_mode << int(trt.BuilderFlag.FP16) | strict_type_constraints << int(trt.BuilderFlag.STRICT_TYPES) builder.max_batch_size = max_batch_size - builder.strict_type_constraints = strict_type_constraints if int8_mode: @@ -574,7 +574,7 @@ def torch2trt(module, inputs, int8_calib_dataset, batch_size=int8_calib_batch_size, algorithm=int8_calib_algorithm ) - engine = builder.build_cuda_engine(network) + engine = builder.build_engine(network, config) module_trt = TRTModule(engine, input_names, output_names) From fe1fda8303612724df97f3f6b6994c44d89b3c34 Mon Sep 17 00:00:00 2001 From: Ibrahim Abedrabbo Date: Wed, 30 Jun 2021 20:53:10 +0300 Subject: [PATCH 2/8] fix setting 'int8_mode' + 'int8_calibrator' --- torch2trt/torch2trt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index 61776b1a..f68e8286 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -567,10 +567,10 @@ def torch2trt(module, if int8_calib_dataset is None: int8_calib_dataset = TensorBatchDataset(inputs_in) - builder.int8_mode = True + config.flags = config.flags | int8_mode << int(trt.BuilderFlag.INT8) # @TODO(jwelsh): Should we set batch_size=max_batch_size? Need to investigate memory consumption - builder.int8_calibrator = DatasetCalibrator( + config.int8_calibrator = DatasetCalibrator( inputs, int8_calib_dataset, batch_size=int8_calib_batch_size, algorithm=int8_calib_algorithm ) From 4d4cb7e831589892344e5b20766064a541f21fa7 Mon Sep 17 00:00:00 2001 From: Ibrahim Abedrabbo Date: Mon, 20 Sep 2021 17:00:40 +0300 Subject: [PATCH 3/8] (chitoku:jp4.6_tensorrt8) Update plugin classes to match the base IPluginV2 class in TensorRT8 --- torch2trt/plugins/group_norm.cpp | 50 +++++++++++++++---------------- torch2trt/plugins/interpolate.cpp | 50 +++++++++++++++---------------- 2 files changed, 50 insertions(+), 50 deletions(-) diff --git a/torch2trt/plugins/group_norm.cpp b/torch2trt/plugins/group_norm.cpp index ccc7b51d..3247553a 100644 --- a/torch2trt/plugins/group_norm.cpp +++ b/torch2trt/plugins/group_norm.cpp @@ -112,19 +112,19 @@ class GroupNormPlugin : public IPluginV2 { return data_str.str(); } - const char* getPluginType() const override { + const char* getPluginType() const noexcept override { return "group_norm"; }; - const char* getPluginVersion() const override { + const char* getPluginVersion() const noexcept override { return "1"; } - int getNbOutputs() const override { + int getNbOutputs() const noexcept override { return 1; } - Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override { + Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) noexcept override { Dims dims; dims.nbDims = inputs->nbDims; @@ -135,8 +135,8 @@ class GroupNormPlugin : public IPluginV2 { return dims; } - bool supportsFormat(DataType type, PluginFormat format) const override { - if (format != PluginFormat::kNCHW) { + bool supportsFormat(DataType type, PluginFormat format) const noexcept override { + if (format != PluginFormat::kLINEAR) { return false; } if (type == DataType::kINT32 || type == DataType::kINT8) { @@ -146,7 +146,7 @@ class GroupNormPlugin : public IPluginV2 { } void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims, - int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) override { + int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) noexcept override { // set data type if (type == DataType::kFLOAT) { @@ -170,7 +170,7 @@ class GroupNormPlugin : public IPluginV2 { } } - int initialize() override { + int initialize() noexcept override { // set device tensor_options = tensor_options.device(c10::kCUDA); @@ -188,11 +188,11 @@ class GroupNormPlugin : public IPluginV2 { return 0; } - void terminate() override {} + void terminate() noexcept override {} - size_t getWorkspaceSize(int maxBatchSize) const override { return 0; } - - int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override { + size_t getWorkspaceSize(int maxBatchSize) const noexcept override { return 0; } + int32_t enqueue(int32_t batchSize, void const* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) noexcept override { // get input / output dimensions std::vector batch_input_sizes = input_sizes; std::vector batch_output_sizes = output_sizes; @@ -235,25 +235,25 @@ class GroupNormPlugin : public IPluginV2 { } - size_t getSerializationSize() const override { + size_t getSerializationSize() const noexcept override { return serializeToString().size(); } - void serialize(void* buffer) const override { + void serialize(void* buffer) const noexcept override { std::string data = serializeToString(); size_t size = getSerializationSize(); data.copy((char *) buffer, size); } - void destroy() override {} + void destroy() noexcept override {} - IPluginV2* clone() const override { + IPluginV2* clone() const noexcept override { return new GroupNormPlugin(num_groups, weight, bias, eps); } - void setPluginNamespace(const char* pluginNamespace) override {} + void setPluginNamespace(const char* pluginNamespace) noexcept override {} - const char *getPluginNamespace() const override { + const char *getPluginNamespace() const noexcept override { return "torch2trt"; } @@ -263,26 +263,26 @@ class GroupNormPluginCreator : public IPluginCreator { public: GroupNormPluginCreator() {} - const char *getPluginNamespace() const override { + const char *getPluginNamespace() const noexcept override { return "torch2trt"; } - const char *getPluginName() const override { + const char *getPluginName() const noexcept override { return "group_norm"; } - const char *getPluginVersion() const override { + const char *getPluginVersion() const noexcept override { return "1"; } - IPluginV2 *deserializePlugin(const char *name, const void *data, size_t length) override { + IPluginV2 *deserializePlugin(const char *name, const void *data, size_t length) noexcept override { return new GroupNormPlugin((const char*) data, length); } - void setPluginNamespace(const char *N) override {} - const PluginFieldCollection *getFieldNames() override { return nullptr; } + void setPluginNamespace(const char *N) noexcept override {} + const PluginFieldCollection *getFieldNames() noexcept override { return nullptr; } - IPluginV2 *createPlugin(const char *name, const PluginFieldCollection *fc) override { return nullptr; } + IPluginV2 *createPlugin(const char *name, const PluginFieldCollection *fc) noexcept override { return nullptr; } }; diff --git a/torch2trt/plugins/interpolate.cpp b/torch2trt/plugins/interpolate.cpp index cf463d0f..cc06a922 100644 --- a/torch2trt/plugins/interpolate.cpp +++ b/torch2trt/plugins/interpolate.cpp @@ -103,19 +103,19 @@ class InterpolatePlugin : public IPluginV2 { return data_str.str(); } - const char* getPluginType() const override { + const char* getPluginType() const noexcept override { return "interpolate"; }; - const char* getPluginVersion() const override { + const char* getPluginVersion() const noexcept override { return "1"; } - int getNbOutputs() const override { + int getNbOutputs() const noexcept override { return 1; } - Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override { + Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) noexcept override { Dims dims; dims.nbDims = inputs->nbDims; @@ -127,8 +127,8 @@ class InterpolatePlugin : public IPluginV2 { return dims; } - bool supportsFormat(DataType type, PluginFormat format) const override { - if (format != PluginFormat::kNCHW) { + bool supportsFormat(DataType type, PluginFormat format) const noexcept override { + if (format != PluginFormat::kLINEAR) { return false; } if (type == DataType::kINT32 || type == DataType::kINT8) { @@ -138,7 +138,7 @@ class InterpolatePlugin : public IPluginV2 { } void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims, - int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) override { + int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) noexcept override { // set data type if (type == DataType::kFLOAT) { @@ -162,7 +162,7 @@ class InterpolatePlugin : public IPluginV2 { } } - int initialize() override { + int initialize() noexcept override { // set device tensor_options = tensor_options.device(c10::kCUDA); @@ -176,11 +176,11 @@ class InterpolatePlugin : public IPluginV2 { return 0; } - void terminate() override {} + void terminate() noexcept override {} - size_t getWorkspaceSize(int maxBatchSize) const override { return 0; } - - int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override { + size_t getWorkspaceSize(int maxBatchSize) const noexcept override { return 0; } + int32_t enqueue(int32_t batchSize, void const* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) noexcept override { // get input / output dimensions std::vector batch_input_sizes = input_sizes; std::vector batch_output_sizes = output_sizes; @@ -227,25 +227,25 @@ class InterpolatePlugin : public IPluginV2 { return 0; } - size_t getSerializationSize() const override { + size_t getSerializationSize() const noexcept override { return serializeToString().size(); } - void serialize(void* buffer) const override { + void serialize(void* buffer) const noexcept override { std::string data = serializeToString(); size_t size = getSerializationSize(); data.copy((char *) buffer, size); } - void destroy() override {} + void destroy() noexcept override {} - IPluginV2* clone() const override { + IPluginV2* clone() const noexcept override { return new InterpolatePlugin(size, mode, align_corners); } - void setPluginNamespace(const char* pluginNamespace) override {} + void setPluginNamespace(const char* pluginNamespace) noexcept override {} - const char *getPluginNamespace() const override { + const char *getPluginNamespace() const noexcept override { return "torch2trt"; } @@ -255,26 +255,26 @@ class InterpolatePluginCreator : public IPluginCreator { public: InterpolatePluginCreator() {} - const char *getPluginNamespace() const override { + const char *getPluginNamespace() const noexcept override { return "torch2trt"; } - const char *getPluginName() const override { + const char *getPluginName() const noexcept override { return "interpolate"; } - const char *getPluginVersion() const override { + const char *getPluginVersion() const noexcept override { return "1"; } - IPluginV2 *deserializePlugin(const char *name, const void *data, size_t length) override { + IPluginV2 *deserializePlugin(const char *name, const void *data, size_t length) noexcept override { return new InterpolatePlugin((const char*) data, length); } - void setPluginNamespace(const char *N) override {} - const PluginFieldCollection *getFieldNames() override { return nullptr; } + void setPluginNamespace(const char *N) noexcept override {} + const PluginFieldCollection *getFieldNames() noexcept override { return nullptr; } - IPluginV2 *createPlugin(const char *name, const PluginFieldCollection *fc) override { return nullptr; } + IPluginV2 *createPlugin(const char *name, const PluginFieldCollection *fc) noexcept override { return nullptr; } }; From 5522e23824d4ea56176f0d9060d5819f0e124696 Mon Sep 17 00:00:00 2001 From: orilador Date: Mon, 18 Oct 2021 13:54:54 +0300 Subject: [PATCH 4/8] added matmul --- torch2trt/converters/__init__.py | 1 + torch2trt/converters/matmul.py | 38 ++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+) create mode 100644 torch2trt/converters/matmul.py diff --git a/torch2trt/converters/__init__.py b/torch2trt/converters/__init__.py index 0835d7f3..0dee59d8 100644 --- a/torch2trt/converters/__init__.py +++ b/torch2trt/converters/__init__.py @@ -35,6 +35,7 @@ from .instance_norm import * from .interpolate import * from .layer_norm import * +from .matmul import * from .max import * from .max_pool2d import * from .mean import * diff --git a/torch2trt/converters/matmul.py b/torch2trt/converters/matmul.py new file mode 100644 index 00000000..fda01951 --- /dev/null +++ b/torch2trt/converters/matmul.py @@ -0,0 +1,38 @@ +from torch2trt.torch2trt import * +from torch2trt.module_test import add_module_test + + +@tensorrt_converter('torch.matmul') +@tensorrt_converter('torch.mm') +@tensorrt_converter('torch.bmm') +def convert_matmul(ctx): + input_a = ctx.method_args[0] + input_b = ctx.method_args[1] + input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) + output = ctx.method_return + layer = ctx.network.add_matrix_multiply(input_a_trt, trt.MatrixOperation.NONE, input_b_trt, trt.MatrixOperation.NONE) + output._trt = layer.get_output(0) + + +class MatMul(torch.nn.Module): + def __init__(self): + super(MatMul, self).__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + +@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4), (1, 4, 5)]) +def test_matmul_basic(): + return MatMul() + + +class BatchMatMul(torch.nn.Module): + def __init__(self): + super(BatchMatMul, self).__init__() + + def forward(self, x, y): + return torch.bmm(x, y) + +@add_module_test(torch.float32, torch.device('cuda'), [(10, 3, 4), (10, 4, 5)], max_batch_size=10) +def test_batchmatmul_basic(): + return BatchMatMul() \ No newline at end of file From 1295f38125d411b42b6e92c5e0129c89c7d242d4 Mon Sep 17 00:00:00 2001 From: orilador Date: Thu, 21 Oct 2021 02:35:10 +0300 Subject: [PATCH 5/8] fixed wrong kernel_size in AdaptiveAvgPool 2d/3d --- torch2trt/converters/AdaptiveAvgPool2d.py | 3 ++- torch2trt/converters/AdaptiveAvgPool3d.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/torch2trt/converters/AdaptiveAvgPool2d.py b/torch2trt/converters/AdaptiveAvgPool2d.py index 41ad141d..cd7c2b79 100644 --- a/torch2trt/converters/AdaptiveAvgPool2d.py +++ b/torch2trt/converters/AdaptiveAvgPool2d.py @@ -16,7 +16,8 @@ def convert_AdaptiveAvgPool2d(ctx): stride = (input_trt.shape[-2] // output_size[-2], input_trt.shape[-1] // output_size[-1]) - kernel_size = stride + kernel_size = (input_trt.shape[-2] - (output_size[-2] - 1) * stride[-2], input_trt.shape[-1] - (output_size[-1] - 1) * stride[-1]) + layer = ctx.network.add_pooling( input=input_trt, type=trt.PoolingType.AVERAGE, window_size=kernel_size) layer.stride = stride diff --git a/torch2trt/converters/AdaptiveAvgPool3d.py b/torch2trt/converters/AdaptiveAvgPool3d.py index 6c1b4302..9fd5b6dd 100644 --- a/torch2trt/converters/AdaptiveAvgPool3d.py +++ b/torch2trt/converters/AdaptiveAvgPool3d.py @@ -22,7 +22,8 @@ def convert_AdaptiveAvgPool3d(ctx): input_trt.shape[-1] // output_size[-1], ) - kernel_size = stride + kernel_size = (input_trt.shape[-3] - (output_size[-3] - 1) * stride[-3], input_trt.shape[-2] - (output_size[-2] - 1) * stride[-2], input_trt.shape[-1] - (output_size[-1] - 1) * stride[-1]) + layer = ctx.network.add_pooling_nd( input=input_trt, type=trt.PoolingType.AVERAGE, From 441c46ab06014a6663a3e6136a365b98938a8c54 Mon Sep 17 00:00:00 2001 From: Ori Lador Date: Tue, 26 Oct 2021 00:38:39 +0300 Subject: [PATCH 6/8] removed unnecessary include --- torch2trt/converters/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch2trt/converters/__init__.py b/torch2trt/converters/__init__.py index 729d7ca1..1c779a25 100644 --- a/torch2trt/converters/__init__.py +++ b/torch2trt/converters/__init__.py @@ -4,7 +4,6 @@ # supported converters will override dummy converters -from .AdaptiveAvgPool2d import * from .BatchNorm1d import * from .BatchNorm2d import * from .clone import * From 3828d8253c6d8995fadd7f14b82fb096329b1280 Mon Sep 17 00:00:00 2001 From: Ibrahim Abedrabbo Date: Wed, 25 May 2022 18:50:08 +0300 Subject: [PATCH 7/8] add the option to pass an ONNX file path if there is one (instead of generating on the go) add the option to override the default int8 calibrator with a custom calibrator instance --- torch2trt/calibration.py | 2 +- torch2trt/torch2trt.py | 40 ++++++++++++++++++++++++++++------------ 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/torch2trt/calibration.py b/torch2trt/calibration.py index 7506ea4e..b09b7193 100644 --- a/torch2trt/calibration.py +++ b/torch2trt/calibration.py @@ -51,7 +51,7 @@ def get_batch(self, *args, **kwargs): buffer[i].copy_(tensor) self.count += 1 - + return [int(buf.data_ptr()) for buf in self.buffers] else: return [] diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index e2aba5d4..58d84672 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -501,10 +501,12 @@ def torch2trt(module, strict_type_constraints=False, keep_network=True, int8_mode=False, + int8_calibrator=None, int8_calib_dataset=None, int8_calib_algorithm=DEFAULT_CALIBRATION_ALGORITHM, int8_calib_batch_size=1, use_onnx=False, + onnx_file_path=None, **kwargs): # capture arguments to provide to context @@ -518,7 +520,7 @@ def torch2trt(module, logger = trt.Logger(log_level) builder = trt.Builder(logger) config = builder.create_builder_config() - + if isinstance(inputs, list): inputs = tuple(inputs) if not isinstance(inputs, tuple): @@ -535,14 +537,24 @@ def torch2trt(module, output_names = default_output_names(len(outputs)) if use_onnx: - - f = io.BytesIO() - torch.onnx.export(module, inputs, f, input_names=input_names, output_names=output_names) - f.seek(0) - onnx_bytes = f.read() network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) parser = trt.OnnxParser(network, logger) - parser.parse(onnx_bytes) + if onnx_file_path is not None: + print('\tBeginning ONNX file parsing.. path = ', onnx_file_path) + with open(onnx_file_path, 'rb') as onnx_model_file: + onnx_model = onnx_model_file.read() + if not parser.parse(onnx_model): + raise RuntimeError("Onnx model parsing from {} failed. Error: {}".format(onnx_model_file, parser.get_error(0).desc())) + else: + parser.parse(onnx_model) + print('\tEND ONNX file parsing.') + else: + f = io.BytesIO() + torch.onnx.export(module, inputs, f, input_names=input_names, output_names=output_names, opset_version=13) + f.seek(0) + onnx_bytes = f.read() + # parser = trt.OnnxParser(network, logger) + parser.parse(onnx_bytes) else: network = builder.create_network() @@ -567,15 +579,19 @@ def torch2trt(module, if int8_calib_dataset is None: int8_calib_dataset = TensorBatchDataset(inputs_in) + config.set_flag(trt.BuilderFlag.FP16) config.set_flag(trt.BuilderFlag.INT8) - + #Making sure not to run calibration with QAT mode on if not 'qat_mode' in kwargs: # @TODO(jwelsh): Should we set batch_size=max_batch_size? Need to investigate memory consumption - calibrator = DatasetCalibrator( - inputs, int8_calib_dataset, batch_size=int8_calib_batch_size, algorithm=int8_calib_algorithm - ) - config.int8_calibrator = calibrator + if int8_calibrator is None: + calibrator = DatasetCalibrator( + inputs, int8_calib_dataset, batch_size=int8_calib_batch_size, algorithm=int8_calib_algorithm + ) + config.int8_calibrator = calibrator + else: + config.int8_calibrator = int8_calibrator engine = builder.build_engine(network, config) From 09ba43c1316e5e3a70d194228d9dd6cbcb696a39 Mon Sep 17 00:00:00 2001 From: orilador Date: Fri, 16 Dec 2022 20:38:02 +0200 Subject: [PATCH 8/8] force input, output to be on cuda --- torch2trt/torch2trt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index 58d84672..0cbd8f5e 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -406,7 +406,7 @@ def add_inputs(self, torch_inputs, names=None): shape=tuple(torch_input.shape)[1:], dtype=torch_dtype_to_trt(torch_input.dtype), ) - trt_tensor.location = torch_device_to_trt(torch_input.device) + trt_tensor.location = trt.TensorLocation.DEVICE torch_input._trt = trt_tensor def mark_outputs(self, torch_outputs, names=None): @@ -417,7 +417,7 @@ def mark_outputs(self, torch_outputs, names=None): for i, torch_output in enumerate(torch_outputs): trt_tensor = torch_output._trt trt_tensor.name = names[i] - trt_tensor.location = torch_device_to_trt(torch_output.device) + trt_tensor.location = trt.TensorLocation.DEVICE trt_tensor.dtype = torch_dtype_to_trt(torch_output.dtype) self.network.mark_output(trt_tensor)