Skip to content

Commit

Permalink
Merge branch 'main' into activation-ops
Browse files Browse the repository at this point in the history
  • Loading branch information
timmoon10 authored Nov 15, 2024
2 parents a6a55e4 + d1488e7 commit 52b1c82
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 92 deletions.
2 changes: 1 addition & 1 deletion tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def test_fp8_scale_update(
)

# Check that scaling factors match expected
w_amax_ref = max(w_vals[: step + 2])
w_amax_ref = max(w_vals[: step + 1])
x_amax_ref = max(x_vals[: step + 1])
dy_amax_ref = max(dy_vals[: step + 1])
w_scale_ref = (fp8_format.value.max_fwd / w_amax_ref) / (2**margin)
Expand Down
35 changes: 18 additions & 17 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2528,12 +2528,13 @@ def backward(ctx, dout):
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

(q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6]
(fp8_fwd_scales, fp8_fwd_scale_invs) = ctx.saved_tensors[6:8]
cu_seqlens_q_per_step = ctx.saved_tensors[8 : 8 + cp_size]
cu_seqlens_kv_per_step = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2]
rng_states = ctx.saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3]
attn_biases = ctx.saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4]
(*saved_tensors,) = ctx.saved_tensors
(q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = saved_tensors[:6]
(fp8_fwd_scales, fp8_fwd_scale_invs) = saved_tensors[6:8]
cu_seqlens_q_per_step = saved_tensors[8 : 8 + cp_size]
cu_seqlens_kv_per_step = saved_tensors[8 + cp_size : 8 + cp_size * 2]
rng_states = saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3]
attn_biases = saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4]

causal = "causal" in ctx.attn_mask_type
padding = "padding" in ctx.attn_mask_type
Expand Down Expand Up @@ -3577,11 +3578,12 @@ def backward(ctx, dout):
cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group)

(q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = ctx.saved_tensors[:5]
cu_seqlens_kv_per_step = ctx.saved_tensors[5:7]
out_per_step = ctx.saved_tensors[7:9]
softmax_lse_per_step = ctx.saved_tensors[9:11]
rng_states = ctx.saved_tensors[11:13]
(*saved_tensors,) = ctx.saved_tensors
(q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5]
cu_seqlens_kv_per_step = saved_tensors[5:7]
out_per_step = saved_tensors[7:9]
softmax_lse_per_step = saved_tensors[9:11]
rng_states = saved_tensors[11:13]
kv_seq_range_per_step = ctx.kv_seq_range_per_step
window_size_per_step = ctx.window_size_per_step

Expand Down Expand Up @@ -4056,12 +4058,11 @@ def backward(ctx, dout):
# pylint: disable=missing-function-docstring
cp_size = get_distributed_world_size(ctx.cp_group)

q, k, v, out = ctx.saved_tensors[:4]
cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = ctx.saved_tensors[
4:8
]
fp8_fwd_scales, fp8_fwd_scale_invs = ctx.saved_tensors[8:10]
aux_ctx_tensors = ctx.saved_tensors[10:]
(*saved_tensors,) = ctx.saved_tensors
q, k, v, out = saved_tensors[:4]
cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = saved_tensors[4:8]
fp8_fwd_scales, fp8_fwd_scale_invs = saved_tensors[8:10]
aux_ctx_tensors = saved_tensors[10:]

qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
causal = "causal" in ctx.attn_mask_type
Expand Down
47 changes: 7 additions & 40 deletions transformer_engine/pytorch/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ def reset(cls) -> None:
cls.fp8_available = None
cls.reason_for_no_fp8 = ""
cls.autocast_arguments = {}
cls.autocast_to_fp8_params = {}
cls.fp8_param_to_autocast = {}
cls.skip_fp8_weight_update_tensor = None

@classmethod
Expand Down Expand Up @@ -156,28 +154,25 @@ def get_buffer_info(cls) -> str:
def get_key_in_buffer(
cls,
forward: bool,
fp8_weights: bool,
fp8_recipe: DelayedScaling,
fp8_group: dist_group_type,
) -> str:
"""Returns a key into the global FP8 buffers."""
autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
fwd_bwd_key = cls.get_fwd_bwd_key(forward)
return f"{fwd_bwd_key}_{fp8_weights}_{autocast_key}"
return f"{fwd_bwd_key}_{autocast_key}"

@classmethod
def split_key_in_buffer(cls, key: str) -> Tuple[bool, bool, str]:
def split_key_in_buffer(cls, key: str) -> Tuple[bool, str]:
"""Splits buffer key into relevant parts."""
forward, fp8_weights, autocast_key = key.split("_", 2)
forward, autocast_key = key.split("_", 1)
forward = forward == "forward"
fp8_weights = fp8_weights == "True"
return forward, fp8_weights, autocast_key
return forward, autocast_key

@classmethod
def add_fp8_tensors_to_global_buffer(
cls,
fp8_meta: Dict[str, Any],
fp8_weights: Optional[List[torch.Tensor]] = None,
) -> None:
"""
The amax reduction process happens completely outside the FP8 modules.
Expand All @@ -202,33 +197,12 @@ def add_fp8_tensors_to_global_buffer(

fp8_meta[index_in_buffer] = []
for forward in (True, False):
# This algorithm creates a two-way map with `autocast_to_fp8_params` and
# `fp8_param_to_autocast`. This is used for keeping track of FP8 weights
# in an autocasted region and cross reference them in `float8_tensor.py`
# to perform the forward amax reduction.
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
if fp8_meta_tensor_key not in fp8_meta:
# Handles non-parameter FP8 modules, e.g. DPA.
continue

if forward and fp8_weights is not None:
autocast_key = cls.get_unique_autocast_key(
fp8_meta["recipe"], fp8_meta["fp8_group"]
)
fp8_weight_set = {id(w._data) for w in fp8_weights}
if autocast_key not in cls.autocast_to_fp8_params:
cls.autocast_to_fp8_params[autocast_key] = fp8_weight_set
else:
cls.autocast_to_fp8_params[autocast_key] = cls.autocast_to_fp8_params[
autocast_key
].union(fp8_weight_set)
# Identify correct autocast key for a given param.
for w in fp8_weight_set:
cls.fp8_param_to_autocast[w] = autocast_key

key = cls.get_key_in_buffer(
forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"]
)
key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"])

if key not in cls.global_amax_buffer:
cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
Expand Down Expand Up @@ -327,20 +301,13 @@ def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_ty
def reduce_and_update_fp8_tensors(
cls,
forward: bool = True,
fp8_weights: bool = False,
) -> None:
"""Concatenate, reduce, and split amaxes in the global buffer."""
for buffer_key, amax_buffer in cls.global_amax_buffer.items():
# Check for forward or backward reduction.
fwd_update, fp8_weights_update, autocast_key = cls.split_key_in_buffer(buffer_key)
fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key)
if fwd_update != forward:
continue
# Only skip a forward update when `fp8_weights` is explicitly set to `True`
# (inside optimizer) and the current key is not an `fp8_weight_update` key.
# For other cases, we need to reduce because of activation tensors.
# TODO(ksivaman) consider separate weight and activation fp8_tensors.
if fwd_update and fp8_weights and not fp8_weights_update:
continue
if len(amax_buffer) == 0:
continue

Expand Down Expand Up @@ -434,7 +401,7 @@ def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None:
# FP8 weight modules are reduced at the end of the optimizer
# step after the weight amax is populated.
if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
cls.reduce_and_update_fp8_tensors(forward=True, fp8_weights=False)
cls.reduce_and_update_fp8_tensors(forward=True)

@classmethod
def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def new_fwd(*user_args, **user_kwargs):
m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
m.fp8_meta, fp8_weights=m._get_fp8_params()
m.fp8_meta,
)
return graphed(*user_args, **user_kwargs)
return orig_fwd(*user_args, **user_kwargs)
Expand Down
4 changes: 1 addition & 3 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,9 +762,7 @@ def prepare_forward(
)

if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing():
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self.fp8_meta, fp8_weights=self._get_fp8_params()
)
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta)

# Activation recomputation is used and this is the first forward phase.
if self.fp8 and self.training and is_fp8_activation_recompute_enabled():
Expand Down
4 changes: 1 addition & 3 deletions transformer_engine/pytorch/ops/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
FP8GlobalStateManager,
get_default_fp8_recipe,
)
from ._common import canonicalize_device, is_float8_tensor
from ._common import canonicalize_device


@dataclasses.dataclass
Expand Down Expand Up @@ -379,10 +379,8 @@ def pre_forward(
self.get_fp8_meta("input"),
)
if self.num_fp8_scales("param"):
fp8_params = list(filter(is_float8_tensor, self.parameters()))
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self.get_fp8_meta("param"),
fp8_weights=(fp8_params if fp8_params else None),
)
if self.num_fp8_scales("grad_output"):
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
Expand Down
27 changes: 0 additions & 27 deletions transformer_engine/pytorch/tensor/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,30 +74,6 @@ def backward(
return grad, None


def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None:
"""Amax scale and update when there is at least 1 trainable FP8 parameter."""
param_id = id(param._data)

if param_id not in FP8GlobalStateManager.fp8_param_to_autocast:
return

autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id]

if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params:
return

if autocast_key in updated_fp8_params:
updated_fp8_params[autocast_key].add(param_id)
else:
updated_fp8_params[autocast_key] = {param_id}

current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key]
# All FP8 trainable parameters have been updated.
if updated_fp8_params[autocast_key] == current_fp8_params_set:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True, fp8_weights=True)
del updated_fp8_params[autocast_key]


class _ToFloat8Func(torch.autograd.Function):
"""Cast to FP8 from other dtype"""

Expand Down Expand Up @@ -676,9 +652,6 @@ def quantize_(
)
dst._transpose_invalid = False

# Callback hook to perform amax reduction after optimizer step
post_optimizer_step_fwd_amax_reduction(self)

return self

@classmethod
Expand Down

0 comments on commit 52b1c82

Please sign in to comment.