diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 61b8af59de..fd2832c1d4 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -293,7 +293,7 @@ def test_fp8_scale_update( ) # Check that scaling factors match expected - w_amax_ref = max(w_vals[: step + 2]) + w_amax_ref = max(w_vals[: step + 1]) x_amax_ref = max(x_vals[: step + 1]) dy_amax_ref = max(dy_vals[: step + 1]) w_scale_ref = (fp8_format.value.max_fwd / w_amax_ref) / (2**margin) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6b153fd3c1..28c1b45ffa 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2528,12 +2528,13 @@ def backward(ctx, dout): recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) - (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6] - (fp8_fwd_scales, fp8_fwd_scale_invs) = ctx.saved_tensors[6:8] - cu_seqlens_q_per_step = ctx.saved_tensors[8 : 8 + cp_size] - cu_seqlens_kv_per_step = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2] - rng_states = ctx.saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3] - attn_biases = ctx.saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4] + (*saved_tensors,) = ctx.saved_tensors + (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = saved_tensors[:6] + (fp8_fwd_scales, fp8_fwd_scale_invs) = saved_tensors[6:8] + cu_seqlens_q_per_step = saved_tensors[8 : 8 + cp_size] + cu_seqlens_kv_per_step = saved_tensors[8 + cp_size : 8 + cp_size * 2] + rng_states = saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3] + attn_biases = saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4] causal = "causal" in ctx.attn_mask_type padding = "padding" in ctx.attn_mask_type @@ -3577,11 +3578,12 @@ def backward(ctx, dout): cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) - (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = ctx.saved_tensors[:5] - cu_seqlens_kv_per_step = ctx.saved_tensors[5:7] - out_per_step = ctx.saved_tensors[7:9] - softmax_lse_per_step = ctx.saved_tensors[9:11] - rng_states = ctx.saved_tensors[11:13] + (*saved_tensors,) = ctx.saved_tensors + (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5] + cu_seqlens_kv_per_step = saved_tensors[5:7] + out_per_step = saved_tensors[7:9] + softmax_lse_per_step = saved_tensors[9:11] + rng_states = saved_tensors[11:13] kv_seq_range_per_step = ctx.kv_seq_range_per_step window_size_per_step = ctx.window_size_per_step @@ -4056,12 +4058,11 @@ def backward(ctx, dout): # pylint: disable=missing-function-docstring cp_size = get_distributed_world_size(ctx.cp_group) - q, k, v, out = ctx.saved_tensors[:4] - cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = ctx.saved_tensors[ - 4:8 - ] - fp8_fwd_scales, fp8_fwd_scale_invs = ctx.saved_tensors[8:10] - aux_ctx_tensors = ctx.saved_tensors[10:] + (*saved_tensors,) = ctx.saved_tensors + q, k, v, out = saved_tensors[:4] + cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = saved_tensors[4:8] + fp8_fwd_scales, fp8_fwd_scale_invs = saved_tensors[8:10] + aux_ctx_tensors = saved_tensors[10:] qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format causal = "causal" in ctx.attn_mask_type diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index f95ba515cb..2a909dabc6 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -109,8 +109,6 @@ def reset(cls) -> None: cls.fp8_available = None cls.reason_for_no_fp8 = "" cls.autocast_arguments = {} - cls.autocast_to_fp8_params = {} - cls.fp8_param_to_autocast = {} cls.skip_fp8_weight_update_tensor = None @classmethod @@ -156,28 +154,25 @@ def get_buffer_info(cls) -> str: def get_key_in_buffer( cls, forward: bool, - fp8_weights: bool, fp8_recipe: DelayedScaling, fp8_group: dist_group_type, ) -> str: """Returns a key into the global FP8 buffers.""" autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) fwd_bwd_key = cls.get_fwd_bwd_key(forward) - return f"{fwd_bwd_key}_{fp8_weights}_{autocast_key}" + return f"{fwd_bwd_key}_{autocast_key}" @classmethod - def split_key_in_buffer(cls, key: str) -> Tuple[bool, bool, str]: + def split_key_in_buffer(cls, key: str) -> Tuple[bool, str]: """Splits buffer key into relevant parts.""" - forward, fp8_weights, autocast_key = key.split("_", 2) + forward, autocast_key = key.split("_", 1) forward = forward == "forward" - fp8_weights = fp8_weights == "True" - return forward, fp8_weights, autocast_key + return forward, autocast_key @classmethod def add_fp8_tensors_to_global_buffer( cls, fp8_meta: Dict[str, Any], - fp8_weights: Optional[List[torch.Tensor]] = None, ) -> None: """ The amax reduction process happens completely outside the FP8 modules. @@ -202,33 +197,12 @@ def add_fp8_tensors_to_global_buffer( fp8_meta[index_in_buffer] = [] for forward in (True, False): - # This algorithm creates a two-way map with `autocast_to_fp8_params` and - # `fp8_param_to_autocast`. This is used for keeping track of FP8 weights - # in an autocasted region and cross reference them in `float8_tensor.py` - # to perform the forward amax reduction. fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) if fp8_meta_tensor_key not in fp8_meta: # Handles non-parameter FP8 modules, e.g. DPA. continue - if forward and fp8_weights is not None: - autocast_key = cls.get_unique_autocast_key( - fp8_meta["recipe"], fp8_meta["fp8_group"] - ) - fp8_weight_set = {id(w._data) for w in fp8_weights} - if autocast_key not in cls.autocast_to_fp8_params: - cls.autocast_to_fp8_params[autocast_key] = fp8_weight_set - else: - cls.autocast_to_fp8_params[autocast_key] = cls.autocast_to_fp8_params[ - autocast_key - ].union(fp8_weight_set) - # Identify correct autocast key for a given param. - for w in fp8_weight_set: - cls.fp8_param_to_autocast[w] = autocast_key - - key = cls.get_key_in_buffer( - forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"] - ) + key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"]) if key not in cls.global_amax_buffer: cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] @@ -327,20 +301,13 @@ def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_ty def reduce_and_update_fp8_tensors( cls, forward: bool = True, - fp8_weights: bool = False, ) -> None: """Concatenate, reduce, and split amaxes in the global buffer.""" for buffer_key, amax_buffer in cls.global_amax_buffer.items(): # Check for forward or backward reduction. - fwd_update, fp8_weights_update, autocast_key = cls.split_key_in_buffer(buffer_key) + fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key) if fwd_update != forward: continue - # Only skip a forward update when `fp8_weights` is explicitly set to `True` - # (inside optimizer) and the current key is not an `fp8_weight_update` key. - # For other cases, we need to reduce because of activation tensors. - # TODO(ksivaman) consider separate weight and activation fp8_tensors. - if fwd_update and fp8_weights and not fp8_weights_update: - continue if len(amax_buffer) == 0: continue @@ -434,7 +401,7 @@ def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: # FP8 weight modules are reduced at the end of the optimizer # step after the weight amax is populated. if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): - cls.reduce_and_update_fp8_tensors(forward=True, fp8_weights=False) + cls.reduce_and_update_fp8_tensors(forward=True) @classmethod def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index ed0ed1c008..c47b792a95 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -465,7 +465,7 @@ def new_fwd(*user_args, **user_kwargs): m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - m.fp8_meta, fp8_weights=m._get_fp8_params() + m.fp8_meta, ) return graphed(*user_args, **user_kwargs) return orig_fwd(*user_args, **user_kwargs) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 534174380f..3a15242c3a 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -762,9 +762,7 @@ def prepare_forward( ) if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing(): - FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - self.fp8_meta, fp8_weights=self._get_fp8_params() - ) + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) # Activation recomputation is used and this is the first forward phase. if self.fp8 and self.training and is_fp8_activation_recompute_enabled(): diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index d03a83d2ca..04a66b7942 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -19,7 +19,7 @@ FP8GlobalStateManager, get_default_fp8_recipe, ) -from ._common import canonicalize_device, is_float8_tensor +from ._common import canonicalize_device @dataclasses.dataclass @@ -379,10 +379,8 @@ def pre_forward( self.get_fp8_meta("input"), ) if self.num_fp8_scales("param"): - fp8_params = list(filter(is_float8_tensor, self.parameters())) FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( self.get_fp8_meta("param"), - fp8_weights=(fp8_params if fp8_params else None), ) if self.num_fp8_scales("grad_output"): FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 36136292df..7ace68a222 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -74,30 +74,6 @@ def backward( return grad, None -def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None: - """Amax scale and update when there is at least 1 trainable FP8 parameter.""" - param_id = id(param._data) - - if param_id not in FP8GlobalStateManager.fp8_param_to_autocast: - return - - autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id] - - if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params: - return - - if autocast_key in updated_fp8_params: - updated_fp8_params[autocast_key].add(param_id) - else: - updated_fp8_params[autocast_key] = {param_id} - - current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key] - # All FP8 trainable parameters have been updated. - if updated_fp8_params[autocast_key] == current_fp8_params_set: - FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True, fp8_weights=True) - del updated_fp8_params[autocast_key] - - class _ToFloat8Func(torch.autograd.Function): """Cast to FP8 from other dtype""" @@ -676,9 +652,6 @@ def quantize_( ) dst._transpose_invalid = False - # Callback hook to perform amax reduction after optimizer step - post_optimizer_step_fwd_amax_reduction(self) - return self @classmethod