From c85bd6224c46f23eaed45af2371a36a05f187e6a Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Thu, 24 Oct 2024 00:45:32 +0000 Subject: [PATCH 1/6] feat: automatic plugin feature --- examples/dynamo/automatic_plugin/custom_op.py | 93 +++++++++++ .../conversion/plugin/plugin_generator.py | 151 ++++++++++++++++++ .../conversion/plugin_ops_converters.py | 47 ++++++ 3 files changed, 291 insertions(+) create mode 100644 examples/dynamo/automatic_plugin/custom_op.py create mode 100644 py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py create mode 100644 py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py diff --git a/examples/dynamo/automatic_plugin/custom_op.py b/examples/dynamo/automatic_plugin/custom_op.py new file mode 100644 index 0000000000..043c75d1e6 --- /dev/null +++ b/examples/dynamo/automatic_plugin/custom_op.py @@ -0,0 +1,93 @@ +import triton +import triton.language as tl + +@triton.jit +def elementwise_add_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr): + # Program ID determines the block of data each thread will process + pid = tl.program_id(0) + # Compute the range of elements that this thread block will work on + block_start = pid * BLOCK_SIZE + # Range of indices this thread will handle + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Load elements from the X and Y tensors + x_vals = tl.load(X + offsets) + y_vals = tl.load(Y + offsets) + # Perform the element-wise multiplication + z_vals = x_vals + y_vals + # Store the result in Z + tl.store(Z + offsets, z_vals) + + +import torch +from torch.library import custom_op + + +@custom_op("torchtrt_ex::elementwise_add", mutates_args=()) # type: ignore[misc] +def elementwise_add(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: + # Ensure the tensors are on the GPU + assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device." + assert X.shape == Y.shape, "Tensors must have the same shape." + + # Create output tensor + Z = torch.empty_like(X) + + # Define block size + BLOCK_SIZE = 1024 + + # Grid of programs + grid = lambda meta: (X.numel() // meta['BLOCK_SIZE'],) + + # Launch the kernel + elementwise_add_kernel[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE) + + return Z + + +# Using the module in PyTorch +# X = torch.randn(1024, device='cuda', requires_grad=True) +# Y = torch.randn(1024, device='cuda', requires_grad=True) +# X = torch.full((128, 128), 2, device='cuda',) +# Y = torch.full((128, 128), 2, device='cuda',) +# # elementwise_mul_op = ElementwiseMulModule() +# Z = torch.ops.torchtrt_ex.elementwise_add(X, Y) +# print(Z) +# print(X + Y) +# print(X) +# print(Y) +# print(Z) +# print(X+Y) +# Z.sum().backward() + + +from torch import nn + + +class MyModel(nn.Module): # type: ignore[misc] + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + z = torch.mul(x, y) + res = torch.ops.torchtrt_ex.elementwise_add(x, z) + + return res + + +my_model = MyModel().to("cuda") +m = torch.full((64, 64), 2, device='cuda',) +n = torch.full((64, 64), 3, device='cuda',) +# print(torch.ops.torchtrt_ex.elementwise_add(m, n)) +# print(my_model.forward(m, n)) + + +@torch.library.register_fake("torchtrt_ex::elementwise_add") +def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + +import torch_tensorrt as torchtrt + + +with torchtrt.logging.info(): + model_trt = torchtrt.compile(my_model, inputs=[m, n], debug=True, min_block_size=1) + res = model_trt(m, n) + print(res) \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py new file mode 100644 index 0000000000..9319b9d045 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py @@ -0,0 +1,151 @@ +import tensorrt as trt + + + +class CustomPlugin(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild, trt.IPluginV3OneRuntime): # type: ignore[misc] + def __init__( + self, plugin_name : str, fc = None, phase = None + ): + # TODO: needs an additional passed in arguments to specify the needs for each plugin + # such as the one here: https://github.com/NVIDIA/TensorRT/blob/40efe7e9f2492657bbc455c4e2876e2ec792b812/samples/python/python_plugin/circ_pad_plugin_multi_tactic.py#L83 + trt.IPluginV3.__init__(self) + trt.IPluginV3OneCore.__init__(self) + trt.IPluginV3OneBuild.__init__(self) + trt.IPluginV3OneRuntime.__init__(self) + + # + # setattr(, ) + # self.pads = [] + # self.X_shape: List[int] = [] + + self.num_outputs = 1 # Defined by schema + self.plugin_namespace = "" + self.plugin_name = plugin_name + self.plugin_version = "1" + + # + # ex. + # TODO: need to parse the field collection here + # if fc is not None: + # assert fc[0].name == "pads" + # self.pads = fc[0].data + + if phase is not None: + self.phase = phase + + def get_capability_interface(self, type): + return self + + def get_output_datatypes( + self, input_types: List[trt.DataType] + ) -> trt.DataType: + # WE CAN USE THE FAKE TENSOR IMPLEMENTATION TO FIGURE OUT THE EXPECTED OUTPUT DATA TYPE + # with torch.fake_tensor(): + # + # fake_outputs = torch.ops..(*fake_inputs) + + # return fake_outputs[index] + + # The example case here is simple for experiment + return [input_types[0]] + + def get_output_shapes( + self, + output_index: int, + inputs: List[trt.DimsExprs], + exprBuilder: trt.IExprBuilder, + ) -> trt.DimsExprs: + + + # WE NEED TO FIND A WAY TO GO FROM FAKE TENSOR IMPL TO CONSTRUCTING A DIMSEXPR + # THIS IS SOLVED IN SHAPE PROP IN PYTORCH WHERE SHAPE PROP CAN GIVE SYMINTS THAT ENCODE THE + # SHAPE MAP. + output_shape = trt.DimsExprs(inputs[0]) + + return [output_shape] + + def get_fields_to_serialize(self): + # should be passed in as another argument + return trt.PluginFieldCollection([ + trt.PluginField("pads", self.pads, trt.PluginFieldType.INT32) + ]) + + def configure_plugin(self, inp, out): + pass + + def on_shape_change(self, inp, out): + X_dims = inp[0].dims + self.X_shape = np.zeros((len(X_dims),)) + for i in range(len(X_dims)): + self.X_shape[i] = X_dims[i] + + def supports_format_combination(self, pos, in_out, num_inputs): + assert num_inputs == 1 + assert pos < len(in_out) + + desc = in_out[pos].desc + if desc.format != trt.TensorFormat.LINEAR: + return False + + # first input should be float16 or float32 + if pos == 0: + return desc.type == trt.DataType.FLOAT or desc.type == trt.DataType.HALF + + # output should have the same type as the input + if pos == 1: + return in_out[0].desc.type == desc.type + + assert False + + + def enqueue( + self, + input_desc: List[trt.PluginTensorDesc], + output_desc: List[trt.PluginTensorDesc], + inputs: List[int], + outputs: List[int], + workspace: int, + stream: int, + ) -> None: + ... + + def attach_to_context(self, context): + return self.clone() + + def get_valid_tactics(self): + return [int(Tactic.TORCH), int(Tactic.TRITON)] + + def set_tactic(self, tactic): + self.tactic = Tactic(tactic) + + if self.phase == trt.TensorRTPhase.RUNTIME: + logger.info(f"Best tactic chosen: {self.tactic}") + + def clone(self) -> Self: + # + + +class PluginCreator(trt.IPluginCreatorV3One): # type: ignore[misc] + def __init__(self, plugin_name : str, plugin_field_names : trt.PluginFieldCollection): + super().__init__() + + self.name = plugin_name + self.plugin_namespace = "" + self.plugin_version = "1" + self.field_names = plugin_field_names + + def create_plugin( + self, name: str, field_collection: trt.PluginFieldCollection_ + ) -> CustomPlugin: + return CustomPlugin(field_collection) + + +# Looks like deserilaize required? Not found in the example here: https://github.com/NVIDIA/TensorRT/blob/main/samples/python/python_plugin/circ_pad_plugin_multi_tactic.py + # def deserialize_plugin(self, name: str, data: bytes) -> CircularPaddingPlugin: + # dict = pkl.loads(data) + # deserialized = () + # deserialized.__dict__.update(dict) + # return deserialized + +TRT_PLUGIN_REGISTRY = trt.get_plugin_registry() +TRT_PLUGIN_REGISTRY.register_creator(CircularPaddingPluginCreator(), "") \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py new file mode 100644 index 0000000000..b9136d492d --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py @@ -0,0 +1,47 @@ +import logging +from typing import Dict, Sequence, Tuple, Union + +import torch +from torch.fx.node import Argument, Target +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + dynamo_tensorrt_converter, +) +from torch_tensorrt.fx.types import TRTTensor + +logger = logging.getLogger(__name__) + +@dynamo_tensorrt_converter(torch.ops.torchtrt_ex.elementwise_add.default) +def torchtrt_ex_elementwise_add( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +): + logger.debug(f"plugin stuff here2") + return torch.add(args) + + # How to retrieve a plugin if it is defined elsewhere (e.g. linked library) + # plugin_registry = trt.get_plugin_registry() + # plugin_creator = plugin_registry.get_plugin_creator( + # type="", version="1", plugin_namespace="" + # ) + # assert plugin_creator, f"Unable to find creator" + + # # Pass configurations to the plugin implementation + # field_configs = + # plugin = plugin_creator.create_plugin(name=name, field_collection=field_configs) + # assert plugin, "Unable to create " + + # + # + # + + # return layer.get_output(0) + + +# 1. generate plugin for any pytorch op \ No newline at end of file From eae249954e4ab4af296550aa5e834c7361cbe373 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Sat, 2 Nov 2024 02:16:37 +0000 Subject: [PATCH 2/6] update --- .../dynamo/conversion/__init__.py | 2 +- .../conversion/plugin/plugin_generator.py | 131 +++++++++++++++--- 2 files changed, 116 insertions(+), 17 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/__init__.py b/py/torch_tensorrt/dynamo/conversion/__init__.py index 5351f02bb6..235d1456b0 100644 --- a/py/torch_tensorrt/dynamo/conversion/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/__init__.py @@ -1,4 +1,4 @@ -from . import aten_ops_converters, ops_evaluators, prims_ops_converters +from . import aten_ops_converters, ops_evaluators, prims_ops_converters, plugin_ops_converters from ._conversion import convert_module, interpret_module_to_result from ._ConversionContext import ConversionContext from ._ConverterRegistry import * # noqa: F403 diff --git a/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py index 9319b9d045..21468ca9be 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py +++ b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py @@ -1,19 +1,55 @@ import tensorrt as trt +import cupy as cp +import torch +import numpy as np +import logging +from enum import IntEnum + +logger = logging.getLogger("CustomPlugin") + + + +_numpy_to_plugin_field_type = { + np.dtype('int32'): trt.PluginFieldType.INT32, + np.dtype('int16'): trt.PluginFieldType.INT16, + np.dtype('int8'): trt.PluginFieldType.INT8, + np.dtype('bool'): trt.PluginFieldType.INT8, + np.dtype('int64'): trt.PluginFieldType.INT64, + np.dtype('float32'): trt.PluginFieldType.FLOAT32, + np.dtype('float64'): trt.PluginFieldType.FLOAT64, + np.dtype('float16'): trt.PluginFieldType.FLOAT16 +} + + +_built_in_to_plugin_field_type = { + int: trt.PluginFieldType.INT64, + float: trt.PluginFieldType.FLOAT64, + bool: trt.PluginFieldType.INT8, + # str is handled separately, so not needed here +} + +class Tactic(IntEnum): + TORCH = 1 + TRITON = 2 + class CustomPlugin(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild, trt.IPluginV3OneRuntime): # type: ignore[misc] def __init__( - self, plugin_name : str, fc = None, phase = None + self, plugin_name : str, attrs, phase = None ): # TODO: needs an additional passed in arguments to specify the needs for each plugin # such as the one here: https://github.com/NVIDIA/TensorRT/blob/40efe7e9f2492657bbc455c4e2876e2ec792b812/samples/python/python_plugin/circ_pad_plugin_multi_tactic.py#L83 trt.IPluginV3.__init__(self) + # Core capability, plugin attributes and behaviors common to both the build and runtime phases of a plugin’s lifetime trt.IPluginV3OneCore.__init__(self) + # Build capability, plugin attributes and behaviors that the plugin must exhibit for the TensorRT builder. trt.IPluginV3OneBuild.__init__(self) + # Runtime capability, plugin attributes and behaviors that the plugin must exhibit for it to be executable trt.IPluginV3OneRuntime.__init__(self) - # + # # setattr(, ) # self.pads = [] # self.X_shape: List[int] = [] @@ -21,7 +57,14 @@ def __init__( self.num_outputs = 1 # Defined by schema self.plugin_namespace = "" self.plugin_name = plugin_name - self.plugin_version = "1" + self.plugin_version = "1" + + # Set the timing cache ID to prevent unnecessary timing of second plugin instance + self.timing_cache_id = "" + + self.attrs = attrs + + self.tactic = None # # ex. @@ -66,18 +109,44 @@ def get_output_shapes( def get_fields_to_serialize(self): # should be passed in as another argument - return trt.PluginFieldCollection([ - trt.PluginField("pads", self.pads, trt.PluginFieldType.INT32) - ]) + field_names = [] + + for key, value in self.attrs.items(): + if isinstance(value, np.ndarray): + field_names.append( + trt.PluginField( + key, + value, + _numpy_to_plugin_field_type[np.dtype(value.dtype)], + ) + ) + elif isinstance(value, str): + field_names.append( + trt.PluginField(key, value.encode(), trt.PluginFieldType.CHAR) + ) + elif isinstance(value, bytes): + field_names.append( + trt.PluginField(key, value, trt.PluginFieldType.UNKNOWN) + ) + else: + field_names.append( + trt.PluginField( + key, + np.array([value]), + _built_in_to_plugin_field_type[type(value)], + ) + ) + + return trt.PluginFieldCollection(field_names) def configure_plugin(self, inp, out): pass - def on_shape_change(self, inp, out): - X_dims = inp[0].dims - self.X_shape = np.zeros((len(X_dims),)) - for i in range(len(X_dims)): - self.X_shape[i] = X_dims[i] + # def on_shape_change(self, inp, out): + # X_dims = inp[0].dims + # self.X_shape = np.zeros((len(X_dims),)) + # for i in range(len(X_dims)): + # self.X_shape[i] = X_dims[i] def supports_format_combination(self, pos, in_out, num_inputs): assert num_inputs == 1 @@ -102,12 +171,40 @@ def enqueue( self, input_desc: List[trt.PluginTensorDesc], output_desc: List[trt.PluginTensorDesc], - inputs: List[int], - outputs: List[int], + inputs, + outputs, workspace: int, stream: int, ) -> None: - ... + # input and output memory handling + input_mems = [None] * (len(inputs)) + + for i in range(len(inputs)): + input_mems[i] = cp.cuda.UnownedMemory(inputs[i], np.prod(input_desc[i].dims) * cp.dtype(trt.nptype(input_desc[i].type)).itemsize, self) + + output_mems = [None] * (len(outputs)) + + for i in range(len(outputs)): + output_mems[i] = cp.cuda.UnownedMemory(outputs[i], np.prod(output_desc[i].dims) * cp.dtype(trt.nptype(output_desc[i].type)).itemsize, self) + + + input_data = [None] * ((len(inputs))) + for i in range(len(inputs)): + input_data[i] = cp.ndarray(tuple(input_desc[i].dims), dtype=input_desc[i].type, memptr = cp.cuda.MemoryPointer(input_mems[i], 0)) + + output_data = [None] * ((len(outputs))) + for i in range(len(outputs)): + output_data[i] = cp.ndarray((np.prod(output_desc[i].dims)), dtype = output_desc[i].type, memptr = cp.cuda.MemoryPointer(output_mems[i], 0)) + + #TODO: This is just for a simple case for elementwise operations + # using Torch implementation for now + input_torch_0 = torch.as_tensor(input_data[0], device='cuda') + input_torch_1 = torch.as_tensor(input_data[1], device='cuda') + + output = torch.add(input_torch_0, input_torch_1) + + cp.copyto(output_data, output) + def attach_to_context(self, context): return self.clone() @@ -121,8 +218,10 @@ def set_tactic(self, tactic): if self.phase == trt.TensorRTPhase.RUNTIME: logger.info(f"Best tactic chosen: {self.tactic}") - def clone(self) -> Self: - # + def clone(self): + cloned_plugin = CustomPlugin(self.plugin_name, self.attrs) + cloned_plugin.__dict__.update(self.__dict__) + return cloned_plugin class PluginCreator(trt.IPluginCreatorV3One): # type: ignore[misc] From 702c149cf048c4014455e5e84a355efd7a3ceccf Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Wed, 6 Nov 2024 01:12:57 +0000 Subject: [PATCH 3/6] update --- .../dynamo/conversion/plugin/__init__.py | 1 + .../conversion/plugin/plugin_generator.py | 56 ++++++++++++++++--- .../conversion/plugin_ops_converters.py | 44 ++++++++++++--- 3 files changed, 85 insertions(+), 16 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/conversion/plugin/__init__.py diff --git a/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py b/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py new file mode 100644 index 0000000000..016c091425 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py @@ -0,0 +1 @@ +from .plugin_generator import PluginCreator \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py index 21468ca9be..4cd4c6f0c5 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py +++ b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py @@ -10,8 +10,6 @@ logger = logging.getLogger("CustomPlugin") - - _numpy_to_plugin_field_type = { np.dtype('int32'): trt.PluginFieldType.INT32, np.dtype('int16'): trt.PluginFieldType.INT16, @@ -23,7 +21,6 @@ np.dtype('float16'): trt.PluginFieldType.FLOAT16 } - _built_in_to_plugin_field_type = { int: trt.PluginFieldType.INT64, float: trt.PluginFieldType.FLOAT64, @@ -225,18 +222,59 @@ def clone(self): class PluginCreator(trt.IPluginCreatorV3One): # type: ignore[misc] - def __init__(self, plugin_name : str, plugin_field_names : trt.PluginFieldCollection): - super().__init__() + def __init__(self, plugin_name : str, plugin_namespace : str, attrs): + trt.IPluginCreatorV3One.__init__(self) self.name = plugin_name - self.plugin_namespace = "" + self.plugin_namespace = plugin_namespace self.plugin_version = "1" - self.field_names = plugin_field_names + + field_names = [] + for name, (builtin, type_) in attrs.items(): + if builtin: + if type_ is str: + field_names.append( + trt.PluginField(name, b"", trt.PluginFieldType.CHAR) + ) + elif type_ is bytes: + field_names.append( + trt.PluginField(name, b"", trt.PluginFieldType.UNKNOWN) + ) + else: + field_names.append( + trt.PluginField( + name, np.array([]), _built_in_to_plugin_field_type[type_] + ) + ) + else: + field_names.append( + trt.PluginField( + name, np.array([]), _numpy_to_plugin_field_type[np.dtype(type_)] + ) + ) + + self.field_names = trt.PluginFieldCollection(field_names) def create_plugin( - self, name: str, field_collection: trt.PluginFieldCollection_ + self, name: str, fc, phase ) -> CustomPlugin: - return CustomPlugin(field_collection) + + + attrs = {} + # for f in fc: + # if f.name not in desc.input_attrs: + # raise AssertionError( + # f"Unexpected attribute {f.name} provided to create_plugin. Expected one of {desc.input_attrs.keys()}." + # ) + + # if _is_numpy_array(desc.input_attrs[f.name]): + # attrs[f.name] = f.data.astype(_infer_numpy_type(desc.input_attrs[f.name])) + # else: + # attrs[f.name] = desc.input_attrs[f.name](f.data) + + custom_plugin = CustomPlugin(name, attrs, fc) + + return custom_plugin # Looks like deserilaize required? Not found in the example here: https://github.com/NVIDIA/TensorRT/blob/main/samples/python/python_plugin/circ_pad_plugin_multi_tactic.py diff --git a/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py index b9136d492d..7aa8b4b5d1 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py @@ -11,9 +11,14 @@ dynamo_tensorrt_converter, ) from torch_tensorrt.fx.types import TRTTensor +from plugin import PluginCreator +import tensorrt as trt +from converter_utils import get_trt_tensor logger = logging.getLogger(__name__) +TRT_PLUGIN_REGISTRY = trt.get_plugin_registry() + @dynamo_tensorrt_converter(torch.ops.torchtrt_ex.elementwise_add.default) def torchtrt_ex_elementwise_add( ctx: ConversionContext, @@ -22,15 +27,17 @@ def torchtrt_ex_elementwise_add( kwargs: Dict[str, Argument], name: str, ): - logger.debug(f"plugin stuff here2") - return torch.add(args) + # logger.debug(f"plugin stuff here2") + # return torch.add(args) # How to retrieve a plugin if it is defined elsewhere (e.g. linked library) - # plugin_registry = trt.get_plugin_registry() - # plugin_creator = plugin_registry.get_plugin_creator( - # type="", version="1", plugin_namespace="" - # ) - # assert plugin_creator, f"Unable to find creator" + plugin_creator = PluginCreator("elementwise_add_plugin") + TRT_PLUGIN_REGISTRY.register_creator(plugin_creator, "") + + plugin_creator = TRT_PLUGIN_REGISTRY.get_plugin_creator( + type=plugin_creator, version="1", plugin_namespace="" + ) + assert plugin_creator, f"Unable to find elementwise_add_plugin creator" # # Pass configurations to the plugin implementation # field_configs = @@ -42,6 +49,29 @@ def torchtrt_ex_elementwise_add( # # return layer.get_output(0) + field_configs = trt.PluginFieldCollection([]) + + plugin = plugin_creator.create_plugin(name=name, field_collection=field_configs) + assert plugin, "Unable to create CircularPaddingPlugin" + + # input_tensor = args[ + # 0 + # ] # Arg 0 `torch.ops.torchtrt_ex.triton_circular_pad` is the input tensor + # if not isinstance(input_tensor, trt.ITensor): + # # Freeze input tensor if not TensorRT Tensor already + # input_tensor = get_trt_tensor(ctx, input_tensor, f"{name}_input") + + lhs_dtype = None + rhs_dtype = None + + lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype) + rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype) + + layer = ctx.net.add_plugin_v2( + [lhs_val, rhs_val], plugin + ) # Add the plugin to the network being constructed + layer.name = f"automatic-{name}" + return layer.get_output(0) # 1. generate plugin for any pytorch op \ No newline at end of file From ec1d50396003f9d5ef11d9e662f30f04d83286ad Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Sat, 9 Nov 2024 01:55:48 +0000 Subject: [PATCH 4/6] update --- .../conversion/plugin/plugin_generator.py | 38 +++++++++++-------- .../conversion/plugin_ops_converters.py | 16 ++++---- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py index 4cd4c6f0c5..9cfbd6b1b9 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py +++ b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py @@ -7,6 +7,8 @@ from enum import IntEnum +from typing import List + logger = logging.getLogger("CustomPlugin") @@ -62,6 +64,7 @@ def __init__( self.attrs = attrs self.tactic = None + # # ex. @@ -76,7 +79,7 @@ def __init__( def get_capability_interface(self, type): return self - def get_output_datatypes( + def get_output_data_types( self, input_types: List[trt.DataType] ) -> trt.DataType: # WE CAN USE THE FAKE TENSOR IMPLEMENTATION TO FIGURE OUT THE EXPECTED OUTPUT DATA TYPE @@ -91,18 +94,19 @@ def get_output_datatypes( def get_output_shapes( self, - output_index: int, inputs: List[trt.DimsExprs], + shape_inputs, exprBuilder: trt.IExprBuilder, ) -> trt.DimsExprs: + print(inputs) # WE NEED TO FIND A WAY TO GO FROM FAKE TENSOR IMPL TO CONSTRUCTING A DIMSEXPR # THIS IS SOLVED IN SHAPE PROP IN PYTORCH WHERE SHAPE PROP CAN GIVE SYMINTS THAT ENCODE THE # SHAPE MAP. - output_shape = trt.DimsExprs(inputs[0]) + output_dims = trt.DimsExprs(inputs[0]) - return [output_shape] + return [output_dims] def get_fields_to_serialize(self): # should be passed in as another argument @@ -139,13 +143,15 @@ def get_fields_to_serialize(self): def configure_plugin(self, inp, out): pass - # def on_shape_change(self, inp, out): - # X_dims = inp[0].dims - # self.X_shape = np.zeros((len(X_dims),)) - # for i in range(len(X_dims)): - # self.X_shape[i] = X_dims[i] + def on_shape_change(self, inp, out): + return + X_dims = inp[0].dims + self.X_shape = np.zeros((len(X_dims),)) + for i in range(len(X_dims)): + self.X_shape[i] = X_dims[i] def supports_format_combination(self, pos, in_out, num_inputs): + return assert num_inputs == 1 assert pos < len(in_out) @@ -198,7 +204,7 @@ def enqueue( input_torch_0 = torch.as_tensor(input_data[0], device='cuda') input_torch_1 = torch.as_tensor(input_data[1], device='cuda') - output = torch.add(input_torch_0, input_torch_1) + output = torch.ops.torchtrt_ex.elementwise_add(input_torch_0, input_torch_1) cp.copyto(output_data, output) @@ -212,8 +218,8 @@ def get_valid_tactics(self): def set_tactic(self, tactic): self.tactic = Tactic(tactic) - if self.phase == trt.TensorRTPhase.RUNTIME: - logger.info(f"Best tactic chosen: {self.tactic}") + # if self.phase == trt.TensorRTPhase.RUNTIME: + # logger.info(f"Best tactic chosen: {self.tactic}") def clone(self): cloned_plugin = CustomPlugin(self.plugin_name, self.attrs) @@ -256,7 +262,7 @@ def __init__(self, plugin_name : str, plugin_namespace : str, attrs): self.field_names = trt.PluginFieldCollection(field_names) def create_plugin( - self, name: str, fc, phase + self, name: str, field_collection, phase=None ) -> CustomPlugin: @@ -272,7 +278,7 @@ def create_plugin( # else: # attrs[f.name] = desc.input_attrs[f.name](f.data) - custom_plugin = CustomPlugin(name, attrs, fc) + custom_plugin = CustomPlugin(name, attrs) return custom_plugin @@ -284,5 +290,5 @@ def create_plugin( # deserialized.__dict__.update(dict) # return deserialized -TRT_PLUGIN_REGISTRY = trt.get_plugin_registry() -TRT_PLUGIN_REGISTRY.register_creator(CircularPaddingPluginCreator(), "") \ No newline at end of file +# TRT_PLUGIN_REGISTRY = trt.get_plugin_registry() +# TRT_PLUGIN_REGISTRY.register_creator(CircularPaddingPluginCreator(), "") \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py index 7aa8b4b5d1..f87c129265 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py @@ -11,9 +11,9 @@ dynamo_tensorrt_converter, ) from torch_tensorrt.fx.types import TRTTensor -from plugin import PluginCreator +from torch_tensorrt.dynamo.conversion.plugin import PluginCreator import tensorrt as trt -from converter_utils import get_trt_tensor +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor logger = logging.getLogger(__name__) @@ -31,11 +31,11 @@ def torchtrt_ex_elementwise_add( # return torch.add(args) # How to retrieve a plugin if it is defined elsewhere (e.g. linked library) - plugin_creator = PluginCreator("elementwise_add_plugin") + plugin_creator = PluginCreator("elementwise_add_plugin", plugin_namespace="", attrs={}) TRT_PLUGIN_REGISTRY.register_creator(plugin_creator, "") plugin_creator = TRT_PLUGIN_REGISTRY.get_plugin_creator( - type=plugin_creator, version="1", plugin_namespace="" + type="elementwise_add_plugin", version="1", plugin_namespace="" ) assert plugin_creator, f"Unable to find elementwise_add_plugin creator" @@ -63,14 +63,16 @@ def torchtrt_ex_elementwise_add( lhs_dtype = None rhs_dtype = None + lhs_val = args[0] + rhs_val = args[1] lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype) rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype) - layer = ctx.net.add_plugin_v2( - [lhs_val, rhs_val], plugin + layer = ctx.net.add_plugin_v3( + [lhs_val, rhs_val], [], plugin ) # Add the plugin to the network being constructed - layer.name = f"automatic-{name}" + # layer.name = f"automatic-{name}" return layer.get_output(0) From 764acad121cbda825d66a1744c7a627aef0a7253 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Wed, 13 Nov 2024 02:16:11 +0000 Subject: [PATCH 5/6] support first example --- .../dynamo/conversion/plugin_ops_converters.py | 11 ++++++++++- setup.py | 2 ++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py index f87c129265..4b8f4b4311 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py @@ -51,7 +51,7 @@ def torchtrt_ex_elementwise_add( # return layer.get_output(0) field_configs = trt.PluginFieldCollection([]) - plugin = plugin_creator.create_plugin(name=name, field_collection=field_configs) + plugin = plugin_creator.create_plugin(name="elementwise_add_plugin", field_collection=field_configs) assert plugin, "Unable to create CircularPaddingPlugin" # input_tensor = args[ @@ -66,6 +66,15 @@ def torchtrt_ex_elementwise_add( lhs_val = args[0] rhs_val = args[1] + if isinstance(lhs_val, TRTTensor): + lhs_dtype = lhs_val.dtype + # is_lhs_trt_tensor = True + if isinstance(rhs_val, TRTTensor): + rhs_dtype = rhs_val.dtype + # is_rhs_trt_tensor = True + + print(lhs_dtype) + lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype) rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype) diff --git a/setup.py b/setup.py index 0b8f47fb6f..bd490ec1be 100644 --- a/setup.py +++ b/setup.py @@ -440,6 +440,7 @@ def run(self): "torch_tensorrt.dynamo.conversion.impl.normalization", "torch_tensorrt.dynamo.conversion.impl.slice", "torch_tensorrt.dynamo.conversion.impl.unary", + "torch_tensorrt.dynamo.conversion.plugin", "torch_tensorrt.dynamo.lowering", "torch_tensorrt.dynamo.lowering.passes", "torch_tensorrt.dynamo.partitioning", @@ -468,6 +469,7 @@ def run(self): "torch_tensorrt.dynamo.conversion.impl.normalization": "py/torch_tensorrt/dynamo/conversion/impl/normalization", "torch_tensorrt.dynamo.conversion.impl.slice": "py/torch_tensorrt/dynamo/conversion/impl/slice", "torch_tensorrt.dynamo.conversion.impl.unary": "py/torch_tensorrt/dynamo/conversion/impl/unary", + "torch_tensorrt.dynamo.conversion.plugin": "py/torch_tensorrt/dynamo/conversion/plugin", "torch_tensorrt.dynamo.lowering": "py/torch_tensorrt/dynamo/lowering", "torch_tensorrt.dynamo.lowering.passes": "py/torch_tensorrt/dynamo/lowering/passes", "torch_tensorrt.dynamo.partitioning": "py/torch_tensorrt/dynamo/partitioning", From f0b0a0f0191c487664c015387dba2fa183d7ec49 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Fri, 22 Nov 2024 01:17:40 +0000 Subject: [PATCH 6/6] remove some comments --- .../dynamo/conversion/plugin/plugin_generator.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py index 9cfbd6b1b9..56efe1c714 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py +++ b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py @@ -282,13 +282,3 @@ def create_plugin( return custom_plugin - -# Looks like deserilaize required? Not found in the example here: https://github.com/NVIDIA/TensorRT/blob/main/samples/python/python_plugin/circ_pad_plugin_multi_tactic.py - # def deserialize_plugin(self, name: str, data: bytes) -> CircularPaddingPlugin: - # dict = pkl.loads(data) - # deserialized = () - # deserialized.__dict__.update(dict) - # return deserialized - -# TRT_PLUGIN_REGISTRY = trt.get_plugin_registry() -# TRT_PLUGIN_REGISTRY.register_creator(CircularPaddingPluginCreator(), "") \ No newline at end of file