diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 3145f21..d81250e 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import Literal, Tuple import torch import torch.distributed as dist @@ -14,7 +14,9 @@ # define the e4m3/e5m2 constants E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max +E4M3_FNUZ_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max +E5M2_FNUZ_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max FP16_MAX_POS = torch.finfo(torch.float16).max @@ -22,14 +24,30 @@ # TODO: align this value with NVIDIA's assumptions (current value is a guess) EPS = 1e-12 +IS_AMD = torch.cuda.is_available() and torch.version.hip is not None + @torch.no_grad() -def amax_to_scale(amax, float8_dtype, orig_dtype): +def amax_to_scale( + amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype +): + """Converts the amax value of a tensor to the fp8 scale. + Args: + amax: The amax value of the tensor. + float8_dtype: The float8 dtype. + orig_dtype: The original dtype of the tensor. + """ scale = torch.empty_like(amax, dtype=torch.float32) if float8_dtype == torch.float8_e4m3fn: res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) - else: # e5m2 + elif float8_dtype == torch.float8_e4m3fnuz: + res = E4M3_FNUZ_MAX_POS / torch.clamp(amax, min=EPS) + elif float8_dtype == torch.float8_e5m2: res = E5M2_MAX_POS / torch.clamp(amax, min=EPS) + elif float8_dtype == torch.float8_e5m2fnuz: + res = E5M2_FNUZ_MAX_POS / torch.clamp(amax, min=EPS) + else: + raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") # Ensure that the scale is representable in float16, # this helps when amax is small. We are assuming that we don't need @@ -42,11 +60,18 @@ def amax_to_scale(amax, float8_dtype, orig_dtype): @torch.no_grad() def amax_history_to_scale( - amax_history, - float8_dtype, - orig_dtype, - history_to_scale_fn_type, + amax_history: torch.Tensor, + float8_dtype: torch.Tensor, + orig_dtype: torch.dtype, + history_to_scale_fn_type: Literal["max"], ): + """Takes in a history of amax values and returns a scale tensor. + Args: + amax_history: A tensor containing the history of amax values. + float8_dtype: The float8 dtype. + orig_dtype: The original dtype of the tensor. + history_to_scale_fn_type: The type of function to use to convert the history to a scale. + """ if history_to_scale_fn_type == "max": amax = torch.max(amax_history) return amax_to_scale(amax, float8_dtype, orig_dtype) @@ -58,9 +83,15 @@ def amax_history_to_scale_stack( amax_history: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype, - history_to_scale_fn_type: str, + history_to_scale_fn_type: Literal["max"], ) -> torch.Tensor: - """Takes in a stack of amax_history tensors and returns a scale tensor.""" + """Takes in a stack of amax_history tensors and returns a scale tensor. + Args: + amax_history: A 2D tensor containing a stack of amax histories. + float8_dtype: The float8 dtype. + orig_dtype: The original dtype of the tensor. + history_to_scale_fn_type: The type of function to use to convert the history to a scale. + """ if history_to_scale_fn_type == "max": amax_stack = torch.max(amax_history, dim=1).values return amax_to_scale(amax_stack, float8_dtype, orig_dtype) @@ -90,21 +121,41 @@ def tensor_to_scale( return amax_to_scale(amax, float8_dtype, x.dtype) -def to_fp8_saturated(x, float8_dtype: torch.dtype): - # The default behavior in PyTorch for casting to `float8_e4m3fn` - # and `e5m2` is to not saturate. In this context, we should saturate. - # A common case where we want to saturate is when the history of a - # tensor has a maximum value of `amax1`, and the current amax value - # is `amax2`, where `amax1 < amax2`. This is common when using delayed - # scaling. +def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype): + """Converts a tensor to a saturated fp8 tensor. + + Note: + The default behavior in PyTorch for casting to `float8_e4m3fn` + and `e5m2` is to not saturate. In this context, we should saturate. + A common case where we want to saturate is when the history of a + tensor has a maximum value of `amax1`, and the current amax value + is `amax2`, where `amax1 < amax2`. This is common when using delayed + scaling. + """ + if float8_dtype == torch.float8_e4m3fn: x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) - else: + elif float8_dtype == torch.float8_e4m3fnuz: + x = x.clamp(min=-1 * E4M3_FNUZ_MAX_POS, max=E4M3_FNUZ_MAX_POS) + elif float8_dtype == torch.float8_e5m2: x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) + elif float8_dtype == torch.float8_e5m2fnuz: + x = x.clamp(min=-1 * E5M2_FNUZ_MAX_POS, max=E5M2_FNUZ_MAX_POS) + else: + raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") return x.to(float8_dtype) -def compute_error(x, y): +def compute_error(x: torch.Tensor, y: torch.Tensor): + """Computes the error between two tensors in dB. + + For more details see: + https://en.wikipedia.org/wiki/Signal-to-noise_ratio + + Args: + x: The original tensor. + y: The tensor to compare to the original tensor. + """ Ps = torch.norm(x) Pn = torch.norm(x - y) return 20 * torch.log10(Ps / Pn) diff --git a/test/test_everything.sh b/test/test_everything.sh index c7817f9..5e787fa 100755 --- a/test/test_everything.sh +++ b/test/test_everything.sh @@ -4,7 +4,7 @@ set -e pytest test/test_base.py -pytest test/test_sam.py +# pytest test/test_sam.py pytest test/test_compile.py ./test/test_fsdp.sh ./test/test_fsdp_compile.sh