From 66748b9766c5b9019d1f92816ad226f1be16f462 Mon Sep 17 00:00:00 2001 From: Yifei Song Date: Tue, 19 Nov 2024 22:50:11 -0800 Subject: [PATCH] Format Update Signed-off-by: Yifei Song --- transformer_engine/pytorch/graph.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index e977e0a241..77eb0df920 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -263,6 +263,8 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument allow_unused=allow_unused_input, ) del outputs, grad_inputs + # The following code is added specifically for MCore's special requirements, + # aimed at preventing warmup from altering the control flow. for module in func.modules(): if hasattr(module, "is_first_microbatch"): module.is_first_microbatch = True @@ -524,12 +526,14 @@ def new_fwd(*user_args, **user_kwargs): # Only Set the FP8 meta for the modules included by forward continue fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + from transformer_engine.pytorch.attention import DotProductAttention + if ( - not fp8_recipe.fp8_mha + isinstance(m, DotProductAttention) + and not fp8_recipe.fp8_mha and not fp8_recipe.fp8_dpa - and hasattr(m, "attention_dropout") - and m.deterministic ): + # Don't need to update FP8 meta for non-FP8 DPA continue m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()