diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index f45364d248..77eb0df920 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -526,7 +526,8 @@ def new_fwd(*user_args, **user_kwargs): # Only Set the FP8 meta for the modules included by forward continue fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - from transformer_engine.pytorch.attention import DotProductAttention + from transformer_engine.pytorch.attention import DotProductAttention + if ( isinstance(m, DotProductAttention) and not fp8_recipe.fp8_mha