Skip to content

Commit

Permalink
Format Update
Browse files Browse the repository at this point in the history
Signed-off-by: Yifei Song <[email protected]>
  • Loading branch information
yifeis-nv committed Nov 20, 2024
1 parent f92d901 commit 66748b9
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
allow_unused=allow_unused_input,
)
del outputs, grad_inputs
# The following code is added specifically for MCore's special requirements,
# aimed at preventing warmup from altering the control flow.
for module in func.modules():
if hasattr(module, "is_first_microbatch"):
module.is_first_microbatch = True
Expand Down Expand Up @@ -524,12 +526,14 @@ def new_fwd(*user_args, **user_kwargs):
# Only Set the FP8 meta for the modules included by forward
continue
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
from transformer_engine.pytorch.attention import DotProductAttention

if (
not fp8_recipe.fp8_mha
isinstance(m, DotProductAttention)
and not fp8_recipe.fp8_mha
and not fp8_recipe.fp8_dpa
and hasattr(m, "attention_dropout")
and m.deterministic
):
# Don't need to update FP8 meta for non-FP8 DPA
continue
m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
Expand Down

0 comments on commit 66748b9

Please sign in to comment.