diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 16d20d3c76..f457bf8eeb 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2730,40 +2730,10 @@ def aten_ops_max_pool( ) -def attention_validator( - node: Node, settings: Optional[CompilationSettings] = None -) -> bool: - # Currently, `attn_mask` is not supported - return args_bounds_check(node.args, 3) is None - - +@dynamo_tensorrt_converter(torch.ops.aten.reshape.default, supports_dynamic_shapes=True) @dynamo_tensorrt_converter( - torch.nn.functional.scaled_dot_product_attention, - capability_validator=attention_validator, - supports_dynamic_shapes=True, + torch.ops.aten._reshape_copy.default, supports_dynamic_shapes=True ) -def tensorrt_scaled_dot_product_attention( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.attention.scaled_dot_product_attention( - ctx, - target, - SourceIR.TORCHTRT_LOWERED, - name, - args[0], - args[1], - args[2], - args_bounds_check(args, 5, False), - kwargs.get("scale", None), - ) - - -@dynamo_tensorrt_converter(torch.ops.aten.reshape.default, supports_dynamic_shapes=True) -@dynamo_tensorrt_converter(torch.ops.aten.view.default, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index c1187f0dd9..36b3476e82 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -2,7 +2,6 @@ activation, addmm, arange, - attention, cast, cat, condition, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/attention.py b/py/torch_tensorrt/dynamo/conversion/impl/attention.py deleted file mode 100644 index 9cc4a30ccf..0000000000 --- a/py/torch_tensorrt/dynamo/conversion/impl/attention.py +++ /dev/null @@ -1,165 +0,0 @@ -import math -from typing import Optional, Union - -import numpy as np -import tensorrt as trt -from torch.fx.node import Target -from torch_tensorrt._enums import dtype -from torch_tensorrt.dynamo.conversion import impl -from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import ( - SourceIR, - cast_trt_tensor, - get_trt_tensor, -) -from torch_tensorrt.fx.types import TRTTensor - - -def tril( - ctx: ConversionContext, - target: Union[Target, str], - source_ir: Optional[SourceIR], - name: str, - input: TRTTensor, -) -> TRTTensor: - # the lower triangle of the tensor means the rows greater than and equal to the cols - row = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", input, 0) - col = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", input, 1) - rc = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", row, col) - arange_tensor = impl.arange.arange( - ctx, target, source_ir, name + "_arange", start=0, end=rc, step=1 - ) - # get the rows - row_tensor = impl.elementwise.trunc_div( - ctx, target, source_ir, name + "_trunc_div_col", arange_tensor, col - ) - # get the cols - col_tensor = impl.elementwise.fmod( - ctx, target, source_ir, name + "_trunc_div_row", arange_tensor, col - ) - cond = impl.elementwise.ge( - ctx, target, source_ir, name + "_ge", row_tensor, col_tensor - ) - return impl.shuffle.reshape( - ctx, target, source_ir, name + "_reshape", cond, [row, col] - ) - - -def scaled_dot_product_attention( - ctx: ConversionContext, - target: Union[Target, str], - source_ir: Optional[SourceIR], - name: str, - query: TRTTensor, - key: TRTTensor, - value: TRTTensor, - is_causal: bool, - scale: Optional[float], -) -> TRTTensor: - # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html - mm = impl.matmul.matrix_multiply( - ctx, - target, - source_ir, - name + "_mm", - query, - key, - other_matrix_op=trt.MatrixOperation.TRANSPOSE, - ) - if scale is None: - scale = query.shape[-1] - if scale < 0: - # dynamic shape - scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1) - sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale) - else: - # static shape - sqrt_scaled = math.sqrt(scale) - scaled = impl.elementwise.div( - ctx, - target, - source_ir, - name + "_scale", - mm, - sqrt_scaled, - ) - else: - scaled = impl.elementwise.mul( - ctx, - target, - source_ir, - name + "_scale", - mm, - scale, - ) - - if is_causal: - L, S = query.shape[-2], key.shape[-2] - if L >= 0 and S >= 0: - # static shape - attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype)) - temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) - attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) - attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") - else: - # if any of the L or S is dynamic shape - if L < 0: - L = impl.shape.shape( - ctx, target, source_ir, name + "_shape_0", query, -2 - ) - if S < 0: - S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, -2) - - LS = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", L, S) - - # this is to generate a tensor which has shape (L, S), type is int32 - arange_tensor = impl.arange.arange( - ctx, target, source_ir, name=name + "_arange", start=0, end=LS, step=1 - ) - shape_tensor = impl.shuffle.reshape( - ctx, target, source_ir, name + "_reshape", arange_tensor, [L, S] - ) - - # since we want our attn_bias to be in float32, so cast it to float32 - shape_tensor = cast_trt_tensor( - ctx, shape_tensor, trt.float32, name + "_casted", target, source_ir - ) - - # initialize the attn_bias as the zeros tensor - attn_bias = impl.elementwise.mul( - ctx, target, source_ir, name + "_mul_zero", shape_tensor, 0.0 - ) - - # generate the mask tensor - tril_tensor = tril(ctx, target, source_ir, name + "_tril", shape_tensor) - temp_mask = impl.unary.logical_not( - ctx, target, source_ir, name + "_logical_not", tril_tensor - ) - inf_tensor = impl.elementwise.mul( - ctx, target, source_ir, name + "_mul_-inf", shape_tensor, float("-inf") - ) - cond = impl.elementwise.eq( - ctx, target, source_ir, name + "_cond_true", temp_mask, bool(True) - ) - # mask out the certain part of the attn_bias - attn_bias = impl.condition.select( - ctx, target, source_ir, name + "_select", inf_tensor, attn_bias, cond - ) - - scaled = impl.elementwise.add( - ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias - ) - - softmax = impl.normalization.softmax( - ctx, target, source_ir, name + "_softmax", scaled, -1, False - ) - out = impl.matmul.matrix_multiply( - ctx, - target, - source_ir, - name + "_out", - softmax, - value, - ) - - return out diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index e4a2e068a5..1c182f92aa 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -1,6 +1,6 @@ import logging from enum import Enum, auto -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import torch from torch._decomp import register_decomposition @@ -435,6 +435,137 @@ def full_like_decomposition(*args, **kwargs) -> torch.Tensor: return torch.full(shape, fill_value, dtype=kwargs["dtype"], device=kwargs["device"]) +@register_torch_trt_decomposition(aten.view.default, registry=TORCH_TRT_DECOMPOSITIONS) +def view_decomposition(x: torch.Tensor, size: List[torch.SymInt]) -> torch.Tensor: + return aten._reshape_copy.default(x, size) + + +@register_torch_trt_decomposition( + aten.scaled_dot_product_attention, registry=TORCH_TRT_DECOMPOSITIONS +) +def scaled_dot_product_attention_decomposition( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + *, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + L, S = query.size(-2), key.size(-2) + device = query.device + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=device) + + if is_causal: + assert attn_mask is None, "attn_mask must be None when is_causal=True" + temp_mask = torch.ones(L, S, dtype=torch.bool, device=device).tril(diagonal=0) + attn_bias = attn_bias.masked_fill(temp_mask.logical_not(), float("-inf")) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias = attn_bias.masked_fill(attn_mask.logical_not(), float("-inf")) + else: + attn_bias = attn_mask + attn_bias + + if enable_gqa: + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) + value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) + + attn_weight = query @ key.transpose(-2, -1) + + if scale is None: + scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int)) + attn_weight = attn_weight / scale + else: + attn_weight = attn_weight * scale + + attn_weight = attn_weight + attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + return attn_weight @ value + + +@register_torch_trt_decomposition( + aten._scaled_dot_product_flash_attention, registry=TORCH_TRT_DECOMPOSITIONS +) +def scaled_dot_product_flash_attention_decomposition( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + *, + scale: Optional[float] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.SymInt, + torch.SymInt, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + attn = scaled_dot_product_attention_decomposition( + query, key, value, None, dropout_p, is_causal, scale=scale + ) + return attn, None, None, None, 0, 0, None, None, None + + +@register_torch_trt_decomposition( + aten._scaled_dot_product_efficient_attention, registry=TORCH_TRT_DECOMPOSITIONS +) +def scaled_dot_product_efficient_attention_decomposition( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[torch.Tensor], + compute_log_sumexp: bool, + dropout_p: float = 0.0, + is_causal: bool = False, + *, + scale: Optional[float] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + attn = scaled_dot_product_attention_decomposition( + query, key, value, attn_bias, dropout_p, is_causal, scale=scale + ) + return attn, None, None, None + + +@register_torch_trt_decomposition( + aten._scaled_dot_product_cudnn_attention, registry=TORCH_TRT_DECOMPOSITIONS +) +def scaled_dot_product_cudnn_attention_decomposition( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[torch.Tensor], + compute_log_sumexp: bool, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + *, + scale: Optional[float] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.SymInt, + torch.SymInt, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + attn = scaled_dot_product_attention_decomposition( + query, key, value, attn_bias, dropout_p, is_causal, scale=scale + ) + return attn, None, None, None, 0, 0, None, None, None + + def get_decompositions( enable_experimental_decompositions: bool = False, ) -> Dict[OpOverload, Callable[[Any], Any]]: diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 661c76d3b6..c589aeeb9d 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -7,14 +7,12 @@ from .accumulate_fp32_matmul import accumulate_fp32_matmul from .constant_folding import constant_fold from .fuse_prims_broadcast import fuse_prims_broadcast -from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention from .pass_manager import DynamoPassManager from .remove_assert_scalar import remove_assert_scalar from .remove_detach import remove_detach from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones from .repair_input_as_output import repair_input_as_output from .replace_max_pool_with_indices import replace_max_pool_with_indices -from .view_to_reshape import view_to_reshape ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist( [ @@ -23,8 +21,6 @@ repair_input_as_output, fuse_prims_broadcast, replace_max_pool_with_indices, - lower_scaled_dot_product_attention, - view_to_reshape, remove_assert_scalar, accumulate_fp32_matmul, ] diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py deleted file mode 100644 index 40fd587615..0000000000 --- a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py +++ /dev/null @@ -1,169 +0,0 @@ -import copy -import logging -import operator -from typing import Callable, Sequence, Tuple - -import torch -from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check -from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( - clean_up_graph_after_modifications, -) - -logger = logging.getLogger(__name__) -REPLACEABLE_ATEN_OPS = { - torch.ops.aten._scaled_dot_product_efficient_attention.default, - torch.ops.aten._scaled_dot_product_flash_attention.default, -} - - -def lower_scaled_dot_product_attention( - gm: torch.fx.GraphModule, settings: CompilationSettings -) -> torch.fx.GraphModule: - """Replace specific versions of scaled_dot_product_attention with an equivalent - implementation which can be easily converted to TRT - """ - original_fns, replacement = scaled_dot_product_attention_replacement() - replaced_nodes = [] - # For each original function, search for it in the graph and replace - for original in original_fns: - replaced_nodes += torch.fx.subgraph_rewriter.replace_pattern_with_filters( - gm, - original, - replacement, - ignore_literals=True, - ) - - if replaced_nodes: - # Repair instances which use the kwargs field (specifically the "scale" kwarg) - # Also repair instances which specified the is_causal or attn_bias fields - for match in replaced_nodes: - attention_node_replaced = None - # Seek the attention operator being replaced - for node in match.nodes_map: - if node.target in REPLACEABLE_ATEN_OPS: - attention_node_replaced = match.nodes_map[node] - break - - assert attention_node_replaced is not None - assert len(match.replacements) == 1 - - new_attention_node = match.replacements[0] - - assert ( - new_attention_node.target - == torch.nn.functional.scaled_dot_product_attention - ) - - # Copy the metadata of the replaced attention node to the new node - # TODO: Investigate why there are multiple FakeTensors in the metadata. - # We only use the first one as it contains the output shape information for this node. - if "val" in attention_node_replaced.meta: - new_attention_node.meta["val"] = copy.copy( - attention_node_replaced.meta["val"][0] - ) - - # If the attention operator had keyword-args, copy them to the new node - if attention_node_replaced.kwargs: - new_attention_node.kwargs = {**attention_node_replaced.kwargs} - - # Set default args in new node: - # Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False - new_attention_node.args = new_attention_node.args + (None, 0.0, False) - - # The `is_causal` argument was specified - if ( - ( - attention_node_replaced.target - == torch.ops.aten._scaled_dot_product_flash_attention.default - ) - and args_bounds_check(attention_node_replaced.args, 4, False) - ) or ( - ( - attention_node_replaced.target - == torch.ops.aten._scaled_dot_product_efficient_attention.default - ) - and args_bounds_check(attention_node_replaced.args, 6, False) - ): - new_attention_node.args = ( - new_attention_node.args[:5] + (True,) + new_attention_node.args[6:] - ) - - # The `attn_bias` argument was specified - if ( - attention_node_replaced.target - == torch.ops.aten._scaled_dot_product_efficient_attention.default - ) and args_bounds_check(attention_node_replaced.args, 3) is not None: - new_attention_node.args = ( - new_attention_node.args[:3] - + attention_node_replaced.args[3] - + new_attention_node.args[4:] - ) - - gm = clean_up_graph_after_modifications(gm) - logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}") - - return gm - - -def scaled_dot_product_attention_replacement() -> Tuple[ - Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]], - Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], -]: - """Constructs the original and replacement functions for efficient attention""" - - # Efficient Attention original graph - def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: - outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( - q, - k, - v, - None, - False, - ) - out = operator.getitem(outputs, 0) - return out - - # Flash Attention original graph - def flash(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: - outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( - q, - k, - v, - ) - out = operator.getitem(outputs, 0) - return out - - # Efficient Attention w/Scale original graph - def efficient_scale( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor - ) -> torch.Tensor: - outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( - q, - k, - v, - None, - False, - scale=1.0, - ) - out = operator.getitem(outputs, 0) - return out - - # Flash Attention w/Scale original graph - def flash_scale(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: - outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( - q, - k, - v, - scale=1.0, - ) - out = operator.getitem(outputs, 0) - return out - - # Replacement graph consists of the functional version of scaled_dot_product_attention - def replacement( - query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> torch.Tensor: - return torch.nn.functional.scaled_dot_product_attention(query, key, value) - - return (efficient, flash, efficient_scale, flash_scale), replacement diff --git a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py deleted file mode 100644 index 795b42f879..0000000000 --- a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py +++ /dev/null @@ -1,39 +0,0 @@ -import logging -from typing import List - -import torch -from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( - clean_up_graph_after_modifications, -) -from torch_tensorrt.dynamo.utils import copy_metadata - -logger = logging.getLogger(__name__) - - -def view_to_reshape( - gm: torch.fx.GraphModule, settings: CompilationSettings -) -> torch.fx.GraphModule: - """Replace aten.view with an equivalent implementation which avoids Tensor memory issues""" - orig_op = torch.ops.aten.view.default - replacement_op = torch.ops.aten.reshape.default - - # Original graph - def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor: - return orig_op(input, shape) - - # Replacement graph - def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor: - return replacement_op(input, shape) - - match_and_replacements = torch.fx.subgraph_rewriter._replace_pattern( - gm, orig, replacement - ) - if match_and_replacements: - gm = clean_up_graph_after_modifications(gm) - logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}") - - # Copy the orig_op's metadata to the replacement op - copy_metadata(match_and_replacements) - - return gm diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py index 76d47d24bd..2f78f6ed7a 100644 --- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -1,19 +1,9 @@ -import sys -import unittest - import torch import torch_tensorrt -from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION from torch.testing._internal.common_utils import TestCase, run_tests from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing -isSM8XDevice = torch.cuda.is_available() and torch.cuda.get_device_capability() in [ - (8, 6), - (8, 7), - (8, 9), -] - class TestInputAsOutput(TestCase): def test_input_as_output(self): @@ -166,71 +156,6 @@ def forward(self, x): torch._dynamo.reset() -class TestLowerViewToReshape(TestCase): - def test_view_to_reshape(self): - class ViewToReshape(torch.nn.Module): - def forward(self, input): - out = torch.ops.aten.view.default(input, (1, 1, -1)) - return out - - inputs = [ - torch.rand((3, 4, 5, 32)).cuda(), - ] - - fx_graph = torch.fx.symbolic_trace(ViewToReshape()) - expected_ops = {torch.ops.aten.reshape.default} - unexpected_ops = { - torch.ops.aten.view.default, - } - - unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( - fx_graph, - inputs, - expected_ops=expected_ops, - unexpected_ops=unexpected_ops, - min_block_size=1, - ) - - self.assertEqual( - len(unexpected_ops_seen), - 0, - f"The following unexpected ops were encountered: {unexpected_ops_seen}", - ) - - self.assertEqual( - len(expected_ops_unseen), - 0, - f"The following expected ops were not encountered: {expected_ops_unseen}", - ) - torch._dynamo.reset() - - # Validate that the results between Torch and Torch-TRT are similar - optimized_model = torch_tensorrt.compile( - fx_graph, - "torch_compile", - inputs, - min_block_size=1, - pass_through_build_failures=True, - ) - optimized_model_results = torch.cat( - [tensor.detach().cpu() for tensor in optimized_model(*inputs)] - ) - torch_model_results = torch.cat( - [tensor.detach().cpu() for tensor in fx_graph(*inputs)] - ) - - max_diff = float( - torch.max(torch.abs(optimized_model_results - torch_model_results)) - ) - self.assertAlmostEqual( - max_diff, - 0, - DECIMALS_OF_AGREEMENT, - msg=f"ViewToReshape TRT outputs don't match with the original model.", - ) - torch._dynamo.reset() - - class TestFP32Accumulation(TestCase): def test_fp32_acc(self): class FP32Acc(torch.nn.Module): @@ -270,249 +195,5 @@ def forward(self, input, weight): torch._dynamo.reset() -class TestLowerEfficientAttention(TestCase): - def test_lower_efficient_attention(self): - class EfficientAttention(torch.nn.Module): - def forward(self, q, k, v): - attn = torch.ops.aten._scaled_dot_product_efficient_attention.default( - q, k, v, None, False - ) - return attn[0] - - inputs = [ - torch.rand(8, 4, 5, 4).cuda(), - torch.rand(8, 4, 2, 4).cuda(), - torch.rand(8, 4, 2, 4).cuda(), - ] - - fx_graph = torch.fx.symbolic_trace(EfficientAttention()) - expected_ops = {torch.nn.functional.scaled_dot_product_attention} - unexpected_ops = { - torch.ops.aten._scaled_dot_product_efficient_attention.default - } - - unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( - fx_graph, - inputs, - expected_ops=expected_ops, - unexpected_ops=unexpected_ops, - min_block_size=1, - ) - - self.assertEqual( - len(unexpected_ops_seen), - 0, - f"The following unexpected ops were encountered: {unexpected_ops_seen}", - ) - - self.assertEqual( - len(expected_ops_unseen), - 0, - f"The following expected ops were not encountered: {expected_ops_unseen}", - ) - torch._dynamo.reset() - - # Validate that the results between Torch and Torch-TRT are similar - optimized_model = torch_tensorrt.compile( - fx_graph, - "torch_compile", - inputs, - min_block_size=1, - pass_through_build_failures=True, - ) - optimized_model_results = torch.cat( - [tensor.detach().cpu() for tensor in optimized_model(*inputs)] - ) - torch_model_results = torch.cat( - [tensor.detach().cpu() for tensor in fx_graph(*inputs)] - ) - - max_diff = float( - torch.max(torch.abs(optimized_model_results - torch_model_results)) - ) - self.assertAlmostEqual( - max_diff, - 0, - DECIMALS_OF_AGREEMENT, - msg=f"EfficientAttention TRT outputs don't match with the original model.", - ) - torch._dynamo.reset() - - def test_efficient_attention_converter(self): - class EfficientAttention(torch.nn.Module): - def forward(self, q, k, v): - attn = torch.ops.aten._scaled_dot_product_efficient_attention.default( - q, k, v, None, False - ) - return attn[0] - - inputs = [ - torch.rand(1, 3, 6, 4).cuda(), - torch.rand(1, 3, 2, 4).cuda(), - torch.rand(1, 3, 2, 4).cuda(), - ] - - fx_graph = torch.fx.symbolic_trace(EfficientAttention()) - - # Validate that the results between Torch and Torch-TRT are similar - optimized_model = torch_tensorrt.compile( - fx_graph, - "torch_compile", - inputs, - min_block_size=1, - pass_through_build_failures=True, - ) - optimized_model_results = torch.cat( - [tensor.detach().cpu() for tensor in optimized_model(*inputs)] - ) - torch_model_results = torch.cat( - [tensor.detach().cpu() for tensor in fx_graph(*inputs)] - ) - - max_diff = float( - torch.max(torch.abs(optimized_model_results - torch_model_results)) - ) - self.assertAlmostEqual( - max_diff, - 0, - DECIMALS_OF_AGREEMENT, - msg=f"EfficientAttention TRT outputs don't match with the original model.", - ) - torch._dynamo.reset() - - -@unittest.skipIf( - torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8, - "GPU compute capability is too low to run flash attention, need Ampere (8.0) or greater", -) -@unittest.skipIf( - sys.platform.startswith("win"), - "Test not supported on Windows", -) -class TestLowerFlashAttention(TestCase): - @unittest.skipIf( - not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice, - "Does not support fused SDPA or not SM86+ hardware", - ) - def test_lower_flash_attention(self): - class FlashAttention(torch.nn.Module): - def forward(self, q, k, v): - attn = torch.ops.aten._scaled_dot_product_flash_attention.default( - q, - k, - v, - scale=0.15, - ) - return attn[0] - - inputs = [ - torch.rand(8, 4, 16, 8).half().cuda(), - torch.rand(8, 4, 16, 8).half().cuda(), - torch.rand(8, 4, 16, 8).half().cuda(), - ] - - fx_graph = torch.fx.symbolic_trace(FlashAttention()) - expected_ops = {torch.nn.functional.scaled_dot_product_attention} - unexpected_ops = {torch.ops.aten._scaled_dot_product_flash_attention.default} - - unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( - fx_graph, - inputs, - expected_ops=expected_ops, - unexpected_ops=unexpected_ops, - min_block_size=1, - ) - - self.assertEqual( - len(unexpected_ops_seen), - 0, - f"The following unexpected ops were encountered: {unexpected_ops_seen}", - ) - - self.assertEqual( - len(expected_ops_unseen), - 0, - f"The following expected ops were not encountered: {expected_ops_unseen}", - ) - torch._dynamo.reset() - - # Validate that the results between Torch and Torch-TRT are similar - optimized_model = torch_tensorrt.compile( - fx_graph, - "torch_compile", - inputs, - min_block_size=1, - pass_through_build_failures=True, - ) - optimized_model_results = torch.cat( - [tensor.detach().cpu() for tensor in optimized_model(*inputs)] - ) - torch_model_results = torch.cat( - [tensor.detach().cpu() for tensor in fx_graph(*inputs)] - ) - - max_diff = float( - torch.max(torch.abs(optimized_model_results - torch_model_results)) - ) - # Remove 1 decimal from the requirement for FP16 - self.assertAlmostEqual( - max_diff, - 0, - DECIMALS_OF_AGREEMENT - 1, - msg=f"FlashAttention TRT outputs don't match with the original model.", - ) - torch._dynamo.reset() - - @unittest.skipIf( - not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice, - "Does not support fused SDPA or not SM86+ hardware", - ) - def test_flash_attention_converter(self): - class FlashAttention(torch.nn.Module): - def forward(self, q, k, v): - attn = torch.ops.aten._scaled_dot_product_flash_attention.default( - q, - k, - v, - scale=0.25, - ) - return attn[0] - - inputs = [ - torch.rand(1, 3, 6, 8).half().cuda(), - torch.rand(1, 3, 2, 8).half().cuda(), - torch.rand(1, 3, 2, 8).half().cuda(), - ] - - fx_graph = torch.fx.symbolic_trace(FlashAttention()) - - # Validate that the results between Torch and Torch-TRT are similar - optimized_model = torch_tensorrt.compile( - fx_graph, - "torch_compile", - inputs, - min_block_size=1, - pass_through_build_failures=True, - ) - optimized_model_results = torch.cat( - [tensor.detach().cpu() for tensor in optimized_model(*inputs)] - ) - torch_model_results = torch.cat( - [tensor.detach().cpu() for tensor in fx_graph(*inputs)] - ) - - max_diff = float( - torch.max(torch.abs(optimized_model_results - torch_model_results)) - ) - # Remove 1 decimal from the requirement for FP16 - self.assertAlmostEqual( - max_diff, - 0, - DECIMALS_OF_AGREEMENT - 1, - msg=f"FlashAttention TRT outputs don't match with the original model.", - ) - torch._dynamo.reset() - - if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index 797d8d3263..1ac445a588 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -1,8 +1,16 @@ +import unittest + import torch import torch_tensorrt from parameterized import parameterized -from testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing +from torch.testing._internal.common_cuda import ( + PLATFORM_SUPPORTS_CUDNN_ATTENTION, + PLATFORM_SUPPORTS_FLASH_ATTENTION, +) from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt.dynamo.utils import ATOL, RTOL + +from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing class TestLowering(TestCase): @@ -1720,6 +1728,396 @@ def forward(self, input, weight, bias, running_mean=None, running_var=None): "Instance_norm TRT outputs don't match with the original model.", ) + def test_lowering_view(self): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.view.default(x, [1, 3, -1]) + + # Operations expected to be removed in the traced graph after decompositions + expected_ops = {torch.ops.aten._reshape_copy.default} + unexpected_ops = {torch.ops.aten.view.default} + + inputs = [torch.randn(1, 3, 5, 7, device="cuda")] + + exported_program = torch.export.export(TestModule(), tuple(inputs)) + fx_graph = exported_program.module() + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEqual( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + trt_model = torch_tensorrt.dynamo.compile( + exported_program, inputs, min_block_size=1 + ) + torch.testing.assert_close( + trt_model(*inputs), + fx_graph(*inputs), + rtol=RTOL, + atol=ATOL, + msg="View TRT outputs don't match with the original model.", + ) + + @parameterized.expand( + [ + (True, False, None, False), + (False, True, 0.123, True), + ] + ) + def test_lowering_scaled_dot_product_attention( + self, attn, is_causal, scale, enable_gqa + ): + class TestModule(torch.nn.Module): + def forward(self, query, key, value, attn_mask=None): + return torch.ops.aten.scaled_dot_product_attention.default( + query, + key, + value, + attn_mask, + 0.0, + is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + # Operations expected to be removed in the traced graph after decompositions + unexpected_ops = {torch.ops.aten.scaled_dot_product_attention.default} + + inputs = [ + torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"), + torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"), + torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"), + ] + if attn: + inputs += [torch.rand(1, 3, 8, 8, dtype=torch.half, device="cuda")] + + exported_program = torch.export.export(TestModule(), tuple(inputs)) + fx_graph = exported_program.module() + unexpected_ops_seen, _ = lower_graph_testing( + fx_graph, inputs, unexpected_ops=unexpected_ops, min_block_size=1 + ) + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + trt_model = torch_tensorrt.dynamo.compile( + exported_program, inputs, enabled_precisions={torch.half}, min_block_size=1 + ) + torch.testing.assert_close( + trt_model(*inputs), + fx_graph(*inputs), + rtol=RTOL, + atol=ATOL, + msg="Scaled_dot_product_attention TRT outputs don't match with the original model.", + ) + + @parameterized.expand( + [ + (True, False, None, False), + (False, True, 0.123, True), + ] + ) + def test_lowering_scaled_dot_product_attention_with_dynamic_shape( + self, attn, is_causal, scale, enable_gqa + ): + class TestModule(torch.nn.Module): + def forward(self, query, key, value, attn_mask=None): + return torch.ops.aten.scaled_dot_product_attention.default( + query, + key, + value, + attn_mask, + 0.0, + is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + example_inputs = [ + torch.zeros(2, 2, 16, 32, dtype=torch.half, device="cuda"), + torch.zeros(2, 2, 16, 32, dtype=torch.half, device="cuda"), + torch.zeros(2, 2, 16, 32, dtype=torch.half, device="cuda"), + ] + if attn: + example_inputs += [ + torch.zeros(2, 2, 16, 16, dtype=torch.half, device="cuda") + ] + + dim0 = torch.export.Dim("dim0", min=2, max=4) + dim1 = torch.export.Dim("dim1", min=2, max=8) + _dim2 = torch.export.Dim("dim2", min=16 // 8, max=64 // 8) + _dim3 = torch.export.Dim("dim3", min=32 // 8, max=128 // 8) + dim2 = _dim2 * 8 + dim3 = _dim3 * 8 + + dynamic_shapes = { + "query": {0: dim0, 1: dim1, 2: dim2, 3: dim3}, + "key": {0: dim0, 1: dim1, 2: dim2, 3: dim3}, + "value": {0: dim0, 1: dim1, 2: dim2, 3: dim3}, + } + if attn: + dynamic_shapes["attn_mask"] = {0: dim0, 1: dim1, 2: dim2, 3: dim2} + + exported_program = torch.export.export( + TestModule(), tuple(example_inputs), dynamic_shapes=dynamic_shapes + ) + fx_graph = exported_program.module() + + inputs = [ + torch_tensorrt.Input( + min_shape=(2, 2, 16, 32), + opt_shape=(3, 4, 32, 64), + max_shape=(4, 8, 64, 128), + dtype=torch.half, + ), + torch_tensorrt.Input( + min_shape=(2, 2, 16, 32), + opt_shape=(3, 4, 32, 64), + max_shape=(4, 8, 64, 128), + dtype=torch.half, + ), + torch_tensorrt.Input( + min_shape=(2, 2, 16, 32), + opt_shape=(3, 4, 32, 64), + max_shape=(4, 8, 64, 128), + dtype=torch.half, + ), + ] + if attn: + inputs += [ + torch_tensorrt.Input( + min_shape=(2, 2, 16, 16), + opt_shape=(3, 4, 32, 32), + max_shape=(4, 8, 64, 64), + dtype=torch.half, + ) + ] + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + trt_model = torch_tensorrt.dynamo.compile( + exported_program, inputs, enabled_precisions={torch.half}, min_block_size=1 + ) + + inputs = [ + torch.rand(4, 8, 64, 128, dtype=torch.half, device="cuda"), + torch.rand(4, 8, 64, 128, dtype=torch.half, device="cuda"), + torch.rand(4, 8, 64, 128, dtype=torch.half, device="cuda"), + ] + if attn: + inputs += [torch.rand(4, 8, 64, 64, dtype=torch.half, device="cuda")] + + torch.testing.assert_close( + trt_model(*inputs), + fx_graph(*inputs), + rtol=RTOL, + atol=ATOL, + msg="Scaled_dot_product_attention_with_dynamic_shape TRT outputs don't match with the original model.", + ) + + @parameterized.expand( + [ + (False, None), + (True, 0.123), + ] + ) + @unittest.skipUnless( + PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform doesn't support Flash attention" + ) + def test_lowering_scaled_dot_product_flash_attention(self, is_causal, scale): + class TestModule(torch.nn.Module): + def forward(self, query, key, value): + return torch.ops.aten._scaled_dot_product_flash_attention.default( + query, + key, + value, + 0.0, + is_causal, + False, + scale=scale, + )[0] + + # Operations expected to be removed in the traced graph after decompositions + unexpected_ops = {torch.ops.aten._scaled_dot_product_flash_attention.default} + + inputs = [ + torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"), + torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"), + torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"), + ] + + exported_program = torch.export.export(TestModule(), tuple(inputs)) + fx_graph = exported_program.module() + unexpected_ops_seen, _ = lower_graph_testing( + fx_graph, inputs, unexpected_ops=unexpected_ops, min_block_size=1 + ) + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + trt_model = torch_tensorrt.dynamo.compile( + exported_program, inputs, enabled_precisions={torch.half}, min_block_size=1 + ) + torch.testing.assert_close( + trt_model(*inputs), + fx_graph(*inputs), + rtol=RTOL, + atol=ATOL, + msg="Scaled_dot_product_flash_attention TRT outputs don't match with the original model.", + ) + + @parameterized.expand( + [ + (True, False, None), + (False, True, 0.123), + ] + ) + def test_lowering_scaled_dot_product_efficient_attention( + self, attn, is_causal, scale + ): + class TestModule(torch.nn.Module): + def forward(self, query, key, value, attn_bias=None): + return torch.ops.aten._scaled_dot_product_efficient_attention.default( + query, + key, + value, + attn_bias, + False, + 0.0, + is_causal, + scale=scale, + )[0] + + # Operations expected to be removed in the traced graph after decompositions + unexpected_ops = { + torch.ops.aten._scaled_dot_product_efficient_attention.default + } + + inputs = [ + torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"), + torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"), + torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"), + ] + if attn: + inputs += [torch.rand(1, 3, 8, 8, dtype=torch.half, device="cuda")] + + exported_program = torch.export.export(TestModule(), tuple(inputs)) + fx_graph = exported_program.module() + unexpected_ops_seen, _ = lower_graph_testing( + fx_graph, inputs, unexpected_ops=unexpected_ops, min_block_size=1 + ) + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + trt_model = torch_tensorrt.dynamo.compile( + exported_program, inputs, enabled_precisions={torch.half}, min_block_size=1 + ) + torch.testing.assert_close( + trt_model(*inputs), + fx_graph(*inputs), + rtol=RTOL, + atol=ATOL, + msg="Scaled_dot_product_efficient_attention TRT outputs don't match with the original model.", + ) + + @parameterized.expand( + [ + (True, False, None), + (False, True, 0.123), + ] + ) + @unittest.skipUnless( + PLATFORM_SUPPORTS_CUDNN_ATTENTION, "Platform doesn't support cuDNN attention" + ) + def test_lowering_scaled_dot_product_cudnn_attention(self, attn, is_causal, scale): + class TestModule(torch.nn.Module): + def forward(self, query, key, value, attn_bias=None): + return torch.ops.aten._scaled_dot_product_cudnn_attention.default( + query, + key, + value, + attn_bias, + False, + 0.0, + is_causal, + False, + scale=scale, + )[0] + + # Operations expected to be removed in the traced graph after decompositions + unexpected_ops = {torch.ops.aten._scaled_dot_product_cudnn_attention.default} + + inputs = [ + torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"), + torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"), + torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"), + ] + if attn: + inputs += [torch.rand(1, 3, 8, 8, dtype=torch.half, device="cuda")] + + exported_program = torch.export.export(TestModule(), tuple(inputs)) + fx_graph = exported_program.module() + unexpected_ops_seen, _ = lower_graph_testing( + fx_graph, inputs, unexpected_ops=unexpected_ops, min_block_size=1 + ) + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + trt_model = torch_tensorrt.dynamo.compile( + exported_program, inputs, enabled_precisions={torch.half}, min_block_size=1 + ) + torch.testing.assert_close( + trt_model(*inputs), + fx_graph(*inputs), + rtol=RTOL, + atol=ATOL, + msg="Scaled_dot_product_cudnn_attention TRT outputs don't match with the original model.", + ) + if __name__ == "__main__": run_tests()