Skip to content

Commit

Permalink
Format Update
Browse files Browse the repository at this point in the history
Signed-off-by: Yifei Song <[email protected]>
  • Loading branch information
yifeis-nv committed Nov 20, 2024
1 parent 8017b6d commit 41d6100
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
allow_unused=allow_unused_input,
)
del outputs, grad_inputs
# The following code is added specifically for MCore's special requirements,
# aimed at preventing warmup from altering the control flow.
for module in func.modules():
if hasattr(module, "is_first_microbatch"):
module.is_first_microbatch = True
Expand Down Expand Up @@ -527,12 +529,14 @@ def new_fwd(*user_args, **user_kwargs):
# Only Set the FP8 meta for the modules included by forward
continue
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
from transformer_engine.pytorch.attention import DotProductAttention

if (
not fp8_recipe.fp8_mha
isinstance(m, DotProductAttention)
and not fp8_recipe.fp8_mha
and not fp8_recipe.fp8_dpa
and hasattr(m, "attention_dropout")
and m.deterministic
):
# Don't need to update FP8 meta for non-FP8 DPA
continue
m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
Expand Down

0 comments on commit 41d6100

Please sign in to comment.