-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Optimization][Operator] Implement and enable Conv2d-Reshape-Add-ReLU fusion #18173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
… fusion This PR introduces an operator fusion for the common `conv2d` followed by `reshape`, `add`, and `relu` sequence, commonly found in deep learning models (e.g., convolution + bias + activation pattern in PyTorch). This optimization significantly improves performance and efficiency by reducing overhead and optimizing memory usage. 1. **Performance Improvement:** * **Reduced Kernel Launch Overhead:** Previously, `conv2d`, `reshape`, `add`, and `relu` each required separate kernel calls. By fusing these four operations into a single, unified DNNL kernel (e.g., `dnnl_fused_conv2d_bias_relu`), the overhead from multiple kernel launches is significantly reduced. This is evident from `src/runtime/contrib/dnnl/dnnl.cc:154-158`, where all operations are handled by a single `execute` call. * **Decreased Memory Bandwidth Consumption:** Intermediate results of individual operations (e.g., `conv_out`, `bias_add`) traditionally required costly memory write-backs and reads. Fusion allows these intermediate values to be processed directly in registers or cache, reducing unnecessary memory accesses, and thus decreasing memory bandwidth usage and overall execution time. 2. **Increased Efficiency:** * **Leveraging Compiler Optimizations:** By utilizing TVM's `FuseOpsByPattern` and `MergeCompositeFunctions` passes, this change generates a composite operation optimized for specific backends (like DNNL). This ensures that common patterns from frontends like PyTorch are automatically recognized within the TVM graph and mapped to high-performance fused kernels provided by libraries like DNNL. * **Simplified IR Module:** Compilers' Intermediate Representation (IR) becomes less complex as multiple operation nodes are condensed into a single composite node. This simplification enhances efficiency in subsequent optimization and code generation stages. This fusion is achieved through a two-stage transformation within the TVM Relax framework: 1. **Pattern Recognition and Composite Function Creation (`FuseConv2dReshapeAddRelu` Pass):** * The `FuseConv2dReshapeAddRelu` class, registered as a `tvm.transform.module_pass`, transforms the `IRModule`. * The `_conv2d_reshape_add_relu_pattern()` helper function defines the specific sequence: `conv2d` -> `reshape` (applied to bias) -> `add` -> `relu` using TVM's Declarative Pattern Language (DPL). This includes matching input tensors (`data`, `weight`, `bias`, `shape`) using `wildcard()` and identifying operation sequence with `is_op()`. * The `relax.transform.FuseOpsByPattern` pass identifies this pattern in the input `IRModule`. Upon detection, the operation sequence is encapsulated into a new Relax function with `{"Composite": "dnnl.conv2d_reshape_add_relu", "Primitive": True}` attributes, marking it as a logical "composite" unit. 2. **Composite Function Merging and Codegen Attribute Assignment (`MergeCompositeFunctions` Pass):** * Following the `FuseConv2dReshapeAddRelu` pass, the `MergeCompositeFunctions` pass is applied via `tvm.ir.transform.Sequential`. * This pass identifies functions marked with the `Composite` attribute and transforms them into external functions bearing the `{"Codegen": "dnnl"}` attribute. This `Codegen` attribute indicates that the composite operation should be offloaded to a specific TVM backend, such as DNNL. * Consequently, during graph execution, the fused function with the `Codegen` attribute will be mapped and executed by an optimized, single DNNL kernel, for instance, `dnnl_fused_conv2d_bias_relu` (defined in `src/runtime/contrib/dnnl/dnnl.cc:199-207`). This implementation successfully enables the fusion of the `conv2d + reshape + add + relu` pattern. This ensures that common convolution + bias + activation patterns originating from frontends like PyTorch are now fully optimized and executed as a single, highly efficient DNNL kernel within TVM. --- To verify this fusion, you can directly run the specific test case: python tests/python/relax/test_conv2d_reshape_add_relu.py
… fusion This PR introduces an operator fusion for the common conv2d followed by reshape, add, and relu sequence, commonly found in deep learning models (e.g., convolution + bias + activation pattern in PyTorch). This optimization significantly improves performance and efficiency by reducing overhead and optimizing memory usage. Performance Improvement: Reduced Kernel Launch Overhead: Previously, conv2d, reshape, add, and relu each required separate kernel calls. By fusing these four operations into a single, unified DNNL kernel (e.g., dnnl_fused_conv2d_bias_relu), the overhead from multiple kernel launches is significantly reduced. This is evident from src/runtime/contrib/dnnl/dnnl.cc:154-158, where all operations are handled by a single execute call. Decreased Memory Bandwidth Consumption: Intermediate results of individual operations (e.g., conv_out, bias_add) traditionally required costly memory write-backs and reads. Fusion allows these intermediate values to be processed directly in registers or cache, reducing unnecessary memory accesses, and thus decreasing memory bandwidth usage and overall execution time. Increased Efficiency: Leveraging Compiler Optimizations: By utilizing TVM's FuseOpsByPattern and MergeCompositeFunctions passes, this change generates a composite operation optimized for specific backends (like DNNL). This ensures that common patterns from frontends like PyTorch are automatically recognized within the TVM graph and mapped to high-performance fused kernels provided by libraries like DNNL. Simplified IR Module: Compilers' Intermediate Representation (IR) becomes less complex as multiple operation nodes are condensed into a single composite node. This simplification enhances efficiency in subsequent optimization and code generation stages. This fusion is achieved through a two-stage transformation within the TVM Relax framework: Pattern Recognition and Composite Function Creation (FuseConv2dReshapeAddRelu Pass): The FuseConv2dReshapeAddRelu class, registered as a tvm.transform.module_pass, transforms the IRModule. The _conv2d_reshape_add_relu_pattern() helper function defines the specific sequence: conv2d -> reshape (applied to bias) -> add -> relu using TVM's Declarative Pattern Language (DPL). This includes matching input tensors (data, weight, bias, shape) using wildcard() and identifying operation sequence with is_op(). The relax.transform.FuseOpsByPattern pass identifies this pattern in the input IRModule. Upon detection, the operation sequence is encapsulated into a new Relax function with {"Composite": "dnnl.conv2d_reshape_add_relu", "Primitive": True} attributes, marking it as a logical "composite" unit. Composite Function Merging and Codegen Attribute Assignment (MergeCompositeFunctions Pass): Following the FuseConv2dReshapeAddRelu pass, the MergeCompositeFunctions pass is applied via tvm.ir.transform.Sequential. This pass identifies functions marked with the Composite attribute and transforms them into external functions bearing the {"Codegen": "dnnl"} attribute. This Codegen attribute indicates that the composite operation should be offloaded to a specific TVM backend, such as DNNL. Consequently, during graph execution, the fused function with the Codegen attribute will be mapped and executed by an optimized, single DNNL kernel, for instance, dnnl_fused_conv2d_bias_relu (defined in src/runtime/contrib/dnnl/dnnl.cc:199-207). This implementation successfully enables the fusion of the conv2d + reshape + add + relu pattern. This ensures that common convolution + bias + activation patterns originating from frontends like PyTorch are now fully optimized and executed as a single, highly efficient DNNL kernel within TVM. To verify this fusion, you can directly run the specific test case: python tests/python/relax/test_conv2d_reshape_add_relu.py
|
||
|
||
@tvm.transform.module_pass(opt_level=0, name="FuseConv2dReshapeAddRelu") | ||
class FuseConv2dReshapeAddRelu: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering if transform.FuseOps
will fuse them, I guess it might work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yongwww
Excellent point! However, after checking the actual implementation, I've confirmed that the generic FuseOps cannot handle this specific pattern.
Summary
The generic relax.transform.FuseOps pass is currently unable to fuse the common conv2d + bias + activation pattern when imported from PyTorch. The root cause is that the PyTorch frontend generates a conv2d -> reshape -> add sequence for the bias term, which the existing pattern matcher in FuseOps does not recognize. This leaves a critical, common pattern unoptimized.
The Pattern Generated by the PyTorch Frontend
When handling a torch.nn.Conv2d layer with bias=True, the PyTorch frontend consistently generates a reshape + add pattern for the bias. This is not specific to Conv2d and is standard behavior for other convolution types as well:
Conv1d: See test_frontend_from_exported_program.py:1752-1753
Conv2d: See test_frontend_from_fx.py:269-270
Conv3d: See test_frontend_from_exported_program.py:3822-3823
Limitation of TVM's Current Pattern Matching
The pattern designed to fuse bias and activation, make_fused_bias_activation_pattern, is defined in pattern.py:1179-1181. This function is currently implemented to match only a simple relax.add operation following the convolution. It cannot see past the reshape operation inserted by the frontend, thus failing to match the sequence.
Proof by Code: A Reproducible Example
The following test case demonstrates that FuseOps fails to fuse this pattern.
import torch
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_fx
# 1. PyTorch Conv2d model with bias and ReLU
class Conv2dWithBias(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 6, 3, bias=True)
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(self.conv(x))
# 2. Trace and convert the model to TVM Relax IR
model = Conv2dWithBias()
graph_model = torch.fx.symbolic_trace(model)
input_info = [([1, 3, 10, 10], "float32")]
mod = from_fx(graph_model, input_info)
print("### Original Relax IR (Before FuseOps):")
print(mod)
# 3. Apply the generic FuseOps pass
fused_mod = relax.transform.FuseOps()(mod)
print("\n### Relax IR After Applying FuseOps:")
print(fused_mod)
Execution Results
Converted IR (Before FuseOps): A sequence of four separate operations is generated: conv2d → reshape → add → relu.
IR After FuseOps: The IR remains completely unchanged, confirming that the fusion failed.
This failure is a direct result of the pattern in pattern.py:1179-1181 matching only relax.add and not the reshape + add sequence.
Conclusion and Proposal
The generic FuseOps pass cannot handle this frontend-specific pattern, leaving a common PyTorch model structure (conv2d + bias + relu) unoptimized.
Therefore, a specialized pass like FuseConv2dReshapeAddRelu is essential to correctly identify and fuse this pattern. This targeted pass is necessary to bridge the gap between the PyTorch frontend's IR generation and TVM's optimization capabilities, unlocking performance for a wide range of models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could extend FuseOps to handle this - that way, other cases could benefit as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a moment, I'll get to it.
This PR introduces an operator fusion for the common
conv2d
followed byreshape
,add
, andrelu
sequence, commonly found in deep learning models (e.g., convolution + bias + activation pattern in PyTorch). This optimization significantly improves performance and efficiency by reducing overhead and optimizing memory usage.Specific Benefits:
Performance Improvement:
conv2d
,reshape
,add
, andrelu
each required separate kernel calls. By fusing these four operations into a single, unified DNNL kernel (e.g.,dnnl_fused_conv2d_bias_relu
), the overhead from multiple kernel launches is significantly reduced. This is evident fromsrc/runtime/contrib/dnnl/dnnl.cc:154-158
, where all operations are handled by a singleexecute
call.conv_out
,bias_add
) traditionally required costly memory write-backs and reads. Fusion allows these intermediate values to be processed directly in registers or cache, reducing unnecessary memory accesses, and thus decreasing memory bandwidth usage and overall execution time.Increased Efficiency:
FuseOpsByPattern
andMergeCompositeFunctions
passes, this change generates a composite operation optimized for specific backends (like DNNL). This ensures that common patterns from frontends like PyTorch are automatically recognized within the TVM graph and mapped to high-performance fused kernels provided by libraries like DNNL.How Fusion Works:
This fusion is achieved through a two-stage transformation within the TVM Relax framework:
Pattern Recognition and Composite Function Creation (
FuseConv2dReshapeAddRelu
Pass):FuseConv2dReshapeAddRelu
class, registered as atvm.transform.module_pass
, transforms theIRModule
._conv2d_reshape_add_relu_pattern()
helper function defines the specific sequence:conv2d
->reshape
(applied to bias) ->add
->relu
using TVM's Declarative Pattern Language (DPL). This includes matching input tensors (data
,weight
,bias
,shape
) usingwildcard()
and identifying operation sequence withis_op()
.relax.transform.FuseOpsByPattern
pass identifies this pattern in the inputIRModule
. Upon detection, the operation sequence is encapsulated into a new Relax function with{"Composite": "dnnl.conv2d_reshape_add_relu", "Primitive": True}
attributes, marking it as a logical "composite" unit.Composite Function Merging and Codegen Attribute Assignment (
MergeCompositeFunctions
Pass):FuseConv2dReshapeAddRelu
pass, theMergeCompositeFunctions
pass is applied viatvm.ir.transform.Sequential
.Composite
attribute and transforms them into external functions bearing the{"Codegen": "dnnl"}
attribute. ThisCodegen
attribute indicates that the composite operation should be offloaded to a specific TVM backend, such as DNNL.Codegen
attribute will be mapped and executed by an optimized, single DNNL kernel, for instance,dnnl_fused_conv2d_bias_relu
(defined insrc/runtime/contrib/dnnl/dnnl.cc:199-207
).Key Achievement:
This implementation successfully enables the fusion of the
conv2d + reshape + add + relu
pattern. This ensures that common convolution + bias + activation patterns originating from frontends like PyTorch are now fully optimized and executed as a single, highly efficient DNNL kernel within TVM.How to Test:
To verify this fusion, you can directly run the specific test case: