diff --git a/examples/dynamo/automatic_plugin/custom_op.py b/examples/dynamo/automatic_plugin/custom_op.py new file mode 100644 index 0000000000..043c75d1e6 --- /dev/null +++ b/examples/dynamo/automatic_plugin/custom_op.py @@ -0,0 +1,93 @@ +import triton +import triton.language as tl + +@triton.jit +def elementwise_add_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr): + # Program ID determines the block of data each thread will process + pid = tl.program_id(0) + # Compute the range of elements that this thread block will work on + block_start = pid * BLOCK_SIZE + # Range of indices this thread will handle + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Load elements from the X and Y tensors + x_vals = tl.load(X + offsets) + y_vals = tl.load(Y + offsets) + # Perform the element-wise multiplication + z_vals = x_vals + y_vals + # Store the result in Z + tl.store(Z + offsets, z_vals) + + +import torch +from torch.library import custom_op + + +@custom_op("torchtrt_ex::elementwise_add", mutates_args=()) # type: ignore[misc] +def elementwise_add(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: + # Ensure the tensors are on the GPU + assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device." + assert X.shape == Y.shape, "Tensors must have the same shape." + + # Create output tensor + Z = torch.empty_like(X) + + # Define block size + BLOCK_SIZE = 1024 + + # Grid of programs + grid = lambda meta: (X.numel() // meta['BLOCK_SIZE'],) + + # Launch the kernel + elementwise_add_kernel[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE) + + return Z + + +# Using the module in PyTorch +# X = torch.randn(1024, device='cuda', requires_grad=True) +# Y = torch.randn(1024, device='cuda', requires_grad=True) +# X = torch.full((128, 128), 2, device='cuda',) +# Y = torch.full((128, 128), 2, device='cuda',) +# # elementwise_mul_op = ElementwiseMulModule() +# Z = torch.ops.torchtrt_ex.elementwise_add(X, Y) +# print(Z) +# print(X + Y) +# print(X) +# print(Y) +# print(Z) +# print(X+Y) +# Z.sum().backward() + + +from torch import nn + + +class MyModel(nn.Module): # type: ignore[misc] + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + z = torch.mul(x, y) + res = torch.ops.torchtrt_ex.elementwise_add(x, z) + + return res + + +my_model = MyModel().to("cuda") +m = torch.full((64, 64), 2, device='cuda',) +n = torch.full((64, 64), 3, device='cuda',) +# print(torch.ops.torchtrt_ex.elementwise_add(m, n)) +# print(my_model.forward(m, n)) + + +@torch.library.register_fake("torchtrt_ex::elementwise_add") +def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + +import torch_tensorrt as torchtrt + + +with torchtrt.logging.info(): + model_trt = torchtrt.compile(my_model, inputs=[m, n], debug=True, min_block_size=1) + res = model_trt(m, n) + print(res) \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/conversion/__init__.py b/py/torch_tensorrt/dynamo/conversion/__init__.py index 5351f02bb6..235d1456b0 100644 --- a/py/torch_tensorrt/dynamo/conversion/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/__init__.py @@ -1,4 +1,4 @@ -from . import aten_ops_converters, ops_evaluators, prims_ops_converters +from . import aten_ops_converters, ops_evaluators, prims_ops_converters, plugin_ops_converters from ._conversion import convert_module, interpret_module_to_result from ._ConversionContext import ConversionContext from ._ConverterRegistry import * # noqa: F403 diff --git a/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py b/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py new file mode 100644 index 0000000000..016c091425 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py @@ -0,0 +1 @@ +from .plugin_generator import PluginCreator \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py new file mode 100644 index 0000000000..56efe1c714 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py @@ -0,0 +1,284 @@ +import tensorrt as trt +import cupy as cp +import torch +import numpy as np + +import logging + + +from enum import IntEnum +from typing import List + + +logger = logging.getLogger("CustomPlugin") + +_numpy_to_plugin_field_type = { + np.dtype('int32'): trt.PluginFieldType.INT32, + np.dtype('int16'): trt.PluginFieldType.INT16, + np.dtype('int8'): trt.PluginFieldType.INT8, + np.dtype('bool'): trt.PluginFieldType.INT8, + np.dtype('int64'): trt.PluginFieldType.INT64, + np.dtype('float32'): trt.PluginFieldType.FLOAT32, + np.dtype('float64'): trt.PluginFieldType.FLOAT64, + np.dtype('float16'): trt.PluginFieldType.FLOAT16 +} + +_built_in_to_plugin_field_type = { + int: trt.PluginFieldType.INT64, + float: trt.PluginFieldType.FLOAT64, + bool: trt.PluginFieldType.INT8, + # str is handled separately, so not needed here +} + +class Tactic(IntEnum): + TORCH = 1 + TRITON = 2 + +class CustomPlugin(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild, trt.IPluginV3OneRuntime): # type: ignore[misc] + def __init__( + self, plugin_name : str, attrs, phase = None + ): + # TODO: needs an additional passed in arguments to specify the needs for each plugin + # such as the one here: https://github.com/NVIDIA/TensorRT/blob/40efe7e9f2492657bbc455c4e2876e2ec792b812/samples/python/python_plugin/circ_pad_plugin_multi_tactic.py#L83 + trt.IPluginV3.__init__(self) + # Core capability, plugin attributes and behaviors common to both the build and runtime phases of a plugin’s lifetime + trt.IPluginV3OneCore.__init__(self) + # Build capability, plugin attributes and behaviors that the plugin must exhibit for the TensorRT builder. + trt.IPluginV3OneBuild.__init__(self) + # Runtime capability, plugin attributes and behaviors that the plugin must exhibit for it to be executable + trt.IPluginV3OneRuntime.__init__(self) + + # + # setattr(, ) + # self.pads = [] + # self.X_shape: List[int] = [] + + self.num_outputs = 1 # Defined by schema + self.plugin_namespace = "" + self.plugin_name = plugin_name + self.plugin_version = "1" + + # Set the timing cache ID to prevent unnecessary timing of second plugin instance + self.timing_cache_id = "" + + self.attrs = attrs + + self.tactic = None + + + # + # ex. + # TODO: need to parse the field collection here + # if fc is not None: + # assert fc[0].name == "pads" + # self.pads = fc[0].data + + if phase is not None: + self.phase = phase + + def get_capability_interface(self, type): + return self + + def get_output_data_types( + self, input_types: List[trt.DataType] + ) -> trt.DataType: + # WE CAN USE THE FAKE TENSOR IMPLEMENTATION TO FIGURE OUT THE EXPECTED OUTPUT DATA TYPE + # with torch.fake_tensor(): + # + # fake_outputs = torch.ops..(*fake_inputs) + + # return fake_outputs[index] + + # The example case here is simple for experiment + return [input_types[0]] + + def get_output_shapes( + self, + inputs: List[trt.DimsExprs], + shape_inputs, + exprBuilder: trt.IExprBuilder, + ) -> trt.DimsExprs: + + print(inputs) + + # WE NEED TO FIND A WAY TO GO FROM FAKE TENSOR IMPL TO CONSTRUCTING A DIMSEXPR + # THIS IS SOLVED IN SHAPE PROP IN PYTORCH WHERE SHAPE PROP CAN GIVE SYMINTS THAT ENCODE THE + # SHAPE MAP. + output_dims = trt.DimsExprs(inputs[0]) + + return [output_dims] + + def get_fields_to_serialize(self): + # should be passed in as another argument + field_names = [] + + for key, value in self.attrs.items(): + if isinstance(value, np.ndarray): + field_names.append( + trt.PluginField( + key, + value, + _numpy_to_plugin_field_type[np.dtype(value.dtype)], + ) + ) + elif isinstance(value, str): + field_names.append( + trt.PluginField(key, value.encode(), trt.PluginFieldType.CHAR) + ) + elif isinstance(value, bytes): + field_names.append( + trt.PluginField(key, value, trt.PluginFieldType.UNKNOWN) + ) + else: + field_names.append( + trt.PluginField( + key, + np.array([value]), + _built_in_to_plugin_field_type[type(value)], + ) + ) + + return trt.PluginFieldCollection(field_names) + + def configure_plugin(self, inp, out): + pass + + def on_shape_change(self, inp, out): + return + X_dims = inp[0].dims + self.X_shape = np.zeros((len(X_dims),)) + for i in range(len(X_dims)): + self.X_shape[i] = X_dims[i] + + def supports_format_combination(self, pos, in_out, num_inputs): + return + assert num_inputs == 1 + assert pos < len(in_out) + + desc = in_out[pos].desc + if desc.format != trt.TensorFormat.LINEAR: + return False + + # first input should be float16 or float32 + if pos == 0: + return desc.type == trt.DataType.FLOAT or desc.type == trt.DataType.HALF + + # output should have the same type as the input + if pos == 1: + return in_out[0].desc.type == desc.type + + assert False + + + def enqueue( + self, + input_desc: List[trt.PluginTensorDesc], + output_desc: List[trt.PluginTensorDesc], + inputs, + outputs, + workspace: int, + stream: int, + ) -> None: + # input and output memory handling + input_mems = [None] * (len(inputs)) + + for i in range(len(inputs)): + input_mems[i] = cp.cuda.UnownedMemory(inputs[i], np.prod(input_desc[i].dims) * cp.dtype(trt.nptype(input_desc[i].type)).itemsize, self) + + output_mems = [None] * (len(outputs)) + + for i in range(len(outputs)): + output_mems[i] = cp.cuda.UnownedMemory(outputs[i], np.prod(output_desc[i].dims) * cp.dtype(trt.nptype(output_desc[i].type)).itemsize, self) + + + input_data = [None] * ((len(inputs))) + for i in range(len(inputs)): + input_data[i] = cp.ndarray(tuple(input_desc[i].dims), dtype=input_desc[i].type, memptr = cp.cuda.MemoryPointer(input_mems[i], 0)) + + output_data = [None] * ((len(outputs))) + for i in range(len(outputs)): + output_data[i] = cp.ndarray((np.prod(output_desc[i].dims)), dtype = output_desc[i].type, memptr = cp.cuda.MemoryPointer(output_mems[i], 0)) + + #TODO: This is just for a simple case for elementwise operations + # using Torch implementation for now + input_torch_0 = torch.as_tensor(input_data[0], device='cuda') + input_torch_1 = torch.as_tensor(input_data[1], device='cuda') + + output = torch.ops.torchtrt_ex.elementwise_add(input_torch_0, input_torch_1) + + cp.copyto(output_data, output) + + + def attach_to_context(self, context): + return self.clone() + + def get_valid_tactics(self): + return [int(Tactic.TORCH), int(Tactic.TRITON)] + + def set_tactic(self, tactic): + self.tactic = Tactic(tactic) + + # if self.phase == trt.TensorRTPhase.RUNTIME: + # logger.info(f"Best tactic chosen: {self.tactic}") + + def clone(self): + cloned_plugin = CustomPlugin(self.plugin_name, self.attrs) + cloned_plugin.__dict__.update(self.__dict__) + return cloned_plugin + + +class PluginCreator(trt.IPluginCreatorV3One): # type: ignore[misc] + def __init__(self, plugin_name : str, plugin_namespace : str, attrs): + trt.IPluginCreatorV3One.__init__(self) + + self.name = plugin_name + self.plugin_namespace = plugin_namespace + self.plugin_version = "1" + + field_names = [] + for name, (builtin, type_) in attrs.items(): + if builtin: + if type_ is str: + field_names.append( + trt.PluginField(name, b"", trt.PluginFieldType.CHAR) + ) + elif type_ is bytes: + field_names.append( + trt.PluginField(name, b"", trt.PluginFieldType.UNKNOWN) + ) + else: + field_names.append( + trt.PluginField( + name, np.array([]), _built_in_to_plugin_field_type[type_] + ) + ) + else: + field_names.append( + trt.PluginField( + name, np.array([]), _numpy_to_plugin_field_type[np.dtype(type_)] + ) + ) + + self.field_names = trt.PluginFieldCollection(field_names) + + def create_plugin( + self, name: str, field_collection, phase=None + ) -> CustomPlugin: + + + attrs = {} + # for f in fc: + # if f.name not in desc.input_attrs: + # raise AssertionError( + # f"Unexpected attribute {f.name} provided to create_plugin. Expected one of {desc.input_attrs.keys()}." + # ) + + # if _is_numpy_array(desc.input_attrs[f.name]): + # attrs[f.name] = f.data.astype(_infer_numpy_type(desc.input_attrs[f.name])) + # else: + # attrs[f.name] = desc.input_attrs[f.name](f.data) + + custom_plugin = CustomPlugin(name, attrs) + + return custom_plugin + diff --git a/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py new file mode 100644 index 0000000000..4b8f4b4311 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py @@ -0,0 +1,88 @@ +import logging +from typing import Dict, Sequence, Tuple, Union + +import torch +from torch.fx.node import Argument, Target +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + dynamo_tensorrt_converter, +) +from torch_tensorrt.fx.types import TRTTensor +from torch_tensorrt.dynamo.conversion.plugin import PluginCreator +import tensorrt as trt +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor + +logger = logging.getLogger(__name__) + +TRT_PLUGIN_REGISTRY = trt.get_plugin_registry() + +@dynamo_tensorrt_converter(torch.ops.torchtrt_ex.elementwise_add.default) +def torchtrt_ex_elementwise_add( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +): + # logger.debug(f"plugin stuff here2") + # return torch.add(args) + + # How to retrieve a plugin if it is defined elsewhere (e.g. linked library) + plugin_creator = PluginCreator("elementwise_add_plugin", plugin_namespace="", attrs={}) + TRT_PLUGIN_REGISTRY.register_creator(plugin_creator, "") + + plugin_creator = TRT_PLUGIN_REGISTRY.get_plugin_creator( + type="elementwise_add_plugin", version="1", plugin_namespace="" + ) + assert plugin_creator, f"Unable to find elementwise_add_plugin creator" + + # # Pass configurations to the plugin implementation + # field_configs = + # plugin = plugin_creator.create_plugin(name=name, field_collection=field_configs) + # assert plugin, "Unable to create " + + # + # + # + + # return layer.get_output(0) + field_configs = trt.PluginFieldCollection([]) + + plugin = plugin_creator.create_plugin(name="elementwise_add_plugin", field_collection=field_configs) + assert plugin, "Unable to create CircularPaddingPlugin" + + # input_tensor = args[ + # 0 + # ] # Arg 0 `torch.ops.torchtrt_ex.triton_circular_pad` is the input tensor + # if not isinstance(input_tensor, trt.ITensor): + # # Freeze input tensor if not TensorRT Tensor already + # input_tensor = get_trt_tensor(ctx, input_tensor, f"{name}_input") + + lhs_dtype = None + rhs_dtype = None + lhs_val = args[0] + rhs_val = args[1] + + if isinstance(lhs_val, TRTTensor): + lhs_dtype = lhs_val.dtype + # is_lhs_trt_tensor = True + if isinstance(rhs_val, TRTTensor): + rhs_dtype = rhs_val.dtype + # is_rhs_trt_tensor = True + + print(lhs_dtype) + + lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype) + rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype) + + layer = ctx.net.add_plugin_v3( + [lhs_val, rhs_val], [], plugin + ) # Add the plugin to the network being constructed + # layer.name = f"automatic-{name}" + return layer.get_output(0) + + +# 1. generate plugin for any pytorch op \ No newline at end of file diff --git a/setup.py b/setup.py index 0b8f47fb6f..bd490ec1be 100644 --- a/setup.py +++ b/setup.py @@ -440,6 +440,7 @@ def run(self): "torch_tensorrt.dynamo.conversion.impl.normalization", "torch_tensorrt.dynamo.conversion.impl.slice", "torch_tensorrt.dynamo.conversion.impl.unary", + "torch_tensorrt.dynamo.conversion.plugin", "torch_tensorrt.dynamo.lowering", "torch_tensorrt.dynamo.lowering.passes", "torch_tensorrt.dynamo.partitioning", @@ -468,6 +469,7 @@ def run(self): "torch_tensorrt.dynamo.conversion.impl.normalization": "py/torch_tensorrt/dynamo/conversion/impl/normalization", "torch_tensorrt.dynamo.conversion.impl.slice": "py/torch_tensorrt/dynamo/conversion/impl/slice", "torch_tensorrt.dynamo.conversion.impl.unary": "py/torch_tensorrt/dynamo/conversion/impl/unary", + "torch_tensorrt.dynamo.conversion.plugin": "py/torch_tensorrt/dynamo/conversion/plugin", "torch_tensorrt.dynamo.lowering": "py/torch_tensorrt/dynamo/lowering", "torch_tensorrt.dynamo.lowering.passes": "py/torch_tensorrt/dynamo/lowering/passes", "torch_tensorrt.dynamo.partitioning": "py/torch_tensorrt/dynamo/partitioning",