Skip to content

Conversation

bowang007
Copy link
Collaborator

Description

This PR implements the automatic plugin feature.

Please delete options that are not relevant and/or add your own.

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: conversion Issues re: Conversion stage component: build system Issues re: Build system component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Nov 22, 2024
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/automatic_plugin/custom_op.py	2024-11-22 01:20:58.215888+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/automatic_plugin/custom_op.py	2024-11-22 01:21:18.909129+00:00
@@ -1,7 +1,8 @@
import triton
import triton.language as tl
+

@triton.jit
def elementwise_add_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr):
    # Program ID determines the block of data each thread will process
    pid = tl.program_id(0)
@@ -25,23 +26,23 @@
@custom_op("torchtrt_ex::elementwise_add", mutates_args=())  # type: ignore[misc]
def elementwise_add(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
    # Ensure the tensors are on the GPU
    assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device."
    assert X.shape == Y.shape, "Tensors must have the same shape."
-    
+
    # Create output tensor
    Z = torch.empty_like(X)
-    
+
    # Define block size
    BLOCK_SIZE = 1024
-    
+
    # Grid of programs
-    grid = lambda meta: (X.numel() // meta['BLOCK_SIZE'],)
-    
+    grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],)
+
    # Launch the kernel
    elementwise_add_kernel[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE)
-    
+
    return Z


# Using the module in PyTorch
# X = torch.randn(1024, device='cuda', requires_grad=True)
@@ -72,22 +73,31 @@

        return res


my_model = MyModel().to("cuda")
-m = torch.full((64, 64), 2, device='cuda',)
-n = torch.full((64, 64), 3, device='cuda',)
+m = torch.full(
+    (64, 64),
+    2,
+    device="cuda",
+)
+n = torch.full(
+    (64, 64),
+    3,
+    device="cuda",
+)
# print(torch.ops.torchtrt_ex.elementwise_add(m, n))
# print(my_model.forward(m, n))


@torch.library.register_fake("torchtrt_ex::elementwise_add")
def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return x

+
import torch_tensorrt as torchtrt


with torchtrt.logging.info():
    model_trt = torchtrt.compile(my_model, inputs=[m, n], debug=True, min_block_size=1)
    res = model_trt(m, n)
-    print(res)
\ No newline at end of file
+    print(res)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/__init__.py	2024-11-22 01:20:58.227888+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/__init__.py	2024-11-22 01:21:19.453080+00:00
@@ -1,6 +1,11 @@
-from . import aten_ops_converters, ops_evaluators, prims_ops_converters, plugin_ops_converters
+from . import (
+    aten_ops_converters,
+    ops_evaluators,
+    prims_ops_converters,
+    plugin_ops_converters,
+)
from ._conversion import convert_module, interpret_module_to_result
from ._ConversionContext import ConversionContext
from ._ConverterRegistry import *  # noqa: F403
from ._TRTInterpreter import *  # noqa: F403
from .truncate_double import repair_double_inputs
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py	2024-11-22 01:20:58.227888+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py	2024-11-22 01:21:20.202267+00:00
@@ -1 +1 @@
-from .plugin_generator import PluginCreator
\ No newline at end of file
+from .plugin_generator import PluginCreator
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py	2024-11-22 01:20:58.227888+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py	2024-11-22 01:21:20.284627+00:00
@@ -17,25 +17,28 @@

logger = logging.getLogger(__name__)

TRT_PLUGIN_REGISTRY = trt.get_plugin_registry()

+
@dynamo_tensorrt_converter(torch.ops.torchtrt_ex.elementwise_add.default)
def torchtrt_ex_elementwise_add(
    ctx: ConversionContext,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
-): 
+):
    # logger.debug(f"plugin stuff here2")
    # return torch.add(args)
-    
+
    # How to retrieve a plugin if it is defined elsewhere (e.g. linked library)
-    plugin_creator = PluginCreator("elementwise_add_plugin", plugin_namespace="", attrs={})
-    TRT_PLUGIN_REGISTRY.register_creator(plugin_creator, "")    
-    
+    plugin_creator = PluginCreator(
+        "elementwise_add_plugin", plugin_namespace="", attrs={}
+    )
+    TRT_PLUGIN_REGISTRY.register_creator(plugin_creator, "")
+
    plugin_creator = TRT_PLUGIN_REGISTRY.get_plugin_creator(
        type="elementwise_add_plugin", version="1", plugin_namespace=""
    )
    assert plugin_creator, f"Unable to find elementwise_add_plugin creator"

@@ -44,45 +47,47 @@
    # plugin = plugin_creator.create_plugin(name=name, field_collection=field_configs)
    # assert plugin, "Unable to create <PLUGIN_NAME>"

    # <GENERATE LINK BETWEEN PLUGIN AND INPUTS>
    #    <GET INPUTS INTO LIST>
-    #    <PASS TO PLUGIN>     
-    
+    #    <PASS TO PLUGIN>
+
    # return layer.get_output(0)
    field_configs = trt.PluginFieldCollection([])
-    
-    plugin = plugin_creator.create_plugin(name="elementwise_add_plugin", field_collection=field_configs)
+
+    plugin = plugin_creator.create_plugin(
+        name="elementwise_add_plugin", field_collection=field_configs
+    )
    assert plugin, "Unable to create CircularPaddingPlugin"
-    
+
    # input_tensor = args[
    #     0
    # ]  # Arg 0 `torch.ops.torchtrt_ex.triton_circular_pad` is the input tensor
    # if not isinstance(input_tensor, trt.ITensor):
    #     # Freeze input tensor if not TensorRT Tensor already
    #     input_tensor = get_trt_tensor(ctx, input_tensor, f"{name}_input")
-    
+
    lhs_dtype = None
    rhs_dtype = None
    lhs_val = args[0]
    rhs_val = args[1]
-    
+
    if isinstance(lhs_val, TRTTensor):
        lhs_dtype = lhs_val.dtype
        # is_lhs_trt_tensor = True
    if isinstance(rhs_val, TRTTensor):
        rhs_dtype = rhs_val.dtype
        # is_rhs_trt_tensor = True
-        
+
    print(lhs_dtype)
-    
+
    lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype)
    rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype)

    layer = ctx.net.add_plugin_v3(
        [lhs_val, rhs_val], [], plugin
    )  # Add the plugin to the network being constructed
    # layer.name = f"automatic-{name}"
    return layer.get_output(0)


-# 1. generate plugin for any pytorch op
\ No newline at end of file
+# 1. generate plugin for any pytorch op
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py	2024-11-22 01:20:58.227888+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py	2024-11-22 01:21:20.380983+00:00
@@ -11,64 +11,63 @@


logger = logging.getLogger("CustomPlugin")

_numpy_to_plugin_field_type = {
-    np.dtype('int32'): trt.PluginFieldType.INT32,
-    np.dtype('int16'): trt.PluginFieldType.INT16,
-    np.dtype('int8'): trt.PluginFieldType.INT8,
-    np.dtype('bool'): trt.PluginFieldType.INT8,
-    np.dtype('int64'): trt.PluginFieldType.INT64,
-    np.dtype('float32'): trt.PluginFieldType.FLOAT32,
-    np.dtype('float64'): trt.PluginFieldType.FLOAT64,
-    np.dtype('float16'): trt.PluginFieldType.FLOAT16
+    np.dtype("int32"): trt.PluginFieldType.INT32,
+    np.dtype("int16"): trt.PluginFieldType.INT16,
+    np.dtype("int8"): trt.PluginFieldType.INT8,
+    np.dtype("bool"): trt.PluginFieldType.INT8,
+    np.dtype("int64"): trt.PluginFieldType.INT64,
+    np.dtype("float32"): trt.PluginFieldType.FLOAT32,
+    np.dtype("float64"): trt.PluginFieldType.FLOAT64,
+    np.dtype("float16"): trt.PluginFieldType.FLOAT16,
}

_built_in_to_plugin_field_type = {
    int: trt.PluginFieldType.INT64,
    float: trt.PluginFieldType.FLOAT64,
    bool: trt.PluginFieldType.INT8,
    # str is handled separately, so not needed here
}

+
class Tactic(IntEnum):
    TORCH = 1
    TRITON = 2

+
class CustomPlugin(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild, trt.IPluginV3OneRuntime):  # type: ignore[misc]
-    def __init__(
-        self, plugin_name : str, attrs, phase = None
-    ):
+    def __init__(self, plugin_name: str, attrs, phase=None):
        # TODO: needs an additional passed in arguments to specify the needs for each plugin
        # such as the one here: https://github.com/NVIDIA/TensorRT/blob/40efe7e9f2492657bbc455c4e2876e2ec792b812/samples/python/python_plugin/circ_pad_plugin_multi_tactic.py#L83
        trt.IPluginV3.__init__(self)
        # Core capability, plugin attributes and behaviors common to both the build and runtime phases of a plugin’s lifetime
        trt.IPluginV3OneCore.__init__(self)
        # Build capability, plugin attributes and behaviors that the plugin must exhibit for the TensorRT builder.
        trt.IPluginV3OneBuild.__init__(self)
        # Runtime capability, plugin attributes and behaviors that the plugin must exhibit for it to be executable
-        trt.IPluginV3OneRuntime.__init__(self)       
-        
+        trt.IPluginV3OneRuntime.__init__(self)
+
        # <ANY NON TENSOR INPUTS SHOULD BE AN ATTRIBUTE OF THE PLUGIN>
-        # setattr(<name of input>, <default value for that type>) 
+        # setattr(<name of input>, <default value for that type>)
        # self.pads = []
        # self.X_shape: List[int] = []
- 
-        self.num_outputs = 1 # Defined by schema 
+
+        self.num_outputs = 1  # Defined by schema
        self.plugin_namespace = ""
        self.plugin_name = plugin_name
-        self.plugin_version = "1"   
+        self.plugin_version = "1"

        # Set the timing cache ID to prevent unnecessary timing of second plugin instance
        self.timing_cache_id = ""

        self.attrs = attrs
-        
+
        self.tactic = None
-        
-
-        # <GENERATE CODE FOR TAKING A FIELD COLLECTION CONTAINING THE NON TENSOR INPUTS AND SETTING AN ATTR> 
+
+        # <GENERATE CODE FOR TAKING A FIELD COLLECTION CONTAINING THE NON TENSOR INPUTS AND SETTING AN ATTR>
        # ex.
        # TODO: need to parse the field collection here
        # if fc is not None:
        #     assert fc[0].name == "pads"
        #     self.pads = fc[0].data
@@ -77,14 +76,12 @@
            self.phase = phase

    def get_capability_interface(self, type):
        return self

-    def get_output_data_types(
-        self, input_types: List[trt.DataType]
-    ) -> trt.DataType:
-        # WE CAN USE THE FAKE TENSOR IMPLEMENTATION TO FIGURE OUT THE EXPECTED OUTPUT DATA TYPE 
+    def get_output_data_types(self, input_types: List[trt.DataType]) -> trt.DataType:
+        # WE CAN USE THE FAKE TENSOR IMPLEMENTATION TO FIGURE OUT THE EXPECTED OUTPUT DATA TYPE
        # with torch.fake_tensor():
        #      <GENERATE FAKE INPUTS OF TYPE INPUT_TYPES>
        #      fake_outputs = torch.ops.<custom_ns>.<custom_op>(*fake_inputs)

        # return fake_outputs[index]
@@ -96,20 +93,20 @@
        self,
        inputs: List[trt.DimsExprs],
        shape_inputs,
        exprBuilder: trt.IExprBuilder,
    ) -> trt.DimsExprs:
-        
+
        print(inputs)

-    #    WE NEED TO FIND A WAY TO GO FROM FAKE TENSOR IMPL TO CONSTRUCTING A DIMSEXPR 
-    #    THIS IS SOLVED IN SHAPE PROP IN PYTORCH WHERE SHAPE PROP CAN GIVE SYMINTS THAT ENCODE THE 
-    #    SHAPE MAP. 
+        #    WE NEED TO FIND A WAY TO GO FROM FAKE TENSOR IMPL TO CONSTRUCTING A DIMSEXPR
+        #    THIS IS SOLVED IN SHAPE PROP IN PYTORCH WHERE SHAPE PROP CAN GIVE SYMINTS THAT ENCODE THE
+        #    SHAPE MAP.
        output_dims = trt.DimsExprs(inputs[0])

        return [output_dims]
-    
+
    def get_fields_to_serialize(self):
        # should be passed in as another argument
        field_names = []

        for key, value in self.attrs.items():
@@ -149,11 +146,11 @@
        self.X_shape = np.zeros((len(X_dims),))
        for i in range(len(X_dims)):
            self.X_shape[i] = X_dims[i]

    def supports_format_combination(self, pos, in_out, num_inputs):
-        return 
+        return
        assert num_inputs == 1
        assert pos < len(in_out)

        desc = in_out[pos].desc
        if desc.format != trt.TensorFormat.LINEAR:
@@ -166,11 +163,10 @@
        # output should have the same type as the input
        if pos == 1:
            return in_out[0].desc.type == desc.type

        assert False
-

    def enqueue(
        self,
        input_desc: List[trt.PluginTensorDesc],
        output_desc: List[trt.PluginTensorDesc],
@@ -180,40 +176,56 @@
        stream: int,
    ) -> None:
        # input and output memory handling
        input_mems = [None] * (len(inputs))

-        for i in range(len(inputs)): 
-            input_mems[i] = cp.cuda.UnownedMemory(inputs[i], np.prod(input_desc[i].dims) * cp.dtype(trt.nptype(input_desc[i].type)).itemsize, self)
+        for i in range(len(inputs)):
+            input_mems[i] = cp.cuda.UnownedMemory(
+                inputs[i],
+                np.prod(input_desc[i].dims)
+                * cp.dtype(trt.nptype(input_desc[i].type)).itemsize,
+                self,
+            )

        output_mems = [None] * (len(outputs))

        for i in range(len(outputs)):
-            output_mems[i] = cp.cuda.UnownedMemory(outputs[i], np.prod(output_desc[i].dims) * cp.dtype(trt.nptype(output_desc[i].type)).itemsize, self)
-    
+            output_mems[i] = cp.cuda.UnownedMemory(
+                outputs[i],
+                np.prod(output_desc[i].dims)
+                * cp.dtype(trt.nptype(output_desc[i].type)).itemsize,
+                self,
+            )

        input_data = [None] * ((len(inputs)))
        for i in range(len(inputs)):
-            input_data[i] = cp.ndarray(tuple(input_desc[i].dims), dtype=input_desc[i].type, memptr = cp.cuda.MemoryPointer(input_mems[i], 0))
+            input_data[i] = cp.ndarray(
+                tuple(input_desc[i].dims),
+                dtype=input_desc[i].type,
+                memptr=cp.cuda.MemoryPointer(input_mems[i], 0),
+            )

        output_data = [None] * ((len(outputs)))
        for i in range(len(outputs)):
-            output_data[i] = cp.ndarray((np.prod(output_desc[i].dims)), dtype = output_desc[i].type, memptr = cp.cuda.MemoryPointer(output_mems[i], 0))
-
-        #TODO: This is just for a simple case for elementwise operations
+            output_data[i] = cp.ndarray(
+                (np.prod(output_desc[i].dims)),
+                dtype=output_desc[i].type,
+                memptr=cp.cuda.MemoryPointer(output_mems[i], 0),
+            )
+
+        # TODO: This is just for a simple case for elementwise operations
        # using Torch implementation for now
-        input_torch_0 = torch.as_tensor(input_data[0], device='cuda')
-        input_torch_1 = torch.as_tensor(input_data[1], device='cuda')
+        input_torch_0 = torch.as_tensor(input_data[0], device="cuda")
+        input_torch_1 = torch.as_tensor(input_data[1], device="cuda")

        output = torch.ops.torchtrt_ex.elementwise_add(input_torch_0, input_torch_1)

        cp.copyto(output_data, output)
-

    def attach_to_context(self, context):
        return self.clone()
-    
+
    def get_valid_tactics(self):
        return [int(Tactic.TORCH), int(Tactic.TRITON)]

    def set_tactic(self, tactic):
        self.tactic = Tactic(tactic)
@@ -226,17 +238,17 @@
        cloned_plugin.__dict__.update(self.__dict__)
        return cloned_plugin


class PluginCreator(trt.IPluginCreatorV3One):  # type: ignore[misc]
-    def __init__(self, plugin_name : str, plugin_namespace : str, attrs):
-        trt.IPluginCreatorV3One.__init__(self)  
+    def __init__(self, plugin_name: str, plugin_namespace: str, attrs):
+        trt.IPluginCreatorV3One.__init__(self)

        self.name = plugin_name
        self.plugin_namespace = plugin_namespace
        self.plugin_version = "1"
-        
+
        field_names = []
        for name, (builtin, type_) in attrs.items():
            if builtin:
                if type_ is str:
                    field_names.append(
@@ -259,15 +271,12 @@
                    )
                )

        self.field_names = trt.PluginFieldCollection(field_names)

-    def create_plugin(
-        self, name: str, field_collection, phase=None
-    ) -> CustomPlugin:
-
-        
+    def create_plugin(self, name: str, field_collection, phase=None) -> CustomPlugin:
+
        attrs = {}
        # for f in fc:
        #     if f.name not in desc.input_attrs:
        #         raise AssertionError(
        #             f"Unexpected attribute {f.name} provided to create_plugin. Expected one of {desc.input_attrs.keys()}."
@@ -275,10 +284,9 @@

        #     if _is_numpy_array(desc.input_attrs[f.name]):
        #         attrs[f.name] = f.data.astype(_infer_numpy_type(desc.input_attrs[f.name]))
        #     else:
        #         attrs[f.name] = desc.input_attrs[f.name](f.data)
-                
+
        custom_plugin = CustomPlugin(name, attrs)
-        
+
        return custom_plugin
-

@github-actions github-actions bot requested a review from gs-olive November 22, 2024 01:21
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/automatic_plugin/custom_op.py	2024-11-26 20:16:28.712186+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/automatic_plugin/custom_op.py	2024-11-26 20:16:48.244419+00:00
@@ -1,7 +1,8 @@
import triton
import triton.language as tl
+

@triton.jit
def elementwise_add_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr):
    # Program ID determines the block of data each thread will process
    pid = tl.program_id(0)
@@ -25,23 +26,23 @@
@custom_op("torchtrt_ex::elementwise_add", mutates_args=())  # type: ignore[misc]
def elementwise_add(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
    # Ensure the tensors are on the GPU
    assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device."
    assert X.shape == Y.shape, "Tensors must have the same shape."
-    
+
    # Create output tensor
    Z = torch.empty_like(X)
-    
+
    # Define block size
    BLOCK_SIZE = 1024
-    
+
    # Grid of programs
-    grid = lambda meta: (X.numel() // meta['BLOCK_SIZE'],)
-    
+    grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],)
+
    # Launch the kernel
    elementwise_add_kernel[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE)
-    
+
    return Z


# Using the module in PyTorch
# X = torch.randn(1024, device='cuda', requires_grad=True)
@@ -72,22 +73,31 @@

        return res


my_model = MyModel().to("cuda")
-m = torch.full((64, 64), 2, device='cuda',)
-n = torch.full((64, 64), 3, device='cuda',)
+m = torch.full(
+    (64, 64),
+    2,
+    device="cuda",
+)
+n = torch.full(
+    (64, 64),
+    3,
+    device="cuda",
+)
# print(torch.ops.torchtrt_ex.elementwise_add(m, n))
# print(my_model.forward(m, n))


@torch.library.register_fake("torchtrt_ex::elementwise_add")
def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return x

+
import torch_tensorrt as torchtrt


with torchtrt.logging.info():
    model_trt = torchtrt.compile(my_model, inputs=[m, n], debug=True, min_block_size=1)
    res = model_trt(m, n)
-    print(res)
\ No newline at end of file
+    print(res)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/__init__.py	2024-11-26 20:16:28.728186+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/__init__.py	2024-11-26 20:16:48.833342+00:00
@@ -1,6 +1,11 @@
-from . import aten_ops_converters, ops_evaluators, prims_ops_converters, plugin_ops_converters
+from . import (
+    aten_ops_converters,
+    ops_evaluators,
+    prims_ops_converters,
+    plugin_ops_converters,
+)
from ._conversion import convert_module, interpret_module_to_result
from ._ConversionContext import ConversionContext
from ._ConverterRegistry import *  # noqa: F403
from ._TRTInterpreter import *  # noqa: F403
from .truncate_double import repair_double_inputs
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py	2024-11-26 20:16:28.728186+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py	2024-11-26 20:16:49.583518+00:00
@@ -1 +1 @@
-from .plugin_generator import PluginCreator
\ No newline at end of file
+from .plugin_generator import PluginCreator
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py	2024-11-26 20:16:28.732186+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py	2024-11-26 20:16:49.650545+00:00
@@ -17,25 +17,28 @@

logger = logging.getLogger(__name__)

TRT_PLUGIN_REGISTRY = trt.get_plugin_registry()

+
@dynamo_tensorrt_converter(torch.ops.torchtrt_ex.elementwise_add.default)
def torchtrt_ex_elementwise_add(
    ctx: ConversionContext,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
-): 
+):
    # logger.debug(f"plugin stuff here2")
    # return torch.add(args)
-    
+
    # How to retrieve a plugin if it is defined elsewhere (e.g. linked library)
-    plugin_creator = PluginCreator("elementwise_add_plugin", plugin_namespace="", attrs={})
-    TRT_PLUGIN_REGISTRY.register_creator(plugin_creator, "")    
-    
+    plugin_creator = PluginCreator(
+        "elementwise_add_plugin", plugin_namespace="", attrs={}
+    )
+    TRT_PLUGIN_REGISTRY.register_creator(plugin_creator, "")
+
    plugin_creator = TRT_PLUGIN_REGISTRY.get_plugin_creator(
        type="elementwise_add_plugin", version="1", plugin_namespace=""
    )
    assert plugin_creator, f"Unable to find elementwise_add_plugin creator"

@@ -44,45 +47,47 @@
    # plugin = plugin_creator.create_plugin(name=name, field_collection=field_configs)
    # assert plugin, "Unable to create <PLUGIN_NAME>"

    # <GENERATE LINK BETWEEN PLUGIN AND INPUTS>
    #    <GET INPUTS INTO LIST>
-    #    <PASS TO PLUGIN>     
-    
+    #    <PASS TO PLUGIN>
+
    # return layer.get_output(0)
    field_configs = trt.PluginFieldCollection([])
-    
-    plugin = plugin_creator.create_plugin(name="elementwise_add_plugin", field_collection=field_configs)
+
+    plugin = plugin_creator.create_plugin(
+        name="elementwise_add_plugin", field_collection=field_configs
+    )
    assert plugin, "Unable to create CircularPaddingPlugin"
-    
+
    # input_tensor = args[
    #     0
    # ]  # Arg 0 `torch.ops.torchtrt_ex.triton_circular_pad` is the input tensor
    # if not isinstance(input_tensor, trt.ITensor):
    #     # Freeze input tensor if not TensorRT Tensor already
    #     input_tensor = get_trt_tensor(ctx, input_tensor, f"{name}_input")
-    
+
    lhs_dtype = None
    rhs_dtype = None
    lhs_val = args[0]
    rhs_val = args[1]
-    
+
    if isinstance(lhs_val, TRTTensor):
        lhs_dtype = lhs_val.dtype
        # is_lhs_trt_tensor = True
    if isinstance(rhs_val, TRTTensor):
        rhs_dtype = rhs_val.dtype
        # is_rhs_trt_tensor = True
-        
+
    print(lhs_dtype)
-    
+
    lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype)
    rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype)

    layer = ctx.net.add_plugin_v3(
        [lhs_val, rhs_val], [], plugin
    )  # Add the plugin to the network being constructed
    # layer.name = f"automatic-{name}"
    return layer.get_output(0)


-# 1. generate plugin for any pytorch op
\ No newline at end of file
+# 1. generate plugin for any pytorch op
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py	2024-11-26 20:16:28.732186+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py	2024-11-26 20:16:49.769861+00:00
@@ -11,64 +11,63 @@


logger = logging.getLogger("CustomPlugin")

_numpy_to_plugin_field_type = {
-    np.dtype('int32'): trt.PluginFieldType.INT32,
-    np.dtype('int16'): trt.PluginFieldType.INT16,
-    np.dtype('int8'): trt.PluginFieldType.INT8,
-    np.dtype('bool'): trt.PluginFieldType.INT8,
-    np.dtype('int64'): trt.PluginFieldType.INT64,
-    np.dtype('float32'): trt.PluginFieldType.FLOAT32,
-    np.dtype('float64'): trt.PluginFieldType.FLOAT64,
-    np.dtype('float16'): trt.PluginFieldType.FLOAT16
+    np.dtype("int32"): trt.PluginFieldType.INT32,
+    np.dtype("int16"): trt.PluginFieldType.INT16,
+    np.dtype("int8"): trt.PluginFieldType.INT8,
+    np.dtype("bool"): trt.PluginFieldType.INT8,
+    np.dtype("int64"): trt.PluginFieldType.INT64,
+    np.dtype("float32"): trt.PluginFieldType.FLOAT32,
+    np.dtype("float64"): trt.PluginFieldType.FLOAT64,
+    np.dtype("float16"): trt.PluginFieldType.FLOAT16,
}

_built_in_to_plugin_field_type = {
    int: trt.PluginFieldType.INT64,
    float: trt.PluginFieldType.FLOAT64,
    bool: trt.PluginFieldType.INT8,
    # str is handled separately, so not needed here
}

+
class Tactic(IntEnum):
    TORCH = 1
    TRITON = 2

+
class CustomPlugin(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild, trt.IPluginV3OneRuntime):  # type: ignore[misc]
-    def __init__(
-        self, plugin_name : str, attrs, phase = None
-    ):
+    def __init__(self, plugin_name: str, attrs, phase=None):
        # TODO: needs an additional passed in arguments to specify the needs for each plugin
        # such as the one here: https://github.com/NVIDIA/TensorRT/blob/40efe7e9f2492657bbc455c4e2876e2ec792b812/samples/python/python_plugin/circ_pad_plugin_multi_tactic.py#L83
        trt.IPluginV3.__init__(self)
        # Core capability, plugin attributes and behaviors common to both the build and runtime phases of a plugin’s lifetime
        trt.IPluginV3OneCore.__init__(self)
        # Build capability, plugin attributes and behaviors that the plugin must exhibit for the TensorRT builder.
        trt.IPluginV3OneBuild.__init__(self)
        # Runtime capability, plugin attributes and behaviors that the plugin must exhibit for it to be executable
-        trt.IPluginV3OneRuntime.__init__(self)       
-        
+        trt.IPluginV3OneRuntime.__init__(self)
+
        # <ANY NON TENSOR INPUTS SHOULD BE AN ATTRIBUTE OF THE PLUGIN>
-        # setattr(<name of input>, <default value for that type>) 
+        # setattr(<name of input>, <default value for that type>)
        # self.pads = []
        # self.X_shape: List[int] = []
- 
-        self.num_outputs = 1 # Defined by schema 
+
+        self.num_outputs = 1  # Defined by schema
        self.plugin_namespace = ""
        self.plugin_name = plugin_name
-        self.plugin_version = "1"   
+        self.plugin_version = "1"

        # Set the timing cache ID to prevent unnecessary timing of second plugin instance
        self.timing_cache_id = ""

        self.attrs = attrs
-        
+
        self.tactic = None
-        
-
-        # <GENERATE CODE FOR TAKING A FIELD COLLECTION CONTAINING THE NON TENSOR INPUTS AND SETTING AN ATTR> 
+
+        # <GENERATE CODE FOR TAKING A FIELD COLLECTION CONTAINING THE NON TENSOR INPUTS AND SETTING AN ATTR>
        # ex.
        # TODO: need to parse the field collection here
        # if fc is not None:
        #     assert fc[0].name == "pads"
        #     self.pads = fc[0].data
@@ -77,14 +76,12 @@
            self.phase = phase

    def get_capability_interface(self, type):
        return self

-    def get_output_data_types(
-        self, input_types: List[trt.DataType]
-    ) -> trt.DataType:
-        # WE CAN USE THE FAKE TENSOR IMPLEMENTATION TO FIGURE OUT THE EXPECTED OUTPUT DATA TYPE 
+    def get_output_data_types(self, input_types: List[trt.DataType]) -> trt.DataType:
+        # WE CAN USE THE FAKE TENSOR IMPLEMENTATION TO FIGURE OUT THE EXPECTED OUTPUT DATA TYPE
        # with torch.fake_tensor():
        #      <GENERATE FAKE INPUTS OF TYPE INPUT_TYPES>
        #      fake_outputs = torch.ops.<custom_ns>.<custom_op>(*fake_inputs)

        # return fake_outputs[index]
@@ -96,20 +93,20 @@
        self,
        inputs: List[trt.DimsExprs],
        shape_inputs,
        exprBuilder: trt.IExprBuilder,
    ) -> trt.DimsExprs:
-        
+
        print(inputs)

-    #    WE NEED TO FIND A WAY TO GO FROM FAKE TENSOR IMPL TO CONSTRUCTING A DIMSEXPR 
-    #    THIS IS SOLVED IN SHAPE PROP IN PYTORCH WHERE SHAPE PROP CAN GIVE SYMINTS THAT ENCODE THE 
-    #    SHAPE MAP. 
+        #    WE NEED TO FIND A WAY TO GO FROM FAKE TENSOR IMPL TO CONSTRUCTING A DIMSEXPR
+        #    THIS IS SOLVED IN SHAPE PROP IN PYTORCH WHERE SHAPE PROP CAN GIVE SYMINTS THAT ENCODE THE
+        #    SHAPE MAP.
        output_dims = trt.DimsExprs(inputs[0])

        return [output_dims]
-    
+
    def get_fields_to_serialize(self):
        # should be passed in as another argument
        field_names = []

        for key, value in self.attrs.items():
@@ -149,11 +146,11 @@
        self.X_shape = np.zeros((len(X_dims),))
        for i in range(len(X_dims)):
            self.X_shape[i] = X_dims[i]

    def supports_format_combination(self, pos, in_out, num_inputs):
-        return 
+        return
        assert num_inputs == 1
        assert pos < len(in_out)

        desc = in_out[pos].desc
        if desc.format != trt.TensorFormat.LINEAR:
@@ -166,11 +163,10 @@
        # output should have the same type as the input
        if pos == 1:
            return in_out[0].desc.type == desc.type

        assert False
-

    def enqueue(
        self,
        input_desc: List[trt.PluginTensorDesc],
        output_desc: List[trt.PluginTensorDesc],
@@ -180,40 +176,56 @@
        stream: int,
    ) -> None:
        # input and output memory handling
        input_mems = [None] * (len(inputs))

-        for i in range(len(inputs)): 
-            input_mems[i] = cp.cuda.UnownedMemory(inputs[i], np.prod(input_desc[i].dims) * cp.dtype(trt.nptype(input_desc[i].type)).itemsize, self)
+        for i in range(len(inputs)):
+            input_mems[i] = cp.cuda.UnownedMemory(
+                inputs[i],
+                np.prod(input_desc[i].dims)
+                * cp.dtype(trt.nptype(input_desc[i].type)).itemsize,
+                self,
+            )

        output_mems = [None] * (len(outputs))

        for i in range(len(outputs)):
-            output_mems[i] = cp.cuda.UnownedMemory(outputs[i], np.prod(output_desc[i].dims) * cp.dtype(trt.nptype(output_desc[i].type)).itemsize, self)
-    
+            output_mems[i] = cp.cuda.UnownedMemory(
+                outputs[i],
+                np.prod(output_desc[i].dims)
+                * cp.dtype(trt.nptype(output_desc[i].type)).itemsize,
+                self,
+            )

        input_data = [None] * ((len(inputs)))
        for i in range(len(inputs)):
-            input_data[i] = cp.ndarray(tuple(input_desc[i].dims), dtype=input_desc[i].type, memptr = cp.cuda.MemoryPointer(input_mems[i], 0))
+            input_data[i] = cp.ndarray(
+                tuple(input_desc[i].dims),
+                dtype=input_desc[i].type,
+                memptr=cp.cuda.MemoryPointer(input_mems[i], 0),
+            )

        output_data = [None] * ((len(outputs)))
        for i in range(len(outputs)):
-            output_data[i] = cp.ndarray((np.prod(output_desc[i].dims)), dtype = output_desc[i].type, memptr = cp.cuda.MemoryPointer(output_mems[i], 0))
-
-        #TODO: This is just for a simple case for elementwise operations
+            output_data[i] = cp.ndarray(
+                (np.prod(output_desc[i].dims)),
+                dtype=output_desc[i].type,
+                memptr=cp.cuda.MemoryPointer(output_mems[i], 0),
+            )
+
+        # TODO: This is just for a simple case for elementwise operations
        # using Torch implementation for now
-        input_torch_0 = torch.as_tensor(input_data[0], device='cuda')
-        input_torch_1 = torch.as_tensor(input_data[1], device='cuda')
+        input_torch_0 = torch.as_tensor(input_data[0], device="cuda")
+        input_torch_1 = torch.as_tensor(input_data[1], device="cuda")

        output = torch.ops.torchtrt_ex.elementwise_add(input_torch_0, input_torch_1)

        cp.copyto(output_data, output)
-

    def attach_to_context(self, context):
        return self.clone()
-    
+
    def get_valid_tactics(self):
        return [int(Tactic.TORCH), int(Tactic.TRITON)]

    def set_tactic(self, tactic):
        self.tactic = Tactic(tactic)
@@ -226,17 +238,17 @@
        cloned_plugin.__dict__.update(self.__dict__)
        return cloned_plugin


class PluginCreator(trt.IPluginCreatorV3One):  # type: ignore[misc]
-    def __init__(self, plugin_name : str, plugin_namespace : str, attrs):
-        trt.IPluginCreatorV3One.__init__(self)  
+    def __init__(self, plugin_name: str, plugin_namespace: str, attrs):
+        trt.IPluginCreatorV3One.__init__(self)

        self.name = plugin_name
        self.plugin_namespace = plugin_namespace
        self.plugin_version = "1"
-        
+
        field_names = []
        for name, (builtin, type_) in attrs.items():
            if builtin:
                if type_ is str:
                    field_names.append(
@@ -259,15 +271,12 @@
                    )
                )

        self.field_names = trt.PluginFieldCollection(field_names)

-    def create_plugin(
-        self, name: str, field_collection, phase=None
-    ) -> CustomPlugin:
-
-        
+    def create_plugin(self, name: str, field_collection, phase=None) -> CustomPlugin:
+
        attrs = {}
        # for f in fc:
        #     if f.name not in desc.input_attrs:
        #         raise AssertionError(
        #             f"Unexpected attribute {f.name} provided to create_plugin. Expected one of {desc.input_attrs.keys()}."
@@ -275,10 +284,9 @@

        #     if _is_numpy_array(desc.input_attrs[f.name]):
        #         attrs[f.name] = f.data.astype(_infer_numpy_type(desc.input_attrs[f.name]))
        #     else:
        #         attrs[f.name] = desc.input_attrs[f.name](f.data)
-                
+
        custom_plugin = CustomPlugin(name, attrs)
-        
+
        return custom_plugin
-

@bowang007 bowang007 closed this Jan 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: build system Issues re: Build system component: conversion Issues re: Conversion stage component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants