Skip to content

Commit

Permalink
Add fp8_group argument and fix fp8 accuracy issue for cudagraph
Browse files Browse the repository at this point in the history
Signed-off-by: Robin Zhang <[email protected]>
Co-authored-by: Yifei Song <[email protected]>
  • Loading branch information
buptzyb and yifeis-nv committed Oct 9, 2024
1 parent 1c9e35b commit cef94e4
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 12 deletions.
12 changes: 6 additions & 6 deletions transformer_engine/pytorch/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,16 +473,16 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non
stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft()

# Replace amaxes and scales with stashed values for phase 2 forward
fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0]
fp8_meta["scaling_fwd"].scale = stashed_fp8_meta[1]
fp8_meta["scaling_fwd"].scale_inv = stashed_fp8_meta[2]
fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0])
fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1])
fp8_meta["scaling_fwd"].scale_inv.copy_(stashed_fp8_meta[2])

@staticmethod
def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
"""Restore latest scaling factors and amaxes after recompute forward run."""
fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"]
fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"]
fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"]
fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"])
fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"])
fp8_meta["scaling_fwd"].scale_inv.copy_(fp8_meta["updated_scale_inv_fwd"])


@contextmanager
Expand Down
14 changes: 13 additions & 1 deletion transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch._C import _graph_pool_handle

from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.pytorch.constants import dist_group_type
from .fp8 import (
fp8_autocast,
FP8GlobalStateManager,
Expand Down Expand Up @@ -247,6 +248,9 @@ def _make_graphed_callables(
allow_unused=allow_unused_input,
)
del outputs, grad_inputs
for module in func.modules():
if hasattr(module, "is_first_microbatch"):
module.is_first_microbatch = True
torch.cuda.synchronize()

# All captures here share a mempool. To avoid replays corrupting each other's memory,
Expand Down Expand Up @@ -549,6 +553,7 @@ def make_graphed_callables(
fp8_enabled: bool = False,
fp8_calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
fp8_group: Optional[dist_group_type] = None,
fp8_weight_caching: bool = False,
_order: Optional[List[int]] = None,
pool: Optional[Tuple[int, ...]] = None,
Expand Down Expand Up @@ -590,6 +595,9 @@ def make_graphed_callables(
using a higher precision.
fp8_recipe: recipe.DelayedScaling, default = `None`
recipe used for FP8 training.
fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
distributed group over which amaxes for the fp8 tensors
are reduced at the end of each training step.
fp8_weight_caching: bool, default = `False`
Whether or not to cache FP8 weights across microbatches. if set to `True`,
the `is_first_microbatch` boolean argument must be passed into the forward
Expand Down Expand Up @@ -618,7 +626,11 @@ def wrap_autocast(block):

def forward_func(*args, **kwargs):
with fp8_autocast(
enabled=fp8_enabled, calibrating=fp8_calibrating, fp8_recipe=fp8_recipe, _graph=True
enabled=fp8_enabled,
calibrating=fp8_calibrating,
fp8_recipe=fp8_recipe,
fp8_group=fp8_group,
_graph=True,
):
outputs = old_forward(*args, **kwargs)
return outputs
Expand Down
7 changes: 5 additions & 2 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,7 +1144,10 @@ def forward(
produced)
"""

skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else:
skip_fp8_weight_update = None
if skip_fp8_weight_update is not None:
is_first_microbatch = False

Expand Down Expand Up @@ -1206,7 +1209,7 @@ def forward(
self.apply_bias and not self.gemm_bias_unfused_add,
self.eps,
is_first_microbatch,
self.fp8,
False if self.fp8 is None else self.fp8,
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
Expand Down
5 changes: 4 additions & 1 deletion transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,7 +1466,10 @@ def forward(
produced)
"""

skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else:
skip_fp8_weight_update = None
if skip_fp8_weight_update is not None:
is_first_microbatch = False

Expand Down
6 changes: 4 additions & 2 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,8 +931,10 @@ def forward(
first microbatch (since it is the first gradient being
produced)
"""

skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else:
skip_fp8_weight_update = None
if skip_fp8_weight_update is not None:
is_first_microbatch = False

Expand Down

0 comments on commit cef94e4

Please sign in to comment.