From 3b786be7dfc90f315a7997422452ae455052cf5d Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Jul 2024 11:00:29 -0700 Subject: [PATCH] [6/x] clean up casting: rename delayed and dynamic casting functions Summary: Renames the delayed and dynamic casting functions to `hp_tensor_to_float8_delayed` and `hp_tensor_to_float8_dynamic` to clarify what they are doing and how they are different from `hp_tensor_and_scale_to_float8`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- float8_experimental/float8_linear.py | 14 +++-- float8_experimental/float8_scaling_utils.py | 55 ++++++++++++++----- float8_experimental/float8_tensor_parallel.py | 8 +-- float8_experimental/fsdp_utils.py | 8 +-- test/test_compile.py | 4 +- 5 files changed, 60 insertions(+), 29 deletions(-) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 7a7adf2..82a74f9 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -18,8 +18,8 @@ from float8_experimental.float8_scaling_utils import ( _maybe_initialize_amaxes_scales_for_float8_cast, - cast_to_float8_delayed, - cast_to_float8_dynamic, + hp_tensor_to_float8_delayed, + hp_tensor_to_float8_dynamic, NoopFwToFloat8E5M2BwDelayed, NoopFwToFloat8E5M2BwDynamic, ) @@ -260,7 +260,7 @@ def cast_input_to_float8( is_amax_initialized, reduce_amax=True, ) - input_fp8 = cast_to_float8_delayed( + input_fp8 = hp_tensor_to_float8_delayed( input, self.fp8_scale_input, e4m3_dtype, @@ -270,7 +270,9 @@ def cast_input_to_float8( ) else: assert self.scaling_type_input is ScalingType.DYNAMIC - input_fp8 = cast_to_float8_dynamic(input, e4m3_dtype, self.linear_mm_config) + input_fp8 = hp_tensor_to_float8_dynamic( + input, e4m3_dtype, self.linear_mm_config + ) return input_fp8 def cast_weight_to_float8( @@ -292,7 +294,7 @@ def cast_weight_to_float8( reduce_amax=False, ) - weight_fp8 = cast_to_float8_delayed( + weight_fp8 = hp_tensor_to_float8_delayed( weight, self.fp8_scale_weight, e4m3_dtype, @@ -305,7 +307,7 @@ def cast_weight_to_float8( if isinstance(self.weight, Float8Tensor): # cast by FSDP weight_fp8 = self.weight else: - weight_fp8 = cast_to_float8_dynamic( + weight_fp8 = hp_tensor_to_float8_dynamic( self.weight, e4m3_dtype, self.linear_mm_config, diff --git a/float8_experimental/float8_scaling_utils.py b/float8_experimental/float8_scaling_utils.py index 7c387bf..ce6422f 100644 --- a/float8_experimental/float8_scaling_utils.py +++ b/float8_experimental/float8_scaling_utils.py @@ -30,18 +30,31 @@ ) -def cast_to_float8_dynamic( - inpt_tensor: torch.Tensor, +def hp_tensor_to_float8_dynamic( + hp_tensor: torch.Tensor, float8_dtype: torch.dtype, linear_mm_config: LinearMMConfig, reduce_amax: bool = False, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, ) -> Float8Tensor: - if tensor_already_casted_to_fp8(inpt_tensor): - return inpt_tensor - scale = tensor_to_scale(inpt_tensor, float8_dtype, reduce_amax) + """ + Given a high precision tensor `hp_tensor`, + scales `hp_tensor` dynamically and returns a `Float8Tensor` of the result. + + Args: + hp_tensor: the tensor to convert + float8_dtype: the float8 dtype to use + linear_mm_config: Defines the configuration for the scaled_mm for + the 3 fwd/bwd gemms of linear + reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks + gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in + the 3 fwd/bwd gemms of linear + """ + if tensor_already_casted_to_fp8(hp_tensor): + return hp_tensor + scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax) return hp_tensor_and_scale_to_float8( - inpt_tensor, + hp_tensor, scale, float8_dtype, linear_mm_config, @@ -49,18 +62,34 @@ def cast_to_float8_dynamic( ) -def cast_to_float8_delayed( - tensor: torch.Tensor, - scale: torch.Tensor, +def hp_tensor_to_float8_delayed( + hp_tensor: torch.Tensor, + s: torch.Tensor, float8_dtype: torch.dtype, amax_buffer: torch.Tensor, linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, -): - amax_buffer.fill_(tensor_to_amax(tensor)) +) -> Float8Tensor: + """ + Given a high precision tensor `hp_tensor` and relevant metadata, scales it using + delayed scaling and returns a `Float8Tensor` of the result. Specifically: + 1. calculates max(abs(hp_tensor)) and stores the result in `amax_buffer`, inplace + 2. scales `hp_tensor` by `s` and returns the result wrapped in Float8Tensor + + Args: + hp_tensor: the tensor to convert + s: the scale to use to convert the tensor + float8_dtype: the float8 dtype to use + amax_buffer: the buffer to modify inplace with max(abs(hp_tensor)) + linear_mm_config: Defines the configuration for the scaled_mm for + the 3 fwd/bwd gemms of linear + gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in + the 3 fwd/bwd gemms of linear + """ + amax_buffer.fill_(tensor_to_amax(hp_tensor)) return hp_tensor_and_scale_to_float8( - tensor, - scale, + hp_tensor, + s, float8_dtype, linear_mm_config, gemm_input_role, diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index 4a77a45..2a91293 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -2,7 +2,7 @@ import torch.nn as nn from float8_experimental.config import ScalingType from float8_experimental.float8_scaling_utils import ( - cast_to_float8_dynamic, + hp_tensor_to_float8_dynamic, NoopFwToFloat8E5M2BwDynamic, ) from float8_experimental.float8_tensor import GemmInputRole @@ -46,7 +46,7 @@ def _prepare_input_fn( input_tensor, device_mesh, input_layouts, run_check=False ) - input_tensor = cast_to_float8_dynamic( + input_tensor = hp_tensor_to_float8_dynamic( input_tensor, e4m3_dtype, mod.linear_mm_config, @@ -100,7 +100,7 @@ def _prepare_input_fn( input_tensor, device_mesh, input_layouts, run_check=False ) - input_tensor = cast_to_float8_dynamic( + input_tensor = hp_tensor_to_float8_dynamic( input_tensor, e4m3_dtype, mod.linear_mm_config, @@ -199,7 +199,7 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): input, mesh, (input_layout,), run_check=False ) - dt_inp = cast_to_float8_dynamic( + dt_inp = hp_tensor_to_float8_dynamic( dt_inp, e4m3_dtype, self.linear_mm_config, diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index c5424ac..bd0a9cc 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -11,8 +11,8 @@ import torch.nn as nn import torch.utils._pytree as pytree from float8_experimental.float8_scaling_utils import ( - cast_to_float8_delayed, - cast_to_float8_dynamic, + hp_tensor_to_float8_delayed, + hp_tensor_to_float8_dynamic, ) from float8_experimental.float8_tensor import ( @@ -175,7 +175,7 @@ def fsdp_pre_all_gather(self, mesh): GemmInputRole.WEIGHT, ) else: - float8_tensor = cast_to_float8_dynamic( + float8_tensor = hp_tensor_to_float8_dynamic( self._tensor, e4m3_dtype, self._linear_mm_config, @@ -355,7 +355,7 @@ def fsdp_pre_all_gather(self, mesh): ) self.is_amax_initialized = True - float8_tensor = cast_to_float8_delayed( + float8_tensor = hp_tensor_to_float8_delayed( self._tensor, self._scale_buffer, e4m3_dtype, diff --git a/test/test_compile.py b/test/test_compile.py index db73471..e4b58b9 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -20,7 +20,7 @@ get_float8_layers, sync_float8_amax_and_scale_history, ) -from float8_experimental.float8_scaling_utils import cast_to_float8_delayed +from float8_experimental.float8_scaling_utils import hp_tensor_to_float8_delayed from float8_experimental.float8_tensor import LinearMMConfig from float8_experimental.float8_utils import e4m3_dtype @@ -179,7 +179,7 @@ def __init__(self, graph_break: bool): self.graph_break = graph_break def forward(self, x): - x_fp8 = cast_to_float8_delayed( + x_fp8 = hp_tensor_to_float8_delayed( x, self.fp8_scale_x, e4m3_dtype,